Source code for enchanter.addons.activation

# ***************************************************
#  _____            _                 _
# | ____|_ __   ___| |__   __ _ _ __ | |_ ___ _ __
# |  _| | '_ \ / __| '_ \ / _` | '_ \| __/ _ \ '__|
# | |___| | | | (__| | | | (_| | | | | ||  __/ |
# |_____|_| |_|\___|_| |_|\__,_|_| |_|\__\___|_|
#
# ***************************************************
from typing import Union

import torch
import torch.nn as nn


__all__ = ["Swish", "mish", "Mish", "FReLU1d", "FReLU2d"]


[docs]class Swish(nn.Module): """ Apply Swish activate function. """ def __init__(self, beta: bool = False): super(Swish, self).__init__() if beta: self.weight: Union[nn.Parameter, float] = nn.Parameter(torch.tensor([1.0]), requires_grad=True) else: self.weight = 1.0
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply Swish function. Examples: >>> import torch >>> act = Swish() >>> x = torch.randn(2) >>> y = act(x) Args: x (torch.Tensor): Returns: Result of applying Swish (torch.Tensor) """ out = x * torch.sigmoid(self.weight * x) return out
@torch.jit.script def mish(x: torch.Tensor) -> torch.Tensor: """ Apply mish function. Examples: >>> import torch >>> inputs = torch.randn(2) >>> outputs = mish(inputs) Args: x (torch.Tensor): Input Data Returns: Result of applying mish (torch.Tensor) """ return x * torch.tanh(nn.functional.softplus(x))
[docs]class Mish(nn.Module): """ Apply mish activate function. """ def __init__(self): super(Mish, self).__init__() self.tanh: nn.Module = nn.Tanh() self.softplus: nn.Module = nn.Softplus()
[docs] def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ Apply Mish to the input. Examples: >>> import torch >>> act = Mish() >>> x = torch.randn(2) >>> y = act(y) Args: inputs(torch.Tensor): Returns: Result of applying mish (torch.Tensor) """ return inputs * self.tanh(self.softplus(inputs))
[docs]class FReLU1d(nn.Module): """ Applies the Funnel Activation (FReLU) for 1d inputs such as sensor signals. Examples: >>> inputs = torch.randn(1, 3, 128) # [N, features, seq_len] >>> frelu = FReLU1d(3) >>> outputs = frelu(inputs) """ def __init__(self, in_features: int, kernel_size: int = 3, stride: int = 1, padding: int = 1): super(FReLU1d, self).__init__() self.conv = nn.Conv1d( in_features, in_features, kernel_size=kernel_size, stride=stride, padding=padding, groups=in_features, ) self.norm = nn.BatchNorm1d(in_features)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Applies the Funnel Activation (FReLU) for 1d inputs such as sensor signals. Args: x: torch.Tensor Returns: """ out = self.conv(x) out = self.norm(out) out = torch.max(x, out) return out
[docs]class FReLU2d(nn.Module): """ Applies the Funnel Activation (FReLU) for 2d inputs such as images. Examples: >>> inputs = torch.randn(1, 3, 128, 128) # [N, channels, heights, widths] >>> frelu = FReLU2d(3) >>> outputs = frelu(inputs) """ def __init__(self, in_features: int, kernel_size: int = 3, stride: int = 1, padding: int = 1): super(FReLU2d, self).__init__() self.conv = nn.Conv2d( in_features, in_features, kernel_size=kernel_size, stride=stride, padding=padding, groups=in_features, ) self.norm = nn.BatchNorm2d(in_features)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Applies the Funnel Activation (FReLU) for 2d inputs such as images. Args: x: torch.Tensor Returns: """ out = self.conv(x) out = self.norm(out) out = torch.max(x, out) return out