Source code for enchanter.metrics.classification
# ***************************************************
# _____ _ _
# | ____|_ __ ___| |__ __ _ _ __ | |_ ___ _ __
# | _| | '_ \ / __| '_ \ / _` | '_ \| __/ _ \ '__|
# | |___| | | | (__| | | | (_| | | | | || __/ |
# |_____|_| |_|\___|_| |_|\__,_|_| |_|\__\___|_|
#
# ***************************************************
import torch
__all__ = ["calculate_accuracy"]
[docs]def calculate_accuracy(inputs: torch.Tensor, targets: torch.Tensor) -> float:
"""
A function that calculates accuracy for batch processing.
Returns accuracy as a Python float.
Args:
inputs (torch.Tensor): shape == [N, n_class]
targets (torch.Tensor): shape == [N]
Returns: accracy (float)
"""
with torch.no_grad():
total = targets.shape[0]
_, predicted = torch.max(inputs, 1)
correct = (predicted == targets).cpu().sum().float().item()
return correct / total