# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from functools import partial
from typing import List, Literal, Union
import torch
import torch.nn as nn
from jaxtyping import Float
import physicsnemo # noqa: F401 for docs
from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module
# Import AFNO layers from physicsnemo.nn
from physicsnemo.nn import (
AFNO2DLayer,
AFNOMlp,
AFNOPatchEmbed,
ModAFNO2DLayer,
ModAFNOMlp,
)
from .modembed import ModEmbedNet
Tensor = torch.Tensor
# Backward compatibility alias
PatchEmbed = AFNOPatchEmbed
class Block(Module):
r"""Modulated AFNO block with spectral convolution and MLP.
Parameters
----------
embed_dim : int
Embedded feature dimensionality.
mod_dim : int
Modulation input dimensionality.
num_blocks : int, optional, default=8
Number of blocks used in the block diagonal weight matrix.
mlp_ratio : float, optional, default=4.0
Ratio of MLP latent variable size to input feature size.
drop : float, optional, default=0.0
Drop out rate in MLP.
activation_fn : nn.Module, optional, default=nn.GELU()
Activation function used in MLP.
norm_layer : nn.Module, optional, default=nn.LayerNorm
Normalization function.
double_skip : bool, optional, default=True
Whether to use double skip connections.
sparsity_threshold : float, optional, default=0.01
Sparsity threshold (softshrink) of spectral features.
hard_thresholding_fraction : float, optional, default=1.0
Threshold for limiting number of modes used, in range ``[0, 1]``.
modulate_filter : bool, optional, default=True
Whether to compute the modulation for the FFT filter.
modulate_mlp : bool, optional, default=True
Whether to compute the modulation for the MLP.
scale_shift_mode : Literal["complex", "real"], optional, default="real"
If ``"complex"``, compute the scale-shift operation using complex
operations. If ``"real"``, use real operations.
Forward
-------
x : torch.Tensor
Input tensor of shape :math:`(B, H, W, C)`.
mod_embed : torch.Tensor
Modulation embedding of shape :math:`(B, D_{mod})`.
Outputs
-------
torch.Tensor
Output tensor of shape :math:`(B, H, W, C)`.
Examples
--------
>>> import torch
>>> from physicsnemo.models.afno.modafno import Block
>>> block = Block(embed_dim=64, mod_dim=32, num_blocks=8)
>>> x = torch.randn(2, 8, 8, 64)
>>> mod_embed = torch.randn(2, 32)
>>> out = block(x, mod_embed)
>>> out.shape
torch.Size([2, 8, 8, 64])
"""
def __init__(
self,
embed_dim: int,
mod_dim: int,
num_blocks: int = 8,
mlp_ratio: float = 4.0,
drop: float = 0.0,
activation_fn: nn.Module = nn.GELU(),
norm_layer: nn.Module = nn.LayerNorm,
double_skip: bool = True,
sparsity_threshold: float = 0.01,
hard_thresholding_fraction: float = 1.0,
modulate_filter: bool = True,
modulate_mlp: bool = True,
scale_shift_mode: Literal["complex", "real"] = "real",
):
super().__init__()
self.norm1 = norm_layer(embed_dim)
if modulate_filter:
self.filter = ModAFNO2DLayer(
embed_dim,
mod_dim,
num_blocks,
sparsity_threshold,
hard_thresholding_fraction,
scale_shift_mode=scale_shift_mode,
)
self.apply_filter = lambda x, mod_embed: self.filter(x, mod_embed)
else:
self.filter = AFNO2DLayer(
embed_dim, num_blocks, sparsity_threshold, hard_thresholding_fraction
)
self.apply_filter = lambda x, mod_embed: self.filter(x)
self.norm2 = norm_layer(embed_dim)
mlp_latent_dim = int(embed_dim * mlp_ratio)
if modulate_mlp:
self.mlp = ModAFNOMlp(
in_features=embed_dim,
latent_features=mlp_latent_dim,
out_features=embed_dim,
mod_features=mod_dim,
activation_fn=activation_fn,
drop=drop,
)
self.apply_mlp = lambda x, mod_embed: self.mlp(x, mod_embed)
else:
self.mlp = AFNOMlp(
in_features=embed_dim,
latent_features=mlp_latent_dim,
out_features=embed_dim,
activation_fn=activation_fn,
drop=drop,
)
self.apply_mlp = lambda x, mod_embed: self.mlp(x)
self.double_skip = double_skip
self.modulate_filter = modulate_filter
self.modulate_mlp = modulate_mlp
def forward(
self,
x: Float[Tensor, "B H W C"],
mod_embed: Float[Tensor, "B D_mod"],
) -> Float[Tensor, "B H W C"]:
r"""Forward pass of the modulated AFNO block."""
residual = x
x = self.norm1(x)
x = self.apply_filter(x, mod_embed)
if self.double_skip:
x = x + residual
residual = x
x = self.norm2(x)
x = self.apply_mlp(x, mod_embed)
x = x + residual
return x
@dataclass
class MetaData(ModelMetaData):
# Optimization
jit: bool = False # ONNX Ops Conflict
cuda_graphs: bool = True
amp: bool = True
# Inference
onnx_cpu: bool = False # No FFT op on CPU
onnx_gpu: bool = True
onnx_runtime: bool = True
# Physics informed
var_dim: int = 1
func_torch: bool = False
auto_grad: bool = False
[docs]
class ModAFNO(Module):
r"""Modulated Adaptive Fourier neural operator (ModAFNO) model.
ModAFNO extends AFNO with modulation capabilities for conditioning on
auxiliary inputs (e.g., time, parameters).
Parameters
----------
inp_shape : List[int]
Input image dimensions as ``[height, width]``.
in_channels : int, optional, default=155
Number of input channels.
out_channels : int, optional, default=73
Number of output channels.
embed_model : dict, optional
Dictionary of arguments to pass to the :class:`ModEmbedNet` embedding model.
patch_size : List[int], optional, default=[2, 2]
Size of image patches as ``[patch_height, patch_width]``.
embed_dim : int, optional, default=512
Embedded channel size.
mod_dim : int, optional, default=64
Modulation input dimensionality.
modulate_filter : bool, optional, default=True
Whether to compute the modulation for the FFT filter.
modulate_mlp : bool, optional, default=True
Whether to compute the modulation for the MLP.
scale_shift_mode : Literal["complex", "real"], optional, default="complex"
If ``"complex"``, compute the scale-shift operation using complex
operations. If ``"real"``, use real operations.
depth : int, optional, default=12
Number of AFNO layers.
mlp_ratio : float, optional, default=2.0
Ratio of layer MLP latent variable size to input feature size.
drop_rate : float, optional, default=0.0
Drop out rate in layer MLPs.
num_blocks : int, optional, default=1
Number of blocks in the block-diag frequency weight matrices.
sparsity_threshold : float, optional, default=0.01
Sparsity threshold (softshrink) of spectral features.
hard_thresholding_fraction : float, optional, default=1.0
Threshold for limiting number of modes used, in range ``[0, 1]``.
Forward
-------
x : torch.Tensor
Input tensor of shape :math:`(B, C_{in}, H, W)`.
mod : torch.Tensor
Modulation input of shape :math:`(B, 1)` or :math:`(B,)`.
Outputs
-------
torch.Tensor
Output tensor of shape :math:`(B, C_{out}, H, W)`.
Examples
--------
>>> import torch
>>> from physicsnemo.models.afno import ModAFNO
>>> model = ModAFNO(
... inp_shape=[32, 32],
... in_channels=2,
... out_channels=1,
... patch_size=(8, 8),
... embed_dim=16,
... depth=2,
... num_blocks=2,
... )
>>> input = torch.randn(32, 2, 32, 32) # (N, C, H, W)
>>> time = torch.full((32, 1), 0.5)
>>> output = model(input, time)
>>> output.size()
torch.Size([32, 1, 32, 32])
See Also
--------
`Modulated Adaptive Fourier Neural Operators for Temporal Interpolation of Weather Forecasts <https://arxiv.org/abs/2410.18904>`_ :
Leinonen et al., arXiv:2410.18904 (2024).
"""
def __init__(
self,
inp_shape: List[int],
in_channels: int = 155,
out_channels: int = 73,
embed_model: Union[dict, None] = None,
patch_size: List[int] = [2, 2],
embed_dim: int = 512,
mod_dim: int = 64,
modulate_filter: bool = True,
modulate_mlp: bool = True,
scale_shift_mode: Literal["complex", "real"] = "complex",
depth: int = 12,
mlp_ratio: float = 2.0,
drop_rate: float = 0.0,
num_blocks: int = 1,
sparsity_threshold: float = 0.01,
hard_thresholding_fraction: float = 1.0,
) -> None:
super().__init__(meta=MetaData())
if len(inp_shape) != 2:
raise ValueError("inp_shape should be a list of length 2")
if len(patch_size) != 2:
raise ValueError("patch_size should be a list of length 2")
if not (
inp_shape[0] % patch_size[0] == 0 and inp_shape[1] % patch_size[1] == 0
):
raise ValueError(
f"input shape {inp_shape} should be divisible by patch_size {patch_size}"
)
self.in_chans = in_channels
self.out_chans = out_channels
self.inp_shape = inp_shape
self.patch_size = patch_size
self.num_features = self.embed_dim = embed_dim
self.num_blocks = num_blocks
self.modulate_filter = modulate_filter
self.modulate_mlp = modulate_mlp
self.scale_shift_mode = scale_shift_mode
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.patch_embed = AFNOPatchEmbed(
inp_shape=inp_shape,
in_channels=self.in_chans,
patch_size=self.patch_size,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
self.h = inp_shape[0] // self.patch_size[0]
self.w = inp_shape[1] // self.patch_size[1]
self.blocks = nn.ModuleList(
[
Block(
embed_dim=embed_dim,
mod_dim=mod_dim,
num_blocks=self.num_blocks,
mlp_ratio=mlp_ratio,
drop=drop_rate,
norm_layer=norm_layer,
sparsity_threshold=sparsity_threshold,
hard_thresholding_fraction=hard_thresholding_fraction,
modulate_filter=modulate_filter,
modulate_mlp=modulate_mlp,
scale_shift_mode=scale_shift_mode,
)
for i in range(depth)
]
)
self.head = nn.Linear(
embed_dim,
self.out_chans * self.patch_size[0] * self.patch_size[1],
bias=False,
)
torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
self.mod_additive_proj = nn.Linear(mod_dim, embed_dim)
if not (modulate_mlp or modulate_filter):
self.mod_embed_net = nn.Identity()
else:
embed_model = {} if embed_model is None else embed_model
self.mod_embed_net = ModEmbedNet(**embed_model)
self.register_load_state_dict_pre_hook(self._migrate_legacy_checkpoint)
@staticmethod
def _migrate_legacy_checkpoint(
module,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
r"""Remap legacy scale-shift keys to the new ModuleList layout.
Previous checkpoints stored scale-shift parameters under
``scale_shift.net.{idx}.{param}``. The current implementation nests the
layers under ``scale_shift.net.layers.{idx}.{param}``. This pre-hook
rewrites legacy keys in-place to maintain compatibility.
Parameters
----------
module : torch.nn.Module
The module being loaded (unused; required by ``register_load_state_dict_pre_hook``).
state_dict : dict
State dict being loaded; modified in-place.
prefix : str
Prefix for the module (unused).
local_metadata : dict, optional
Local metadata (unused).
strict : bool
Whether strict loading is requested (unused).
missing_keys : list of str
List of missing keys (unused).
unexpected_keys : list of str
List of unexpected keys (unused).
error_msgs : list of str
Error messages (unused).
Returns
-------
None
Modifies ``state_dict`` in-place; no return value.
"""
legacy_token = ".scale_shift.net." # noqa: S105 - state_dict key token
new_token = ".scale_shift.net.layers." # noqa: S105 - state_dict key token
for old_key in list(state_dict.keys()):
if legacy_token not in old_key or new_token in old_key:
continue
new_key = old_key.replace(legacy_token, new_token)
if new_key not in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
def _init_weights(self, m: nn.Module) -> None:
r"""Initialize model weights.
Parameters
----------
m : nn.Module
Module to initialize.
"""
if isinstance(m, nn.Linear):
torch.nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _forward_features(
self,
x: Float[Tensor, "B C H W"],
mod: Float[Tensor, "B mod_in"],
) -> Float[Tensor, "B H W D"]:
r"""Forward pass of core ModAFNO feature extraction.
Parameters
----------
x : torch.Tensor
Input tensor of shape :math:`(B, C_{in}, H, W)`.
mod : torch.Tensor
Modulation input of shape :math:`(B, 1)` or :math:`(B,)`.
Returns
-------
torch.Tensor
Features of shape :math:`(B, h, w, D)` where :math:`h, w` are patch
grid dimensions and :math:`D` is ``embed_dim``.
"""
B = x.shape[0]
# Embed patches and add positional encoding
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
# Compute modulation embedding and add to features
mod_embed = self.mod_embed_net(mod)
mod_additive = self.mod_additive_proj(mod_embed).unsqueeze(dim=(1))
x = x + mod_additive
# Reshape to 2D grid and apply modulated blocks
x = x.reshape(B, self.h, self.w, self.embed_dim)
for blk in self.blocks:
x = blk(x, mod_embed=mod_embed)
return x
def forward(
self,
x: Float[Tensor, "B C_in H W"],
mod: Float[Tensor, "B mod_in"],
) -> Float[Tensor, "B C_out H W"]:
r"""Forward pass of the ModAFNO model."""
# Input validation: single check for shape (B, in_channels, H, W)
if not torch.compiler.is_compiling():
expected = (
self.in_chans,
self.inp_shape[0],
self.inp_shape[1],
)
if x.ndim != 4 or (x.shape[1], x.shape[2], x.shape[3]) != expected:
raise ValueError(
f"Expected input shape (B, {expected[0]}, {expected[1]}, {expected[2]}), "
f"got {tuple(x.shape)}"
)
# Extract features through modulated AFNO blocks
x = self._forward_features(x, mod)
# Project to output channels
x = self.head(x)
# Reshape tensor back into [B, C, H, W]
out = x.view(list(x.shape[:-1]) + [self.patch_size[0], self.patch_size[1], -1])
out = torch.permute(out, (0, 5, 1, 3, 2, 4))
out = out.reshape(list(out.shape[:2]) + [self.inp_shape[0], self.inp_shape[1]])
return out