Source code for enchanter.utils.backend

from typing import Union
from numbers import Number

import torch
import numpy as np

__all__ = ["slice_axis", "is_scalar"]


@torch.jit.script
def slice_axis(data: torch.Tensor, axis: int, begin: int, end: int) -> torch.Tensor:
    """

    Examples:
        >>> import torch
        >>> x = torch.tensor([
        >>>     [  1.,   2.,   3.,   4.],
        >>>     [  5.,   6.,   7.,   8.],
        >>>     [  9.,  10.,  11.,  12.]
        >>> ])
        >>>
        >>> slice_axis(x, axis=0, begin=1, end=3)
        >>> # [[  5.,   6.,   7.,   8.],
        >>> # [  9.,  10.,  11.,  12.]]
        >>>
        >>> slice_axis(x, axis=1, begin=0, end=2)
        >>> # [[  1.,   2.],
        >>> # [  5.,   6.],
        >>> # [  9.,  10.]]
        >>>
        >>> slice_axis(x, axis=1, begin=-3, end=-1)
        >>> # [[  2.,   3.],
        >>> # [  6.,   7.],
        >>> # [ 10.,  11.]]


    References:
        - `Deep Graph Library \
        <https://github.com/dmlc/dgl/blob/f25bc176d0365234ebb051d5069edff24ad2de4d/python/dgl/backend/pytorch/tensor.py#L159-L160>`_

        - `mxnet.ndarray.slice_axis \
        <https://beta.mxnet.io/api/ndarray/_autogen/mxnet.ndarray.slice_axis.html#mxnet-ndarray-slice-axis>`_

    Args:
        data: Source input
        axis: Axis along which to be sliced
        begin: The beginning index along the axis to be sliced
        end: The ending index along the axis to be sliced

    Returns:
        output - the output of this function.

    """
    if begin < 0:
        begin = data.shape[axis] + begin

    if end < 0:
        end = data.shape[axis] + end
    return torch.narrow(data, axis, begin, end - begin)


[docs]def is_scalar(data: Union[Number, Union[np.ndarray, torch.Tensor]]) -> bool: """ Returns True if the type of ``data`` is a scalar type. Args: data (Union[Number, Union[np.ndarray, torch.Tensor]]): Numerical value Returns: True if ``data`` is a scalar type, False if it is not. Examples: >>> a = torch.tensor([1.0]) >>> is_scalar(a) # True >>> a = torch.tensor(1.0) >>> is_scalar(a) # True >>> a = torch.tensor([1, 2, 3]) >>> is_scalar(a) # False >>> a = 1.0 >>> is_scalar(a) # True """ if isinstance(data, Number): return True else: if len(data.shape) == 0: return True else: try: _ = data.item() except ValueError: return False else: return True