enchanter.callbacks

A set of functions to be executed during training, verification and testing.

Callback

Callback

class enchanter.callbacks.Callback[source]

Bases: abc.ABC

on_epoch_end(epoch, logs: Optional[Dict] = None)[source]

Called when the epoch ends.

on_epoch_start(epoch, logs: Optional[Dict] = None)[source]

Called when the epoch begins.

on_test_end(logs: Optional[Dict] = None)[source]

Called when the test ends.

on_test_start(logs: Optional[Dict] = None)[source]

Called when the test begins.

on_test_step_end(logs: Optional[Dict] = None)[source]

Called when the test batch ends. You can access the output of test_step.

on_test_step_start(logs: Optional[Dict] = None)[source]
on_train_end(logs: Optional[Dict] = None)[source]

Called when the train ends.

on_train_start(logs: Optional[Dict] = None)[source]

Called when the train begins.

on_train_step_end(logs: Optional[Dict] = None)[source]

Called when the training batch ends. You can access the output of train_step.

on_train_step_start(logs: Optional[Dict] = None)[source]

Called when the training batch begins.

on_validation_end(logs: Optional[Dict] = None)[source]

Called when the validation loop ends.

on_validation_start(logs: Optional[Dict] = None)[source]

Called when the validation loop begins.

on_validation_step_end(logs: Optional[Dict] = None)[source]

Called when the validation batch ends. You can access the output of val_step.

on_validation_step_start(logs: Optional[Dict] = None)[source]
set_device(device)[source]
set_model(model)[source]
set_optimizer(optimizer)[source]

CallbackManager

CallbackManager

class enchanter.callbacks.CallbackManager(callbacks: Optional[List[enchanter.callbacks.base.Callback]] = None)[source]

A class for managing Callback.

Examples

>>> from enchanter.callbacks import EarlyStopping
>>> manager = CallbackManager([
>>>     EarlyStopping()
>>> ])
>>> for epoch in range(5):
>>>     logs = ...
>>>     manager.on_epoch_end(epoch, logs)
>>>     if manager.stop_runner:
>>>         break

EarlyStopping

EarlyStopping

class enchanter.callbacks.EarlyStopping(monitor: str = 'val_avg_loss', min_delta: float = 0.0, patience: int = 0, mode: str = 'auto')[source]

Bases: enchanter.callbacks.base.Callback

The training ends when the value to be monitored stops changing.

Examples

>>> from enchanter.tasks import ClassificationRunner
>>> runner = ClassificationRunner(callbacks=[ClassificationRunner()])

Initializer

Parameters
  • monitor – You can choose the return values of train_end(), val_end(), or test_end(). To specify, in the case of the return value of train_end(), train_XXX, in the case of the return value of val_end(), val_XXX, in the case of the return value of test_end(), test_XXX.

  • min_delta – Minimum value of change determined as improvement for the monitored value.

  • patience – If there is no improvement in the monitored value during the specified number of epochs, the training stops.

  • mode

    One of {'auto', 'min', 'max'} is selected.

    • min mode ends the training when the decrease in the monitored value stops.

    • max mode, the training is terminated when the monitored values stop increasing.

    • auto mode, automatically estimated from the monitored values.

Methods

check_metrics(logs)

on_epoch_end(epoch[, logs])

Called when the epoch ends.

on_epoch_end(epoch, logs: Optional[Dict] = None) None[source]

Called when the epoch ends.

EarlyStoppingForTSUS

class enchanter.callbacks.EarlyStoppingForTSUS(data: torch.Tensor, target: torch.Tensor, classifier: sklearn.base.BaseEstimator = SVC(), monitor: str = 'accuracy', min_delta: float = 0.0, patience: int = 0, kfold: Optional[sklearn.model_selection._split.BaseCrossValidator] = None, mode: str = 'auto', grid_search: Optional[Dict[str, Any]] = None)[source]

Bases: enchanter.callbacks.base.Callback

Early Stopping for Time Series Unsupervised Runner.

Examples

>>> import torch
>>> from enchanter.tasks import TimeSeriesUnsupervisedRunner
>>> x_train = torch.randn(32, 3, 128)
>>> y_train = torch.randint(0, high=4, size=(128, ))
>>> runner = TimeSeriesUnsupervisedRunner(
>>>     callbacks=[EarlyStoppingForTSUS(x_train, y_train)]
>>> )

Initializer

Parameters
  • data – Data for training a classifier to evaluate the quality of the encoder’s output representation.

  • target – Targets for training a classifier to evaluate the quality of the encoder’s output representation.

  • classifier – A classifier for evaluating the quality of the output representation of the encoder.

  • monitor

  • min_delta

  • patience

  • kfold

  • mode

Methods

cross_val()

encode()

on_epoch_end(epoch[, logs])

Called when the epoch ends.

on_epoch_end(epoch, logs: Optional[Dict] = None)[source]

Called when the epoch ends.


Logging

Provides an alternative logging method to comet_ml.Experiment.

BaseLogger

class enchanter.callbacks.BaseLogger[source]

Bases: abc.ABC

Provides minimal compatibility with the comet_ml.Experiment, which is required to run Runner.

abstract end()[source]

Use to indicate that the experiment is complete.

Warning

If you create your own Logger, you will need to implement this method.

abstract log_metric(name: str, value: Any, step: Optional[int] = None, epoch: Optional[int] = None, include_context: bool = True) None[source]

Logs a metric.

Parameters
  • name – name of metric

  • value

  • step

  • epoch

  • include_context

Returns

None

Warning

If you create your own Logger, you will need to implement this method.

abstract log_metrics(dic: Dict, prefix: Optional[str] = None, step: Optional[int] = None, epoch: Optional[int] = None) None[source]

Logs a key, value dictionary of metrics.

See also

log_metric

Warning

If you create your own Logger, you will need to implement this method.

abstract log_parameter(name: str, value: Any, step: Optional[int] = None) None[source]

Logs a single hyper-parameter.

Parameters
  • name – name of hyper-parameter

  • value – value

  • step

Warning

If you create your own Logger, you will need to implement this method.

abstract log_parameters(dic: Dict, prefix: Optional[str] = None, step: Optional[int] = None) None[source]

Logs a key, value dictionary of hyper-parameters.

See also

log_peramter

Warning

If you create your own Logger, you will need to implement this method.

TensorBoardLogger

class enchanter.callbacks.TensorBoardLogger(*args, **kwargs)[source]

Bases: enchanter.callbacks.loggers.BaseLogger

TenorBoardLogger is a module that supports the minimum logging in the environment where comet.ml cannot be used.

Examples

>>> from enchanter.tasks import ClassificationRunner
>>> model, optimizer, criterion = ...
>>> runner = ClassificationRunner(
>>>     model, optimizer, criterion, experiment=TensorBoardLogger()
>>> )

See the initializer argument for torch.utils.tensorboard.writer.SummaryWriter .

end()[source]

Use to indicate that the experiment is complete.

log_metric(name: str, value: Any, step: Optional[int] = None, epoch: Optional[int] = None, include_context: bool = True) None[source]

Logs a metric.

Parameters
  • name – name of metric

  • value – value

  • step – Used as the X axis when plotting on TensorBoard.

  • epoch – Used as the X axis when plotting on TensorBoard.

  • include_context

Returns

None

log_metrics(dic: Dict, prefix: Optional[str] = None, step: Optional[int] = None, epoch: Optional[int] = None) None[source]

Logs a key, value dictionary of metrics.

See also

log_metric

log_parameter(name: Optional[str], value: Any, step: Optional[int] = None) None[source]

Logs a single hyper-parameter.

Parameters
  • name – name of hyper-parameter

  • value – value

  • step – used as the X-axis when plotting on TensorBoard.

Returns:

log_parameters(dic: Dict, prefix: Optional[str] = None, step: Optional[int] = None) Any[source]

Logs a key, value dictionary of hyper-parameters.

See also

log_peramter