Source code for physicsnemo.models.meshgraphnet.bsms_mgn

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Iterable, Literal, Optional

import torch
from torch import Tensor

from physicsnemo.core.meta import ModelMetaData
from physicsnemo.models.meshgraphnet import MeshGraphNet
from physicsnemo.nn.module.gnn_layers.bsms import BistrideGraphMessagePassing
from physicsnemo.nn.module.gnn_layers.utils import GraphType


@dataclass
class MetaData(ModelMetaData):
    # Optimization, no JIT as DGLGraph causes trouble
    jit: bool = False
    cuda_graphs: bool = False
    amp_cpu: bool = False
    amp_gpu: bool = True
    torch_fx: bool = False
    # Inference
    onnx: bool = False
    # Physics informed
    func_torch: bool = True
    auto_grad: bool = True


[docs] class BiStrideMeshGraphNet(MeshGraphNet): r"""Bi-stride MeshGraphNet network architecture. Bi-stride MGN augments vanilla MGN with a U-Net-like multi-scale message passing that alternates between coarsening and refining the mesh. This improves modeling of long-range interactions. Parameters ---------- input_dim_nodes : int Number of node features. input_dim_edges : int Number of edge features. output_dim : int Number of outputs. processor_size : int, optional, default=15 Number of message passing blocks. mlp_activation_fn : str, optional, default="relu" Activation function to use. num_layers_node_processor : int, optional, default=2 Number of MLP layers for processing nodes in each message passing block. num_layers_edge_processor : int, optional, default=2 Number of MLP layers for processing edge features in each message passing block. num_mesh_levels : int, optional, default=2 Number of mesh levels used by the bi-stride U-Net (multi-scale) processor. bistride_pos_dim : int, optional, default=3 Dimensionality of node positions stored in ``graph.pos`` (required by bi-stride). num_layers_bistride : int, optional, default=2 Number of layers within each bi-stride message passing block. bistride_unet_levels : int, optional, default=1 Number of times to apply the bi-stride U-Net (depth of repeat). hidden_dim_processor : int, optional, default=128 Hidden layer size for the message passing blocks. hidden_dim_node_encoder : int, optional, default=128 Hidden layer size for the node feature encoder. num_layers_node_encoder : Union[int, None], optional, default=2 Number of MLP layers for the node feature encoder. If ``None`` is provided, the MLP collapses to an identity function (no node encoder). hidden_dim_edge_encoder : int, optional, default=128 Hidden layer size for the edge feature encoder. num_layers_edge_encoder : Union[int, None], optional, default=2 Number of MLP layers for the edge feature encoder. If ``None`` is provided, the MLP collapses to an identity function (no edge encoder). hidden_dim_node_decoder : int, optional, default=128 Hidden layer size for the node feature decoder. num_layers_node_decoder : Union[int, None], optional, default=2 Number of MLP layers for the node feature decoder. If ``None`` is provided, the MLP collapses to an identity function (no decoder). aggregation : Literal["sum", "mean"], optional, default="sum" Message aggregation type. Allowed values are ``"sum"`` and ``"mean"``. do_concat_trick : bool, optional, default=False Whether to replace concat+MLP with MLP+idx+sum. num_processor_checkpoint_segments : int, optional, default=0 Number of processor segments for gradient checkpointing (checkpointing disabled if 0). The number of segments should be a factor of :math:`2\times\text{processor\_size}`. For example, if ``processor_size`` is 15, then ``num_processor_checkpoint_segments`` can be 10 since it's a factor of :math:`15 \times 2 = 30`. Start with fewer segments if memory is tight, as each segment affects training speed. recompute_activation : bool, optional, default=False Whether to recompute activations during backward to reduce memory usage. Forward ------- node_features : torch.Tensor Input node features of shape :math:`(N_{nodes}, D_{in}^{node})`. edge_features : torch.Tensor Input edge features of shape :math:`(N_{edges}, D_{in}^{edge})`. graph : :class:`~physicsnemo.nn.module.gnn_layers.utils.GraphType` Graph connectivity/topology container (PyG). Connectivity/topology only. Do not duplicate node or edge features on the graph; pass them via ``node_features`` and ``edge_features``. If present on the graph, they will be ignored by the model. ``node_features.shape[0]`` must equal the number of nodes in the graph ``graph.num_nodes``. ``edge_features.shape[0]`` must equal the number of edges in the graph ``graph.num_edges``. The current :class:`~physicsnemo.nn.module.gnn_layers.graph_types.GraphType` resolves to PyTorch Geometric objects (``torch_geometric.data.Data`` or ``torch_geometric.data.HeteroData``). See :mod:`physicsnemo.nn.module.gnn_layers.graph_types` for the exact alias and requirements. Requires ``graph.pos`` with shape :math:`(N_{nodes}, \text{bistride\_pos\_dim})` for bi-stride. ms_edges : Iterable[torch.Tensor], optional Multi-scale edge lists; each is typically an integer tensor of shape :math:`(2, E_l)`. ms_ids : Iterable[torch.Tensor], optional Multi-scale node id tensors per level; typically shape :math:`(N_l,)`. Outputs ------- torch.Tensor Output node features of shape :math:`(N_{nodes}, D_{out})`. Examples -------- >>> import torch >>> from torch_geometric.data import Data >>> from physicsnemo.models.meshgraphnet.bsms_mgn import BiStrideMeshGraphNet >>> >>> # Choose a device and create the model on it >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') >>> >>> # Create a simple graph >>> num_nodes = 8 >>> src = torch.arange(num_nodes, device=device) >>> dst = (src + 1) % num_nodes >>> edge_index = torch.stack([src, dst], dim=0) # (2, E) >>> graph = Data(edge_index=edge_index, num_nodes=num_nodes).to(device) >>> graph.pos = torch.randn(num_nodes, 3, device=device) # position needed by bi-stride >>> >>> # Features >>> node_features = torch.randn(num_nodes, 10, device=device) >>> edge_features = torch.randn(edge_index.shape[1], 4, device=device) >>> >>> # Multi-scale inputs (one level for simplicity) >>> ms_edges = [edge_index, edge_index] # list of (2, E_l) tensors >>> ms_ids = [torch.arange(num_nodes, device=device), torch.arange(num_nodes, device=device)] # list of (N_l,) tensors >>> >>> # Model >>> model = BiStrideMeshGraphNet( ... input_dim_nodes=10, ... input_dim_edges=4, ... output_dim=4, ... processor_size=2, ... hidden_dim_processor=32, ... hidden_dim_node_encoder=16, ... hidden_dim_edge_encoder=16, ... num_layers_bistride=1, ... num_mesh_levels=1, ... ).to(device) >>> >>> out = model(node_features, edge_features, graph, ms_edges, ms_ids) >>> out.size() torch.Size([8, 4]) Note ---- Reference: `Efficient Learning of Mesh-Based Physical Simulation with Bi-Stride Multi-Scale Graph Neural Network <https://arxiv.org/pdf/2210.02573>`. """ def __init__( self, input_dim_nodes: int, input_dim_edges: int, output_dim: int, processor_size: int = 15, mlp_activation_fn: str = "relu", num_layers_node_processor: int = 2, num_layers_edge_processor: int = 2, num_mesh_levels: int = 2, bistride_pos_dim: int = 3, num_layers_bistride: int = 2, bistride_unet_levels: int = 1, hidden_dim_processor: int = 128, hidden_dim_node_encoder: int = 128, num_layers_node_encoder: Optional[int] = 2, hidden_dim_edge_encoder: int = 128, num_layers_edge_encoder: Optional[int] = 2, hidden_dim_node_decoder: int = 128, num_layers_node_decoder: Optional[int] = 2, aggregation: Literal["sum", "mean"] = "sum", do_concat_trick: bool = False, num_processor_checkpoint_segments: int = 0, recompute_activation: bool = False, ): super().__init__( input_dim_nodes, input_dim_edges, output_dim, processor_size=processor_size, mlp_activation_fn=mlp_activation_fn, num_layers_node_processor=num_layers_node_processor, num_layers_edge_processor=num_layers_edge_processor, hidden_dim_processor=hidden_dim_processor, hidden_dim_node_encoder=hidden_dim_node_encoder, num_layers_node_encoder=num_layers_node_encoder, hidden_dim_edge_encoder=hidden_dim_edge_encoder, num_layers_edge_encoder=num_layers_edge_encoder, hidden_dim_node_decoder=hidden_dim_node_decoder, num_layers_node_decoder=num_layers_node_decoder, aggregation=aggregation, do_concat_trick=do_concat_trick, num_processor_checkpoint_segments=num_processor_checkpoint_segments, recompute_activation=recompute_activation, ) self.meta = MetaData() self.bistride_unet_levels = bistride_unet_levels self.bistride_processor = BistrideGraphMessagePassing( unet_depth=num_mesh_levels, latent_dim=hidden_dim_processor, hidden_layer=num_layers_bistride, pos_dim=bistride_pos_dim, ) def forward( self, node_features: Tensor, edge_features: Tensor, graph: GraphType, ms_edges: Iterable[Tensor] = (), ms_ids: Iterable[Tensor] = (), **kwargs, ) -> Tensor: if not torch.compiler.is_compiling(): if ( node_features.ndim != 2 or node_features.shape[1] != self.input_dim_nodes ): raise ValueError( f"Expected tensor of shape (N_nodes, {self.input_dim_nodes}) but got tensor of shape {tuple(node_features.shape)}" ) if ( edge_features.ndim != 2 or edge_features.shape[1] != self.input_dim_edges ): raise ValueError( f"Expected tensor of shape (N_edges, {self.input_dim_edges}) but got tensor of shape {tuple(edge_features.shape)}" ) edge_features = self.edge_encoder(edge_features) node_features = self.node_encoder(node_features) x = self.processor(node_features, edge_features, graph) node_pos = graph.pos ms_edges = [es.to(node_pos.device).squeeze(0) for es in ms_edges] ms_ids = [ids.squeeze(0) for ids in ms_ids] for _ in range(self.bistride_unet_levels): x = self.bistride_processor(x, ms_ids, ms_edges, node_pos) x = self.node_decoder(x) return x