Source code for enchanter.addons.layers.dense_interpolation

# ***************************************************
#  _____            _                 _
# | ____|_ __   ___| |__   __ _ _ __ | |_ ___ _ __
# |  _| | '_ \ / __| '_ \ / _` | '_ \| __/ _ \ '__|
# | |___| | | | (__| | | | (_| | | | | ||  __/ |
# |_____|_| |_|\___|_| |_|\__,_|_| |_|\__\___|_|
#
# ***************************************************

import numpy as np
from torch import tensor, bmm
from torch import Tensor
from torch.nn import Module


__all__ = ["DenseInterpolation"]


[docs]class DenseInterpolation(Module): """ Args: seq_len: length of input sequence. factor: """ def __init__(self, seq_len: int, factor: int) -> None: super(DenseInterpolation, self).__init__() W = np.zeros((factor, seq_len), dtype=np.float32) for t in range(seq_len): s = np.array((factor * (t + 1)) / seq_len, dtype=np.float32) for m in range(factor): tmp = np.array(1 - (np.abs(s - (1 + m)) / factor), dtype=np.float32) w = np.power(tmp, 2, dtype=np.float32) W[m, t] = w W = tensor(W).float().unsqueeze(0) self.register_buffer("W", W)
[docs] def forward(self, x: Tensor) -> Tensor: """ Apply ``Dense Interpolation`` to the input. Args: x (torch.Tensor): The shape of the input array is assumed to be ``[N, seq_len, features]``. Returns: (torch.Tensor) """ w = self.W.repeat(x.shape[0], 1, 1).requires_grad_(False) # type: ignore u = bmm(w, x) return u.transpose_(1, 2)