Source code for enchanter.tasks.classification

# ***************************************************
#  _____            _                 _
# | ____|_ __   ___| |__   __ _ _ __ | |_ ___ _ __
# |  _| | '_ \ / __| '_ \ / _` | '_ \| __/ _ \ '__|
# | |___| | | | (__| | | | (_| | | | | ||  __/ |
# |_____|_| |_|\___|_| |_|\__,_|_| |_|\__\___|_|
#
# ***************************************************

from typing import Tuple, List, Union, Optional, Dict

import numpy as np
import torch
from torch.nn.modules import Module
from torch.nn.modules.loss import _Loss
from torch.optim.optimizer import Optimizer
from torch.cuda.amp import GradScaler, autocast

from enchanter.engine import BaseRunner
from enchanter.callbacks import Callback
from enchanter.metrics import calculate_accuracy


__all__ = ["ClassificationRunner"]


[docs]class ClassificationRunner(BaseRunner): """ Runner for classification tasks. Examples: >>> from comet_ml import Experiment >>> import torch >>> model = torch.nn.Sequential(...) >>> optimizer = torch.optim.Adam(model.parameters()) >>> criterion = torch.nn.CrossEntropyLoss() >>> runner = ClassificationRunner( >>> model, >>> optimizer, >>> criterion, >>> Experiment() >>> ) >>> runner.fit(...) >>> # or >>> runner.add_loader(...) >>> runner.train_config(epochs=10) >>> runner.run() """ def __init__( self, model: Module, optimizer: Optimizer, criterion: _Loss, experiment, scheduler: Optional[List] = None, callbacks: Optional[List[Callback]] = None, ) -> None: super(ClassificationRunner, self).__init__() self.model: Module = model self.optimizer: Optimizer = optimizer self.experiment = experiment self.criterion: _Loss = criterion if scheduler is None: self.scheduler: List = list() else: self.scheduler = scheduler self.callbacks = callbacks
[docs] def general_step(self, batch: Tuple) -> Dict: """ This method is executed by train_step, val_step, test_step. Args: batch: Returns: """ x, y = batch with autocast(enabled=isinstance(self.scaler, GradScaler)): out = self.model(x) loss = self.criterion(out, y) accuracy = calculate_accuracy(out, y) return {"loss": loss, "accuracy": accuracy}
[docs] @staticmethod def general_end(outputs: List) -> Dict: """ This method is executed by train_end, val_end, test_end. Args: outputs: Returns: """ avg_loss = torch.stack([x["loss"] for x in outputs]).mean() avg_acc = torch.stack([torch.tensor(x["accuracy"]) for x in outputs]).mean() return {"avg_loss": avg_loss, "avg_acc": avg_acc}
[docs] def train_step(self, batch: Tuple) -> Dict: return self.general_step(batch)
[docs] def train_end(self, outputs: List) -> Dict: return self.general_end(outputs)
[docs] def val_step(self, batch: Tuple) -> Dict: return self.general_step(batch)
[docs] def val_end(self, outputs: List) -> Dict: return self.general_end(outputs)
[docs] def test_step(self, batch: Tuple) -> Dict: return self.general_step(batch)
[docs] def test_end(self, outputs: List) -> Dict: return self.general_end(outputs)
[docs] def predict(self, x: Union[torch.Tensor, np.ndarray]) -> np.ndarray: self.model.eval() with torch.no_grad(): x = torch.as_tensor(x, device=self.device) out = self.model(x) _, predicted = torch.max(out, 1) return predicted.cpu().numpy()