Source code for enchanter.engine.saving
from typing import Union, Optional, Dict
from collections import OrderedDict
from time import ctime
from pathlib import Path
from copy import deepcopy
import torch
import torch.nn as nn
__all__ = ["RunnerIO"]
[docs]class RunnerIO:
"""
A class responsible for loading and saving parameters such as PyTorch model weights and Optimizer state.
"""
def __init__(self):
self.model = NotImplemented
self.optimizer = NotImplemented
self.experiment = NotImplemented
self.save_dir: Optional[str] = None
[docs] def model_name(self) -> str:
"""
fetch model name
Returns: model name
"""
if isinstance(self.model, nn.DataParallel) or isinstance(self.model, nn.parallel.DistributedDataParallel):
model_name = self.model.module.__class__.__name__
else:
model_name = self.model.__class__.__name__
return model_name
[docs] def save_checkpoint(self) -> Dict[str, Union[Dict[str, torch.Tensor], dict]]:
"""
A method to output model weights and Optimizer state as a dictionary.
Returns:
Returns a dictionary with the following keys and values.
- ``model_state_dict``: model weights
- ``optimizer_state_dict``: Optimizer state
"""
if isinstance(self.model, nn.DataParallel) or isinstance(self.model, nn.parallel.DistributedDataParallel):
model = self.model.module.state_dict()
else:
model = self.model.state_dict()
checkpoint = {
"model_state_dict": deepcopy(model),
"optimizer_state_dict": deepcopy(self.optimizer.state_dict()),
}
return checkpoint
[docs] def load_checkpoint(self, checkpoint: Dict[str, OrderedDict]):
"""
Takes a dictionary with keys ``model_state_dict`` and ``optimizer_state_dict``
and uses them to restore the state of the model and the Optimizer.
Args:
checkpoint:
Takes a dictionary with the following keys and values.
- ``model_state_dict``: model weights
- ``optimizer_state_dict``: Optimizer state
"""
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return self
[docs] def save(self, directory: Optional[str] = None, epoch: Optional[int] = None, filename: Optional[str] = None):
"""
Save the model and the Optimizer state file in the specified directory.
Notes:
``enchanter_checkpoints_epoch_{}.pth`` file contains ``model_state_dict`` & ``optimizer_state_dict``.
Args:
directory (Optional[str]):
epoch (Optional[int]):
filename (Optional[str]):
"""
if directory is None and self.save_dir:
directory = self.save_dir
if directory is None:
if filename is None:
raise ValueError("The argument `directory` or `filename` must be specified.")
else:
path = filename
else:
directory_path = Path(directory)
if not directory_path.exists():
directory_path.mkdir(parents=True)
if epoch is None:
epoch_str = ctime().replace(" ", "_")
else:
epoch_str = str(epoch)
if not filename:
filename = "enchanter_checkpoints_epoch_{}.pth".format(epoch_str)
path = str(directory_path / filename)
checkpoint = self.save_checkpoint()
torch.save(checkpoint, path)
model_name = self.model_name()
self.experiment.log_model(model_name, str(path))
[docs] def load(self, filename: str, map_location: str = "cpu"):
"""
Restores the model and Optimizer state based on the specified file.
Args:
filename (str):
map_location (str): default: 'cpu'
"""
checkpoint = torch.load(filename, map_location=map_location)
self.load_checkpoint(checkpoint)
return self