Model Development Pipeline

Model Evaluation

A utility class for evaluating model performance on (multi-label) supervised learning.

class dgllife.utils.Meter(mean=None, std=None)[source]

Track and summarize model performance on a dataset for (multi-label) prediction.

When dealing with multitask learning, quite often we normalize the labels so they are roughly at a same scale. During the evaluation, we need to undo the normalization on the predicted labels. If mean and std are not None, we will undo the normalization.

Currently we support evaluation with 4 metrics:

  • pearson r2

  • mae

  • rmse

  • roc auc score

Parameters
  • mean (torch.float32 tensor of shape (T) or None.) – Mean of existing training labels across tasks if not None. T for the number of tasks. Default to None and we assume no label normalization has been performed.

  • std (torch.float32 tensor of shape (T)) – Std of existing training labels across tasks if not None. Default to None and we assume no label normalization has been performed.

Examples

Below gives a demo for a fake evaluation epoch.

>>> import torch
>>> from dgllife.utils import Meter
>>> meter = Meter()
>>> # Simulate 10 fake mini-batches
>>> for batch_id in range(10):
>>>     batch_label = torch.randn(3, 3)
>>>     batch_pred = torch.randn(3, 3)
>>>     meter.update(batch_pred, batch_label)
>>> # Get MAE for all tasks
>>> print(meter.compute_metric('mae'))
[1.1325558423995972, 1.0543707609176636, 1.094650149345398]
>>> # Get MAE averaged over all tasks
>>> print(meter.compute_metric('mae', reduction='mean'))
1.0938589175542195
>>> # Get the sum of MAE over all tasks
>>> print(meter.compute_metric('mae', reduction='sum'))
3.2815767526626587
compute_metric(metric_name, reduction='none')[source]

Compute metric based on metric name.

Parameters
  • metric_name (str) –

    • 'r2': compute squared Pearson correlation coefficient

    • 'mae': compute mean absolute error

    • 'rmse': compute root mean square error

    • 'roc_auc_score': compute roc-auc score

    • 'pr_auc_score': compute pr-auc score

  • reduction ('none' or 'mean' or 'sum') – Controls the form of scores for all tasks

Returns

  • If reduction == 'none', return the list of scores for all tasks.

  • If reduction == 'mean', return the mean of scores for all tasks.

  • If reduction == 'sum', return the sum of scores for all tasks.

Return type

float or list of float

update(y_pred, y_true, mask=None)[source]

Update for the result of an iteration

Parameters
  • y_pred (float32 tensor) – Predicted labels with shape (B, T), B for number of graphs in the batch and T for the number of tasks

  • y_true (float32 tensor) – Ground truth labels with shape (B, T)

  • mask (None or float32 tensor) – Binary mask indicating the existence of ground truth labels with shape (B, T). If None, we assume that all labels exist and create a one-tensor for placeholder.

Early Stopping

Early stopping is a standard practice for preventing models from overfitting and we provide a utility class for handling it.

class dgllife.utils.EarlyStopping(mode='higher', patience=10, filename=None, metric=None)[source]

Early stop tracker

Save model checkpoint when observing a performance improvement on the validation set and early stop if improvement has not been observed for a particular number of epochs.

Parameters
  • mode (str) –

    • ‘higher’: Higher metric suggests a better model

    • ’lower’: Lower metric suggests a better model

    If metric is not None, then mode will be determined automatically from that.

  • patience (int) – The early stopping will happen if we do not observe performance improvement for patience consecutive epochs.

  • filename (str or None) – Filename for storing the model checkpoint. If not specified, we will automatically generate a file starting with early_stop based on the current time.

  • metric (str or None) – A metric name that can be used to identify if a higher value is better, or vice versa. Default to None. Valid options include: 'r2', 'mae', 'rmse', 'roc_auc_score'.

Examples

Below gives a demo for a fake training process.

>>> import torch
>>> import torch.nn as nn
>>> from torch.nn import MSELoss
>>> from torch.optim import Adam
>>> from dgllife.utils import EarlyStopping
>>> model = nn.Linear(1, 1)
>>> criterion = MSELoss()
>>> # For MSE, the lower, the better
>>> stopper = EarlyStopping(mode='lower', filename='test.pth')
>>> optimizer = Adam(params=model.parameters(), lr=1e-3)
>>> for epoch in range(1000):
>>>     x = torch.randn(1, 1) # Fake input
>>>     y = torch.randn(1, 1) # Fake label
>>>     pred = model(x)
>>>     loss = criterion(y, pred)
>>>     optimizer.zero_grad()
>>>     loss.backward()
>>>     optimizer.step()
>>>     early_stop = stopper.step(loss.detach().data, model)
>>>     if early_stop:
>>>         break
>>> # Load the final parameters saved by the model
>>> stopper.load_checkpoint(model)
step(score, model)[source]

Update based on a new score.

The new score is typically model performance on the validation set for a new epoch.

Parameters
  • score (float) – New score.

  • model (nn.Module) – Model instance.

Returns

Whether an early stop should be performed.

Return type

bool