# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from functools import partial
from typing import List
import torch
import torch.nn as nn
from jaxtyping import Float
import physicsnemo # noqa: F401 for docs
from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module
# Import AFNO layers from physicsnemo.nn
from physicsnemo.nn import AFNO2DLayer, AFNOMlp, AFNOPatchEmbed
Tensor = torch.Tensor
# Backward compatibility alias
PatchEmbed = AFNOPatchEmbed
[docs]
class Block(Module):
r"""AFNO block consisting of spectral convolution and MLP.
Parameters
----------
embed_dim : int
Embedded feature dimensionality.
num_blocks : int, optional, default=8
Number of blocks used in the block diagonal weight matrix.
mlp_ratio : float, optional, default=4.0
Ratio of MLP latent variable size to input feature size.
drop : float, optional, default=0.0
Drop out rate in MLP.
activation_fn : nn.Module, optional, default=nn.GELU()
Activation function used in MLP.
norm_layer : nn.Module, optional, default=nn.LayerNorm
Normalization function.
double_skip : bool, optional, default=True
Whether to use double skip connections.
sparsity_threshold : float, optional, default=0.01
Sparsity threshold (softshrink) of spectral features.
hard_thresholding_fraction : float, optional, default=1.0
Threshold for limiting number of modes used, in range ``[0, 1]``.
Forward
-------
x : torch.Tensor
Input tensor of shape :math:`(B, H, W, C)` where :math:`B` is batch size,
:math:`H, W` are spatial dimensions, and :math:`C` is ``embed_dim``.
Outputs
-------
torch.Tensor
Output tensor of shape :math:`(B, H, W, C)`.
Examples
--------
>>> import torch
>>> from physicsnemo.models.afno.afno import Block
>>> block = Block(embed_dim=64, num_blocks=8)
>>> x = torch.randn(2, 8, 8, 64) # (B, H, W, C)
>>> out = block(x)
>>> out.shape
torch.Size([2, 8, 8, 64])
"""
def __init__(
self,
embed_dim: int,
num_blocks: int = 8,
mlp_ratio: float = 4.0,
drop: float = 0.0,
activation_fn: nn.Module = nn.GELU(),
norm_layer: nn.Module = nn.LayerNorm,
double_skip: bool = True,
sparsity_threshold: float = 0.01,
hard_thresholding_fraction: float = 1.0,
):
super().__init__()
self.norm1 = norm_layer(embed_dim)
self.filter = AFNO2DLayer(
embed_dim, num_blocks, sparsity_threshold, hard_thresholding_fraction
)
self.norm2 = norm_layer(embed_dim)
mlp_latent_dim = int(embed_dim * mlp_ratio)
self.mlp = AFNOMlp(
in_features=embed_dim,
latent_features=mlp_latent_dim,
out_features=embed_dim,
activation_fn=activation_fn,
drop=drop,
)
self.double_skip = double_skip
def forward(self, x: Float[Tensor, "B H W C"]) -> Float[Tensor, "B H W C"]:
r"""Forward pass of the AFNO block."""
residual = x
x = self.norm1(x)
x = self.filter(x)
if self.double_skip:
x = x + residual
residual = x
x = self.norm2(x)
x = self.mlp(x)
x = x + residual
return x
@dataclass
class MetaData(ModelMetaData):
# Optimization
jit: bool = False # ONNX Ops Conflict
cuda_graphs: bool = True
amp: bool = True
# Inference
onnx_cpu: bool = False # No FFT op on CPU
onnx_gpu: bool = True
onnx_runtime: bool = True
# Physics informed
var_dim: int = 1
func_torch: bool = False
auto_grad: bool = False
[docs]
class AFNO(Module):
r"""Adaptive Fourier neural operator (AFNO) model.
AFNO is a model that is designed for 2D images only. It combines patch
embedding with spectral convolution blocks in the Fourier domain.
Parameters
----------
inp_shape : List[int]
Input image dimensions as ``[height, width]``.
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
patch_size : List[int], optional, default=[16, 16]
Size of image patches as ``[patch_height, patch_width]``.
embed_dim : int, optional, default=256
Embedded channel size.
depth : int, optional, default=4
Number of AFNO layers.
mlp_ratio : float, optional, default=4.0
Ratio of layer MLP latent variable size to input feature size.
drop_rate : float, optional, default=0.0
Drop out rate in layer MLPs.
num_blocks : int, optional, default=16
Number of blocks in the block-diag frequency weight matrices.
sparsity_threshold : float, optional, default=0.01
Sparsity threshold (softshrink) of spectral features.
hard_thresholding_fraction : float, optional, default=1.0
Threshold for limiting number of modes used, in range ``[0, 1]``.
Forward
-------
x : torch.Tensor
Input tensor of shape :math:`(B, C_{in}, H, W)` where :math:`B` is batch
size, :math:`C_{in}` is the number of input channels, and :math:`H, W` are
spatial dimensions matching ``inp_shape``.
Outputs
-------
torch.Tensor
Output tensor of shape :math:`(B, C_{out}, H, W)` where :math:`C_{out}` is
``out_channels``.
Examples
--------
>>> import torch
>>> import physicsnemo
>>> model = physicsnemo.models.afno.AFNO(
... inp_shape=[32, 32],
... in_channels=2,
... out_channels=1,
... patch_size=(8, 8),
... embed_dim=16,
... depth=2,
... num_blocks=2,
... )
>>> input = torch.randn(32, 2, 32, 32) # (N, C, H, W)
>>> output = model(input)
>>> output.size()
torch.Size([32, 1, 32, 32])
See Also
--------
:class:`~physicsnemo.models.afno.distributed.DistributedAFNO` :
Distributed (model-parallel) AFNO for multi-GPU training.
`Adaptive Fourier Neural Operator (AFNO) <https://arxiv.org/abs/2111.13587>`_ :
Original AFNO paper.
"""
def __init__(
self,
inp_shape: List[int],
in_channels: int,
out_channels: int,
patch_size: List[int] = [16, 16],
embed_dim: int = 256,
depth: int = 4,
mlp_ratio: float = 4.0,
drop_rate: float = 0.0,
num_blocks: int = 16,
sparsity_threshold: float = 0.01,
hard_thresholding_fraction: float = 1.0,
) -> None:
super().__init__(meta=MetaData())
if len(inp_shape) != 2:
raise ValueError("inp_shape should be a list of length 2")
if len(patch_size) != 2:
raise ValueError("patch_size should be a list of length 2")
if not (
inp_shape[0] % patch_size[0] == 0 and inp_shape[1] % patch_size[1] == 0
):
raise ValueError(
f"input shape {inp_shape} should be divisible by patch_size {patch_size}"
)
self.in_chans = in_channels
self.out_chans = out_channels
self.inp_shape = inp_shape
self.patch_size = patch_size
self.num_features = self.embed_dim = embed_dim
self.num_blocks = num_blocks
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.patch_embed = AFNOPatchEmbed(
inp_shape=inp_shape,
in_channels=self.in_chans,
patch_size=self.patch_size,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
self.h = inp_shape[0] // self.patch_size[0]
self.w = inp_shape[1] // self.patch_size[1]
self.blocks = nn.ModuleList(
[
Block(
embed_dim=embed_dim,
num_blocks=self.num_blocks,
mlp_ratio=mlp_ratio,
drop=drop_rate,
norm_layer=norm_layer,
sparsity_threshold=sparsity_threshold,
hard_thresholding_fraction=hard_thresholding_fraction,
)
for i in range(depth)
]
)
self.head = nn.Linear(
embed_dim,
self.out_chans * self.patch_size[0] * self.patch_size[1],
bias=False,
)
torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None:
r"""Initialize model weights.
Parameters
----------
m : nn.Module
Module to initialize.
"""
if isinstance(m, nn.Linear):
torch.nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _forward_features(
self, x: Float[Tensor, "B C H W"]
) -> Float[Tensor, "B H W D"]:
r"""Forward pass of core AFNO feature extraction.
Parameters
----------
x : torch.Tensor
Input tensor of shape :math:`(B, C_{in}, H, W)`.
Returns
-------
torch.Tensor
Features of shape :math:`(B, h, w, D)` where :math:`h, w` are patch
grid dimensions and :math:`D` is ``embed_dim``.
"""
B = x.shape[0]
# Embed patches and add positional encoding
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
# Reshape to 2D grid and apply blocks
x = x.reshape(B, self.h, self.w, self.embed_dim)
for blk in self.blocks:
x = blk(x)
return x
[docs]
def forward(self, x: Float[Tensor, "B C_in H W"]) -> Float[Tensor, "B C_out H W"]:
r"""Forward pass of the AFNO model."""
# Input validation: single check against expected shape (B, in_chans, H, W)
if not torch.compiler.is_compiling():
expected = (
self.in_chans,
self.inp_shape[0],
self.inp_shape[1],
)
if x.ndim != 4 or (x.shape[1], x.shape[2], x.shape[3]) != expected:
raise ValueError(
f"Expected input shape (B, {expected[0]}, {expected[1]}, {expected[2]}), "
f"got {tuple(x.shape)}"
)
# Extract features through AFNO blocks
x = self._forward_features(x)
# Project to output channels
x = self.head(x)
# Reshape tensor back into [B, C, H, W]
out = x.view(list(x.shape[:-1]) + [self.patch_size[0], self.patch_size[1], -1])
out = torch.permute(out, (0, 5, 1, 3, 2, 4))
out = out.reshape(list(out.shape[:2]) + [self.inp_shape[0], self.inp_shape[1]])
return out