Source code for physicsnemo.nn.module.afno_layers

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Adaptive Fourier Neural Operator (AFNO) layers.

This module contains reusable AFNO building blocks that can be used
in various AFNO-based architectures.
"""

from typing import List, Literal, Type, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float

import physicsnemo.nn.module.fft as fft
from physicsnemo.core.module import Module
from physicsnemo.nn.module.mlp_layers import Mlp

Tensor = torch.Tensor


[docs] class AFNOMlp(Module): r"""Fully-connected Multi-layer perception used inside AFNO. Parameters ---------- in_features : int Input feature size. latent_features : int Latent feature size. out_features : int Output feature size. activation_fn : nn.Module, optional, default=nn.GELU() Activation function. drop : float, optional, default=0.0 Drop out rate. Forward ------- x : torch.Tensor Input tensor of shape :math:`(*, D_{in})` where :math:`D_{in}` is ``in_features``. Outputs ------- torch.Tensor Output tensor of shape :math:`(*, D_{out})` where :math:`D_{out}` is ``out_features``. Examples -------- >>> import torch >>> mlp = AFNOMlp(in_features=64, latent_features=128, out_features=64) >>> x = torch.randn(4, 32, 32, 64) >>> output = mlp(x) >>> output.shape torch.Size([4, 32, 32, 64]) """ def __init__( self, in_features: int, latent_features: int, out_features: int, activation_fn: nn.Module = nn.GELU(), drop: float = 0.0, ): super().__init__() self.fc1 = nn.Linear(in_features, latent_features) self.act = activation_fn self.fc2 = nn.Linear(latent_features, out_features) self.drop = nn.Dropout(drop)
[docs] def forward(self, x: Float[Tensor, "*dims D_in"]) -> Float[Tensor, "*dims D_out"]: r"""Forward pass of the MLP.""" x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x
[docs] class AFNO2DLayer(Module): r"""AFNO spectral convolution layer. This layer performs spectral mixing using block-diagonal weight matrices in the Fourier domain with soft shrinkage for sparsity. Parameters ---------- hidden_size : int Feature dimensionality. num_blocks : int, optional, default=8 Number of blocks used in the block diagonal weight matrix. sparsity_threshold : float, optional, default=0.01 Sparsity threshold (softshrink) of spectral features. hard_thresholding_fraction : float, optional, default=1 Threshold for limiting number of modes used, in range ``[0, 1]``. hidden_size_factor : int, optional, default=1 Factor to increase spectral features by after weight multiplication. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, H, W, C)` where :math:`B` is batch size, :math:`H, W` are spatial dimensions, and :math:`C` is ``hidden_size``. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, H, W, C)`. Examples -------- >>> import torch >>> layer = AFNO2DLayer(hidden_size=64, num_blocks=8) >>> x = torch.randn(4, 32, 32, 64) >>> output = layer(x) >>> output.shape torch.Size([4, 32, 32, 64]) """ def __init__( self, hidden_size: int, num_blocks: int = 8, sparsity_threshold: float = 0.01, hard_thresholding_fraction: float = 1, hidden_size_factor: int = 1, ): super().__init__() if not (hidden_size % num_blocks == 0): raise ValueError( f"hidden_size {hidden_size} should be divisible by num_blocks {num_blocks}" ) self.hidden_size = hidden_size self.sparsity_threshold = sparsity_threshold self.num_blocks = num_blocks self.block_size = self.hidden_size // self.num_blocks self.hard_thresholding_fraction = hard_thresholding_fraction self.hidden_size_factor = hidden_size_factor self.scale = 0.02 self.w1 = nn.Parameter( self.scale * torch.randn( 2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor, ) ) self.b1 = nn.Parameter( self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor) ) self.w2 = nn.Parameter( self.scale * torch.randn( 2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size, ) ) self.b2 = nn.Parameter( self.scale * torch.randn(2, self.num_blocks, self.block_size) )
[docs] def forward(self, x: Float[Tensor, "B H W C"]) -> Float[Tensor, "B H W C"]: r"""Forward pass of the AFNO spectral layer.""" bias = x dtype = x.dtype x = x.float() B, H, W, C = x.shape # Apply 2D FFT in the spatial dimensions x = fft.rfft2(x, dim=(1, 2), norm="ortho") x_real, x_imag = fft.real(x), fft.imag(x) x_real = x_real.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) x_imag = x_imag.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) o1_real = torch.zeros( [ B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor, ], device=x.device, ) o1_imag = torch.zeros( [ B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor, ], device=x.device, ) o2 = torch.zeros(x_real.shape + (2,), device=x.device) total_modes = H // 2 + 1 kept_modes = int(total_modes * self.hard_thresholding_fraction) o1_real[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = ( F.relu( torch.einsum( "nyxbi,bio->nyxbo", x_real[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ], self.w1[0], ) - torch.einsum( "nyxbi,bio->nyxbo", x_imag[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ], self.w1[1], ) + self.b1[0] ) ) o1_imag[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = ( F.relu( torch.einsum( "nyxbi,bio->nyxbo", x_imag[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ], self.w1[0], ) + torch.einsum( "nyxbi,bio->nyxbo", x_real[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ], self.w1[1], ) + self.b1[1] ) ) o2[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ..., 0 ] = ( torch.einsum( "nyxbi,bio->nyxbo", o1_real[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w2[0], ) - torch.einsum( "nyxbi,bio->nyxbo", o1_imag[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w2[1], ) + self.b2[0] ) o2[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ..., 1 ] = ( torch.einsum( "nyxbi,bio->nyxbo", o1_imag[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w2[0], ) + torch.einsum( "nyxbi,bio->nyxbo", o1_real[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w2[1], ) + self.b2[1] ) x = F.softshrink(o2, lambd=self.sparsity_threshold) x = fft.view_as_complex(x) # TODO(akamenev): replace the following branching with # a one-liner, something like x.reshape(..., -1).squeeze(-1), # but this currently fails during ONNX export. if torch.onnx.is_in_onnx_export(): x = x.reshape(B, H, W // 2 + 1, C, 2) else: x = x.reshape(B, H, W // 2 + 1, C) # Using ONNX friendly FFT functions x = fft.irfft2(x, s=(H, W), dim=(1, 2), norm="ortho") x = x.type(dtype) return x + bias
[docs] class AFNOPatchEmbed(Module): r"""Patch embedding layer for AFNO. Converts 2D patches into a 1D vector sequence for input to AFNO. This differs from :class:`~physicsnemo.nn.module.utils.patch_embed.PatchEmbed2D` as it flattens the output to a sequence format. Parameters ---------- inp_shape : List[int] Input image dimensions as ``[height, width]``. in_channels : int Number of input channels. patch_size : List[int], optional, default=[16, 16] Size of image patches as ``[patch_height, patch_width]``. embed_dim : int, optional, default=256 Embedded channel size. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, H, W)` where :math:`B` is batch size, :math:`C_{in}` is the number of input channels, and :math:`H, W` are spatial dimensions matching ``inp_shape``. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, N, D)` where :math:`N` is the number of patches and :math:`D` is ``embed_dim``. Examples -------- >>> import torch >>> patch_embed = AFNOPatchEmbed( ... inp_shape=[32, 32], in_channels=3, patch_size=[8, 8], embed_dim=64 ... ) >>> x = torch.randn(4, 3, 32, 32) >>> output = patch_embed(x) >>> output.shape torch.Size([4, 16, 64]) """ def __init__( self, inp_shape: List[int], in_channels: int, patch_size: List[int] = [16, 16], embed_dim: int = 256, ): super().__init__() if len(inp_shape) != 2: raise ValueError("inp_shape should be a list of length 2") if len(patch_size) != 2: raise ValueError("patch_size should be a list of length 2") num_patches = (inp_shape[1] // patch_size[1]) * (inp_shape[0] // patch_size[0]) self.inp_shape = inp_shape self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=patch_size, stride=patch_size )
[docs] def forward(self, x: Float[Tensor, "B C H W"]) -> Float[Tensor, "B N D"]: r"""Forward pass of patch embedding.""" # Input validation: single check for shape (B, C, H, W) if not torch.compiler.is_compiling(): expected_c = self.proj.in_channels expected_h, expected_w = self.inp_shape[0], self.inp_shape[1] if ( x.ndim != 4 or x.shape[1] != expected_c or x.shape[2] != expected_h or x.shape[3] != expected_w ): raise ValueError( f"Expected input shape (B, {expected_c}, {expected_h}, {expected_w}), " f"got {tuple(x.shape)}" ) x = self.proj(x).flatten(2).transpose(1, 2) return x
# Alias for backward compatibility PatchEmbed = AFNOPatchEmbed
[docs] class ScaleShiftMlp(Module): r"""MLP used to compute the scale and shift parameters of the ModAFNO block. Parameters ---------- in_features : int Input feature size. out_features : int Output feature size. hidden_features : int, optional Hidden feature size. Defaults to ``2 * out_features``. hidden_layers : int, optional, default=0 Number of hidden layers. activation_fn : Type[nn.Module], optional, default=nn.GELU Activation function class. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, D_{in})`. Outputs ------- Tuple[torch.Tensor, torch.Tensor] Tuple of (scale, shift) tensors, each of shape :math:`(B, D_{out})`. Scale is offset by 1, i.e., ``(1 + scale, shift)``. Examples -------- >>> import torch >>> mlp = ScaleShiftMlp(in_features=64, out_features=128) >>> x = torch.randn(4, 64) >>> scale, shift = mlp(x) >>> scale.shape, shift.shape (torch.Size([4, 128]), torch.Size([4, 128])) See Also -------- :class:`~physicsnemo.nn.module.mlp_layers.Mlp` : The MLP used internally to produce the concatenated (scale, shift) vector. """ def __init__( self, in_features: int, out_features: int, hidden_features: Union[int, None] = None, hidden_layers: int = 0, activation_fn: Type[nn.Module] = nn.GELU, ): super().__init__() if hidden_features is None: hidden_features = out_features * 2 # Build hidden dims: one layer by default, plus hidden_layers extra hidden_dims = [hidden_features] * (hidden_layers + 1) self.net = Mlp( in_features=in_features, hidden_features=hidden_dims, out_features=out_features * 2, act_layer=activation_fn, drop=0.0, final_dropout=False, )
[docs] def forward( self, x: Float[Tensor, "B D_in"] ) -> tuple[Float[Tensor, "B D_out"], Float[Tensor, "B D_out"]]: r"""Forward pass computing scale and shift parameters.""" (scale, shift) = torch.chunk(self.net(x), 2, dim=1) return (1 + scale, shift)
[docs] class ModAFNOMlp(AFNOMlp): r"""Modulated MLP used inside ModAFNO. Extends :class:`AFNOMlp` with scale-shift modulation based on a conditioning embedding. Parameters ---------- in_features : int Input feature size. latent_features : int Latent feature size. out_features : int Output feature size. mod_features : int Modulation embedding feature size. activation_fn : nn.Module, optional, default=nn.GELU() Activation function. drop : float, optional, default=0.0 Drop out rate. scale_shift_kwargs : dict, optional Options to the MLP that computes the scale-shift parameters. Forward ------- x : torch.Tensor Input tensor of shape :math:`(*, D_{in})`. mod_embed : torch.Tensor Modulation embedding of shape :math:`(B, D_{mod})`. Outputs ------- torch.Tensor Output tensor of shape :math:`(*, D_{out})`. Examples -------- >>> import torch >>> mlp = ModAFNOMlp( ... in_features=64, latent_features=128, out_features=64, mod_features=32 ... ) >>> x = torch.randn(4, 16, 16, 64) >>> mod_embed = torch.randn(4, 32) >>> output = mlp(x, mod_embed) >>> output.shape torch.Size([4, 16, 16, 64]) """ def __init__( self, in_features: int, latent_features: int, out_features: int, mod_features: int, activation_fn: nn.Module = nn.GELU(), drop: float = 0.0, scale_shift_kwargs: Union[dict, None] = None, ): super().__init__( in_features=in_features, latent_features=latent_features, out_features=out_features, activation_fn=activation_fn, drop=drop, ) if scale_shift_kwargs is None: scale_shift_kwargs = {} self.scale_shift = ScaleShiftMlp( mod_features, latent_features, **scale_shift_kwargs )
[docs] def forward( self, x: Float[Tensor, "*dims D_in"], mod_embed: Float[Tensor, "B D_mod"], ) -> Float[Tensor, "*dims D_out"]: r"""Forward pass with modulation.""" # Compute scale and shift from modulation embedding (scale, shift) = self.scale_shift(mod_embed) scale_shift_shape = (scale.shape[0],) + (1,) * (x.ndim - 2) + (scale.shape[1],) scale = scale.view(*scale_shift_shape) shift = shift.view(*scale_shift_shape) # Apply modulated MLP x = self.fc1(x) x = x * scale + shift x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x
[docs] class ModAFNO2DLayer(AFNO2DLayer): r"""Modulated AFNO spectral convolution layer. Extends :class:`AFNO2DLayer` with scale-shift modulation in the spectral domain. Parameters ---------- hidden_size : int Feature dimensionality. mod_features : int Number of modulation features. num_blocks : int, optional, default=8 Number of blocks used in the block diagonal weight matrix. sparsity_threshold : float, optional, default=0.01 Sparsity threshold (softshrink) of spectral features. hard_thresholding_fraction : float, optional, default=1 Threshold for limiting number of modes used, in range ``[0, 1]``. hidden_size_factor : int, optional, default=1 Factor to increase spectral features by after weight multiplication. scale_shift_kwargs : dict, optional Options to the MLP that computes the scale-shift parameters. scale_shift_mode : Literal["complex", "real"], optional, default="complex" If ``"complex"``, compute the scale-shift operation using complex operations. If ``"real"``, use real operations. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, H, W, C)`. mod_embed : torch.Tensor Modulation embedding of shape :math:`(B, D_{mod})`. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, H, W, C)`. Examples -------- >>> import torch >>> layer = ModAFNO2DLayer(hidden_size=64, mod_features=32, num_blocks=8) >>> x = torch.randn(4, 16, 16, 64) >>> mod_embed = torch.randn(4, 32) >>> output = layer(x, mod_embed) >>> output.shape torch.Size([4, 16, 16, 64]) """ def __init__( self, hidden_size: int, mod_features: int, num_blocks: int = 8, sparsity_threshold: float = 0.01, hard_thresholding_fraction: float = 1, hidden_size_factor: int = 1, scale_shift_kwargs: Union[dict, None] = None, scale_shift_mode: Literal["complex", "real"] = "complex", ): super().__init__( hidden_size=hidden_size, num_blocks=num_blocks, sparsity_threshold=sparsity_threshold, hard_thresholding_fraction=hard_thresholding_fraction, hidden_size_factor=hidden_size_factor, ) if scale_shift_mode not in ("complex", "real"): raise ValueError("scale_shift_mode must be 'real' or 'complex'") self.scale_shift_mode = scale_shift_mode self.channel_mul = 1 if scale_shift_mode == "real" else 2 if scale_shift_kwargs is None: scale_shift_kwargs = {} self.scale_shift = ScaleShiftMlp( mod_features, self.num_blocks * self.block_size * self.hidden_size_factor * self.channel_mul, **scale_shift_kwargs, )
[docs] def forward( self, x: Float[Tensor, "B H W C"], mod_embed: Float[Tensor, "B D_mod"], ) -> Float[Tensor, "B H W C"]: r"""Forward pass with modulation.""" bias = x dtype = x.dtype x = x.float() B, H, W, C = x.shape # Apply 2D FFT in the spatial dimensions x = fft.rfft2(x, dim=(1, 2), norm="ortho") x_real, x_imag = fft.real(x), fft.imag(x) x_real = x_real.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) x_imag = x_imag.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) o1_shape = ( B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor, ) scale_shift_shape = (B, self.channel_mul, 1, o1_shape[3], o1_shape[4]) o1_real = torch.zeros(o1_shape, device=x.device) o1_imag = torch.zeros(o1_shape, device=x.device) o2 = torch.zeros(x_real.shape + (2,), device=x.device) total_modes = min(H, W) // 2 + 1 kept_modes = int(total_modes * self.hard_thresholding_fraction) o1_re = ( torch.einsum( "nyxbi,bio->nyxbo", x_real[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w1[0], ) - torch.einsum( "nyxbi,bio->nyxbo", x_imag[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w1[1], ) + self.b1[0] ) o1_im = ( torch.einsum( "nyxbi,bio->nyxbo", x_imag[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w1[0], ) + torch.einsum( "nyxbi,bio->nyxbo", x_real[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w1[1], ) + self.b1[1] ) # scale-shift operation (scale, shift) = self.scale_shift(mod_embed) scale = scale.view(*scale_shift_shape) shift = shift.view(*scale_shift_shape) if self.scale_shift_mode == "real": o1_re = o1_re * scale + shift o1_im = o1_im * scale + shift elif self.scale_shift_mode == "complex": (scale_re, scale_im) = torch.chunk(scale, 2, dim=1) (shift_re, shift_im) = torch.chunk(shift, 2, dim=1) (o1_re, o1_im) = ( o1_re * scale_re - o1_im * scale_im + shift_re, o1_im * scale_re + o1_re * scale_im + shift_im, ) o1_real[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = ( F.relu(o1_re) ) o1_imag[:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes] = ( F.relu(o1_im) ) o2[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ..., 0 ] = ( torch.einsum( "nyxbi,bio->nyxbo", o1_real[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w2[0], ) - torch.einsum( "nyxbi,bio->nyxbo", o1_imag[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w2[1], ) + self.b2[0] ) o2[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ..., 1 ] = ( torch.einsum( "nyxbi,bio->nyxbo", o1_imag[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w2[0], ) + torch.einsum( "nyxbi,bio->nyxbo", o1_real[ :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes ], self.w2[1], ) + self.b2[1] ) x = F.softshrink(o2, lambd=self.sparsity_threshold) x = fft.view_as_complex(x) # TODO(akamenev): replace the following branching with # a one-liner, something like x.reshape(..., -1).squeeze(-1), # but this currently fails during ONNX export. if torch.onnx.is_in_onnx_export(): x = x.reshape(B, H, W // 2 + 1, C, 2) else: x = x.reshape(B, H, W // 2 + 1, C) # Using ONNX friendly FFT functions x = fft.irfft2(x, s=(H, W), dim=(1, 2), norm="ortho") x = x.type(dtype) return x + bias