Attention and Transformers#

class physicsnemo.nn.module.attention_layers.AttentionOp(*args, **kwargs)[source]#

Bases: Function

Attention weight computation, i.e., softmax(Q^T * K). Performs all computation using FP32, but uses the original datatype for inputs/outputs/gradients to conserve memory.

static backward(ctx, dw)[source]#

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, q, k)[source]#

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class physicsnemo.nn.module.attention_layers.EarthAttention2D(
dim,
input_resolution,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
)[source]#

Bases: Module

Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn 2D window attention with earth position bias. It supports both of shifted and non-shifted window.

Parameters:
  • dim (int) – Number of input channels.

  • input_resolution (tuple[int]) – [latitude, longitude]

  • window_size (tuple[int]) – [latitude, longitude]

  • num_heads (int) – Number of attention heads.

  • qkv_bias (bool, optional) – If True, add a learnable bias to query, key, value. Default: True

  • qk_scale (float | None, optional) – Override default qk scale of head_dim ** -0.5 if set

  • attn_drop (float, optional) – Dropout ratio of attention weight. Default: 0.0

  • proj_drop (float, optional) – Dropout ratio of output. Default: 0.0

forward(x: Tensor, mask=None)[source]#
Parameters:
  • x – input features with shape of (B * num_lon, num_lat, N, C)

  • mask – (0/-inf) mask with shape of (num_lon, num_lat, Wlat*Wlon, Wlat*Wlon)

class physicsnemo.nn.module.attention_layers.EarthAttention3D(
dim,
input_resolution,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
)[source]#

Bases: Module

Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn 3D window attention with earth position bias. It supports both of shifted and non-shifted window.

Parameters:
  • dim (int) – Number of input channels.

  • input_resolution (tuple[int]) – [pressure levels, latitude, longitude]

  • window_size (tuple[int]) – [pressure levels, latitude, longitude]

  • num_heads (int) – Number of attention heads.

  • qkv_bias (bool, optional) – If True, add a learnable bias to query, key, value. Default: True

  • qk_scale (float | None, optional) – Override default qk scale of head_dim ** -0.5 if set

  • attn_drop (float, optional) – Dropout ratio of attention weight. Default: 0.0

  • proj_drop (float, optional) – Dropout ratio of output. Default: 0.0

forward(x: Tensor, mask=None)[source]#
Parameters:
  • x – input features with shape of (B * num_lon, num_pl*num_lat, N, C)

  • mask – (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon)

class physicsnemo.nn.module.attention_layers.UNetAttention(
*,
out_channels: int,
num_heads: int,
eps: float = 1e-05,
init_zero: Dict[str, Any] = {'init_weight': 0},
init_attn: Any = None,
init: Dict[str, Any] = {},
use_apex_gn: bool = False,
amp_mode: bool = False,
fused_conv_bias: bool = False,
)[source]#

Bases: Module

Self-attention block used in U-Net-style architectures, such as DDPM++, NCSN++, and ADM. Applies GroupNorm followed by multi-head self-attention and a projection layer.

Parameters:
  • out_channels (int) – Number of channels \(C\) in the input and output feature maps.

  • num_heads (int) – Number of attention heads. Must be a positive integer.

  • eps (float, optional, default=1e-5) – Epsilon value for numerical stability in GroupNorm.

  • init_zero (dict, optional, default={'init_weight': 0}) – Initialization parameters with zero weights for certain layers.

  • init_attn (dict, optional, default=None) – Initialization parameters specific to attention mechanism layers. Defaults to ‘init’ if not provided.

  • init (dict, optional, default={}) – Initialization parameters for convolutional and linear layers.

  • use_apex_gn (bool, optional, default=False) – A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. Need to set this as False on cpu.

  • amp_mode (bool, optional, default=False) – A boolean flag indicating whether mixed-precision (AMP) training is enabled.

  • fused_conv_bias (bool, optional, default=False) – A boolean flag indicating whether bias will be passed as a parameter of conv2d.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, C, H, W)\), where \(B\) is batch size, \(C\) is out_channels, and \(H, W\) are spatial dimensions.

Outputs:

torch.Tensor – Output tensor of the same shape as input: \((B, C, H, W)\).

forward(x: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class physicsnemo.nn.module.transformer_layers.DecoderLayer(
img_size,
patch_size,
out_chans,
dim,
output_resolution,
middle_resolution,
depth,
depth_middle,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>,
)[source]#

Bases: Module

A 2D Transformer Decoder Module for one stage

Parameters:
  • img_size (tuple[int]) – image size(Lat, Lon).

  • patch_size (tuple[int]) – Patch token size of Patch Recovery.

  • out_chans (int) – number of output channels of Patch Recovery.

  • dim (int) – Number of input channels of transformer.

  • output_resolution (tuple[int]) – Input resolution for transformer after upsampling.

  • middle_resolution (tuple[int]) – Input resolution for transformer before upsampling.

  • depth (int) – Number of blocks for transformer after upsampling.

  • depth_middle (int) – Number of blocks for transformer before upsampling.

  • num_heads (int) – Number of attention heads.

  • window_size (tuple[int]) – Local window size.

  • mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim.

  • qkv_bias (bool, optional) – If True, add a learnable bias to query, key, value. Default: True

  • qk_scale (float | None, optional) – Override default qk scale of head_dim ** -0.5 if set.

  • drop (float, optional) – Dropout rate. Default: 0.0

  • attn_drop (float, optional) – Attention dropout rate. Default: 0.0

  • drop_path (float | tuple[float], optional) – Stochastic depth rate. Default: 0.0

  • norm_layer (nn.Module, optional) – Normalization layer. Default: nn.LayerNorm

forward(x, skip)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class physicsnemo.nn.module.transformer_layers.EncoderLayer(
img_size,
patch_size,
in_chans,
dim,
input_resolution,
middle_resolution,
depth,
depth_middle,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>,
)[source]#

Bases: Module

A 2D Transformer Encoder Module for one stage

Parameters:
  • img_size (tuple[int]) – image size(Lat, Lon).

  • patch_size (tuple[int]) – Patch token size of Patch Embedding.

  • in_chans (int) – number of input channels of Patch Embedding.

  • dim (int) – Number of input channels of transformer.

  • input_resolution (tuple[int]) – Input resolution for transformer before downsampling.

  • middle_resolution (tuple[int]) – Input resolution for transformer after downsampling.

  • depth (int) – Number of blocks for transformer before downsampling.

  • depth_middle (int) – Number of blocks for transformer after downsampling.

  • num_heads (int) – Number of attention heads.

  • window_size (tuple[int]) – Local window size.

  • mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim.

  • qkv_bias (bool, optional) – If True, add a learnable bias to query, key, value. Default: True

  • qk_scale (float | None, optional) – Override default qk scale of head_dim ** -0.5 if set.

  • drop (float, optional) – Dropout rate. Default: 0.0

  • attn_drop (float, optional) – Attention dropout rate. Default: 0.0

  • drop_path (float | tuple[float], optional) – Stochastic depth rate. Default: 0.0

  • norm_layer (nn.Module, optional) – Normalization layer. Default: nn.LayerNorm

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class physicsnemo.nn.module.transformer_layers.FuserLayer(
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>,
)[source]#

Bases: Module

Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn A basic 3D Transformer layer for one stage

Parameters:
  • dim (int) – Number of input channels.

  • input_resolution (tuple[int]) – Input resolution.

  • depth (int) – Number of blocks.

  • num_heads (int) – Number of attention heads.

  • window_size (tuple[int]) – Local window size.

  • mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim.

  • qkv_bias (bool, optional) – If True, add a learnable bias to query, key, value. Default: True

  • qk_scale (float | None, optional) – Override default qk scale of head_dim ** -0.5 if set.

  • drop (float, optional) – Dropout rate. Default: 0.0

  • attn_drop (float, optional) – Attention dropout rate. Default: 0.0

  • drop_path (float | tuple[float], optional) – Stochastic depth rate. Default: 0.0

  • norm_layer (nn.Module, optional) – Normalization layer. Default: nn.LayerNorm

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class physicsnemo.nn.module.transformer_layers.SwinTransformer(
embed_dim,
input_resolution,
num_heads,
window_size,
depth,
)[source]#

Bases: Module

Swin Transformer :param embed_dim: Patch embedding dimension. :type embed_dim: int :param input_resolution: Lat, Lon. :type input_resolution: tuple[int] :param num_heads: Number of attention heads in different layers. :type num_heads: int :param window_size: Window size. :type window_size: int | tuple[int] :param depth: Number of blocks. :type depth: int

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class physicsnemo.nn.module.transformer_layers.Transformer2DBlock(
dim,
input_resolution,
num_heads,
window_size=None,
shift_size=None,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=<class 'torch.nn.modules.activation.GELU'>,
norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>,
)[source]#

Bases: Module

Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn 2D Transformer Block :param dim: Number of input channels. :type dim: int :param input_resolution: Input resulotion. :type input_resolution: tuple[int] :param num_heads: Number of attention heads. :type num_heads: int :param window_size: Window size [latitude, longitude]. :type window_size: tuple[int] :param shift_size: Shift size for SW-MSA [latitude, longitude]. :type shift_size: tuple[int] :param mlp_ratio: Ratio of mlp hidden dim to embedding dim. :type mlp_ratio: float :param qkv_bias: If True, add a learnable bias to query, key, value. Default: True :type qkv_bias: bool, optional :param qk_scale: Override default qk scale of head_dim ** -0.5 if set. :type qk_scale: float | None, optional :param drop: Dropout rate. Default: 0.0 :type drop: float, optional :param attn_drop: Attention dropout rate. Default: 0.0 :type attn_drop: float, optional :param drop_path: Stochastic depth rate. Default: 0.0 :type drop_path: float, optional :param act_layer: Activation layer. Default: nn.GELU :type act_layer: nn.Module, optional :param norm_layer: Normalization layer. Default: nn.LayerNorm :type norm_layer: nn.Module, optional

forward(x: Tensor)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class physicsnemo.nn.module.transformer_layers.Transformer3DBlock(
dim,
input_resolution,
num_heads,
window_size=None,
shift_size=None,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=<class 'torch.nn.modules.activation.GELU'>,
norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>,
)[source]#

Bases: Module

Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn 3D Transformer Block :param dim: Number of input channels. :type dim: int :param input_resolution: Input resulotion. :type input_resolution: tuple[int] :param num_heads: Number of attention heads. :type num_heads: int :param window_size: Window size [pressure levels, latitude, longitude]. :type window_size: tuple[int] :param shift_size: Shift size for SW-MSA [pressure levels, latitude, longitude]. :type shift_size: tuple[int] :param mlp_ratio: Ratio of mlp hidden dim to embedding dim. :type mlp_ratio: float :param qkv_bias: If True, add a learnable bias to query, key, value. Default: True :type qkv_bias: bool, optional :param qk_scale: Override default qk scale of head_dim ** -0.5 if set. :type qk_scale: float | None, optional :param drop: Dropout rate. Default: 0.0 :type drop: float, optional :param attn_drop: Attention dropout rate. Default: 0.0 :type attn_drop: float, optional :param drop_path: Stochastic depth rate. Default: 0.0 :type drop_path: float, optional :param act_layer: Activation layer. Default: nn.GELU :type act_layer: nn.Module, optional :param norm_layer: Normalization layer. Default: nn.LayerNorm :type norm_layer: nn.Module, optional

forward(x: Tensor)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class physicsnemo.nn.module.transformer_decoder.DecoderOnlyLayer(d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: str | ~typing.Callable[[~torch.Tensor], ~torch.Tensor] = <function relu>, layer_norm_eps: float = 1e-05, batch_first: bool = False, norm_first: bool = False, bias: bool = True, device=None, dtype=None)[source]#

Bases: Module

Parameters:
  • d_model (int) – Number of expected features in the input.

  • nhead (int) – Number of heads in the multiheadattention models.

  • dim_feedforward (int) – Dimension of the feedforward network model, by default 2048.

  • dropout (float) – The dropout value, by default 0.1.

  • activation (str) – The activation function of the intermediate layer, by default ‘relu’.

  • layer_norm_eps (float) – The eps value in layer normalization components, by default 1e-5.

  • batch_first (Bool) – If True, then the input and output tensors are provided as (batch, seq, feature), by default False (seq, batch, feature).

  • norm_first (Bool) – If True, layer norm is done prior to self attention, multihead attention and feedforward operations, respectively. Otherwise it’s done after, by default False (after).

  • bias (If set to False, Linear and LayerNorm layers will not learn an additive) – bias. Default: True.

forward(
tgt: Tensor,
tgt_mask: Tensor | None = None,
tgt_key_padding_mask: Tensor | None = None,
tgt_is_causal: bool = False,
) Tensor[source]#

Pass the inputs (and mask) through the decoder layer.

class physicsnemo.nn.module.transformer_decoder.TransformerDecoder(decoder_layer, num_layers, norm=None)[source]#

Bases: Module

TransformerDecoder is a stack of N decoder layers

Parameters:
  • decoder_layer (torch.nn.Module) – Layer used for the doceder

  • num_layers (int) – Number of sub-decoder-layers in the decoder.

  • norm (str) – Layer normalization component.

forward(
tgt: Tensor,
tgt_mask: Tensor | None = None,
tgt_key_padding_mask: Tensor | None = None,
tgt_is_causal: bool | None = None,
) Tensor[source]#

Pass the inputs (and mask) through the decoder layer in turn.

Physics attention modules for the Transolver model.

This module provides physics-informed attention mechanisms that project inputs onto learned physics slices before applying attention. These attention variants support irregular meshes, 2D structured grids, and 3D volumetric data.

This code was modified from https://github.com/thuml/Transolver

The following license is provided from their source,

MIT License

Copyright (c) 2024 THUML @ Tsinghua University

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

class physicsnemo.nn.module.physics_attention.PhysicsAttentionBase(
dim: int,
heads: int,
dim_head: int,
dropout: float,
slice_num: int,
use_te: bool,
plus: bool,
)[source]#

Bases: Module, ABC

Base class for physics attention modules.

This class implements the core physics attention mechanism that projects inputs onto learned physics-informed slices before applying attention. Subclasses implement domain-specific input projections.

The physics attention mechanism consists of:

  1. Project inputs onto learned slice space

  2. Compute slice weights via temperature-scaled softmax

  3. Aggregate features for each slice

  4. Apply attention among slices

  5. Project attended features back to original space

Parameters:
  • dim (int) – Input feature dimension.

  • heads (int) – Number of attention heads.

  • dim_head (int) – Dimension per attention head.

  • dropout (float) – Dropout rate.

  • slice_num (int) – Number of physics slices.

  • use_te (bool) – Whether to use transformer engine.

  • plus (bool) – Whether to use Transolver++ variant.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, N, C)\) where \(B\) is batch size, \(N\) is number of tokens, and \(C\) is feature dimension.

Outputs:

torch.Tensor – Output tensor of shape \((B, N, C)\).

See also

This

-

class:PhysicsAttentionIrregularMesh for unstructured mesh data

-

class:PhysicsAttentionStructuredMesh2D for 2D image-like data

-

class:PhysicsAttentionStructuredMesh3D for 3D volumetric data

forward(
x: Float[Tensor, 'B N C'],
) Float[Tensor, 'B N C'][source]#

Forward pass of physics attention.

Parameters:

x (torch.Tensor) – Input tensor of shape \((B, N, C)\).

Returns:

Output tensor of shape \((B, N, C)\).

Return type:

torch.Tensor

abstractmethod project_input_onto_slices(
x: Float[Tensor, 'B N C'],
) Float[Tensor, 'B N H D'] | tuple[Float[Tensor, 'B N H D'], Float[Tensor, 'B N H D']][source]#

Project input tensor onto the slice space.

Parameters:

x (torch.Tensor) – Input tensor of shape \((B, N, C)\).

Returns:

For Transolver++: single projected tensor of shape \((B, N, H, D)\) where \(H\) is number of attention heads and \(D\) is dimension per head. For standard Transolver: tuple of (x_mid, fx_mid) both of shape \((B, N, H, D)\).

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

class physicsnemo.nn.module.physics_attention.PhysicsAttentionIrregularMesh(
dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
slice_num: int = 64,
use_te: bool = True,
plus: bool = False,
)[source]#

Bases: PhysicsAttentionBase

Physics attention for irregular/unstructured mesh data.

Uses linear projections to map input tokens to the slice space, suitable for meshes with arbitrary connectivity.

Parameters:
  • dim (int) – Input feature dimension.

  • heads (int, optional, default=8) – Number of attention heads.

  • dim_head (int, optional, default=64) – Dimension per attention head.

  • dropout (float, optional, default=0.0) – Dropout rate.

  • slice_num (int, optional, default=64) – Number of physics slices.

  • use_te (bool, optional, default=True) – Whether to use transformer engine.

  • plus (bool, optional, default=False) – Whether to use Transolver++ variant.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, N, C)\) where \(B\) is batch size, \(N\) is number of tokens, and \(C\) is feature dimension.

Outputs:

torch.Tensor – Output tensor of shape \((B, N, C)\).

Examples

>>> import torch
>>> attn = PhysicsAttentionIrregularMesh(dim=128, heads=4, dim_head=32, dropout=0.0, slice_num=16, use_te=False)
>>> x = torch.randn(2, 1000, 128)
>>> out = attn(x)
>>> out.shape
torch.Size([2, 1000, 128])
project_input_onto_slices(
x: Float[Tensor, 'B N C'],
) Float[Tensor, 'B N H D'] | tuple[Float[Tensor, 'B N H D'], Float[Tensor, 'B N H D']][source]#

Project input onto slice space using linear layers.

Parameters:

x (torch.Tensor) – Input tensor of shape \((B, N, C)\).

Returns:

Projected tensors of shape \((B, N, H, D)\) where \(H\) is number of attention heads and \(D\) is dimension per head.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

class physicsnemo.nn.module.physics_attention.PhysicsAttentionStructuredMesh2D(
dim: int,
spatial_shape: tuple[int, int],
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
slice_num: int = 64,
kernel: int = 3,
use_te: bool = True,
plus: bool = False,
)[source]#

Bases: PhysicsAttentionBase

Physics attention for 2D structured/image-like data.

Uses 2D convolutions to project inputs, leveraging spatial locality in structured grids.

Parameters:
  • dim (int) – Input feature dimension.

  • spatial_shape (tuple[int, int]) – Spatial dimensions (height, width) of the input.

  • heads (int, optional, default=8) – Number of attention heads.

  • dim_head (int, optional, default=64) – Dimension per attention head.

  • dropout (float, optional, default=0.0) – Dropout rate.

  • slice_num (int, optional, default=64) – Number of physics slices.

  • kernel (int, optional, default=3) – Convolution kernel size.

  • use_te (bool, optional, default=True) – Whether to use transformer engine.

  • plus (bool, optional, default=False) – Whether to use Transolver++ variant.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, N, C)\) where \(B\) is batch size, \(N\) is number of tokens (flattened spatial: height times width), and \(C\) is feature dimension.

Outputs:

torch.Tensor – Output tensor of shape \((B, N, C)\).

Examples

>>> import torch
>>> attn = PhysicsAttentionStructuredMesh2D(
...     dim=128,
...     spatial_shape=(32, 32),
...     heads=4,
...     dim_head=32,
...     dropout=0.0,
...     slice_num=16,
...     use_te=False,
... )
>>> x = torch.randn(2, 32*32, 128)
>>> out = attn(x)
>>> out.shape
torch.Size([2, 1024, 128])
project_input_onto_slices(
x: Float[Tensor, 'B N C'],
) Float[Tensor, 'B N H D'] | tuple[Float[Tensor, 'B N H D'], Float[Tensor, 'B N H D']][source]#

Project input onto slice space using 2D convolutions.

Parameters:

x (torch.Tensor) – Input tensor of shape \((B, N, C)\) where \(N\) is the flattened spatial dimension (height times width).

Returns:

Projected tensors of shape \((B, N, H, D)\) where \(H\) is number of attention heads and \(D\) is dimension per head.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

class physicsnemo.nn.module.physics_attention.PhysicsAttentionStructuredMesh3D(
dim: int,
spatial_shape: tuple[int, int, int],
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
slice_num: int = 32,
kernel: int = 3,
use_te: bool = True,
plus: bool = False,
)[source]#

Bases: PhysicsAttentionBase

Physics attention for 3D structured/volumetric data.

Uses 3D convolutions to project inputs, suitable for voxel-based representations.

Parameters:
  • dim (int) – Input feature dimension.

  • spatial_shape (tuple[int, int, int]) – Spatial dimensions (height, width, depth) of the input.

  • heads (int, optional, default=8) – Number of attention heads.

  • dim_head (int, optional, default=64) – Dimension per attention head.

  • dropout (float, optional, default=0.0) – Dropout rate.

  • slice_num (int, optional, default=32) – Number of physics slices.

  • kernel (int, optional, default=3) – Convolution kernel size.

  • use_te (bool, optional, default=True) – Whether to use transformer engine.

  • plus (bool, optional, default=False) – Whether to use Transolver++ variant.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, N, C)\) where \(B\) is batch size, \(N\) is number of tokens (flattened spatial: height times width times depth), and \(C\) is feature dimension.

Outputs:

torch.Tensor – Output tensor of shape \((B, N, C)\).

Examples

>>> import torch
>>> attn = PhysicsAttentionStructuredMesh3D(
...     dim=64,
...     spatial_shape=(16, 16, 16),
...     heads=4,
...     dim_head=16,
...     dropout=0.0,
...     slice_num=8,
...     use_te=False,
... )
>>> x = torch.randn(2, 16*16*16, 64)
>>> out = attn(x)
>>> out.shape
torch.Size([2, 4096, 64])
project_input_onto_slices(
x: Float[Tensor, 'B N C'],
) Float[Tensor, 'B N H D'] | tuple[Float[Tensor, 'B N H D'], Float[Tensor, 'B N H D']][source]#

Project input onto slice space using 3D convolutions.

Parameters:

x (torch.Tensor) – Input tensor of shape \((B, N, C)\) where \(N\) is the flattened spatial dimension (height times width times depth).

Returns:

Projected tensors of shape \((B, N, H, D)\) where \(H\) is number of attention heads and \(D\) is dimension per head.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]