Source code for physicsnemo.nn.module.conv_layers

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import List, Literal

import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor

from physicsnemo.core.module import Module
from physicsnemo.nn.module.utils.utils import _validate_amp
from physicsnemo.nn.module.utils.weight_init import _weight_init


[docs] class CubeEmbedding(nn.Module): """ 3D Image Cube Embedding Args: img_size (tuple[int]): Image size [T, Lat, Lon]. patch_size (tuple[int]): Patch token size [T, Lat, Lon]. in_chans (int): Number of input image channels. embed_dim (int): Number of projection output channels. norm_layer (nn.Module, optional): Normalization layer. Default: torch.nn.LayerNorm """ def __init__( self, img_size, patch_size, in_chans, embed_dim, norm_layer=nn.LayerNorm ): super().__init__() patches_resolution = [ img_size[0] // patch_size[0], img_size[1] // patch_size[1], img_size[2] // patch_size[2], ] self.img_size = img_size self.patches_resolution = patches_resolution self.embed_dim = embed_dim self.proj = nn.Conv3d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None
[docs] def forward(self, x: torch.Tensor): B, C, T, Lat, Lon = x.shape x = self.proj(x).reshape(B, self.embed_dim, -1).transpose(1, 2) # B T*Lat*Lon C if self.norm is not None: x = self.norm(x) x = x.transpose(1, 2).reshape(B, self.embed_dim, *self.patches_resolution) return x
[docs] class ConvBlock(nn.Module): """ Conv2d block Args: in_chans (int): Number of input channels. out_chans (int): Number of output channels. num_groups (int): Number of groups to separate the channels into for group normalization. num_residuals (int, optinal): Number of Conv2d operator. Default: 2 upsample (int, optinal): 1: Upsample, 0: Conv, -1: Downsample. Default: 0 """ def __init__(self, in_chans, out_chans, num_groups, num_residuals=2, upsample=0): super().__init__() if upsample == 1: self.conv = nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2) elif upsample == -1: self.conv = nn.Conv2d( in_chans, out_chans, kernel_size=(3, 3), stride=2, padding=1 ) elif upsample == 0: self.conv = nn.Conv2d( in_chans, out_chans, kernel_size=(3, 3), stride=1, padding=1 ) blk = [] for i in range(num_residuals): blk.append( nn.Conv2d(out_chans, out_chans, kernel_size=3, stride=1, padding=1) ) blk.append(nn.GroupNorm(num_groups, out_chans)) blk.append(nn.SiLU()) self.b = nn.Sequential(*blk)
[docs] def forward(self, x): x = self.conv(x) x_skip = x x = self.b(x) return x + x_skip
def _get_same_padding(x: int, k: int, s: int) -> int: r""" Function to compute "same" padding. Inspired from: `timm padding <https://github.com/huggingface/pytorch-image-models/blob/0.5.x/timm/models/layers/padding.py>`_ Parameters ---------- x : int Input dimension size. k : int Kernel size. s : int Stride. Returns ------- int Padding value to achieve "same" padding. """ return max(s * math.ceil(x / s) - s - x + k, 0)
[docs] class Conv2d(torch.nn.Module): """ A custom 2D convolutional layer implementation with support for up-sampling, down-sampling, and custom weight and bias initializations. The layer's weights and biases canbe initialized using custom initialization strategies like "kaiming_normal", and can be further scaled by factors `init_weight` and `init_bias`. Parameters ---------- in_channels : int Number of channels in the input image. out_channels : int Number of channels produced by the convolution. kernel : int Size of the convolving kernel. bias : bool, optional The biases of the layer. If set to `None`, the layer will not learn an additive bias. By default True. up : bool, optional Whether to perform up-sampling. By default False. down : bool, optional Whether to perform down-sampling. By default False. resample_filter : List[int], optional Filter to be used for resampling. By default [1, 1]. fused_resample : bool, optional If True, performs fused up-sampling and convolution or fused down-sampling and convolution. By default False. init_mode : str, optional (default="kaiming_normal") init_mode : str, optional (default="kaiming_normal") The mode/type of initialization to use for weights and biases. Supported modes are: - "xavier_uniform": Xavier (Glorot) uniform initialization. - "xavier_normal": Xavier (Glorot) normal initialization. - "kaiming_uniform": Kaiming (He) uniform initialization. - "kaiming_normal": Kaiming (He) normal initialization. By default "kaiming_normal". init_weight : float, optional A scaling factor to multiply with the initialized weights. By default 1.0. init_bias : float, optional A scaling factor to multiply with the initialized biases. By default 0.0. fused_conv_bias: bool, optional A boolean flag indicating whether bias will be passed as a parameter of conv2d. By default False. amp_mode : bool, optional A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( self, in_channels: int, out_channels: int, kernel: int, bias: bool = True, up: bool = False, down: bool = False, resample_filter: List[int] = [1, 1], fused_resample: bool = False, init_mode: Literal[ "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ] = "kaiming_normal", init_weight: float = 1.0, init_bias: float = 0.0, fused_conv_bias: bool = False, amp_mode: bool = False, ): if up and down: raise ValueError("Both 'up' and 'down' cannot be true at the same time.") if not kernel and fused_conv_bias: print( "Warning: Kernel is required when fused_conv_bias is enabled. Setting fused_conv_bias to False." ) fused_conv_bias = False super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.up = up self.down = down self.fused_resample = fused_resample self.fused_conv_bias = fused_conv_bias self.amp_mode = amp_mode init_kwargs = dict( mode=init_mode, fan_in=in_channels * kernel * kernel, fan_out=out_channels * kernel * kernel, ) self.weight = ( torch.nn.Parameter( _weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) * init_weight ) if kernel else None ) self.bias = ( torch.nn.Parameter(_weight_init([out_channels], **init_kwargs) * init_bias) if kernel and bias else None ) f = torch.as_tensor(resample_filter, dtype=torch.float32) f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square() self.register_buffer("resample_filter", f if up or down else None)
[docs] def forward(self, x): weight, bias, resample_filter = self.weight, self.bias, self.resample_filter _validate_amp(self.amp_mode) if not self.amp_mode: if self.weight is not None and self.weight.dtype != x.dtype: weight = self.weight.to(x.dtype) if self.bias is not None and self.bias.dtype != x.dtype: bias = self.bias.to(x.dtype) if ( self.resample_filter is not None and self.resample_filter.dtype != x.dtype ): resample_filter = self.resample_filter.to(x.dtype) w = weight if weight is not None else None b = bias if bias is not None else None f = resample_filter if resample_filter is not None else None w_pad = w.shape[-1] // 2 if w is not None else 0 f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 if self.fused_resample and self.up and w is not None: x = torch.nn.functional.conv_transpose2d( x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=max(f_pad - w_pad, 0), ) if self.fused_conv_bias: x = torch.nn.functional.conv2d( x, w, padding=max(w_pad - f_pad, 0), bias=b ) else: x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) elif self.fused_resample and self.down and w is not None: x = torch.nn.functional.conv2d(x, w, padding=w_pad + f_pad) if self.fused_conv_bias: x = torch.nn.functional.conv2d( x, f.tile([self.out_channels, 1, 1, 1]), groups=self.out_channels, stride=2, bias=b, ) else: x = torch.nn.functional.conv2d( x, f.tile([self.out_channels, 1, 1, 1]), groups=self.out_channels, stride=2, ) else: if self.up: x = torch.nn.functional.conv_transpose2d( x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad, ) if self.down: x = torch.nn.functional.conv2d( x, f.tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad, ) if w is not None: # ask in corrdiff channel whether w will ever be none if self.fused_conv_bias: x = torch.nn.functional.conv2d(x, w, padding=w_pad, bias=b) else: x = torch.nn.functional.conv2d(x, w, padding=w_pad) if b is not None and not self.fused_conv_bias: x = x.add_(b.reshape(1, -1, 1, 1)) return x
[docs] class ConvLayer(Module): r""" Generalized Convolution Block. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. dimension : int Dimensionality of the input (1, 2, or 3). kernel_size : int Kernel size for the convolution. stride : int, optional, default=1 Stride for the convolution. activation_fn : nn.Module, optional, default=nn.Identity() Activation function to use. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, *)` where :math:`*` represents spatial dimensions matching ``dimension``. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{out}, *)` where spatial dimensions depend on stride and padding. """ def __init__( self, in_channels: int, out_channels: int, dimension: int, kernel_size: int, stride: int = 1, activation_fn: nn.Module = nn.Identity(), ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.dimension = dimension self.activation_fn = activation_fn if self.dimension == 1: self.conv = nn.Conv1d( self.in_channels, self.out_channels, self.kernel_size, self.stride, bias=True, ) elif self.dimension == 2: self.conv = nn.Conv2d( self.in_channels, self.out_channels, self.kernel_size, self.stride, bias=True, ) elif self.dimension == 3: self.conv = nn.Conv3d( self.in_channels, self.out_channels, self.kernel_size, self.stride, bias=True, ) else: raise ValueError("Only 1D, 2D and 3D dimensions are supported") self._reset_parameters() def _exec_activation_fn( self, x: Float[Tensor, "batch channels ..."], # noqa: F722 ) -> Float[Tensor, "batch channels ..."]: # noqa: F722 r""" Executes activation function on the input. Parameters ---------- x : torch.Tensor Input tensor of shape :math:`(B, C, *)`. Returns ------- torch.Tensor Output tensor of shape :math:`(B, C, *)`. """ return self.activation_fn(x) def _reset_parameters(self) -> None: r""" Initialization for network parameters. """ nn.init.constant_(self.conv.bias, 0) nn.init.xavier_uniform_(self.conv.weight)
[docs] def forward( self, x: Float[Tensor, "batch in_channels ..."], # noqa: F722 ) -> Float[Tensor, "batch out_channels ..."]: # noqa: F722 r"""Forward pass with same padding.""" ### Input validation if not torch.compiler.is_compiling(): input_length = len(x.size()) - 2 # exclude channel and batch dims if input_length != self.dimension: raise ValueError( f"Expected {self.dimension}D input tensor (excluding batch and channel dims), " f"got {input_length}D tensor with shape {tuple(x.shape)}" ) input_length = len(x.size()) - 2 # exclude channel and batch dims # Apply same padding based on dimensionality if input_length == 1: iw = x.size()[-1:][0] pad_w = _get_same_padding(iw, self.kernel_size, self.stride) x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2], mode="constant", value=0.0) elif input_length == 2: ih, iw = x.size()[-2:] pad_h, pad_w = ( _get_same_padding(ih, self.kernel_size, self.stride), _get_same_padding(iw, self.kernel_size, self.stride), ) # F.pad expects padding in reverse dimension order: [left, right, top, bottom] x = F.pad( x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], mode="constant", value=0.0, ) else: _id, ih, iw = x.size()[-3:] pad_d, pad_h, pad_w = ( _get_same_padding(_id, self.kernel_size, self.stride), _get_same_padding(ih, self.kernel_size, self.stride), _get_same_padding(iw, self.kernel_size, self.stride), ) # F.pad expects padding in reverse dimension order: [left, right, top, bottom, front, back] x = F.pad( x, [ pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2, pad_d // 2, pad_d - pad_d // 2, ], mode="constant", value=0.0, ) # Apply convolution x = self.conv(x) # Apply activation if not identity if self.activation_fn is not nn.Identity(): x = self._exec_activation_fn(x) return x
[docs] class TransposeConvLayer(Module): r""" Generalized Transposed Convolution Block. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. dimension : int Dimensionality of the input (1, 2, or 3). kernel_size : int Kernel size for the convolution. stride : int, optional, default=1 Stride for the convolution. activation_fn : nn.Module, optional, default=nn.Identity() Activation function to use. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, *)` where :math:`*` represents spatial dimensions matching ``dimension``. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{out}, *)` where spatial dimensions are upsampled based on stride. """ def __init__( self, in_channels: int, out_channels: int, dimension: int, kernel_size: int, stride: int = 1, activation_fn=nn.Identity(), ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.dimension = dimension self.activation_fn = activation_fn if dimension == 1: self.trans_conv = nn.ConvTranspose1d( self.in_channels, self.out_channels, self.kernel_size, self.stride, bias=True, ) elif dimension == 2: self.trans_conv = nn.ConvTranspose2d( self.in_channels, self.out_channels, self.kernel_size, self.stride, bias=True, ) elif dimension == 3: self.trans_conv = nn.ConvTranspose3d( self.in_channels, self.out_channels, self.kernel_size, self.stride, bias=True, ) else: raise ValueError("Only 1D, 2D and 3D dimensions are supported") self._reset_parameters() def _exec_activation_fn( self, x: Float[Tensor, "batch channels ..."], # noqa: F722 ) -> Float[Tensor, "batch channels ..."]: # noqa: F722 r""" Executes activation function on the input. Parameters ---------- x : torch.Tensor Input tensor of shape :math:`(B, C, *)`. Returns ------- torch.Tensor Output tensor of shape :math:`(B, C, *)`. """ return self.activation_fn(x) def _reset_parameters(self) -> None: r""" Initialization for network parameters. """ nn.init.constant_(self.trans_conv.bias, 0) nn.init.xavier_uniform_(self.trans_conv.weight)
[docs] def forward( self, x: Float[Tensor, "batch in_channels ..."], # noqa: F722 ) -> Float[Tensor, "batch out_channels ..."]: # noqa: F722 r"""Forward pass with transposed convolution and cropping.""" ### Input validation if not torch.compiler.is_compiling(): input_length = len(x.size()) - 2 # exclude channel and batch dims if input_length != self.dimension: raise ValueError( f"Expected {self.dimension}D input tensor (excluding batch and channel dims), " f"got {input_length}D tensor with shape {tuple(x.shape)}" ) orig_x = x input_length = len(orig_x.size()) - 2 # exclude channel and batch dims # Apply transposed convolution x = self.trans_conv(x) # Crop output to match expected output size (same padding logic) if input_length == 1: iw = orig_x.size()[-1:][0] pad_w = _get_same_padding(iw, self.kernel_size, self.stride) x = x[ :, :, pad_w // 2 : x.size(-1) - (pad_w - pad_w // 2), ] elif input_length == 2: ih, iw = orig_x.size()[-2:] pad_h, pad_w = ( _get_same_padding( ih, self.kernel_size, self.stride, ), _get_same_padding(iw, self.kernel_size, self.stride), ) x = x[ :, :, pad_h // 2 : x.size(-2) - (pad_h - pad_h // 2), pad_w // 2 : x.size(-1) - (pad_w - pad_w // 2), ] else: _id, ih, iw = orig_x.size()[-3:] pad_d, pad_h, pad_w = ( _get_same_padding(_id, self.kernel_size, self.stride), _get_same_padding(ih, self.kernel_size, self.stride), _get_same_padding(iw, self.kernel_size, self.stride), ) x = x[ :, :, pad_d // 2 : x.size(-3) - (pad_d - pad_d // 2), pad_h // 2 : x.size(-2) - (pad_h - pad_h // 2), pad_w // 2 : x.size(-1) - (pad_w - pad_w // 2), ] # Apply activation if not identity if self.activation_fn is not nn.Identity(): x = self._exec_activation_fn(x) return x
[docs] class ConvGRULayer(Module): r""" Convolutional GRU layer. Parameters ---------- in_features : int Input features/channels. hidden_size : int Hidden layer features/channels. dimension : int Spatial dimension of the input. activation_fn : nn.Module, optional, default=nn.ReLU() Activation Function to use. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, *)` where :math:`*` represents spatial dimensions. hidden : torch.Tensor Hidden state tensor of shape :math:`(B, H, *)` where :math:`H` is ``hidden_size``. Outputs ------- torch.Tensor Next hidden state of shape :math:`(B, H, *)`. """ def __init__( self, in_features: int, hidden_size: int, dimension: int, activation_fn: nn.Module = nn.ReLU(), ) -> None: super().__init__() self.in_features = in_features self.hidden_size = hidden_size self.activation_fn = activation_fn self.conv_1 = ConvLayer( in_channels=in_features + hidden_size, out_channels=2 * hidden_size, kernel_size=3, stride=1, dimension=dimension, ) self.conv_2 = ConvLayer( in_channels=in_features + hidden_size, out_channels=hidden_size, kernel_size=3, stride=1, dimension=dimension, ) def _exec_activation_fn( self, x: Float[Tensor, "batch channels ..."], # noqa: F722 ) -> Float[Tensor, "batch channels ..."]: # noqa: F722 r""" Executes activation function on the input. Parameters ---------- x : torch.Tensor Input tensor of shape :math:`(B, C, *)`. Returns ------- torch.Tensor Output tensor of shape :math:`(B, C, *)`. """ return self.activation_fn(x)
[docs] def forward( self, x: Float[Tensor, "batch in_features ..."], # noqa: F722 hidden: Float[Tensor, "batch hidden_size ..."], # noqa: F722 ) -> Float[Tensor, "batch hidden_size ..."]: # noqa: F722 r"""Forward pass implementing GRU update.""" ### Input validation if not torch.compiler.is_compiling(): if x.shape[1] != self.in_features: raise ValueError( f"Expected input with {self.in_features} features, " f"got {x.shape[1]} features in tensor with shape {tuple(x.shape)}" ) if hidden.shape[1] != self.hidden_size: raise ValueError( f"Expected hidden state with {self.hidden_size} features, " f"got {hidden.shape[1]} features in tensor with shape {tuple(hidden.shape)}" ) if x.shape[0] != hidden.shape[0] or x.shape[2:] != hidden.shape[2:]: raise ValueError( f"Input and hidden state must have matching batch size and spatial dims. " f"Got input shape {tuple(x.shape)} and hidden shape {tuple(hidden.shape)}" ) # Concatenate input and hidden state concat = torch.cat((x, hidden), dim=1) # (B, in_features + hidden_size, *) # Compute reset and update gates conv_concat = self.conv_1(concat) # (B, 2 * hidden_size, *) conv_r, conv_z = torch.split(conv_concat, self.hidden_size, 1) reset_gate = torch.special.expit(conv_r) # (B, hidden_size, *) update_gate = torch.special.expit(conv_z) # (B, hidden_size, *) # Compute candidate hidden state concat = torch.cat((x, torch.mul(hidden, reset_gate)), dim=1) n = self._exec_activation_fn(self.conv_2(concat)) # (B, hidden_size, *) # Compute next hidden state h_next = torch.mul((1 - update_gate), n) + torch.mul(update_gate, hidden) return h_next
[docs] class ConvResidualBlock(Module): r""" Convolutional ResNet Block. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. dimension : int Dimensionality of the input. stride : int, optional, default=1 Stride of the convolutions. gated : bool, optional, default=False Residual Gate activation. layer_normalization : bool, optional, default=False Whether to apply layer normalization. begin_activation_fn : bool, optional, default=True Whether to use activation function in the beginning. activation_fn : nn.Module, optional, default=nn.ReLU() Activation function to use. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, *)` where :math:`*` represents spatial dimensions matching ``dimension``. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{out}, *)` with residual connection. Raises ------ ValueError If stride > 2 (not supported). """ def __init__( self, in_channels: int, out_channels: int, dimension: int, stride: int = 1, gated: bool = False, layer_normalization: bool = False, begin_activation_fn: bool = True, activation_fn: nn.Module = nn.ReLU(), ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.stride = stride self.dimension = dimension self.gated = gated self.layer_normalization = layer_normalization self.begin_activation_fn = begin_activation_fn self.activation_fn = activation_fn if self.stride == 1: self.conv_1 = ConvLayer( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=3, stride=self.stride, dimension=self.dimension, ) elif self.stride == 2: self.conv_1 = ConvLayer( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=4, stride=self.stride, dimension=self.dimension, ) else: raise ValueError("stride > 2 is not supported") if not self.gated: self.conv_2 = ConvLayer( in_channels=self.out_channels, out_channels=self.out_channels, kernel_size=3, stride=1, dimension=self.dimension, ) else: self.conv_2 = ConvLayer( in_channels=self.out_channels, out_channels=2 * self.out_channels, kernel_size=3, stride=1, dimension=self.dimension, ) def _exec_activation_fn( self, x: Float[Tensor, "batch channels ..."], # noqa: F722 ) -> Float[Tensor, "batch channels ..."]: # noqa: F722 r""" Executes activation function on the input. Parameters ---------- x : torch.Tensor Input tensor of shape :math:`(B, C, *)`. Returns ------- torch.Tensor Output tensor of shape :math:`(B, C, *)`. """ return self.activation_fn(x)
[docs] def forward( self, x: Float[Tensor, "batch in_channels ..."], # noqa: F722 ) -> Float[Tensor, "batch out_channels ..."]: # noqa: F722 r"""Forward pass with residual connection.""" ### Input validation if not torch.compiler.is_compiling(): input_length = len(x.size()) - 2 # exclude channel and batch dims if input_length != self.dimension: raise ValueError( f"Expected {self.dimension}D input tensor (excluding batch and channel dims), " f"got {input_length}D tensor with shape {tuple(x.shape)}" ) orig_x = x # Apply layer normalization and activation at the beginning if specified if self.begin_activation_fn: if self.layer_normalization: layer_norm = nn.LayerNorm(x.size()[1:], elementwise_affine=False) x = layer_norm(x) x = self._exec_activation_fn(x) # First convolutional layer x = self.conv_1(x) # Apply layer normalization after first convolution if self.layer_normalization: layer_norm = nn.LayerNorm(x.size()[1:], elementwise_affine=False) x = layer_norm(x) # Second activation and convolution x = self._exec_activation_fn(x) x = self.conv_2(x) # Apply gating if specified if self.gated: x_1, x_2 = torch.split(x, x.size(1) // 2, 1) x = x_1 * torch.special.expit(x_2) # Adjust skip connection if spatial dimensions differ (due to stride) if orig_x.size(-1) > x.size(-1): # Check if widths are different if len(orig_x.size()) - 2 == 1: iw = orig_x.size()[-1:][0] pad_w = _get_same_padding(iw, 2, 2) pool = torch.nn.AvgPool1d( 2, 2, padding=pad_w // 2, count_include_pad=False ) elif len(orig_x.size()) - 2 == 2: ih, iw = orig_x.size()[-2:] pad_h, pad_w = ( _get_same_padding( ih, 2, 2, ), _get_same_padding(iw, 2, 2), ) pool = torch.nn.AvgPool2d( 2, 2, padding=(pad_h // 2, pad_w // 2), count_include_pad=False ) elif len(orig_x.size()) - 2 == 3: _id, ih, iw = orig_x.size()[-3:] pad_d, pad_h, pad_w = ( _get_same_padding(_id, 2, 2), _get_same_padding(ih, 2, 2), _get_same_padding(iw, 2, 2), ) pool = torch.nn.AvgPool3d( 2, 2, padding=(pad_d // 2, pad_h // 2, pad_w // 2), count_include_pad=False, ) else: raise ValueError("Only 1D, 2D and 3D dimensions are supported") orig_x = pool(orig_x) # Adjust skip connection channels if needed in_channels = int(orig_x.size(1)) if self.out_channels > in_channels: orig_x = F.pad( orig_x, (len(orig_x.size()) - 2) * (0, 0) + (self.out_channels - self.in_channels, 0), ) elif self.out_channels < in_channels: pass return orig_x + x