Transolver#

The Transolver model adapts the transformer architecture with a physics-attention mechanism for solving partial differential equations on structured and unstructured meshes. It projects inputs onto physics-informed slices before applying attention, enabling efficient learning of physical systems.

class physicsnemo.models.transolver.transolver.Transolver(*args, **kwargs)[source]#

Bases: Module

Transolver model for physics-informed neural operator learning.

Transolver adapts the transformer architecture with a physics-attention mechanism replacing standard attention. It can work on both structured (2D/3D grids) and unstructured (mesh) data.

For architecture details, see:

Note

When using structured data, pass the structured_shape as a tuple. Length-2 tuples are treated as 2D image-like data, length-3 tuples as 3D volumetric data.

Parameters:
  • functional_dim (int) – Dimension of input values, not including embeddings.

  • out_dim (int) – Dimension of model output.

  • embedding_dim (int | None, optional, default=None) – Dimension of input embeddings. Required if unified_pos=False.

  • n_layers (int, optional, default=4) – Number of Transolver blocks.

  • n_hidden (int, optional, default=256) – Hidden dimension of the transformer.

  • dropout (float, optional, default=0.0) – Dropout rate.

  • n_head (int, optional, default=8) – Number of attention heads. Must evenly divide n_hidden.

  • act (str, optional, default="gelu") – Activation function name.

  • mlp_ratio (int, optional, default=4) – Ratio of MLP hidden dimension to n_hidden.

  • slice_num (int, optional, default=32) – Number of physics slices in attention layers.

  • unified_pos (bool, optional, default=False) – Whether to use unified positional embeddings (structured data only).

  • ref (int, optional, default=8) – Reference grid size for unified position encoding.

  • structured_shape (None | tuple[int, ...], optional, default=None) – Shape of structured data. None for unstructured mesh data.

  • use_te (bool, optional, default=True) – Whether to use transformer engine.

  • time_input (bool, optional, default=False) – Whether to include time embeddings.

  • plus (bool, optional, default=False) – Whether to use Transolver++ variant.

Forward:
  • fx (torch.Tensor) – Functional input tensor of shape \((B, N, C_{in})\) for flattened data where \(B\) is batch size, \(N\) is number of tokens, and \(C_{in}\) is functional dimension. For structured data, shape is \((B, H_s, W_s, C_{in})\) for 2D or \((B, H_s, W_s, D_s, C_{in})\) for 3D, where \(H_s, W_s, D_s\) are spatial dimensions.

  • embedding (torch.Tensor | None, optional) – Embedding tensor. Required if unified_pos=False. Shape should match fx spatial dimensions.

  • time (torch.Tensor | None, optional) – Time tensor of shape \((B,)\) for time-dependent models.

Outputs:

torch.Tensor – Output tensor with same spatial shape as input and \(C_{out}\) features (equal to out_dim).

Examples

Structured 2D data with unified position:

>>> import torch
>>> from physicsnemo.models.transolver import Transolver
>>> model = Transolver(
...     functional_dim=3,
...     out_dim=1,
...     structured_shape=(64, 64),
...     unified_pos=True,
...     n_hidden=128,
...     n_head=4,
...     use_te=False,
... )
>>> x = torch.randn(2, 64, 64, 3)
>>> out = model(x)
>>> out.shape
torch.Size([2, 64, 64, 1])

Unstructured mesh data:

>>> model = Transolver(
...     functional_dim=2,
...     embedding_dim=3,
...     out_dim=1,
...     structured_shape=None,
...     unified_pos=False,
...     n_hidden=128,
...     n_head=4,
...     use_te=False,
... )
>>> fx = torch.randn(2, 1000, 2)
>>> emb = torch.randn(2, 1000, 3)
>>> out = model(fx, embedding=emb)
>>> out.shape
torch.Size([2, 1000, 1])
get_grid(ref: int, batchsize: int = 1) Tensor[source]#

Generate unified positional encoding grid for structured 2D data.

Parameters:
  • ref (int) – Reference grid size for unified position encoding.

  • batchsize (int, optional, default=1) – Batch size for the generated grid.

Returns:

Positional encoding tensor of shape \((B, H \times W, \text{ref}^2)\).

Return type:

torch.Tensor

initialize_weights() None[source]#

Initialize model weights using truncated normal distribution.

Building blocks#

class physicsnemo.models.transolver.transolver.TransolverBlock(
num_heads: int,
hidden_dim: int,
dropout: float,
act: str = 'gelu',
mlp_ratio: int = 4,
last_layer: bool = False,
out_dim: int = 1,
slice_num: int = 32,
spatial_shape: tuple[int, ...] | None = None,
use_te: bool = True,
plus: bool = False,
)[source]#

Bases: Module

Transformer encoder block with physics attention mechanism.

This block replaces standard attention with physics attention, which learns to project inputs onto physics-informed slices before applying attention.

Parameters:
  • num_heads (int) – Number of attention heads.

  • hidden_dim (int) – Hidden dimension of the block.

  • dropout (float) – Dropout rate.

  • act (str, optional, default="gelu") – Activation function name.

  • mlp_ratio (int, optional, default=4) – Ratio of MLP hidden dimension to hidden_dim.

  • last_layer (bool, optional, default=False) – Whether this is the last layer (applies output projection).

  • out_dim (int, optional, default=1) – Output dimension (only used if last_layer=True).

  • slice_num (int, optional, default=32) – Number of physics slices.

  • spatial_shape (tuple[int, ...] | None, optional, default=None) – Spatial shape for structured data. None for irregular meshes.

  • use_te (bool, optional, default=True) – Whether to use transformer engine.

  • plus (bool, optional, default=False) – Whether to use Transolver++ variant.

Forward:

fx (torch.Tensor) – Input tensor of shape \((B, N, C)\) where \(B\) is batch size, \(N\) is number of tokens, \(C\) is hidden dimension.

Outputs:

torch.Tensor – Output tensor of shape \((B, N, C)\), or \((B, N, C_{out})\) if last_layer=True.