Source code for physicsnemo.models.transolver.transolver

# ignore_header_test
# ruff: noqa: E402

r"""
Transolver model and building blocks for physics-informed neural operator learning.

This module provides the main Transolver model class along with its internal
building blocks (MLP, Transolver_block) for solving PDEs on structured and
unstructured meshes.

This code was modified from https://github.com/thuml/Transolver

The following license is provided from their source,

MIT License

Copyright (c) 2024 THUML @ Tsinghua University

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import importlib
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from jaxtyping import Float

import physicsnemo  # noqa: F401 for docs
from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module
from physicsnemo.core.version_check import check_version_spec
from physicsnemo.nn import Mlp, PositionalEmbedding
from physicsnemo.nn.module.physics_attention import (
    PhysicsAttentionIrregularMesh,
    PhysicsAttentionStructuredMesh2D,
    PhysicsAttentionStructuredMesh3D,
)

TE_AVAILABLE = check_version_spec("transformer_engine", hard_fail=False)
if TE_AVAILABLE:
    te = importlib.import_module("transformer_engine.pytorch")
else:
    te = None


ACTIVATION = {
    "gelu": nn.GELU,
    "tanh": nn.Tanh,
    "sigmoid": nn.Sigmoid,
    "relu": nn.ReLU,
    "leaky_relu": nn.LeakyReLU(0.1),
    "softplus": nn.Softplus,
    "ELU": nn.ELU,
    "silu": nn.SiLU,
}


class _TransolverMlp(Mlp):
    """Mlp subclass with state dict compatibility for legacy Transolver checkpoints.

    This class provides backward compatibility for loading checkpoints saved with
    the old Transolver MLP class, which used different attribute names:
    - Old: `linear_pre`, `linear_post`, `linears`
    - New: `layers` (Sequential)

    The mapping handles the common case where `n_layers=0` (no hidden layers with
    residual connections), which was the typical usage pattern in Transolver.
    """

    # Mapping from old checkpoint keys to new Mlp keys
    # This assumes the typical usage: n_layers=0, which means just linear_pre -> act -> linear_post
    # Old structure: linear_pre (input->hidden), linear_post (hidden->output)
    # New structure: layers.0 (input->hidden), layers.1 (activation), layers.2 (hidden->output)
    _OLD_TO_NEW_KEYS = {
        "linear_pre.weight": "layers.0.weight",
        "linear_pre.bias": "layers.0.bias",
        "linear_post.weight": "layers.2.weight",
        "linear_post.bias": "layers.2.bias",
    }

    _NEW_TO_OLD_KEYS = {v: k for k, v in _OLD_TO_NEW_KEYS.items()}

    def _load_from_state_dict(
        self,
        state_dict: dict,
        prefix: str,
        local_metadata: dict,
        strict: bool,
        missing_keys: list,
        unexpected_keys: list,
        assign: bool = False,
    ):
        """Load state dict with automatic key remapping for legacy checkpoints.

        This hook is called by PyTorch for each module during load_state_dict().
        We intercept it to remap old-style keys (linear_pre, linear_post) to
        new-style keys (layers.0, layers.2) before the actual loading.
        """
        # Check for old-style keys with this module's prefix
        # e.g., prefix="preprocess." -> look for "preprocess.linear_pre.weight"
        for old_suffix, new_suffix in self._OLD_TO_NEW_KEYS.items():
            old_key = prefix + old_suffix
            new_key = prefix + new_suffix
            if old_key in state_dict and new_key not in state_dict:
                # Remap old key to new key
                state_dict[new_key] = state_dict.pop(old_key)

        return super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            assign,
        )


[docs] class TransolverBlock(nn.Module): r""" Transformer encoder block with physics attention mechanism. This block replaces standard attention with physics attention, which learns to project inputs onto physics-informed slices before applying attention. Parameters ---------- num_heads : int Number of attention heads. hidden_dim : int Hidden dimension of the block. dropout : float Dropout rate. act : str, optional, default="gelu" Activation function name. mlp_ratio : int, optional, default=4 Ratio of MLP hidden dimension to ``hidden_dim``. last_layer : bool, optional, default=False Whether this is the last layer (applies output projection). out_dim : int, optional, default=1 Output dimension (only used if ``last_layer=True``). slice_num : int, optional, default=32 Number of physics slices. spatial_shape : tuple[int, ...] | None, optional, default=None Spatial shape for structured data. ``None`` for irregular meshes. use_te : bool, optional, default=True Whether to use transformer engine. plus : bool, optional, default=False Whether to use Transolver++ variant. Forward ------- fx : torch.Tensor Input tensor of shape :math:`(B, N, C)` where :math:`B` is batch size, :math:`N` is number of tokens, :math:`C` is hidden dimension. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, N, C)`, or :math:`(B, N, C_{out})` if ``last_layer=True``. """ def __init__( self, num_heads: int, hidden_dim: int, dropout: float, act: str = "gelu", mlp_ratio: int = 4, last_layer: bool = False, out_dim: int = 1, slice_num: int = 32, spatial_shape: tuple[int, ...] | None = None, use_te: bool = True, plus: bool = False, ): super().__init__() if use_te and not TE_AVAILABLE: raise ImportError( "Transformer Engine is not installed. Please install it with " "`pip install transformer-engine`." ) self.last_layer = last_layer # Layer normalization before attention if use_te: self.ln_1 = te.LayerNorm(hidden_dim) else: self.ln_1 = nn.LayerNorm(hidden_dim) # Select appropriate physics attention based on spatial structure if spatial_shape is None: self.Attn = PhysicsAttentionIrregularMesh( hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout, slice_num=slice_num, use_te=use_te, plus=plus, ) else: if len(spatial_shape) == 2: self.Attn = PhysicsAttentionStructuredMesh2D( hidden_dim, spatial_shape=spatial_shape, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout, slice_num=slice_num, use_te=use_te, plus=plus, ) elif len(spatial_shape) == 3: self.Attn = PhysicsAttentionStructuredMesh3D( hidden_dim, spatial_shape=spatial_shape, heads=num_heads, dim_head=hidden_dim // num_heads, dropout=dropout, slice_num=slice_num, use_te=use_te, plus=plus, ) else: raise ValueError( f"Unexpected length of spatial shape encountered in Transolver_block: " f"{len(spatial_shape)}. Expected 2 or 3." ) # Feed-forward network with layer norm if use_te: self.ln_mlp1 = te.LayerNormMLP( hidden_size=hidden_dim, ffn_hidden_size=hidden_dim * mlp_ratio, ) else: self.ln_mlp1 = nn.Sequential( nn.LayerNorm(hidden_dim), _TransolverMlp( in_features=hidden_dim, hidden_features=hidden_dim * mlp_ratio, out_features=hidden_dim, act_layer=act, use_te=False, ), ) # Output projection for final layer if self.last_layer: if use_te: self.ln_mlp2 = te.LayerNormLinear( in_features=hidden_dim, out_features=out_dim ) else: self.ln_mlp2 = nn.Sequential( nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, out_dim), ) def forward( self, fx: Float[torch.Tensor, "B N C"] ) -> Float[torch.Tensor, "B N C_out"]: r""" Forward pass of the Transolver block. Parameters ---------- fx : torch.Tensor Input tensor of shape :math:`(B, N, C)`. Returns ------- torch.Tensor Output tensor of shape :math:`(B, N, C)`, or :math:`(B, N, C_{out})` if ``last_layer=True``. """ # Apply physics attention with residual connection fx = self.Attn(self.ln_1(fx)) + fx # Apply feed-forward network with residual connection fx = self.ln_mlp1(fx) + fx # Apply output projection if last layer if self.last_layer: return self.ln_mlp2(fx) else: return fx
@dataclass class MetaData(ModelMetaData): r"""Metadata for the Transolver model.""" # Optimization jit: bool = False cuda_graphs: bool = False amp: bool = True # Inference onnx_cpu: bool = False # No FFT op on CPU onnx_gpu: bool = True onnx_runtime: bool = True # Physics informed var_dim: int = 1 func_torch: bool = False auto_grad: bool = False
[docs] class Transolver(Module): r""" Transolver model for physics-informed neural operator learning. Transolver adapts the transformer architecture with a physics-attention mechanism replacing standard attention. It can work on both structured (2D/3D grids) and unstructured (mesh) data. For architecture details, see: - `Transolver paper <https://arxiv.org/pdf/2402.02366>`_ - `Transolver++ paper <https://arxiv.org/pdf/2502.02414>`_ .. note:: When using structured data, pass the ``structured_shape`` as a tuple. Length-2 tuples are treated as 2D image-like data, length-3 tuples as 3D volumetric data. Parameters ---------- functional_dim : int Dimension of input values, not including embeddings. out_dim : int Dimension of model output. embedding_dim : int | None, optional, default=None Dimension of input embeddings. Required if ``unified_pos=False``. n_layers : int, optional, default=4 Number of Transolver blocks. n_hidden : int, optional, default=256 Hidden dimension of the transformer. dropout : float, optional, default=0.0 Dropout rate. n_head : int, optional, default=8 Number of attention heads. Must evenly divide ``n_hidden``. act : str, optional, default="gelu" Activation function name. mlp_ratio : int, optional, default=4 Ratio of MLP hidden dimension to ``n_hidden``. slice_num : int, optional, default=32 Number of physics slices in attention layers. unified_pos : bool, optional, default=False Whether to use unified positional embeddings (structured data only). ref : int, optional, default=8 Reference grid size for unified position encoding. structured_shape : None | tuple[int, ...], optional, default=None Shape of structured data. ``None`` for unstructured mesh data. use_te : bool, optional, default=True Whether to use transformer engine. time_input : bool, optional, default=False Whether to include time embeddings. plus : bool, optional, default=False Whether to use Transolver++ variant. Forward ------- fx : torch.Tensor Functional input tensor of shape :math:`(B, N, C_{in})` for flattened data where :math:`B` is batch size, :math:`N` is number of tokens, and :math:`C_{in}` is functional dimension. For structured data, shape is :math:`(B, H_s, W_s, C_{in})` for 2D or :math:`(B, H_s, W_s, D_s, C_{in})` for 3D, where :math:`H_s, W_s, D_s` are spatial dimensions. embedding : torch.Tensor | None, optional Embedding tensor. Required if ``unified_pos=False``. Shape should match ``fx`` spatial dimensions. time : torch.Tensor | None, optional Time tensor of shape :math:`(B,)` for time-dependent models. Outputs ------- torch.Tensor Output tensor with same spatial shape as input and :math:`C_{out}` features (equal to ``out_dim``). Examples -------- Structured 2D data with unified position: >>> import torch >>> from physicsnemo.models.transolver import Transolver >>> model = Transolver( ... functional_dim=3, ... out_dim=1, ... structured_shape=(64, 64), ... unified_pos=True, ... n_hidden=128, ... n_head=4, ... use_te=False, ... ) >>> x = torch.randn(2, 64, 64, 3) >>> out = model(x) >>> out.shape torch.Size([2, 64, 64, 1]) Unstructured mesh data: >>> model = Transolver( ... functional_dim=2, ... embedding_dim=3, ... out_dim=1, ... structured_shape=None, ... unified_pos=False, ... n_hidden=128, ... n_head=4, ... use_te=False, ... ) >>> fx = torch.randn(2, 1000, 2) >>> emb = torch.randn(2, 1000, 3) >>> out = model(fx, embedding=emb) >>> out.shape torch.Size([2, 1000, 1]) """ def __init__( self, functional_dim: int, out_dim: int, embedding_dim: int | None = None, n_layers: int = 4, n_hidden: int = 256, dropout: float = 0.0, n_head: int = 8, act: str = "gelu", mlp_ratio: int = 4, slice_num: int = 32, unified_pos: bool = False, ref: int = 8, structured_shape: None | tuple[int, ...] = None, use_te: bool = True, time_input: bool = False, plus: bool = False, ) -> None: super().__init__(meta=MetaData()) self.use_te = use_te # Validate hidden dimension and head compatibility if not n_hidden % n_head == 0: raise ValueError( f"Transolver requires n_hidden % n_head == 0, " f"but got n_hidden={n_hidden}, n_head={n_head} " f"(remainder={n_hidden % n_head})" ) # Validate structured shape if provided if structured_shape is not None: if len(structured_shape) not in [2, 3]: raise ValueError( f"Transolver only supports 2D or 3D structured data, " f"got shape with {len(structured_shape)} dimensions" ) if not all([s > 0 and s == int(s) for s in structured_shape]): raise ValueError( f"Transolver requires positive integer shapes, " f"got {structured_shape}" ) else: if unified_pos: raise ValueError( "Transolver requires structured_shape when using unified_pos=True" ) self.structured_shape = structured_shape self.unified_pos = unified_pos # Set up positional embeddings if unified_pos: if structured_shape is None: raise ValueError( "Transolver cannot use unified position without " "structured_shape argument (got None)" ) # Register unified position embedding as buffer self.register_buffer("embedding", self.get_grid(ref)) self.embedding_dim = ref * ref mlp_input_dimension = functional_dim + ref * ref else: if embedding_dim is None: raise ValueError( "Transolver requires embedding_dim when unified_pos=False" ) self.embedding_dim = embedding_dim mlp_input_dimension = functional_dim + embedding_dim # Initial projection MLP self.preprocess = _TransolverMlp( in_features=mlp_input_dimension, hidden_features=n_hidden * 2, out_features=n_hidden, act_layer=act, use_te=use_te, ) self.time_input = time_input self.n_hidden = n_hidden # Time embedding projection if time_input: self.time_embed = PositionalEmbedding( num_channels=n_hidden, max_positions=10000, endpoint=False, learnable=False, embed_fn="cos_sin", ) self.time_fc = nn.Sequential( nn.Linear(n_hidden, n_hidden), nn.SiLU(), nn.Linear(n_hidden, n_hidden), ) # Build transformer blocks self.blocks = nn.ModuleList( [ TransolverBlock( num_heads=n_head, hidden_dim=n_hidden, dropout=dropout, act=act, mlp_ratio=mlp_ratio, out_dim=out_dim, slice_num=slice_num, spatial_shape=structured_shape, last_layer=(_ == n_layers - 1), use_te=use_te, plus=plus, ) for _ in range(n_layers) ] ) self.initialize_weights()
[docs] def initialize_weights(self) -> None: r"""Initialize model weights using truncated normal distribution.""" self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None: r""" Initialize weights for a single module. Parameters ---------- m : nn.Module Module to initialize. """ linear_layers = (nn.Linear,) if self.use_te: linear_layers = linear_layers + (te.Linear,) if isinstance(m, linear_layers): nn.init.trunc_normal_(m.weight, std=0.02) # type: ignore[arg-type] if isinstance(m, linear_layers) and m.bias is not None: nn.init.constant_(m.bias, 0) # type: ignore[arg-type] norm_layers = (nn.LayerNorm, nn.BatchNorm1d) if self.use_te: norm_layers = norm_layers + (te.LayerNorm,) if isinstance(m, norm_layers): nn.init.constant_(m.bias, 0) # type: ignore[arg-type] nn.init.constant_(m.weight, 1.0) # type: ignore[arg-type]
[docs] def get_grid(self, ref: int, batchsize: int = 1) -> torch.Tensor: r""" Generate unified positional encoding grid for structured 2D data. Parameters ---------- ref : int Reference grid size for unified position encoding. batchsize : int, optional, default=1 Batch size for the generated grid. Returns ------- torch.Tensor Positional encoding tensor of shape :math:`(B, H \times W, \text{ref}^2)`. """ if self.structured_shape is None: raise ValueError( "Cannot generate positional encoding grid: structured_shape is None. " "This method requires structured_shape to be set." ) size_x, size_y = self.structured_shape # Create spatial grid for the structured shape gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) grid = torch.cat((gridx, gridy), dim=-1) # (B, H, W, 2) # Create reference grid gridx = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float) gridx = gridx.reshape(1, ref, 1, 1).repeat([batchsize, 1, ref, 1]) gridy = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float) gridy = gridy.reshape(1, 1, ref, 1).repeat([batchsize, ref, 1, 1]) grid_ref = torch.cat((gridx, gridy), dim=-1) # (B, ref, ref, 2) # Compute distance-based positional encoding pos = ( torch.sqrt( torch.sum( (grid[:, :, :, None, None, :] - grid_ref[:, None, None, :, :, :]) ** 2, dim=-1, ) ) .reshape(batchsize, -1, ref * ref) # Flatten spatial dims .contiguous() ) return pos
def forward( self, fx: Float[torch.Tensor, "B *spatial C_in"], embedding: Float[torch.Tensor, "B *spatial C_emb"] | None = None, time: Float[torch.Tensor, " B"] | None = None, ) -> Float[torch.Tensor, "B *spatial C_out"]: r""" Forward pass of the Transolver model. Parameters ---------- fx : torch.Tensor Functional input tensor. Shape :math:`(B, N, C_{in})` for flattened data or :math:`(B, H_s, W_s, C_{in})` for structured 2D, where :math:`B` is batch size, :math:`N` is number of tokens, and :math:`C_{in}` is functional dimension. embedding : torch.Tensor | None, optional Embedding tensor. Required if ``unified_pos=False``. time : torch.Tensor | None, optional Time tensor of shape :math:`(B,)` for time-dependent models. Returns ------- torch.Tensor Output tensor with same spatial shape as input and :math:`C_{out}` features. """ # Input validation (skip during torch.compile for performance) if not torch.compiler.is_compiling(): if fx.ndim < 2: raise ValueError( f"Expected input tensor with at least 2 dimensions, " f"got {fx.ndim}D tensor with shape {tuple(fx.shape)}" ) if not self.unified_pos and embedding is None: raise ValueError("Embedding is required when unified_pos=False") # Track whether we need to unflatten output unflatten_output = False n_tokens = 0 if self.unified_pos: # Extend unified position embedding to batch size emb_buffer: torch.Tensor = self.embedding # type: ignore[assignment] embedding = emb_buffer.repeat(fx.shape[0], 1, 1) # Reshape structured data to flattened format if necessary if self.structured_shape is not None: if len(fx.shape) != 3: unflatten_output = True fx = fx.reshape(fx.shape[0], -1, fx.shape[-1]) if embedding is not None and len(embedding.shape) != 3: embedding = embedding.reshape( embedding.shape[0], *self.structured_shape, -1 ) else: if embedding is None: raise ValueError("Embedding is required for unstructured data") # Store n_tokens for time embedding if embedding is not None: n_tokens = embedding.shape[1] # Concatenate embedding with functional input if embedding is not None: fx = torch.cat((embedding, fx), -1) # Project to hidden dimension fx = self.preprocess(fx) # Add time embedding if provided if time is not None: time_emb = self.time_embed(time).unsqueeze(1).repeat(1, n_tokens, 1) time_emb = self.time_fc(time_emb) fx = fx + time_emb # Apply transformer blocks for block in self.blocks: fx = block(fx) # Reshape back to structured format if needed if self.structured_shape is not None: if unflatten_output: fx = fx.reshape(fx.shape[0], *self.structured_shape, -1) return fx