Source code for physicsnemo.nn.module.activations

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn

import physicsnemo  # noqa: F401 for docs

Tensor = torch.Tensor


[docs] class Identity(nn.Module): """Identity activation function Dummy function for removing activations from a model Example ------- >>> idnt_func = physicsnemo.nn.Identity() >>> input = torch.randn(2, 2) >>> output = idnt_func(input) >>> torch.allclose(input, output) True """
[docs] def forward(self, x: Tensor) -> Tensor: return x
[docs] class Stan(nn.Module): """Self-scalable Tanh (Stan) for 1D Tensors Parameters ---------- out_features : int, optional Number of features, by default 1 Note ---- References: Gnanasambandam, Raghav and Shen, Bo and Chung, Jihoon and Yue, Xubo and others. Self-scalable Tanh (Stan): Faster Convergence and Better Generalization in Physics-informed Neural Networks. arXiv preprint arXiv:2204.12589, 2022. Example ------- >>> stan_func = physicsnemo.nn.Stan(out_features=1) >>> input = torch.Tensor([[0],[1],[2]]) >>> stan_func(input) tensor([[0.0000], [1.5232], [2.8921]], grad_fn=<MulBackward0>) """ def __init__(self, out_features: int = 1): super().__init__() self.beta = nn.Parameter(torch.ones(out_features))
[docs] def forward(self, x: Tensor) -> Tensor: if x.shape[-1] != self.beta.shape[-1]: raise ValueError( f"The last dimension of the input must be equal to the dimension of Stan parameters. Got inputs: {x.shape}, params: {self.beta.shape}" ) return torch.tanh(x) * (1.0 + self.beta * x)
[docs] class SquarePlus(nn.Module): """Squareplus activation Note ---- Reference: arXiv preprint arXiv:2112.11687 Example ------- >>> sqr_func = physicsnemo.nn.SquarePlus() >>> input = torch.Tensor([[1,2],[3,4]]) >>> sqr_func(input) tensor([[1.6180, 2.4142], [3.3028, 4.2361]]) """ def __init__(self): super().__init__() self.b = 4
[docs] def forward(self, x: Tensor) -> Tensor: return 0.5 * (x + torch.sqrt(x * x + self.b))
[docs] class CappedLeakyReLU(torch.nn.Module): """ Implements a ReLU with capped maximum value. Example ------- >>> capped_leakyReLU_func = physicsnemo.nn.CappedLeakyReLU() >>> input = torch.Tensor([[-2,-1],[0,1],[2,3]]) >>> capped_leakyReLU_func(input) tensor([[-0.0200, -0.0100], [ 0.0000, 1.0000], [ 1.0000, 1.0000]]) """ def __init__(self, cap_value=1.0, **kwargs): """ Parameters: ---------- cap_value: float, optional Maximum that values will be capped at **kwargs: Keyword arguments to be passed to the `torch.nn.LeakyReLU` function """ super().__init__() self.add_module("leaky_relu", torch.nn.LeakyReLU(**kwargs)) self.register_buffer("cap", torch.tensor(cap_value, dtype=torch.float32))
[docs] def forward(self, inputs): x = self.leaky_relu(inputs) x = torch.clamp(x, max=self.cap) return x
[docs] class CappedGELU(torch.nn.Module): """ Implements a GELU with capped maximum value. Example ------- >>> capped_gelu_func = physicsnemo.nn.CappedGELU() >>> input = torch.Tensor([[-2,-1],[0,1],[2,3]]) >>> capped_gelu_func(input) tensor([[-0.0455, -0.1587], [ 0.0000, 0.8413], [ 1.0000, 1.0000]]) """ def __init__(self, cap_value=1.0, **kwargs): """ Parameters: ---------- cap_value: float, optional Maximum that values will be capped at **kwargs: Keyword arguments to be passed to the `torch.nn.GELU` function """ super().__init__() self.add_module("gelu", torch.nn.GELU(**kwargs)) self.register_buffer("cap", torch.tensor(cap_value, dtype=torch.float32))
[docs] def forward(self, inputs): x = self.gelu(inputs) x = torch.clamp(x, max=self.cap) return x
# Dictionary of activation functions ACT2FN = { "relu": nn.ReLU, "leaky_relu": (nn.LeakyReLU, {"negative_slope": 0.1}), "prelu": nn.PReLU, "relu6": nn.ReLU6, "elu": nn.ELU, "celu": (nn.CELU, {"alpha": 1.0}), "selu": nn.SELU, "silu": nn.SiLU, "gelu": nn.GELU, "sigmoid": nn.Sigmoid, "logsigmoid": nn.LogSigmoid, "softplus": nn.Softplus, "softshrink": nn.Softshrink, "softsign": nn.Softsign, "tanh": nn.Tanh, "tanhshrink": nn.Tanhshrink, "threshold": (nn.Threshold, {"threshold": 1.0, "value": 1.0}), "hardtanh": nn.Hardtanh, "identity": Identity, "stan": Stan, "squareplus": SquarePlus, "cappek_leaky_relu": CappedLeakyReLU, "capped_gelu": CappedGELU, }
[docs] def get_activation(activation: str) -> nn.Module: """Returns an activation function given a string Parameters ---------- activation : str String identifier for the desired activation function Returns ------- Activation function Raises ------ KeyError If the specified activation function is not found in the dictionary """ try: activation = activation.lower() module = ACT2FN[activation] if isinstance(module, tuple): return module[0](**module[1]) else: return module() except KeyError: raise KeyError( f"Activation function {activation} not found. Available options are: {list(ACT2FN.keys())}" )