Fourier, FFT, and Spectral Layers#

class physicsnemo.nn.module.fourier_layers.FourierFilter(
in_features: int,
layer_size: int,
nr_layers: int,
input_scale: float,
)[source]#

Bases: Module

Fourier filter used in the multiplicative filter network

forward(x: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters() None[source]#

Resets parameters

class physicsnemo.nn.module.fourier_layers.FourierLayer(in_features: int, frequencies)[source]#

Bases: Module

Fourier layer used in the Fourier feature network

forward(x: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class physicsnemo.nn.module.fourier_layers.FourierMLP(
input_features: int,
base_layer: int,
fourier_features: bool,
num_modes: int,
activation: Module | str,
)[source]#

Bases: Module

This is an MLP that will, optionally, fourier encode the input features.

The encoded features are concatenated to the original inputs, and then processed with an MLP.

Parameters:
  • input_features – The number of input features to the MLP.

  • base_layer – The number of neurons in the hidden layer of the MLP.

  • fourier_features – Whether to fourier encode the input features.

  • num_modes – The number of modes to use for the fourier encoding.

  • activation – The activation function to use in the MLP.

forward(x: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class physicsnemo.nn.module.fourier_layers.GaborFilter(
in_features: int,
layer_size: int,
nr_layers: int,
input_scale: float,
alpha: float,
beta: float,
)[source]#

Bases: Module

Gabor filter used in the multiplicative filter network

forward(x: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters() None[source]#

Resets parameters

physicsnemo.nn.module.fourier_layers.fourier_encode(
coords: Tensor,
freqs: Tensor,
) Tensor[source]#

Vectorized Fourier feature encoding

Parameters:
  • coords – Tensor containing coordinates, of shape (batch_size, D)

  • freqs – Tensor containing frequencies, of shape (F,) (num frequencies)

Returns:

Tensor containing Fourier features, of shape (batch_size, D * 2 * F)

class physicsnemo.nn.module.spectral_layers.SpectralConv1d(in_channels: int, out_channels: int, modes1: int)[source]#

Bases: Module

1D Fourier layer. It does FFT, linear transform, and Inverse FFT.

Parameters:
  • in_channels (int) – Number of input channels

  • out_channels (int) – Number of output channels

  • modes1 (int) – Number of Fourier modes to multiply, at most floor(N/2) + 1

compl_mul1d(
input: Complex[Tensor, 'batch in_channels modes'],
weights: Float[Tensor, 'in_channels out_channels modes 2'],
) Complex[Tensor, 'batch out_channels modes'][source]#

Complex multiplication

Parameters:
  • input (Tensor) – Input tensor

  • weights (Tensor) – Weights tensor

Returns:

Product of complex multiplication

Return type:

Tensor

forward(
x: Float[Tensor, 'batch in_channels x'],
) Float[Tensor, 'batch out_channels x'][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()[source]#

Reset spectral weights with distribution scale*U(0,1)

class physicsnemo.nn.module.spectral_layers.SpectralConv2d(
in_channels: int,
out_channels: int,
modes1: int,
modes2: int,
)[source]#

Bases: Module

2D Fourier layer. It does FFT, linear transform, and Inverse FFT.

Parameters:
  • in_channels (int) – Number of input channels

  • out_channels (int) – Number of output channels

  • modes1 (int) – Number of Fourier modes to multiply in first dimension, at most floor(N/2) + 1

  • modes2 (int) – Number of Fourier modes to multiply in second dimension, at most floor(N/2) + 1

compl_mul2d(
input: Complex[Tensor, 'batch in_channels modes1 modes2'],
weights: Float[Tensor, 'in_channels out_channels modes1 modes2 2'],
) Complex[Tensor, 'batch out_channels modes1 modes2'][source]#

Complex multiplication

Parameters:
  • input (Tensor) – Input tensor

  • weights (Tensor) – Weights tensor

Returns:

Product of complex multiplication

Return type:

Tensor

forward(
x: Float[Tensor, 'batch in_channels h w'],
) Float[Tensor, 'batch out_channels h w'][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()[source]#

Reset spectral weights with distribution scale*U(0,1)

class physicsnemo.nn.module.spectral_layers.SpectralConv3d(
in_channels: int,
out_channels: int,
modes1: int,
modes2: int,
modes3: int,
)[source]#

Bases: Module

3D Fourier layer. It does FFT, linear transform, and Inverse FFT.

Parameters:
  • in_channels (int) – Number of input channels

  • out_channels (int) – Number of output channels

  • modes1 (int) – Number of Fourier modes to multiply in first dimension, at most floor(N/2) + 1

  • modes2 (int) – Number of Fourier modes to multiply in second dimension, at most floor(N/2) + 1

  • modes3 (int) – Number of Fourier modes to multiply in third dimension, at most floor(N/2) + 1

compl_mul3d(
input: Complex[Tensor, 'batch in_channels modes1 modes2 modes3'],
weights: Float[Tensor, 'in_channels out_channels modes1 modes2 modes3 2'],
) Complex[Tensor, 'batch out_channels modes1 modes2 modes3'][source]#

Complex multiplication

Parameters:
  • input (Tensor) – Input tensor

  • weights (Tensor) – Weights tensor

Returns:

Product of complex multiplication

Return type:

Tensor

forward(
x: Float[Tensor, 'batch in_channels d1 d2 d3'],
) Float[Tensor, 'batch out_channels d1 d2 d3'][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()[source]#

Reset spectral weights with distribution scale*U(0,1)

class physicsnemo.nn.module.spectral_layers.SpectralConv4d(
in_channels: int,
out_channels: int,
modes1: int,
modes2: int,
modes3: int,
modes4: int,
)[source]#

Bases: Module

4D Fourier layer. It does FFT, linear transform, and Inverse FFT.

Parameters:
  • in_channels (int) – Number of input channels

  • out_channels (int) – Number of output channels

  • modes1 (int) – Number of Fourier modes to multiply in first dimension, at most floor(N/2) + 1

  • modes2 (int) – Number of Fourier modes to multiply in second dimension, at most floor(N/2) + 1

  • modes3 (int) – Number of Fourier modes to multiply in third dimension, at most floor(N/2) + 1

compl_mul4d(
input: Complex[Tensor, 'batch in_channels modes1 modes2 modes3 modes4'],
weights: Float[Tensor, 'in_channels out_channels modes1 modes2 modes3 modes4 2'],
) Complex[Tensor, 'batch out_channels modes1 modes2 modes3 modes4'][source]#

Complex multiplication

Parameters:
  • input (Tensor) – Input tensor

  • weights (Tensor) – Weights tensor

Returns:

Product of complex multiplication

Return type:

Tensor

forward(
x: Float[Tensor, 'batch in_channels d1 d2 d3 d4'],
) Float[Tensor, 'batch out_channels d1 d2 d3 d4'][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()[source]#

Reset spectral weights with distribution scale*U(0,1)

physicsnemo.nn.module.spectral_layers.calc_latent_derivatives(
x: Tensor,
domain_length: List[int] = 2,
) Tuple[List[Tensor], List[Tensor]][source]#

Compute first and second order derivatives of latent variables

physicsnemo.nn.module.spectral_layers.first_order_pino_grads(
u: Tensor,
ux: List[Tensor],
weights_1: Tensor,
weights_2: Tensor,
bias_1: Tensor,
) Tuple[Tensor][source]#

Compute first order derivatives of output variables

physicsnemo.nn.module.spectral_layers.fourier_derivatives(
x: Tensor,
ell: List[float],
) Tuple[Tensor, Tensor][source]#

Fourier derivative function for PINO

physicsnemo.nn.module.spectral_layers.second_order_pino_grads(
u: Tensor,
ux: Tensor,
uxx: Tensor,
weights_1: Tensor,
weights_2: Tensor,
bias_1: Tensor,
) Tuple[Tensor][source]#

Compute second order derivatives of output variables

class physicsnemo.nn.module.fft.OnnxIrfft(*args, **kwargs)[source]#

Bases: Function

Auto-grad function to mimic irfft for ONNX exporting

Note

Should only be called during an ONNX export

static forward(ctx, input: Tensor) Tensor[source]#

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

static symbolic(
g: Graph,
input: Value,
) Value[source]#

Symbolic representation for onnx graph

class physicsnemo.nn.module.fft.OnnxIrfft2(*args, **kwargs)[source]#

Bases: Function

Auto-grad function to mimic irfft2 for ONNX exporting.

Note

Should only be called during an ONNX export

static forward(ctx, input: Tensor) Tensor[source]#

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

static symbolic(
g: Graph,
input: Value,
) Value[source]#

Symbolic representation for onnx graph

class physicsnemo.nn.module.fft.OnnxRfft(*args, **kwargs)[source]#

Bases: Function

Auto-grad function to mimic rfft for ONNX exporting

Note

Should only be called during an ONNX export

static forward(ctx, input: Tensor) Tensor[source]#

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

static symbolic(
g: Graph,
input: Value,
) Value[source]#

Symbolic representation for onnx graph

class physicsnemo.nn.module.fft.OnnxRfft2(*args, **kwargs)[source]#

Bases: Function

Auto-grad function to mimic rfft2 for ONNX exporting

Note

Should only be called during an ONNX export

static forward(ctx, input: Tensor) Tensor[source]#

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

static symbolic(
g: Graph,
input: Value,
) Value[source]#

Symbolic representation for onnx graph

physicsnemo.nn.module.fft.imag(input: Tensor) Tensor[source]#

ONNX compatable method to view input as imaginary tensor

Parameters:

input (Tensor) – The input Tensor

Note

The function is equivalent to input.imag when not running in ONNX export mode

Raises:

AssertionError – If input tensor shape is not […,2] during ONNX runtime where the last dimension denotes the real / imaginary tensors

physicsnemo.nn.module.fft.irfft(
input: Tensor,
n: int | None = None,
dim: int = -1,
norm: str | None = None,
) Tensor[source]#

ONNX compatable method to compute the inverse of rfft.

Parameters:
  • input (Tensor) – Real input tensor

  • n (Optional[int], optional) – Signal strength, by default None

  • dim (int, optional) – Dimension along which to take the real IFFT, by default -1

  • norm (Optional[str], optional) – Normalization mode with options “forward”, “backward” and “ortho”. When set to None, normalization will default to backward (no normalization), by default None

Note

The function is equivalent to torch.fft.irfft when not running in ONNX export mode

physicsnemo.nn.module.fft.irfft2(
input: Tensor,
s: Tuple[int] | None = None,
dim: Tuple[int] = (-2, -1),
norm: str | None = None,
) Tensor[source]#

ONNX compatable method to compute the inverse of rfft2.

Parameters:
  • input (Tensor) – Real input tensor

  • s (Optional[Tuple[int]], optional) – Signal size in the transformed dimensions, by default None

  • dim (Tuple[int], optional) – Dimensions along which to take the real 2D IFFT, by default (-2, -1)

  • norm (Optional[str], optional) – Normalization mode with options “forward”, “backward” and “ortho”. When set to None, normalization will default to backward (normalize by 1/n), by default None

Note

The function is equivalent to torch.fft.irfft2 when not running in ONNX export mode

physicsnemo.nn.module.fft.real(input: Tensor) Tensor[source]#

ONNX compatable method to view input as real tensor

Parameters:

input (Tensor) – The input Tensor

Note

The function is equivalent to input.real when not running in ONNX export mode

Raises:

AssertionError – If input tensor shape is not […,2] during ONNX runtime where the last dimension denotes the real / imaginary tensors

physicsnemo.nn.module.fft.rfft(
input: Tensor,
n: int | None = None,
dim: int = -1,
norm: str | None = None,
) Tensor[source]#

ONNX compatable method to compute the 1d Fourier transform of real-valued input.

Parameters:
  • input (Tensor) – Real input tensor

  • n (Optional[int], optional) – Signal strength, by default None

  • dim (int, optional) – Dimension along which to take the real FFT, by default -1

  • norm (Optional[str], optional) – Normalization mode with options “forward”, “backward and “ortho”. When set to None, normalization will default to backward (no normalization), by default None

Note

The function is equivalent to torch.fft.rfft when not running in ONNX export mode

physicsnemo.nn.module.fft.rfft2(
input: Tensor,
s: Tuple[int] | None = None,
dim: Tuple[int] = (-2, -1),
norm: str | None = None,
) Tensor[source]#

ONNX compatable method to compute the 2d Fourier transform of real-valued input.

Parameters:
  • input (Tensor) – Real input tensor

  • s (Optional[Tuple[int]], optional) – Signal size in the transformed dimensions, by default None

  • dim (Tuple[int], optional) – Dimensions along which to take the real 2D FFT, by default (-2, -1)

  • norm (Optional[str], optional) – Normalization mode with options “forward”, “backward” and “ortho”. When set to None, normalization will default to backward (normalize by 1/n), by default None

Note

The function is equivalent to torch.fft.rfft2 when not running in ONNX export mode

physicsnemo.nn.module.fft.view_as_complex(input: Tensor) Tensor[source]#

ONNX compatable method to view input as complex tensor

Parameters:

input (Tensor) – The input Tensor

Note

The function is equivalent to torch.view_as_complex when not running in ONNX export mode

Raises:

AssertionError – If input tensor shape is not […,2] during ONNX runtime where the last dimension denotes the real / imaginary tensors

Adaptive Fourier Neural Operator (AFNO) layers.

This module contains reusable AFNO building blocks that can be used in various AFNO-based architectures.

class physicsnemo.nn.module.afno_layers.AFNO2DLayer(*args, **kwargs)[source]#

Bases: Module

AFNO spectral convolution layer.

This layer performs spectral mixing using block-diagonal weight matrices in the Fourier domain with soft shrinkage for sparsity.

Parameters:
  • hidden_size (int) – Feature dimensionality.

  • num_blocks (int, optional, default=8) – Number of blocks used in the block diagonal weight matrix.

  • sparsity_threshold (float, optional, default=0.01) – Sparsity threshold (softshrink) of spectral features.

  • hard_thresholding_fraction (float, optional, default=1) – Threshold for limiting number of modes used, in range [0, 1].

  • hidden_size_factor (int, optional, default=1) – Factor to increase spectral features by after weight multiplication.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, H, W, C)\) where \(B\) is batch size, \(H, W\) are spatial dimensions, and \(C\) is hidden_size.

Outputs:

torch.Tensor – Output tensor of shape \((B, H, W, C)\).

Examples

>>> import torch
>>> layer = AFNO2DLayer(hidden_size=64, num_blocks=8)
>>> x = torch.randn(4, 32, 32, 64)
>>> output = layer(x)
>>> output.shape
torch.Size([4, 32, 32, 64])
forward(
x: Float[Tensor, 'B H W C'],
) Float[Tensor, 'B H W C'][source]#

Forward pass of the AFNO spectral layer.

class physicsnemo.nn.module.afno_layers.AFNOMlp(*args, **kwargs)[source]#

Bases: Module

Fully-connected Multi-layer perception used inside AFNO.

Parameters:
  • in_features (int) – Input feature size.

  • latent_features (int) – Latent feature size.

  • out_features (int) – Output feature size.

  • activation_fn (nn.Module, optional, default=nn.GELU()) – Activation function.

  • drop (float, optional, default=0.0) – Drop out rate.

Forward:

x (torch.Tensor) – Input tensor of shape \((*, D_{in})\) where \(D_{in}\) is in_features.

Outputs:

torch.Tensor – Output tensor of shape \((*, D_{out})\) where \(D_{out}\) is out_features.

Examples

>>> import torch
>>> mlp = AFNOMlp(in_features=64, latent_features=128, out_features=64)
>>> x = torch.randn(4, 32, 32, 64)
>>> output = mlp(x)
>>> output.shape
torch.Size([4, 32, 32, 64])
forward(
x: Float[Tensor, '*dims D_in'],
) Float[Tensor, '*dims D_out'][source]#

Forward pass of the MLP.

class physicsnemo.nn.module.afno_layers.AFNOPatchEmbed(*args, **kwargs)[source]#

Bases: Module

Patch embedding layer for AFNO.

Converts 2D patches into a 1D vector sequence for input to AFNO. This differs from PatchEmbed2D as it flattens the output to a sequence format.

Parameters:
  • inp_shape (List[int]) – Input image dimensions as [height, width].

  • in_channels (int) – Number of input channels.

  • patch_size (List[int], optional, default=[16, 16]) – Size of image patches as [patch_height, patch_width].

  • embed_dim (int, optional, default=256) – Embedded channel size.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, C_{in}, H, W)\) where \(B\) is batch size, \(C_{in}\) is the number of input channels, and \(H, W\) are spatial dimensions matching inp_shape.

Outputs:

torch.Tensor – Output tensor of shape \((B, N, D)\) where \(N\) is the number of patches and \(D\) is embed_dim.

Examples

>>> import torch
>>> patch_embed = AFNOPatchEmbed(
...     inp_shape=[32, 32], in_channels=3, patch_size=[8, 8], embed_dim=64
... )
>>> x = torch.randn(4, 3, 32, 32)
>>> output = patch_embed(x)
>>> output.shape
torch.Size([4, 16, 64])
forward(
x: Float[Tensor, 'B C H W'],
) Float[Tensor, 'B N D'][source]#

Forward pass of patch embedding.

class physicsnemo.nn.module.afno_layers.ModAFNO2DLayer(*args, **kwargs)[source]#

Bases: AFNO2DLayer

Modulated AFNO spectral convolution layer.

Extends AFNO2DLayer with scale-shift modulation in the spectral domain.

Parameters:
  • hidden_size (int) – Feature dimensionality.

  • mod_features (int) – Number of modulation features.

  • num_blocks (int, optional, default=8) – Number of blocks used in the block diagonal weight matrix.

  • sparsity_threshold (float, optional, default=0.01) – Sparsity threshold (softshrink) of spectral features.

  • hard_thresholding_fraction (float, optional, default=1) – Threshold for limiting number of modes used, in range [0, 1].

  • hidden_size_factor (int, optional, default=1) – Factor to increase spectral features by after weight multiplication.

  • scale_shift_kwargs (dict, optional) – Options to the MLP that computes the scale-shift parameters.

  • scale_shift_mode (Literal["complex", "real"], optional, default="complex") – If "complex", compute the scale-shift operation using complex operations. If "real", use real operations.

Forward:
  • x (torch.Tensor) – Input tensor of shape \((B, H, W, C)\).

  • mod_embed (torch.Tensor) – Modulation embedding of shape \((B, D_{mod})\).

Outputs:

torch.Tensor – Output tensor of shape \((B, H, W, C)\).

Examples

>>> import torch
>>> layer = ModAFNO2DLayer(hidden_size=64, mod_features=32, num_blocks=8)
>>> x = torch.randn(4, 16, 16, 64)
>>> mod_embed = torch.randn(4, 32)
>>> output = layer(x, mod_embed)
>>> output.shape
torch.Size([4, 16, 16, 64])
forward(
x: Float[Tensor, 'B H W C'],
mod_embed: Float[Tensor, 'B D_mod'],
) Float[Tensor, 'B H W C'][source]#

Forward pass with modulation.

class physicsnemo.nn.module.afno_layers.ModAFNOMlp(*args, **kwargs)[source]#

Bases: AFNOMlp

Modulated MLP used inside ModAFNO.

Extends AFNOMlp with scale-shift modulation based on a conditioning embedding.

Parameters:
  • in_features (int) – Input feature size.

  • latent_features (int) – Latent feature size.

  • out_features (int) – Output feature size.

  • mod_features (int) – Modulation embedding feature size.

  • activation_fn (nn.Module, optional, default=nn.GELU()) – Activation function.

  • drop (float, optional, default=0.0) – Drop out rate.

  • scale_shift_kwargs (dict, optional) – Options to the MLP that computes the scale-shift parameters.

Forward:
  • x (torch.Tensor) – Input tensor of shape \((*, D_{in})\).

  • mod_embed (torch.Tensor) – Modulation embedding of shape \((B, D_{mod})\).

Outputs:

torch.Tensor – Output tensor of shape \((*, D_{out})\).

Examples

>>> import torch
>>> mlp = ModAFNOMlp(
...     in_features=64, latent_features=128, out_features=64, mod_features=32
... )
>>> x = torch.randn(4, 16, 16, 64)
>>> mod_embed = torch.randn(4, 32)
>>> output = mlp(x, mod_embed)
>>> output.shape
torch.Size([4, 16, 16, 64])
forward(
x: Float[Tensor, '*dims D_in'],
mod_embed: Float[Tensor, 'B D_mod'],
) Float[Tensor, '*dims D_out'][source]#

Forward pass with modulation.

physicsnemo.nn.module.afno_layers.PatchEmbed#

alias of AFNOPatchEmbed

class physicsnemo.nn.module.afno_layers.ScaleShiftMlp(*args, **kwargs)[source]#

Bases: Module

MLP used to compute the scale and shift parameters of the ModAFNO block.

Parameters:
  • in_features (int) – Input feature size.

  • out_features (int) – Output feature size.

  • hidden_features (int, optional) – Hidden feature size. Defaults to 2 * out_features.

  • hidden_layers (int, optional, default=0) – Number of hidden layers.

  • activation_fn (Type[nn.Module], optional, default=nn.GELU) – Activation function class.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, D_{in})\).

Outputs:

Tuple[torch.Tensor, torch.Tensor] – Tuple of (scale, shift) tensors, each of shape \((B, D_{out})\). Scale is offset by 1, i.e., (1 + scale, shift).

Examples

>>> import torch
>>> mlp = ScaleShiftMlp(in_features=64, out_features=128)
>>> x = torch.randn(4, 64)
>>> scale, shift = mlp(x)
>>> scale.shape, shift.shape
(torch.Size([4, 128]), torch.Size([4, 128]))

See also

Mlp

The MLP used internally to produce the concatenated (scale, shift) vector.

forward(
x: Float[Tensor, 'B D_in'],
) tuple[Float[Tensor, 'B D_out'], Float[Tensor, 'B D_out']][source]#

Forward pass computing scale and shift parameters.

Fourier Neural Operator (FNO) encoder layers.

This module contains reusable FNO encoder building blocks that can be used in various FNO-based architectures.

class physicsnemo.nn.module.fno_layers.FNO1DEncoder(*args, **kwargs)[source]#

Bases: Module

1D Spectral encoder for FNO.

This encoder applies a lifting network followed by spectral convolution layers in the Fourier domain for 1D input data.

Parameters:
  • in_channels (int, optional, default=1) – Number of input channels.

  • num_fno_layers (int, optional, default=4) – Number of spectral convolutional layers.

  • fno_layer_size (int, optional, default=32) – Latent features size in spectral convolutions.

  • num_fno_modes (Union[int, List[int]], optional, default=16) – Number of Fourier modes kept in spectral convolutions.

  • padding (Union[int, List[int]], optional, default=8) – Domain padding for spectral convolutions.

  • padding_type (str, optional, default="constant") – Type of padding for spectral convolutions.

  • activation_fn (nn.Module, optional, default=nn.GELU()) – Activation function.

  • coord_features (bool, optional, default=True) – Use coordinate grid as additional feature map.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, C_{in}, L)\) where \(B\) is batch size, \(C_{in}\) is the number of input channels, and \(L\) is the sequence length (spatial dimension).

Outputs:

torch.Tensor – Output tensor of shape \((B, C_{latent}, L)\) where \(C_{latent}\) is fno_layer_size.

Examples

>>> import torch
>>> encoder = FNO1DEncoder(in_channels=3, fno_layer_size=32, num_fno_modes=8)
>>> x = torch.randn(4, 3, 64)
>>> output = encoder(x)
>>> output.shape
torch.Size([4, 32, 64])
forward(
x: Float[Tensor, 'B C_in L'],
) Float[Tensor, 'B C_latent L'][source]#

Forward pass of the 1D FNO encoder.

grid_to_points(
value: Tensor,
) Tuple[Tensor, List[int]][source]#

Convert from grid-based (image) to point-based representation.

Parameters:

value (Tensor) – Grid tensor of shape \((B, C, L)\).

Returns:

Tuple of (flattened tensor, original shape).

Return type:

Tuple[Tensor, List[int]]

points_to_grid(
value: Tensor,
shape: List[int],
) Tensor[source]#

Convert from point-based to grid-based (image) representation.

Parameters:
  • value (Tensor) – Point tensor of shape \((B \times X, C)\).

  • shape (List[int]) – Original grid shape as [B, C, L].

Returns:

Grid tensor of shape \((B, C, L)\).

Return type:

Tensor

class physicsnemo.nn.module.fno_layers.FNO2DEncoder(*args, **kwargs)[source]#

Bases: Module

2D Spectral encoder for FNO.

This encoder applies a lifting network followed by spectral convolution layers in the Fourier domain for 2D input data.

Parameters:
  • in_channels (int, optional, default=1) – Number of input channels.

  • num_fno_layers (int, optional, default=4) – Number of spectral convolutional layers.

  • fno_layer_size (int, optional, default=32) – Latent features size in spectral convolutions.

  • num_fno_modes (Union[int, List[int]], optional, default=16) – Number of Fourier modes kept in spectral convolutions.

  • padding (Union[int, List[int]], optional, default=8) – Domain padding for spectral convolutions.

  • padding_type (str, optional, default="constant") – Type of padding for spectral convolutions.

  • activation_fn (nn.Module, optional, default=nn.GELU()) – Activation function.

  • coord_features (bool, optional, default=True) – Use coordinate grid as additional feature map.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, C_{in}, H, W)\) where \(B\) is batch size, \(C_{in}\) is the number of input channels, and \(H, W\) are spatial dimensions.

Outputs:

torch.Tensor – Output tensor of shape \((B, C_{latent}, H, W)\) where \(C_{latent}\) is fno_layer_size.

Examples

>>> import torch
>>> encoder = FNO2DEncoder(in_channels=3, fno_layer_size=32, num_fno_modes=8)
>>> x = torch.randn(4, 3, 32, 32)
>>> output = encoder(x)
>>> output.shape
torch.Size([4, 32, 32, 32])
forward(
x: Float[Tensor, 'B C_in H W'],
) Float[Tensor, 'B C_latent H W'][source]#

Forward pass of the 2D FNO encoder.

grid_to_points(
value: Tensor,
) Tuple[Tensor, List[int]][source]#

Convert from grid-based (image) to point-based representation.

Parameters:

value (Tensor) – Grid tensor of shape \((B, C, H, W)\).

Returns:

Tuple of (flattened tensor, original shape).

Return type:

Tuple[Tensor, List[int]]

points_to_grid(
value: Tensor,
shape: List[int],
) Tensor[source]#

Convert from point-based to grid-based (image) representation.

Parameters:
  • value (Tensor) – Point tensor of shape \((B \times H \times W, C)\).

  • shape (List[int]) – Original grid shape as [B, C, H, W].

Returns:

Grid tensor of shape \((B, C, H, W)\).

Return type:

Tensor

class physicsnemo.nn.module.fno_layers.FNO3DEncoder(*args, **kwargs)[source]#

Bases: Module

3D Spectral encoder for FNO.

This encoder applies a lifting network followed by spectral convolution layers in the Fourier domain for 3D input data.

Parameters:
  • in_channels (int, optional, default=1) – Number of input channels.

  • num_fno_layers (int, optional, default=4) – Number of spectral convolutional layers.

  • fno_layer_size (int, optional, default=32) – Latent features size in spectral convolutions.

  • num_fno_modes (Union[int, List[int]], optional, default=16) – Number of Fourier modes kept in spectral convolutions.

  • padding (Union[int, List[int]], optional, default=8) – Domain padding for spectral convolutions.

  • padding_type (str, optional, default="constant") – Type of padding for spectral convolutions.

  • activation_fn (nn.Module, optional, default=nn.GELU()) – Activation function.

  • coord_features (bool, optional, default=True) – Use coordinate grid as additional feature map.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, C_{in}, D, H, W)\) where \(B\) is batch size, \(C_{in}\) is the number of input channels, and \(D, H, W\) are spatial dimensions.

Outputs:

torch.Tensor – Output tensor of shape \((B, C_{latent}, D, H, W)\) where \(C_{latent}\) is fno_layer_size.

Examples

>>> import torch
>>> encoder = FNO3DEncoder(in_channels=3, fno_layer_size=32, num_fno_modes=8)
>>> x = torch.randn(4, 3, 16, 16, 16)
>>> output = encoder(x)
>>> output.shape
torch.Size([4, 32, 16, 16, 16])
forward(
x: Float[Tensor, 'B C_in D H W'],
) Float[Tensor, 'B C_latent D H W'][source]#

Forward pass of the 3D FNO encoder.

grid_to_points(
value: Tensor,
) Tuple[Tensor, List[int]][source]#

Convert from grid-based (image) to point-based representation.

Parameters:

value (Tensor) – Grid tensor of shape \((B, C, D, H, W)\).

Returns:

Tuple of (flattened tensor, original shape).

Return type:

Tuple[Tensor, List[int]]

points_to_grid(
value: Tensor,
shape: List[int],
) Tensor[source]#

Convert from point-based to grid-based (image) representation.

Parameters:
  • value (Tensor) – Point tensor of shape \((B \times D \times H \times W, C)\).

  • shape (List[int]) – Original grid shape as [B, C, D, H, W].

Returns:

Grid tensor of shape \((B, C, D, H, W)\).

Return type:

Tensor

class physicsnemo.nn.module.fno_layers.FNO4DEncoder(*args, **kwargs)[source]#

Bases: Module

4D Spectral encoder for FNO.

This encoder applies a lifting network followed by spectral convolution layers in the Fourier domain for 4D input data (3D spatial + time).

Parameters:
  • in_channels (int, optional, default=1) – Number of input channels.

  • num_fno_layers (int, optional, default=4) – Number of spectral convolutional layers.

  • fno_layer_size (int, optional, default=32) – Latent features size in spectral convolutions.

  • num_fno_modes (Union[int, List[int]], optional, default=16) – Number of Fourier modes kept in spectral convolutions.

  • padding (Union[int, List[int]], optional, default=8) – Domain padding for spectral convolutions.

  • padding_type (str, optional, default="constant") – Type of padding for spectral convolutions.

  • activation_fn (nn.Module, optional, default=nn.GELU()) – Activation function.

  • coord_features (bool, optional, default=True) – Use coordinate grid as additional feature map.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, C_{in}, X, Y, Z, T)\) where \(B\) is batch size, \(C_{in}\) is the number of input channels, and \(X, Y, Z, T\) are spatial and temporal dimensions.

Outputs:

torch.Tensor – Output tensor of shape \((B, C_{latent}, X, Y, Z, T)\) where \(C_{latent}\) is fno_layer_size.

Examples

>>> import torch
>>> encoder = FNO4DEncoder(in_channels=3, fno_layer_size=32, num_fno_modes=4)
>>> x = torch.randn(2, 3, 8, 8, 8, 8)
>>> output = encoder(x)
>>> output.shape
torch.Size([2, 32, 8, 8, 8, 8])
forward(
x: Float[Tensor, 'B C_in X Y Z T'],
) Float[Tensor, 'B C_latent X Y Z T'][source]#

Forward pass of the 4D FNO encoder.

grid_to_points(
value: Tensor,
) Tuple[Tensor, List[int]][source]#

Convert from grid-based (image) to point-based representation.

Parameters:

value (Tensor) – Grid tensor of shape \((B, C, X, Y, Z, T)\).

Returns:

Tuple of (flattened tensor, original shape).

Return type:

Tuple[Tensor, List[int]]

points_to_grid(
value: Tensor,
shape: List[int],
) Tensor[source]#

Convert from point-based to grid-based (image) representation.

Parameters:
  • value (Tensor) – Point tensor of shape \((B \times X \times Y \times Z \times T, C)\).

  • shape (List[int]) – Original grid shape as [B, C, X, Y, Z, T].

Returns:

Grid tensor of shape \((B, C, X, Y, Z, T)\).

Return type:

Tensor