Graph Neural Networks#

class physicsnemo.models.meshgraphnet.meshgraphnet.MeshGraphNet(*args, **kwargs)[source]#

Bases: Module

MeshGraphNet network architecture.

Parameters:
  • input_dim_nodes (int) – Number of node features.

  • input_dim_edges (int) – Number of edge features.

  • output_dim (int) – Number of outputs.

  • processor_size (int, optional, default=15) – Number of message passing blocks.

  • mlp_activation_fn (str, optional, default="relu") – Activation function to use.

  • num_layers_node_processor (int, optional, default=2) – Number of MLP layers for processing nodes in each message passing block.

  • num_layers_edge_processor (int, optional, default=2) – Number of MLP layers for processing edge features in each message passing block.

  • hidden_dim_processor (int, optional, default=128) – Hidden layer size for the message passing blocks.

  • hidden_dim_node_encoder (int, optional, default=128) – Hidden layer size for the node feature encoder.

  • num_layers_node_encoder (int, optional, default=2) – Number of MLP layers for the node feature encoder.

  • hidden_dim_edge_encoder (int, optional, default=128) – Hidden layer size for the edge feature encoder.

  • num_layers_edge_encoder (int, optional, default=2) – Number of MLP layers for the edge feature encoder.

  • hidden_dim_node_decoder (int, optional, default=128) – Hidden layer size for the node feature decoder.

  • num_layers_node_decoder (int, optional, default=2) – Number of MLP layers for the node feature decoder.

  • aggregation (Literal["sum", "mean"], optional, default="sum") – Message aggregation type. Allowed values are "sum" and "mean".

  • do_concat_trick (bool, optional, default=False) – Whether to replace concat+MLP with MLP+idx+sum.

  • num_processor_checkpoint_segments (int, optional, default=0) – Number of processor segments for gradient checkpointing (0 disables checkpointing).

  • checkpoint_offloading (bool, optional, default=False) – Whether to offload the checkpointing to the CPU.

  • norm_type (Literal["LayerNorm", "TELayerNorm"], optional, default="LayerNorm") – Normalization type. Allowed values are "LayerNorm" and "TELayerNorm". "TELayerNorm" refers to the Transformer Engine implementation of LayerNorm and requires NVIDIA Transformer Engine to be installed (optional dependency).

Forward:
  • node_features (torch.Tensor) – Input node features of shape \((N_{nodes}, D_{in}^{node})\).

  • edge_features (torch.Tensor) – Input edge features of shape \((N_{edges}, D_{in}^{edge})\).

  • graph (GraphType) – Graph connectivity/topology container (PyG). Connectivity/topology only. Do not duplicate node or edge features on the graph; pass them via node_features and edge_features. If present on the graph, they will be ignored by the model. node_features.shape[0] must equal the number of nodes in the graph graph.num_nodes. edge_features.shape[0] must equal the number of edges in the graph graph.num_edges. The current GraphType resolves to PyTorch Geometric objects (torch_geometric.data.Data or torch_geometric.data.HeteroData). See physicsnemo.nn.module.gnn_layers.graph_types for the exact alias and requirements.

Outputs:

torch.Tensor – Output node features of shape \((N_{nodes}, D_{out})\).

Examples

>>> # ``norm_type`` in MeshGraphNet is deprecated,
>>> # TE will be automatically used if possible unless told otherwise.
>>> # (You don't have to set this variable, it's faster to use TE!)
>>> # Example of how to disable:
>>> import os
>>> os.environ['PHYSICSNEMO_FORCE_TE'] = 'False'
>>>
>>> model = physicsnemo.models.meshgraphnet.MeshGraphNet(
...         input_dim_nodes=4,
...         input_dim_edges=3,
...         output_dim=2,
...     )
>>> from torch_geometric.data import Data
>>> edge_index = torch.randint(0, 10, (2, 5))
>>> graph = Data(edge_index=edge_index)
>>> node_features = torch.randn(10, 4)
>>> edge_features = torch.randn(5, 3)
>>> output = model(node_features, edge_features, graph)
>>> output.size()
torch.Size([10, 2])

Note

Reference: Learning Mesh-Based Simulation with Graph Networks <https://arxiv.org/pdf/2010.03409>.

See also MeshGraphMLP, MeshEdgeBlock, and MeshNodeBlock.

class physicsnemo.models.meshgraphnet.meshgraphnet.MeshGraphNetProcessor(*args, **kwargs)[source]#

Bases: Module

MeshGraphNet processor block.

Parameters:
  • processor_size (int, optional, default=15) – Number of alternating edge/node update layers in the processor.

  • input_dim_node (int, optional, default=128) – Dimensionality of per-node hidden features provided to the processor.

  • input_dim_edge (int, optional, default=128) – Dimensionality of per-edge hidden features provided to the processor.

  • num_layers_node (int, optional, default=2) – Number of MLP layers within each node update block.

  • num_layers_edge (int, optional, default=2) – Number of MLP layers within each edge update block.

  • aggregation (Literal["sum", "mean"], optional, default="sum") – Message aggregation type. Allowed values are "sum" and "mean".

  • norm_type (Literal["LayerNorm", "TELayerNorm"], optional, default="LayerNorm") – Normalization type. Allowed values are "LayerNorm" and "TELayerNorm". "TELayerNorm" uses the Transformer Engine LayerNorm and requires NVIDIA Transformer Engine to be installed.

  • activation_fn (torch.nn.Module, optional, default=nn.ReLU()) – Activation function module used inside the MLPs.

  • do_concat_trick (bool, optional, default=False) – Whether to replace concat+MLP with MLP+idx+sum.

  • num_processor_checkpoint_segments (int, optional, default=0) – Number of checkpoint segments across processor layers (0 disables checkpointing).

  • checkpoint_offloading (bool, optional, default=False) – Whether to offload checkpoint activations to CPU.

Forward:
  • node_features (torch.Tensor) – Node features of shape \((N_{nodes}, D_{node})\).

  • edge_features (torch.Tensor) – Edge features of shape \((N_{edges}, D_{edge})\).

  • graph (GraphType) – Graph connectivity/topology container (PyG). Connectivity/topology only. Do not duplicate node or edge features on the graph; pass them via node_features and edge_features. If present on the graph, they will be ignored by the model. node_features.shape[0] must equal the number of nodes in the graph graph.num_nodes. edge_features.shape[0] must equal the number of edges in the graph graph.num_edges. The current GraphType resolves to PyTorch Geometric objects (torch_geometric.data.Data or torch_geometric.data.HeteroData). See physicsnemo.nn.module.gnn_layers.graph_types for the exact alias and requirements.

Outputs:

torch.Tensor – Updated node features of shape \((N_{nodes}, D_{node})\).

class physicsnemo.models.mesh_reduced.mesh_reduced.Mesh_Reduced(
input_dim_nodes: int,
input_dim_edges: int,
output_decode_dim: int,
output_encode_dim: int = 3,
processor_size: int = 15,
num_layers_node_processor: int = 2,
num_layers_edge_processor: int = 2,
hidden_dim_processor: int = 128,
hidden_dim_node_encoder: int = 128,
num_layers_node_encoder: int = 2,
hidden_dim_edge_encoder: int = 128,
num_layers_edge_encoder: int = 2,
hidden_dim_node_decoder: int = 128,
num_layers_node_decoder: int = 2,
k: int = 3,
aggregation: Literal['sum', 'mean'] = 'mean',
)[source]#

Bases: Module

PbGMR-GMUS architecture. A mesh-reduced architecture that combines encoding and decoding processors for physics prediction in reduced mesh space.

Parameters:
  • input_dim_nodes (int) – Number of node features.

  • input_dim_edges (int) – Number of edge features.

  • output_decode_dim (int) – Number of decoding outputs (per node).

  • output_encode_dim (int, optional, default=3) – Number of encoding outputs (per pivotal position).

  • processor_size (int, optional, default=15) – Number of message passing blocks.

  • num_layers_node_processor (int, optional, default=2) – Number of MLP layers for processing nodes in each message passing block.

  • num_layers_edge_processor (int, optional, default=2) – Number of MLP layers for processing edge features in each message passing block.

  • hidden_dim_processor (int, optional, default=128) – Hidden layer size for the message passing blocks.

  • hidden_dim_node_encoder (int, optional, default=128) – Hidden layer size for the node feature encoder.

  • num_layers_node_encoder (int, optional, default=2) – Number of MLP layers for the node feature encoder.

  • hidden_dim_edge_encoder (int, optional, default=128) – Hidden layer size for the edge feature encoder.

  • num_layers_edge_encoder (int, optional, default=2) – Number of MLP layers for the edge feature encoder.

  • hidden_dim_node_decoder (int, optional, default=128) – Hidden layer size for the node feature decoder.

  • num_layers_node_decoder (int, optional, default=2) – Number of MLP layers for the node feature decoder.

  • k (int, optional, default=3) – Number of nearest neighbors for interpolation.

  • aggregation (Literal["sum", "mean"], optional, default="mean") – Message aggregation type. Allowed values are "sum" and "mean".

Forward:
  • node_features (torch.Tensor) – Input node features of shape \((N_{nodes}^{batch}, D_{in}^{node})\).

  • edge_features (torch.Tensor) – Input edge features of shape \((N_{edges}^{batch}, D_{in}^{edge})\).

  • graph (GraphType) – Graph connectivity/topology container (PyG). Connectivity/topology only. Do not duplicate node or edge features on the graph; pass them via node_features and edge_features. If present on the graph, they will be ignored by the model. node_features.shape[0] must equal the number of nodes in the graph graph.num_nodes. edge_features.shape[0] must equal the number of edges in the graph graph.num_edges. The current GraphType resolves to PyTorch Geometric objects (torch_geometric.data.Data or torch_geometric.data.HeteroData). See physicsnemo.nn.gnn_layers.graph_types for the exact alias and requirements.

  • position_mesh (torch.Tensor) – Per-graph reference mesh positions of shape \((N_{mesh}, D_{pos})\). These positions are repeated internally across the batch.

  • position_pivotal (torch.Tensor) – Per-graph pivotal positions of shape \((N_{pivotal}, D_{pos})\). These positions are repeated internally across the batch.

Returns:

Decoded node features of shape \((N_{nodes}^{batch}, D_{out}^{decode})\).

Return type:

torch.Tensor

Examples

>>> import torch
>>> from torch_geometric.data import Data
>>> from physicsnemo.models.mesh_reduced.mesh_reduced import Mesh_Reduced
>>>
>>> # Choose a consistent device
>>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
>>>
>>> # Instantiate model
>>> model = Mesh_Reduced(
...     input_dim_nodes=4,
...     input_dim_edges=3,
...     output_decode_dim=2,
... ).to(device)
>>>
>>> # Build a simple PyG graph
>>> # Note: num_nodes must match len(position_mesh) for batch alignment
>>> num_mesh = 20
>>> num_nodes, num_edges = num_mesh, 30
>>> edge_index = torch.randint(0, num_nodes, (2, num_edges))
>>> graph = Data(edge_index=edge_index, num_nodes=num_nodes).to(device)
>>> # For a single graph, set a batch vector of zeros
>>> graph.batch = torch.zeros(num_nodes, dtype=torch.long, device=device)
>>>
>>> # Node/edge features
>>> node_features = torch.randn(num_nodes, 4, device=device)
>>> edge_features = torch.randn(num_edges, 3, device=device)
>>>
>>> # Per-graph positions (repeated internally across the batch)
>>> position_mesh = torch.randn(num_mesh, 3, device=device)       # (N_mesh, D_pos)
>>> position_pivotal = torch.randn(5, 3, device=device)     # (N_pivotal, D_pos)
>>>
>>> # Encode to pivotal space, then decode back to mesh space
>>> enc = model.encode(node_features, edge_features, graph, position_mesh, position_pivotal)
>>> out = model.decode(enc, edge_features, graph, position_mesh, position_pivotal)
>>> out.size()
torch.Size([20, 2])

Notes

Reference: Predicting physics in mesh-reduced space with temporal attention <https://arxiv.org/pdf/2201.09113>.

decode(
x: Tensor,
edge_features: Tensor,
graph: None,
position_mesh: Tensor,
position_pivotal: Tensor,
) Tensor[source]#

Decode pivotal features back to mesh space.

Parameters:
  • x (torch.Tensor) – Input features in pivotal space of shape \((N_{pivotal}^{batch}, D_{enc})\).

  • edge_features (torch.Tensor) – Edge features of shape \((N_{edges}^{batch}, D_{in}^{edge})\).

  • graph (GraphType) – Graph connectivity/topology container (PyG). Connectivity/topology only. Do not duplicate node or edge features on the graph; pass them via node_features and edge_features. If present on the graph, they will be ignored by the model. node_features.shape[0] must equal the number of nodes in the graph graph.num_nodes. edge_features.shape[0] must equal the number of edges in the graph graph.num_edges. The current GraphType resolves to PyTorch Geometric objects (torch_geometric.data.Data or torch_geometric.data.HeteroData). See physicsnemo.nn.gnn_layers.graph_types for the exact alias and requirements.

  • position_mesh (torch.Tensor) – Per-graph mesh positions of shape \((N_{mesh}, D_{pos})\).

  • position_pivotal (torch.Tensor) – Per-graph pivotal positions of shape \((N_{pivotal}, D_{pos})\).

Returns:

Decoded features in mesh space of shape \((N_{nodes}^{batch}, D_{out}^{decode})\).

Return type:

torch.Tensor

encode(
x: Tensor,
edge_features: Tensor,
graph: None,
position_mesh: Tensor,
position_pivotal: Tensor,
) Tensor[source]#

Encode mesh features to pivotal space.

Parameters:
  • x (torch.Tensor) – Input node features of shape \((N_{nodes}^{batch}, D_{in}^{node})\).

  • edge_features (torch.Tensor) – Edge features of shape \((N_{edges}^{batch}, D_{in}^{edge})\).

  • graph (GraphType) – PyG graph container with batch information.

  • position_mesh (torch.Tensor) – Per-graph reference mesh positions of shape \((N_{mesh}, D_{pos})\).

  • position_pivotal (torch.Tensor) – Per-graph pivotal positions of shape \((N_{pivotal}, D_{pos})\).

Returns:

Encoded pivotal features of shape \((N_{pivotal}^{batch}, D_{enc})\).

Return type:

torch.Tensor

knn_interpolate(
x: Tensor,
pos_x: Tensor,
pos_y: Tensor,
batch_x: Tensor | None = None,
batch_y: Tensor | None = None,
k: int = 3,
num_workers: int = 1,
) Tuple[Tensor, Tensor, Tensor, Tensor][source]#

Perform k-nearest neighbor interpolation from pos_x to pos_y.

Parameters:
  • x (torch.Tensor) – Source features of shape \((N_x, D_x)\).

  • pos_x (torch.Tensor) – Source positions of shape \((N_x, D_{pos})\).

  • pos_y (torch.Tensor) – Target positions of shape \((N_y, D_{pos})\).

  • batch_x (torch.Tensor, optional) – Batch indices for pos_x of shape \((N_x,)\). If provided, neighbors are computed per-graph. Default is None.

  • batch_y (torch.Tensor, optional) – Batch indices for pos_y of shape \((N_y,)\). If provided, neighbors are computed per-graph. Default is None.

  • k (int, optional, default=3) – Number of nearest neighbors.

  • num_workers (int, optional, default=1) – Number of workers for the KNN search.

Returns:

A tuple (y, col, row, weights) where:

  • y: interpolated features of shape \((N_y, D_x)\)

  • col: indices into pos_x (source) of shape \((k \cdot N_y,)\)

  • row: indices into pos_y (target) of shape \((k \cdot N_y,)\)

  • weights: interpolation weights of shape \((k \cdot N_y, 1)\)

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

class physicsnemo.models.meshgraphnet.bsms_mgn.BiStrideMeshGraphNet(*args, **kwargs)[source]#

Bases: MeshGraphNet

Bi-stride MeshGraphNet network architecture.

Bi-stride MGN augments vanilla MGN with a U-Net-like multi-scale message passing that alternates between coarsening and refining the mesh. This improves modeling of long-range interactions.

Parameters:
  • input_dim_nodes (int) – Number of node features.

  • input_dim_edges (int) – Number of edge features.

  • output_dim (int) – Number of outputs.

  • processor_size (int, optional, default=15) – Number of message passing blocks.

  • mlp_activation_fn (str, optional, default="relu") – Activation function to use.

  • num_layers_node_processor (int, optional, default=2) – Number of MLP layers for processing nodes in each message passing block.

  • num_layers_edge_processor (int, optional, default=2) – Number of MLP layers for processing edge features in each message passing block.

  • num_mesh_levels (int, optional, default=2) – Number of mesh levels used by the bi-stride U-Net (multi-scale) processor.

  • bistride_pos_dim (int, optional, default=3) – Dimensionality of node positions stored in graph.pos (required by bi-stride).

  • num_layers_bistride (int, optional, default=2) – Number of layers within each bi-stride message passing block.

  • bistride_unet_levels (int, optional, default=1) – Number of times to apply the bi-stride U-Net (depth of repeat).

  • hidden_dim_processor (int, optional, default=128) – Hidden layer size for the message passing blocks.

  • hidden_dim_node_encoder (int, optional, default=128) – Hidden layer size for the node feature encoder.

  • num_layers_node_encoder (Union[int, None], optional, default=2) – Number of MLP layers for the node feature encoder. If None is provided, the MLP collapses to an identity function (no node encoder).

  • hidden_dim_edge_encoder (int, optional, default=128) – Hidden layer size for the edge feature encoder.

  • num_layers_edge_encoder (Union[int, None], optional, default=2) – Number of MLP layers for the edge feature encoder. If None is provided, the MLP collapses to an identity function (no edge encoder).

  • hidden_dim_node_decoder (int, optional, default=128) – Hidden layer size for the node feature decoder.

  • num_layers_node_decoder (Union[int, None], optional, default=2) – Number of MLP layers for the node feature decoder. If None is provided, the MLP collapses to an identity function (no decoder).

  • aggregation (Literal["sum", "mean"], optional, default="sum") – Message aggregation type. Allowed values are "sum" and "mean".

  • do_concat_trick (bool, optional, default=False) – Whether to replace concat+MLP with MLP+idx+sum.

  • num_processor_checkpoint_segments (int, optional, default=0) – Number of processor segments for gradient checkpointing (checkpointing disabled if 0). The number of segments should be a factor of \(2\times\text{processor\_size}\). For example, if processor_size is 15, then num_processor_checkpoint_segments can be 10 since it’s a factor of \(15 \times 2 = 30\). Start with fewer segments if memory is tight, as each segment affects training speed.

  • recompute_activation (bool, optional, default=False) – Whether to recompute activations during backward to reduce memory usage.

Forward:
  • node_features (torch.Tensor) – Input node features of shape \((N_{nodes}, D_{in}^{node})\).

  • edge_features (torch.Tensor) – Input edge features of shape \((N_{edges}, D_{in}^{edge})\).

  • graph (GraphType) – Graph connectivity/topology container (PyG). Connectivity/topology only. Do not duplicate node or edge features on the graph; pass them via node_features and edge_features. If present on the graph, they will be ignored by the model. node_features.shape[0] must equal the number of nodes in the graph graph.num_nodes. edge_features.shape[0] must equal the number of edges in the graph graph.num_edges. The current GraphType resolves to PyTorch Geometric objects (torch_geometric.data.Data or torch_geometric.data.HeteroData). See physicsnemo.nn.module.gnn_layers.graph_types for the exact alias and requirements. Requires graph.pos with shape \((N_{nodes}, \text{bistride\_pos\_dim})\) for bi-stride.

  • ms_edges (Iterable[torch.Tensor], optional) – Multi-scale edge lists; each is typically an integer tensor of shape \((2, E_l)\).

  • ms_ids (Iterable[torch.Tensor], optional) – Multi-scale node id tensors per level; typically shape \((N_l,)\).

Outputs:

torch.Tensor – Output node features of shape \((N_{nodes}, D_{out})\).

Examples

>>> import torch
>>> from torch_geometric.data import Data
>>> from physicsnemo.models.meshgraphnet.bsms_mgn import BiStrideMeshGraphNet
>>>
>>> # Choose a device and create the model on it
>>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
>>>
>>> # Create a simple graph
>>> num_nodes = 8
>>> src = torch.arange(num_nodes, device=device)
>>> dst = (src + 1) % num_nodes
>>> edge_index = torch.stack([src, dst], dim=0)  # (2, E)
>>> graph = Data(edge_index=edge_index, num_nodes=num_nodes).to(device)
>>> graph.pos = torch.randn(num_nodes, 3, device=device)  # position needed by bi-stride
>>>
>>> # Features
>>> node_features = torch.randn(num_nodes, 10, device=device)
>>> edge_features = torch.randn(edge_index.shape[1], 4, device=device)
>>>
>>> # Multi-scale inputs (one level for simplicity)
>>> ms_edges = [edge_index, edge_index]  # list of (2, E_l) tensors
>>> ms_ids = [torch.arange(num_nodes, device=device), torch.arange(num_nodes, device=device)]  # list of (N_l,) tensors
>>>
>>> # Model
>>> model = BiStrideMeshGraphNet(
...     input_dim_nodes=10,
...     input_dim_edges=4,
...     output_dim=4,
...     processor_size=2,
...     hidden_dim_processor=32,
...     hidden_dim_node_encoder=16,
...     hidden_dim_edge_encoder=16,
...     num_layers_bistride=1,
...     num_mesh_levels=1,
... ).to(device)
>>>
>>> out = model(node_features, edge_features, graph, ms_edges, ms_ids)
>>> out.size()
torch.Size([8, 4])

Note

Reference: Efficient Learning of Mesh-Based Physical Simulation with Bi-Stride Multi-Scale Graph Neural Network <https://arxiv.org/pdf/2210.02573>.