Source code for enchanter.callbacks.manager
from typing import List, Optional, Dict
from .base import Callback
__all__ = ["CallbackManager"]
[docs]class CallbackManager(Callback):
"""
A class for managing Callback.
Examples:
>>> from enchanter.callbacks import EarlyStopping
>>> manager = CallbackManager([
>>> EarlyStopping()
>>> ])
>>> for epoch in range(5):
>>> logs = ...
>>> manager.on_epoch_end(epoch, logs)
>>> if manager.stop_runner:
>>> break
"""
def __init__(self, callbacks: Optional[List[Callback]] = None):
super(CallbackManager, self).__init__()
self.callbacks = callbacks
def set_experiment(self, experiment):
self.experiment = experiment
if self.callbacks is not None:
for callback in self.callbacks:
callback.experiment = self.experiment
def set_device(self, device):
self.device = device
if self.callbacks is not None:
for callback in self.callbacks:
callback.device = self.device
def set_model(self, model):
self.model = model
if self.callbacks is not None:
for callback in self.callbacks:
callback.model = self.model
def set_optimizer(self, optimizer):
self.optimizer = optimizer
if self.callbacks is not None:
for callback in self.callbacks:
callback.optimizer = self.optimizer
def flag_check(self, stop_runner: bool):
if stop_runner and self.callbacks is not None:
self.stop_runner = stop_runner
def on_epoch_start(self, epoch, logs=None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_epoch_start(epoch, logs)
self.flag_check(callback.stop_runner)
def on_epoch_end(self, epoch, logs=None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_epoch_end(epoch, logs)
self.flag_check(callback.stop_runner)
if "best_weight" in callback.params.keys():
self.params["best_weight"] = callback.params["best_weight"]
def on_train_step_start(self, logs: Optional[Dict] = None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_train_step_start(logs)
self.flag_check(callback.stop_runner)
def on_train_step_end(self, logs: Optional[Dict] = None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_train_step_end(logs)
self.flag_check(callback.stop_runner)
def on_validation_step_start(self, logs: Optional[Dict] = None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_validation_step_start(logs)
self.flag_check(callback.stop_runner)
def on_validation_step_end(self, logs: Optional[Dict] = None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_validation_step_end(logs)
self.flag_check(callback.stop_runner)
def on_test_step_start(self, logs: Optional[Dict] = None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_test_step_start(logs)
self.flag_check(callback.stop_runner)
def on_test_step_end(self, logs: Optional[Dict] = None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_test_step_end(logs)
self.flag_check(callback.stop_runner)
def on_train_start(self, logs=None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_train_start(logs)
self.flag_check(callback.stop_runner)
def on_train_end(self, logs=None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_train_end(logs)
self.flag_check(callback.stop_runner)
def on_validation_start(self, logs=None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_validation_start(logs)
self.flag_check(callback.stop_runner)
def on_validation_end(self, logs=None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_validation_end(logs)
self.flag_check(callback.stop_runner)
def on_test_start(self, logs=None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_test_start(logs)
self.flag_check(callback.stop_runner)
if "grid_search" in callback.params.keys():
self.params["grid_search"] = callback.params["grid_search"]
def on_test_end(self, logs=None):
if self.callbacks is not None:
for callback in self.callbacks:
callback.on_test_end(logs)
self.flag_check(callback.stop_runner)