Source code for physicsnemo.nn.module.fully_connected_layers

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Literal, Union

import torch
import torch.nn as nn
from torch import Tensor

from physicsnemo.core import Module
from physicsnemo.nn.module.utils.utils import _validate_amp
from physicsnemo.nn.module.utils.weight_init import _weight_init

from .activations import Identity
from .weight_fact import WeightFactLinear
from .weight_norm import WeightNormLinear


[docs] class FCLayer(Module): r"""Densely connected neural network layer. A single fully connected layer with optional activation, weight normalization, and weight factorization. Parameters ---------- in_features : int Size of input features :math:`D_{in}`. out_features : int Size of output features :math:`D_{out}`. activation_fn : Union[nn.Module, Callable[[Tensor], Tensor], None], optional, default=None Activation function to use. Can be ``None`` for no activation. weight_norm : bool, optional, default=False Applies weight normalization to the layer. weight_fact : bool, optional, default=False Applies weight factorization to the layer. activation_par : Union[nn.Parameter, None], optional, default=None Learnable scaling parameter for adaptive activations. Forward ------- x : torch.Tensor Input tensor of shape :math:`(*, D_{in})` where :math:`*` denotes any number of leading batch dimensions. Outputs ------- torch.Tensor Output tensor of shape :math:`(*, D_{out})`. """ def __init__( self, in_features: int, out_features: int, activation_fn: Union[nn.Module, Callable[[Tensor], Tensor], None] = None, weight_norm: bool = False, weight_fact: bool = False, activation_par: Union[nn.Parameter, None] = None, ) -> None: super().__init__() if activation_fn is None: self.activation_fn = Identity() else: self.activation_fn = activation_fn self.weight_norm = weight_norm self.weight_fact = weight_fact self.activation_par = activation_par # Ensure weight_norm and weight_fact are not both True if weight_norm and weight_fact: raise ValueError( "Cannot apply both weight normalization and weight factorization together, please select one." ) if weight_norm: self.linear = WeightNormLinear(in_features, out_features, bias=True) elif weight_fact: self.linear = WeightFactLinear(in_features, out_features, bias=True) else: self.linear = nn.Linear(in_features, out_features, bias=True) self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Reset fully connected layer weights to Xavier uniform initialization.""" if not self.weight_norm and not self.weight_fact: nn.init.constant_(self.linear.bias, 0) nn.init.xavier_uniform_(self.linear.weight)
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass through the layer.""" x = self.linear(x) if self.activation_par is None: x = self.activation_fn(x) else: x = self.activation_fn(self.activation_par * x) return x
[docs] class ConvFCLayer(Module): r"""Base class for 1x1 convolutional layers acting on image channels. This abstract base class provides activation handling for convolutional layers that act like fully connected layers over the channel dimension. Parameters ---------- activation_fn : Union[nn.Module, Callable[[Tensor], Tensor], None], optional, default=None Activation function to use. Can be ``None`` for no activation. activation_par : Union[nn.Parameter, None], optional, default=None Learnable scaling parameter for adaptive activations. Forward ------- x : torch.Tensor Input tensor (shape depends on subclass). Outputs ------- torch.Tensor Output tensor with activation applied. """ def __init__( self, activation_fn: Union[nn.Module, Callable[[Tensor], Tensor], None] = None, activation_par: Union[nn.Parameter, None] = None, ) -> None: super().__init__() if activation_fn is None: self.activation_fn = Identity() else: self.activation_fn = activation_fn self.activation_par = activation_par
[docs] def apply_activation(self, x: Tensor) -> Tensor: r"""Apply activation function with optional learnable scaling. Parameters ---------- x : torch.Tensor Input tensor of arbitrary shape. Returns ------- torch.Tensor Tensor with activation applied, same shape as input. """ if self.activation_par is None: x = self.activation_fn(x) else: x = self.activation_fn(self.activation_par * x) return x
[docs] class Conv1dFCLayer(ConvFCLayer): r"""Channel-wise fully connected layer using 1D convolutions. Applies a 1x1 convolution followed by an optional activation function. This is equivalent to a fully connected layer operating on the channel dimension of 1D signals. Parameters ---------- in_features : int Number of input channels :math:`C_{in}`. out_features : int Number of output channels :math:`C_{out}`. activation_fn : Union[nn.Module, Callable[[Tensor], Tensor], None], optional, default=None Activation function to use. Can be ``None`` for no activation. activation_par : Union[nn.Parameter, None], optional, default=None Learnable scaling parameter for adaptive activations. weight_norm : bool, optional, default=False Weight normalization (not currently supported, raises ``NotImplementedError``). Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, L)` where :math:`B` is batch size and :math:`L` is sequence length. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{out}, L)`. """ def __init__( self, in_features: int, out_features: int, activation_fn: Union[nn.Module, Callable[[Tensor], Tensor], None] = None, activation_par: Union[nn.Parameter, None] = None, weight_norm: bool = False, ) -> None: super().__init__(activation_fn, activation_par) self.in_channels = in_features self.out_channels = out_features self.conv = nn.Conv1d(in_features, out_features, kernel_size=1, bias=True) self.reset_parameters() if weight_norm: raise NotImplementedError("Weight norm not supported for Conv FC layers")
[docs] def reset_parameters(self) -> None: """Reset layer weights to Xavier uniform initialization.""" nn.init.constant_(self.conv.bias, 0) nn.init.xavier_uniform_(self.conv.weight)
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass through the 1D convolutional layer.""" x = self.conv(x) x = self.apply_activation(x) return x
[docs] class Conv2dFCLayer(ConvFCLayer): r"""Channel-wise fully connected layer using 2D convolutions. Applies a 1x1 convolution followed by an optional activation function. This is equivalent to a fully connected layer operating on the channel dimension of 2D images. Parameters ---------- in_channels : int Number of input channels :math:`C_{in}`. out_channels : int Number of output channels :math:`C_{out}`. activation_fn : Union[nn.Module, Callable[[Tensor], Tensor], None], optional, default=None Activation function to use. Can be ``None`` for no activation. activation_par : Union[nn.Parameter, None], optional, default=None Learnable scaling parameter for adaptive activations. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, H, W)` where :math:`B` is batch size, :math:`H` is height, and :math:`W` is width. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{out}, H, W)`. """ def __init__( self, in_channels: int, out_channels: int, activation_fn: Union[nn.Module, Callable[[Tensor], Tensor], None] = None, activation_par: Union[nn.Parameter, None] = None, ) -> None: super().__init__(activation_fn, activation_par) self.in_channels = in_channels self.out_channels = out_channels self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True) self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Reset layer weights to Xavier uniform initialization.""" nn.init.constant_(self.conv.bias, 0) self.conv.bias.requires_grad = False nn.init.xavier_uniform_(self.conv.weight)
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass through the 2D convolutional layer.""" x = self.conv(x) x = self.apply_activation(x) return x
[docs] class Conv3dFCLayer(ConvFCLayer): r"""Channel-wise fully connected layer using 3D convolutions. Applies a 1x1x1 convolution followed by an optional activation function. This is equivalent to a fully connected layer operating on the channel dimension of 3D volumes. Parameters ---------- in_channels : int Number of input channels :math:`C_{in}`. out_channels : int Number of output channels :math:`C_{out}`. activation_fn : Union[nn.Module, Callable[[Tensor], Tensor], None], optional, default=None Activation function to use. Can be ``None`` for no activation. activation_par : Union[nn.Parameter, None], optional, default=None Learnable scaling parameter for adaptive activations. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, D, H, W)` where :math:`B` is batch size, :math:`D` is depth, :math:`H` is height, and :math:`W` is width. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{out}, D, H, W)`. """ def __init__( self, in_channels: int, out_channels: int, activation_fn: Union[nn.Module, Callable[[Tensor], Tensor], None] = None, activation_par: Union[nn.Parameter, None] = None, ) -> None: super().__init__(activation_fn, activation_par) self.in_channels = in_channels self.out_channels = out_channels self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=True) self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Reset layer weights to Xavier uniform initialization.""" nn.init.constant_(self.conv.bias, 0) nn.init.xavier_uniform_(self.conv.weight)
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass through the 3D convolutional layer.""" x = self.conv(x) x = self.apply_activation(x) return x
[docs] class ConvNdFCLayer(ConvFCLayer): r"""Channel-wise fully connected layer with N-dimensional convolutions. Applies a kernel-1 convolution followed by an optional activation function. For dimensions 1, 2, or 3, use :class:`Conv1dFCLayer`, :class:`Conv2dFCLayer`, or :class:`Conv3dFCLayer` instead for better performance. Parameters ---------- in_channels : int Number of input channels :math:`C_{in}`. out_channels : int Number of output channels :math:`C_{out}`. activation_fn : Union[nn.Module, None], optional, default=None Activation function to use. Can be ``None`` for no activation. activation_par : Union[nn.Parameter, None], optional, default=None Learnable scaling parameter for adaptive activations. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, *spatial)` where :math:`B` is batch size and :math:`*spatial` represents arbitrary spatial dimensions. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{out}, *spatial)`. """ def __init__( self, in_channels: int, out_channels: int, activation_fn: Union[nn.Module, None] = None, activation_par: Union[nn.Parameter, None] = None, ) -> None: super().__init__(activation_fn, activation_par) self.in_channels = in_channels self.out_channels = out_channels self.conv = ConvNdKernel1Layer(in_channels, out_channels) self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Reset layer weights by recursively applying Xavier initialization.""" self.conv.apply(self.initialise_parameters)
[docs] def initialise_parameters(self, model: nn.Module) -> None: """Initialize weights and biases for a module. Parameters ---------- model : nn.Module Module to initialize. """ if hasattr(model, "bias"): nn.init.constant_(model.bias, 0) if hasattr(model, "weight"): nn.init.xavier_uniform_(model.weight)
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass through the N-dimensional convolutional layer.""" x = self.conv(x) x = self.apply_activation(x) return x
[docs] class ConvNdKernel1Layer(Module): r"""Kernel-1 convolution layer for N-dimensional inputs. Implements a 1x1 convolution by reshaping the input to 1D, applying a 1D convolution, and reshaping back. For dimensions 1, 2, or 3, use the specialized layer classes for better performance. Parameters ---------- in_channels : int Number of input channels :math:`C_{in}`. out_channels : int Number of output channels :math:`C_{out}`. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, *spatial)` where :math:`B` is batch size and :math:`*spatial` represents arbitrary spatial dimensions. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{out}, *spatial)`. """ def __init__( self, in_channels: int, out_channels: int, ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=True)
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass through the N-dimensional kernel-1 convolution.""" dims = list(x.size()) dims[1] = self.out_channels x = self.conv(x.view(dims[0], self.in_channels, -1)).view(dims) return x
[docs] class Linear(Module): r"""Fully connected (dense) layer with customizable initialization. The layer's weights and biases can be initialized using custom strategies like ``"kaiming_normal"``, and scaled by ``init_weight`` and ``init_bias``. Parameters ---------- in_features : int Size of each input sample :math:`D_{in}`. out_features : int Size of each output sample :math:`D_{out}`. bias : bool, optional, default=True If ``True``, adds a learnable bias to the output. If ``False``, the layer will not learn an additive bias. init_mode : str, optional, default="kaiming_normal" The initialization mode for weights and biases. Supported modes are: ``"xavier_uniform"``, ``"xavier_normal"``, ``"kaiming_uniform"``, ``"kaiming_normal"``. init_weight : float, optional, default=1 A scaling factor to multiply with the initialized weights. init_bias : float, optional, default=0 A scaling factor to multiply with the initialized biases. amp_mode : bool, optional, default=False Whether mixed-precision (AMP) training is enabled. Forward ------- x : torch.Tensor Input tensor of shape :math:`(*, D_{in})` where :math:`*` denotes any number of leading batch dimensions. Outputs ------- torch.Tensor Output tensor of shape :math:`(*, D_{out})`. """ def __init__( self, in_features: int, out_features: int, bias: bool = True, init_mode: Literal[ "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ] = "kaiming_normal", init_weight: int = 1, init_bias: int = 0, amp_mode: bool = False, ): super().__init__() self.in_features = in_features self.out_features = out_features self.amp_mode = amp_mode init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) self.weight = torch.nn.Parameter( _weight_init([out_features, in_features], **init_kwargs) * init_weight ) self.bias = ( torch.nn.Parameter(_weight_init([out_features], **init_kwargs) * init_bias) if bias else None )
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass through the linear layer.""" weight, bias = self.weight, self.bias _validate_amp(self.amp_mode) if not self.amp_mode: if self.weight is not None and self.weight.dtype != x.dtype: weight = self.weight.to(x.dtype) if self.bias is not None and self.bias.dtype != x.dtype: bias = self.bias.to(x.dtype) x = x @ weight.t() if self.bias is not None: x = x.add_(bias) return x