Source code for enchanter.engine.modules

# ***************************************************
#  _____            _                 _
# | ____|_ __   ___| |__   __ _ _ __ | |_ ___ _ __
# |  _| | '_ \ / __| '_ \ / _` | '_ \| __/ _ \ '__|
# | |___| | | | (__| | | | (_| | | | | ||  __/ |
# |_____|_| |_|\___|_| |_|\__,_|_| |_|\__\___|_|
#
# ***************************************************

import warnings
from copy import deepcopy
from typing import Union, Tuple, Any, Dict
from os import environ as os_environ
from random import seed as std_seed

from numpy import ndarray
from numpy.random import seed as np_seed
import torch
import torch.nn as nn
from torch.backends import cudnn
from torch.utils.data import Dataset, TensorDataset
from torch.cuda import is_available as cuda_is_available

try:
    import tensorflow as tf

    IS_TF_DS_AVAILABLE = True

except ImportError:
    IS_TF_DS_AVAILABLE = False


__all__ = [
    "is_jupyter",
    "get_dataset",
    "fix_seed",
    "send",
    "is_tfds",
    "tfds_to_numpy",
    "fetch_state_dict",
    "restore_state_dict",
]


[docs]def is_jupyter() -> bool: """ Determines if the running environment is a Jupyter Notebook. Returns: True if run on jupyrer notebook else False """ try: from IPython import get_ipython except ImportError: return False env = get_ipython().__class__.__name__ # noqa if env == "TerminalInteractiveShell": return False return True
[docs]def get_dataset(x: Union[ndarray, torch.Tensor], y: Union[ndarray, torch.Tensor] = None) -> Dataset: """ Generates ``torch.utils.data.TensorDataset`` based on the values entered. Examples: >>> import torch >>> x = torch.randn(512, 6) >>> y = torch.randint(0, 9, size=[512]) >>> ds = get_dataset(x, y) Args: x (Union[np.ndarray, torch.Tensor]): y (Optional[Union[np.ndarray, torch.Tensor]]): Returns: `torch.utils.data.TensorDataset` """ x = torch.as_tensor(x) if y is not None: y = torch.as_tensor(y) ds = TensorDataset(x, y) else: ds = TensorDataset(x) return ds
[docs]def send(batch: Tuple[Any, ...], device: torch.device, non_blocking: bool = True) -> Tuple[Any, ...]: """ Send `variable` to `device` Args: batch: Tuple which contain variable device: torch.device non_blocking: bool Returns: new tuple """ def transfer(x): if isinstance(x, torch.Tensor): return x.to(device, non_blocking=non_blocking) elif isinstance(x, ndarray): return torch.tensor(x, device=device) else: return x return tuple(map(transfer, batch))
[docs]def fix_seed(seed: int, deterministic: bool = False, benchmark: bool = False) -> None: """ Fixed the ``Seed`` value of PyTorch, NumPy, Pure Python Random at once. Examples: >>> import torch >>> import numpy as np >>> fix_seed(0) >>> x = torch.randn(...) >>> y = np.random.randn(...) Args: seed (int): random state (sedd) deterministic (bool): Whether to ensure reproducibility as much as possible on CuDNN. benchmark (bool): Returns: None """ std_seed(seed) os_environ["PYTHONHASHSEED"] = str(seed) np_seed(seed) torch.manual_seed(seed) if cuda_is_available(): if deterministic: cudnn.deterministic = True if benchmark: cudnn.benchmark = False
[docs]def is_tfds(loader: Any) -> bool: """ Determines if the input is a TensorFlow Dataset. Returns False if TensorFlow is not installed. Args: loader: Returns: True if input is TensorFlow Dataset """ if IS_TF_DS_AVAILABLE: if isinstance(loader, tf.data.Dataset): warnings.warn( "TensorFlow Dataset detection. Experimental support at this stage.", UserWarning, ) return True else: return False else: return False
[docs]def tfds_to_numpy(loader): """ If TensorFlow Dataset is entered, it will be converted to numpy. Args: loader: Returns: TensorFlow Dataset Raises: If `tfds <https://www.tensorflow.org/datasets>`_ is not installed, this function will raise a ``RuntimeError``. """ if IS_TF_DS_AVAILABLE: return loader.as_numpy_iterator() else: raise RuntimeError
def fetch_state_dict(model: nn.Module) -> Dict[str, torch.Tensor]: if isinstance(model, nn.parallel.DataParallel) or isinstance(model, nn.parallel.DistributedDataParallel): weights = model.module.state_dict() else: weights = model.state_dict() return deepcopy(weights) def restore_state_dict(model: nn.Module, weights: Dict[str, torch.Tensor], strict: bool = True) -> nn.Module: if isinstance(model, nn.parallel.DataParallel) or isinstance(model, nn.parallel.DistributedDataParallel): model.module.load_state_dict(weights, strict=strict) else: model.load_state_dict(weights, strict=strict) return model