enchanter.callbacks¶
A set of functions to be executed during training, verification and testing.
Callback¶
Callback¶
- class enchanter.callbacks.Callback[source]¶
Bases:
abc.ABC- on_test_step_end(logs: Optional[Dict] = None)[source]¶
Called when the test batch ends. You can access the output of test_step.
- on_train_step_end(logs: Optional[Dict] = None)[source]¶
Called when the training batch ends. You can access the output of train_step.
CallbackManager¶
CallbackManager¶
- class enchanter.callbacks.CallbackManager(callbacks: Optional[List[enchanter.callbacks.base.Callback]] = None)[source]¶
A class for managing Callback.
Examples
>>> from enchanter.callbacks import EarlyStopping >>> manager = CallbackManager([ >>> EarlyStopping() >>> ]) >>> for epoch in range(5): >>> logs = ... >>> manager.on_epoch_end(epoch, logs) >>> if manager.stop_runner: >>> break
EarlyStopping¶
EarlyStopping¶
- class enchanter.callbacks.EarlyStopping(monitor: str = 'val_avg_loss', min_delta: float = 0.0, patience: int = 0, mode: str = 'auto')[source]¶
Bases:
enchanter.callbacks.base.CallbackThe training ends when the value to be monitored stops changing.
Examples
>>> from enchanter.tasks import ClassificationRunner >>> runner = ClassificationRunner(callbacks=[ClassificationRunner()])
Initializer
- Parameters
monitor – You can choose the return values of
train_end(),val_end(), ortest_end(). To specify, in the case of the return value oftrain_end(),train_XXX, in the case of the return value ofval_end(),val_XXX, in the case of the return value oftest_end(),test_XXX.min_delta – Minimum value of change determined as improvement for the monitored value.
patience – If there is no improvement in the monitored value during the specified number of epochs, the training stops.
mode –
One of
{'auto', 'min', 'max'}is selected.minmode ends the training when the decrease in the monitored value stops.maxmode, the training is terminated when the monitored values stop increasing.automode, automatically estimated from the monitored values.
Methods
check_metrics(logs)on_epoch_end(epoch[, logs])Called when the epoch ends.
EarlyStoppingForTSUS¶
- class enchanter.callbacks.EarlyStoppingForTSUS(data: torch.Tensor, target: torch.Tensor, classifier: sklearn.base.BaseEstimator = SVC(), monitor: str = 'accuracy', min_delta: float = 0.0, patience: int = 0, kfold: Optional[sklearn.model_selection._split.BaseCrossValidator] = None, mode: str = 'auto', grid_search: Optional[Dict[str, Any]] = None)[source]¶
Bases:
enchanter.callbacks.base.CallbackEarly Stopping for Time Series Unsupervised Runner.
Examples
>>> import torch >>> from enchanter.tasks import TimeSeriesUnsupervisedRunner >>> x_train = torch.randn(32, 3, 128) >>> y_train = torch.randint(0, high=4, size=(128, )) >>> runner = TimeSeriesUnsupervisedRunner( >>> callbacks=[EarlyStoppingForTSUS(x_train, y_train)] >>> )
Initializer
- Parameters
data – Data for training a classifier to evaluate the quality of the encoder’s output representation.
target – Targets for training a classifier to evaluate the quality of the encoder’s output representation.
classifier – A classifier for evaluating the quality of the output representation of the encoder.
monitor –
min_delta –
patience –
kfold –
mode –
Methods
cross_val()encode()on_epoch_end(epoch[, logs])Called when the epoch ends.
Logging¶
Provides an alternative logging method to comet_ml.Experiment.
BaseLogger¶
- class enchanter.callbacks.BaseLogger[source]¶
Bases:
abc.ABCProvides minimal compatibility with the comet_ml.Experiment, which is required to run Runner.
- abstract end()[source]¶
Use to indicate that the experiment is complete.
Warning
If you create your own Logger, you will need to implement this method.
- abstract log_metric(name: str, value: Any, step: Optional[int] = None, epoch: Optional[int] = None, include_context: bool = True) None[source]¶
Logs a metric.
- Parameters
name – name of metric
value –
step –
epoch –
include_context –
- Returns
None
Warning
If you create your own Logger, you will need to implement this method.
- abstract log_metrics(dic: Dict, prefix: Optional[str] = None, step: Optional[int] = None, epoch: Optional[int] = None) None[source]¶
Logs a key, value dictionary of metrics.
See also
log_metric
Warning
If you create your own Logger, you will need to implement this method.
TensorBoardLogger¶
- class enchanter.callbacks.TensorBoardLogger(*args, **kwargs)[source]¶
Bases:
enchanter.callbacks.loggers.BaseLoggerTenorBoardLogger is a module that supports the minimum logging in the environment where comet.ml cannot be used.
Examples
>>> from enchanter.tasks import ClassificationRunner >>> model, optimizer, criterion = ... >>> runner = ClassificationRunner( >>> model, optimizer, criterion, experiment=TensorBoardLogger() >>> )
See the initializer argument for torch.utils.tensorboard.writer.SummaryWriter .
- log_metric(name: str, value: Any, step: Optional[int] = None, epoch: Optional[int] = None, include_context: bool = True) None[source]¶
Logs a metric.
- Parameters
name – name of metric
value – value
step – Used as the X axis when plotting on TensorBoard.
epoch – Used as the X axis when plotting on TensorBoard.
include_context –
- Returns
None
- log_metrics(dic: Dict, prefix: Optional[str] = None, step: Optional[int] = None, epoch: Optional[int] = None) None[source]¶
Logs a key, value dictionary of metrics.
See also
log_metric