Source code for enchanter.callbacks.early_stopping

# ***************************************************
#  _____            _                 _
# | ____|_ __   ___| |__   __ _ _ __ | |_ ___ _ __
# |  _| | '_ \ / __| '_ \ / _` | '_ \| __/ _ \ '__|
# | |___| | | | (__| | | | (_| | | | | ||  __/ |
# |_____|_| |_|\___|_| |_|\__,_|_| |_|\__\___|_|
#
# ***************************************************

import re
from typing import Any, Dict, Optional, Union

import torch
import numpy as np
from sklearn.svm import SVC
from sklearn.base import BaseEstimator
from sklearn.model_selection import cross_validate, StratifiedKFold, BaseCrossValidator, GridSearchCV, train_test_split
from .base import Callback
from ..engine.modules import fetch_state_dict


__all__ = ["EarlyStopping", "EarlyStoppingForTSUS"]


[docs]class EarlyStopping(Callback): """ The training ends when the value to be monitored stops changing. Examples: >>> from enchanter.tasks import ClassificationRunner >>> runner = ClassificationRunner(callbacks=[ClassificationRunner()]) """ def __init__( self, monitor: str = "val_avg_loss", min_delta: float = 0.0, patience: int = 0, mode: str = "auto", ) -> None: """ Initializer Args: monitor: You can choose the return values of ``train_end()``, ``val_end()``, or ``test_end()``. To specify, in the case of the return value of ``train_end()``, ``train_XXX``, in the case of the return value of ``val_end()``, ``val_XXX``, in the case of the return value of ``test_end()``, ``test_XXX``. min_delta: Minimum value of change determined as improvement for the monitored value. patience: If there is no improvement in the monitored value during the specified number of epochs, the training stops. mode: One of ``{'auto', 'min', 'max'}`` is selected. - ``min`` mode ends the training when the decrease in the monitored value stops. - ``max`` mode, the training is terminated when the monitored values stop increasing. - ``auto`` mode, automatically estimated from the monitored values. """ super(EarlyStopping, self).__init__() if re.match("^validate_", monitor): monitor = monitor.replace("validate_", "val_") self.monitor: str = monitor self.patience: int = patience self.min_delta: float = min_delta self.wait: int = 0 self.stopped_epoch: int = 0 mode_dict = { "min": np.less, "max": np.greater, "auto": np.greater if "acc" in self.monitor else np.less, } if mode not in mode_dict: mode = "auto" self.monitor_op: Any = mode_dict[mode] self.min_delta *= 1 if self.monitor_op == np.greater else -1 self.best: Any = np.Inf if self.monitor_op == np.less else -np.Inf def check_metrics(self, logs: Dict) -> bool: monitor_val = logs.get(self.monitor) if monitor_val is None: return False return True
[docs] def on_epoch_end(self, epoch, logs: Optional[Dict] = None) -> None: stop = False if not isinstance(logs, dict): raise TypeError("The argument `logs` is not the expected data type.") else: cat_logs: Dict[str, Union[float, int]] = {} for pk in logs.keys(): for ck, cv in logs[pk].items(): cat_logs["{}_{}".format(pk, ck)] = cv if not self.check_metrics(cat_logs): self.stop_runner = stop else: try: current = cat_logs[self.monitor] except KeyError: raise KeyError("Can't find the value specified in the argument `monitor`.") if self.monitor_op(current - self.min_delta, self.best): self.best = current self.wait = 0 self.params["best_weight"] = fetch_state_dict(self.model) else: self.wait += 1 if self.wait > self.patience: self.stopped_epoch = epoch stop = True self.stop_runner = stop
[docs]class EarlyStoppingForTSUS(Callback): """ Early Stopping for Time Series Unsupervised Runner. Examples: >>> import torch >>> from enchanter.tasks import TimeSeriesUnsupervisedRunner >>> x_train = torch.randn(32, 3, 128) >>> y_train = torch.randint(0, high=4, size=(128, )) >>> runner = TimeSeriesUnsupervisedRunner( >>> callbacks=[EarlyStoppingForTSUS(x_train, y_train)] >>> ) """ def __init__( self, data: torch.Tensor, target: torch.Tensor, classifier: BaseEstimator = SVC(), monitor: str = "accuracy", min_delta: float = 0.0, patience: int = 0, kfold: Optional[BaseCrossValidator] = None, mode: str = "auto", grid_search: Optional[Dict[str, Any]] = None, ): """ Initializer Args: data: Data for training a classifier to evaluate the quality of the encoder's output representation. target: Targets for training a classifier to evaluate the quality of the encoder's output representation. classifier: A classifier for evaluating the quality of the output representation of the encoder. monitor: min_delta: patience: kfold: mode: """ super(EarlyStoppingForTSUS, self).__init__() self.classifier = classifier self.data: torch.Tensor = data self.target: torch.Tensor = target self.monitor: str = monitor self.patience = patience self.min_delta = min_delta self.best: float = 0.0 self.wait: int = 0 self.stopped_epoch: int = 0 self.grid_search = grid_search if isinstance(kfold, BaseCrossValidator): self.kfold = kfold else: self.kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=0) mode_dict = { "min": np.less, "max": np.greater, "auto": np.greater if "acc" in self.monitor else np.less, } if mode not in mode_dict: mode = "auto" self.monitor_op: Any = mode_dict[mode] self.min_delta *= 1 if self.monitor_op == np.greater else -1 def encode(self): out = [] self.model.eval() with torch.no_grad(): self.data = self.data.to(self.device) out.append(self.model(self.data).cpu().numpy()) return np.vstack(out) def cross_val(self): features: np.ndarray = self.encode() targets = self.target.cpu().numpy() if self.grid_search is not None: evaluator = GridSearchCV(self.classifier, self.grid_search, cv=5, n_jobs=5) if len(targets) >= 10000: split = train_test_split(features, targets, train_size=10000, random_state=0, stratify=targets) features, targets = split[0], split[2] evaluator.fit(features, targets) score = evaluator.best_score_ params = evaluator.best_params_ else: evaluator = self.classifier score = cross_validate( evaluator, features, targets, cv=self.kfold, scoring=[self.monitor], n_jobs=self.kfold.n_splits, )["test_{}".format(self.monitor)].mean() params = evaluator.get_params() return score, params
[docs] def on_epoch_end(self, epoch, logs: Optional[Dict] = None): stop = False current_score, current_params = self.cross_val() if self.experiment is not None: self.experiment.log_metric("Early Stopping metric", current_score, step=epoch) if self.monitor_op(current_score - self.min_delta, self.best): self.params["grid_search"] = current_params self.best = current_score self.wait = 0 self.params["best_weight"] = fetch_state_dict(self.model) else: if self.wait > self.patience: self.stopped_epoch = epoch stop = True self.wait += 1 self.stop_runner = stop