Source code for enchanter.addons.layers.conv

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _single
from ...utils.backend import slice_axis


__all__ = ["CausalConv1d", "TemporalConvBlock"]


[docs]class CausalConv1d(nn.Conv1d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, ) -> None: """ Causal Conv1d Paper: `WaveNet: A Generative Model for Raw Audio <https://arxiv.org/abs/1609.03499>`_ Args: in_channels: the number of input channels out_channels: the number of output channels kernel_size: kernel size stride: stride dilation: rate of dilation groups: the number of groups bias: if true use bias (default: True) """ super(CausalConv1d, self).__init__( in_channels, out_channels, kernel_size, stride, padding=_single((kernel_size - 1) * dilation), dilation=dilation, groups=groups, bias=bias, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward propagation Args: x: [N, in_channels, L] Returns: [N, out_channels, L] """ out = F.conv1d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) if self.kernel_size[0] > 0: out = slice_axis(out, axis=2, begin=0, end=-self.padding[0]) return out
[docs]class TemporalConvBlock(nn.Module): def __init__( self, in_features: int, out_features: int, kernel_size: int, stride: int = 1, dilation: int = 1, dropout: float = 0.5, activation: nn.Module = nn.ReLU(), final_activation: bool = False, ): r""" Temporal Convolutional Block Paper: `An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling \ <https://arxiv.org/abs/1803.01271>`_ Args: in_features: the number of input channels out_features: the number of output channels kernel_size: kernel size stride: stride dilation: dilation rate dropout: dropout rate activation: activation function (default: ReLU) final_activation: If true, apply the activation function after the residual connection """ super(TemporalConvBlock, self).__init__() self.final_activation = final_activation self.conv = nn.Sequential( nn.utils.weight_norm( CausalConv1d(in_features, out_features, kernel_size, stride=stride, dilation=dilation) ), activation, nn.Dropout(dropout), nn.utils.weight_norm( CausalConv1d(out_features, out_features, kernel_size, stride=stride, dilation=dilation) ), activation, nn.Dropout(dropout), ) self.downsample = nn.Conv1d(in_features, out_features, 1) if in_features != out_features else None self.activation = activation
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: identity: torch.Tensor = x out: torch.Tensor = self.conv(x) if self.downsample is not None: identity = self.downsample(x) out = out + identity if self.final_activation: out = self.activation(out) return out