Source code for enchanter.callbacks.base
# ***************************************************
# _____ _ _
# | ____|_ __ ___| |__ __ _ _ __ | |_ ___ _ __
# | _| | '_ \ / __| '_ \ / _` | '_ \| __/ _ \ '__|
# | |___| | | | (__| | | | (_| | | | | || __/ |
# |_____|_| |_|\___|_| |_|\__,_|_| |_|\__\___|_|
#
# ***************************************************
from abc import ABC
from typing import Dict, Optional
__all__ = ["Callback"]
[docs]class Callback(ABC):
def __init__(self):
self.stop_runner: bool = False
self.model = None
self.optimizer = None
self.device = None
self.experiment = None
self.params = {}
[docs] def set_device(self, device):
self.device = device
[docs] def set_model(self, model):
self.model = model
[docs] def set_optimizer(self, optimizer):
self.optimizer = optimizer
[docs] def on_epoch_start(self, epoch, logs: Optional[Dict] = None):
"""Called when the epoch begins."""
pass
[docs] def on_epoch_end(self, epoch, logs: Optional[Dict] = None):
"""Called when the epoch ends."""
pass
[docs] def on_train_step_start(self, logs: Optional[Dict] = None):
"""Called when the training batch begins."""
pass
[docs] def on_train_step_end(self, logs: Optional[Dict] = None):
"""Called when the training batch ends. You can access the output of train_step."""
pass
[docs] def on_validation_step_start(self, logs: Optional[Dict] = None):
pass
[docs] def on_validation_step_end(self, logs: Optional[Dict] = None):
"""Called when the validation batch ends. You can access the output of val_step."""
pass
[docs] def on_test_step_start(self, logs: Optional[Dict] = None):
pass
[docs] def on_test_step_end(self, logs: Optional[Dict] = None):
"""Called when the test batch ends. You can access the output of test_step."""
pass
[docs] def on_train_start(self, logs: Optional[Dict] = None):
"""Called when the train begins."""
pass
[docs] def on_train_end(self, logs: Optional[Dict] = None):
"""Called when the train ends."""
pass
[docs] def on_validation_start(self, logs: Optional[Dict] = None):
"""Called when the validation loop begins."""
pass
[docs] def on_validation_end(self, logs: Optional[Dict] = None):
"""Called when the validation loop ends."""
pass
[docs] def on_test_start(self, logs: Optional[Dict] = None):
"""Called when the test begins."""
pass
[docs] def on_test_end(self, logs: Optional[Dict] = None):
"""Called when the test ends."""
pass