Source code for dgllife.utils.early_stop

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Early stopping"""
# pylint: disable= no-member, arguments-differ, invalid-name

import datetime
import torch

__all__ = ['EarlyStopping']

# pylint: disable=C0103
[docs]class EarlyStopping(object): """Early stop tracker Save model checkpoint when observing a performance improvement on the validation set and early stop if improvement has not been observed for a particular number of epochs. Parameters ---------- mode : str * 'higher': Higher metric suggests a better model * 'lower': Lower metric suggests a better model If ``metric`` is not None, then mode will be determined automatically from that. patience : int The early stopping will happen if we do not observe performance improvement for ``patience`` consecutive epochs. filename : str or None Filename for storing the model checkpoint. If not specified, we will automatically generate a file starting with ``early_stop`` based on the current time. metric : str or None A metric name that can be used to identify if a higher value is better, or vice versa. Default to None. Valid options include: ``'r2'``, ``'mae'``, ``'rmse'``, ``'roc_auc_score'``. Examples -------- Below gives a demo for a fake training process. >>> import torch >>> import torch.nn as nn >>> from torch.nn import MSELoss >>> from torch.optim import Adam >>> from dgllife.utils import EarlyStopping >>> model = nn.Linear(1, 1) >>> criterion = MSELoss() >>> # For MSE, the lower, the better >>> stopper = EarlyStopping(mode='lower', filename='test.pth') >>> optimizer = Adam(params=model.parameters(), lr=1e-3) >>> for epoch in range(1000): >>> x = torch.randn(1, 1) # Fake input >>> y = torch.randn(1, 1) # Fake label >>> pred = model(x) >>> loss = criterion(y, pred) >>> optimizer.zero_grad() >>> loss.backward() >>> optimizer.step() >>> early_stop = stopper.step(loss.detach().data, model) >>> if early_stop: >>> break >>> # Load the final parameters saved by the model >>> stopper.load_checkpoint(model) """ def __init__(self, mode='higher', patience=10, filename=None, metric=None): if filename is None: dt = datetime.datetime.now() filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format( dt.date(), dt.hour, dt.minute, dt.second) if metric is not None: assert metric in ['r2', 'mae', 'rmse', 'roc_auc_score', 'pr_auc_score'], \ "Expect metric to be 'r2' or 'mae' or " \ "'rmse' or 'roc_auc_score', got {}".format(metric) if metric in ['r2', 'roc_auc_score', 'pr_auc_score']: print('For metric {}, the higher the better'.format(metric)) mode = 'higher' if metric in ['mae', 'rmse']: print('For metric {}, the lower the better'.format(metric)) mode = 'lower' assert mode in ['higher', 'lower'] self.mode = mode if self.mode == 'higher': self._check = self._check_higher else: self._check = self._check_lower self.patience = patience self.counter = 0 self.timestep = 0 self.filename = filename self.best_score = None self.early_stop = False def _check_higher(self, score, prev_best_score): """Check if the new score is higher than the previous best score. Parameters ---------- score : float New score. prev_best_score : float Previous best score. Returns ------- bool Whether the new score is higher than the previous best score. """ return score > prev_best_score def _check_lower(self, score, prev_best_score): """Check if the new score is lower than the previous best score. Parameters ---------- score : float New score. prev_best_score : float Previous best score. Returns ------- bool Whether the new score is lower than the previous best score. """ return score < prev_best_score
[docs] def step(self, score, model): """Update based on a new score. The new score is typically model performance on the validation set for a new epoch. Parameters ---------- score : float New score. model : nn.Module Model instance. Returns ------- bool Whether an early stop should be performed. """ self.timestep += 1 if self.best_score is None: self.best_score = score self.save_checkpoint(model) elif self._check(score, self.best_score): self.best_score = score self.save_checkpoint(model) self.counter = 0 else: self.counter += 1 print( f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True return self.early_stop
[docs] def save_checkpoint(self, model): '''Saves model when the metric on the validation set gets improved. Parameters ---------- model : nn.Module Model instance. ''' torch.save({'model_state_dict': model.state_dict(), 'timestep': self.timestep}, self.filename)
[docs] def load_checkpoint(self, model): '''Load the latest checkpoint Parameters ---------- model : nn.Module Model instance. ''' model.load_state_dict(torch.load(self.filename)['model_state_dict'])