Graph Neural Networks#
- class physicsnemo.models.meshgraphnet.meshgraphnet.MeshGraphNet(*args, **kwargs)[source]#
Bases:
ModuleMeshGraphNet network architecture.
- Parameters:
input_dim_nodes (int) – Number of node features.
input_dim_edges (int) – Number of edge features.
output_dim (int) – Number of outputs.
processor_size (int, optional, default=15) – Number of message passing blocks.
mlp_activation_fn (str, optional, default="relu") – Activation function to use.
num_layers_node_processor (int, optional, default=2) – Number of MLP layers for processing nodes in each message passing block.
num_layers_edge_processor (int, optional, default=2) – Number of MLP layers for processing edge features in each message passing block.
hidden_dim_processor (int, optional, default=128) – Hidden layer size for the message passing blocks.
hidden_dim_node_encoder (int, optional, default=128) – Hidden layer size for the node feature encoder.
num_layers_node_encoder (int, optional, default=2) – Number of MLP layers for the node feature encoder.
hidden_dim_edge_encoder (int, optional, default=128) – Hidden layer size for the edge feature encoder.
num_layers_edge_encoder (int, optional, default=2) – Number of MLP layers for the edge feature encoder.
hidden_dim_node_decoder (int, optional, default=128) – Hidden layer size for the node feature decoder.
num_layers_node_decoder (int, optional, default=2) – Number of MLP layers for the node feature decoder.
aggregation (Literal["sum", "mean"], optional, default="sum") – Message aggregation type. Allowed values are
"sum"and"mean".do_concat_trick (bool, optional, default=False) – Whether to replace concat+MLP with MLP+idx+sum.
num_processor_checkpoint_segments (int, optional, default=0) – Number of processor segments for gradient checkpointing (0 disables checkpointing).
checkpoint_offloading (bool, optional, default=False) – Whether to offload the checkpointing to the CPU.
norm_type (Literal["LayerNorm", "TELayerNorm"], optional, default="LayerNorm") – Normalization type. Allowed values are
"LayerNorm"and"TELayerNorm"."TELayerNorm"refers to the Transformer Engine implementation of LayerNorm and requires NVIDIA Transformer Engine to be installed (optional dependency).
- Forward:
node_features (torch.Tensor) – Input node features of shape \((N_{nodes}, D_{in}^{node})\).
edge_features (torch.Tensor) – Input edge features of shape \((N_{edges}, D_{in}^{edge})\).
graph (
GraphType) – Graph connectivity/topology container (PyG). Connectivity/topology only. Do not duplicate node or edge features on the graph; pass them vianode_featuresandedge_features. If present on the graph, they will be ignored by the model.node_features.shape[0]must equal the number of nodes in the graphgraph.num_nodes.edge_features.shape[0]must equal the number of edges in the graphgraph.num_edges. The currentGraphTyperesolves to PyTorch Geometric objects (torch_geometric.data.Dataortorch_geometric.data.HeteroData). Seephysicsnemo.nn.module.gnn_layers.graph_typesfor the exact alias and requirements.
- Outputs:
torch.Tensor – Output node features of shape \((N_{nodes}, D_{out})\).
Examples
>>> # ``norm_type`` in MeshGraphNet is deprecated, >>> # TE will be automatically used if possible unless told otherwise. >>> # (You don't have to set this variable, it's faster to use TE!) >>> # Example of how to disable: >>> import os >>> os.environ['PHYSICSNEMO_FORCE_TE'] = 'False' >>> >>> model = physicsnemo.models.meshgraphnet.MeshGraphNet( ... input_dim_nodes=4, ... input_dim_edges=3, ... output_dim=2, ... ) >>> from torch_geometric.data import Data >>> edge_index = torch.randint(0, 10, (2, 5)) >>> graph = Data(edge_index=edge_index) >>> node_features = torch.randn(10, 4) >>> edge_features = torch.randn(5, 3) >>> output = model(node_features, edge_features, graph) >>> output.size() torch.Size([10, 2])
Note
Reference: Learning Mesh-Based Simulation with Graph Networks <https://arxiv.org/pdf/2010.03409>.
See also
MeshGraphMLP,MeshEdgeBlock, andMeshNodeBlock.
- class physicsnemo.models.meshgraphnet.meshgraphnet.MeshGraphNetProcessor(*args, **kwargs)[source]#
Bases:
ModuleMeshGraphNet processor block.
- Parameters:
processor_size (int, optional, default=15) – Number of alternating edge/node update layers in the processor.
input_dim_node (int, optional, default=128) – Dimensionality of per-node hidden features provided to the processor.
input_dim_edge (int, optional, default=128) – Dimensionality of per-edge hidden features provided to the processor.
num_layers_node (int, optional, default=2) – Number of MLP layers within each node update block.
num_layers_edge (int, optional, default=2) – Number of MLP layers within each edge update block.
aggregation (Literal["sum", "mean"], optional, default="sum") – Message aggregation type. Allowed values are
"sum"and"mean".norm_type (Literal["LayerNorm", "TELayerNorm"], optional, default="LayerNorm") – Normalization type. Allowed values are
"LayerNorm"and"TELayerNorm"."TELayerNorm"uses the Transformer Engine LayerNorm and requires NVIDIA Transformer Engine to be installed.activation_fn (torch.nn.Module, optional, default=nn.ReLU()) – Activation function module used inside the MLPs.
do_concat_trick (bool, optional, default=False) – Whether to replace concat+MLP with MLP+idx+sum.
num_processor_checkpoint_segments (int, optional, default=0) – Number of checkpoint segments across processor layers (0 disables checkpointing).
checkpoint_offloading (bool, optional, default=False) – Whether to offload checkpoint activations to CPU.
- Forward:
node_features (torch.Tensor) – Node features of shape \((N_{nodes}, D_{node})\).
edge_features (torch.Tensor) – Edge features of shape \((N_{edges}, D_{edge})\).
graph (
GraphType) – Graph connectivity/topology container (PyG). Connectivity/topology only. Do not duplicate node or edge features on the graph; pass them vianode_featuresandedge_features. If present on the graph, they will be ignored by the model.node_features.shape[0]must equal the number of nodes in the graphgraph.num_nodes.edge_features.shape[0]must equal the number of edges in the graphgraph.num_edges. The currentGraphTyperesolves to PyTorch Geometric objects (torch_geometric.data.Dataortorch_geometric.data.HeteroData). Seephysicsnemo.nn.module.gnn_layers.graph_typesfor the exact alias and requirements.
- Outputs:
torch.Tensor – Updated node features of shape \((N_{nodes}, D_{node})\).
- class physicsnemo.models.mesh_reduced.mesh_reduced.Mesh_Reduced(
- input_dim_nodes: int,
- input_dim_edges: int,
- output_decode_dim: int,
- output_encode_dim: int = 3,
- processor_size: int = 15,
- num_layers_node_processor: int = 2,
- num_layers_edge_processor: int = 2,
- hidden_dim_processor: int = 128,
- hidden_dim_node_encoder: int = 128,
- num_layers_node_encoder: int = 2,
- hidden_dim_edge_encoder: int = 128,
- num_layers_edge_encoder: int = 2,
- hidden_dim_node_decoder: int = 128,
- num_layers_node_decoder: int = 2,
- k: int = 3,
- aggregation: Literal['sum', 'mean'] = 'mean',
Bases:
ModulePbGMR-GMUS architecture. A mesh-reduced architecture that combines encoding and decoding processors for physics prediction in reduced mesh space.
- Parameters:
input_dim_nodes (int) – Number of node features.
input_dim_edges (int) – Number of edge features.
output_decode_dim (int) – Number of decoding outputs (per node).
output_encode_dim (int, optional, default=3) – Number of encoding outputs (per pivotal position).
processor_size (int, optional, default=15) – Number of message passing blocks.
num_layers_node_processor (int, optional, default=2) – Number of MLP layers for processing nodes in each message passing block.
num_layers_edge_processor (int, optional, default=2) – Number of MLP layers for processing edge features in each message passing block.
hidden_dim_processor (int, optional, default=128) – Hidden layer size for the message passing blocks.
hidden_dim_node_encoder (int, optional, default=128) – Hidden layer size for the node feature encoder.
num_layers_node_encoder (int, optional, default=2) – Number of MLP layers for the node feature encoder.
hidden_dim_edge_encoder (int, optional, default=128) – Hidden layer size for the edge feature encoder.
num_layers_edge_encoder (int, optional, default=2) – Number of MLP layers for the edge feature encoder.
hidden_dim_node_decoder (int, optional, default=128) – Hidden layer size for the node feature decoder.
num_layers_node_decoder (int, optional, default=2) – Number of MLP layers for the node feature decoder.
k (int, optional, default=3) – Number of nearest neighbors for interpolation.
aggregation (Literal["sum", "mean"], optional, default="mean") – Message aggregation type. Allowed values are
"sum"and"mean".
- Forward:
node_features (torch.Tensor) – Input node features of shape \((N_{nodes}^{batch}, D_{in}^{node})\).
edge_features (torch.Tensor) – Input edge features of shape \((N_{edges}^{batch}, D_{in}^{edge})\).
graph (
GraphType) – Graph connectivity/topology container (PyG). Connectivity/topology only. Do not duplicate node or edge features on the graph; pass them vianode_featuresandedge_features. If present on the graph, they will be ignored by the model.node_features.shape[0]must equal the number of nodes in the graphgraph.num_nodes.edge_features.shape[0]must equal the number of edges in the graphgraph.num_edges. The currentGraphTyperesolves to PyTorch Geometric objects (torch_geometric.data.Dataortorch_geometric.data.HeteroData). Seephysicsnemo.nn.gnn_layers.graph_typesfor the exact alias and requirements.position_mesh (torch.Tensor) – Per-graph reference mesh positions of shape \((N_{mesh}, D_{pos})\). These positions are repeated internally across the batch.
position_pivotal (torch.Tensor) – Per-graph pivotal positions of shape \((N_{pivotal}, D_{pos})\). These positions are repeated internally across the batch.
- Returns:
Decoded node features of shape \((N_{nodes}^{batch}, D_{out}^{decode})\).
- Return type:
torch.Tensor
Examples
>>> import torch >>> from torch_geometric.data import Data >>> from physicsnemo.models.mesh_reduced.mesh_reduced import Mesh_Reduced >>> >>> # Choose a consistent device >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') >>> >>> # Instantiate model >>> model = Mesh_Reduced( ... input_dim_nodes=4, ... input_dim_edges=3, ... output_decode_dim=2, ... ).to(device) >>> >>> # Build a simple PyG graph >>> # Note: num_nodes must match len(position_mesh) for batch alignment >>> num_mesh = 20 >>> num_nodes, num_edges = num_mesh, 30 >>> edge_index = torch.randint(0, num_nodes, (2, num_edges)) >>> graph = Data(edge_index=edge_index, num_nodes=num_nodes).to(device) >>> # For a single graph, set a batch vector of zeros >>> graph.batch = torch.zeros(num_nodes, dtype=torch.long, device=device) >>> >>> # Node/edge features >>> node_features = torch.randn(num_nodes, 4, device=device) >>> edge_features = torch.randn(num_edges, 3, device=device) >>> >>> # Per-graph positions (repeated internally across the batch) >>> position_mesh = torch.randn(num_mesh, 3, device=device) # (N_mesh, D_pos) >>> position_pivotal = torch.randn(5, 3, device=device) # (N_pivotal, D_pos) >>> >>> # Encode to pivotal space, then decode back to mesh space >>> enc = model.encode(node_features, edge_features, graph, position_mesh, position_pivotal) >>> out = model.decode(enc, edge_features, graph, position_mesh, position_pivotal) >>> out.size() torch.Size([20, 2])
Notes
Reference: Predicting physics in mesh-reduced space with temporal attention <https://arxiv.org/pdf/2201.09113>.
- decode(
- x: Tensor,
- edge_features: Tensor,
- graph: None,
- position_mesh: Tensor,
- position_pivotal: Tensor,
Decode pivotal features back to mesh space.
- Parameters:
x (torch.Tensor) – Input features in pivotal space of shape \((N_{pivotal}^{batch}, D_{enc})\).
edge_features (torch.Tensor) – Edge features of shape \((N_{edges}^{batch}, D_{in}^{edge})\).
graph (
GraphType) – Graph connectivity/topology container (PyG). Connectivity/topology only. Do not duplicate node or edge features on the graph; pass them vianode_featuresandedge_features. If present on the graph, they will be ignored by the model.node_features.shape[0]must equal the number of nodes in the graphgraph.num_nodes.edge_features.shape[0]must equal the number of edges in the graphgraph.num_edges. The currentGraphTyperesolves to PyTorch Geometric objects (torch_geometric.data.Dataortorch_geometric.data.HeteroData). Seephysicsnemo.nn.gnn_layers.graph_typesfor the exact alias and requirements.position_mesh (torch.Tensor) – Per-graph mesh positions of shape \((N_{mesh}, D_{pos})\).
position_pivotal (torch.Tensor) – Per-graph pivotal positions of shape \((N_{pivotal}, D_{pos})\).
- Returns:
Decoded features in mesh space of shape \((N_{nodes}^{batch}, D_{out}^{decode})\).
- Return type:
torch.Tensor
- encode(
- x: Tensor,
- edge_features: Tensor,
- graph: None,
- position_mesh: Tensor,
- position_pivotal: Tensor,
Encode mesh features to pivotal space.
- Parameters:
x (torch.Tensor) – Input node features of shape \((N_{nodes}^{batch}, D_{in}^{node})\).
edge_features (torch.Tensor) – Edge features of shape \((N_{edges}^{batch}, D_{in}^{edge})\).
graph (
GraphType) – PyG graph container with batch information.position_mesh (torch.Tensor) – Per-graph reference mesh positions of shape \((N_{mesh}, D_{pos})\).
position_pivotal (torch.Tensor) – Per-graph pivotal positions of shape \((N_{pivotal}, D_{pos})\).
- Returns:
Encoded pivotal features of shape \((N_{pivotal}^{batch}, D_{enc})\).
- Return type:
torch.Tensor
- knn_interpolate(
- x: Tensor,
- pos_x: Tensor,
- pos_y: Tensor,
- batch_x: Tensor | None = None,
- batch_y: Tensor | None = None,
- k: int = 3,
- num_workers: int = 1,
Perform k-nearest neighbor interpolation from
pos_xtopos_y.- Parameters:
x (torch.Tensor) – Source features of shape \((N_x, D_x)\).
pos_x (torch.Tensor) – Source positions of shape \((N_x, D_{pos})\).
pos_y (torch.Tensor) – Target positions of shape \((N_y, D_{pos})\).
batch_x (torch.Tensor, optional) – Batch indices for
pos_xof shape \((N_x,)\). If provided, neighbors are computed per-graph. Default isNone.batch_y (torch.Tensor, optional) – Batch indices for
pos_yof shape \((N_y,)\). If provided, neighbors are computed per-graph. Default isNone.k (int, optional, default=3) – Number of nearest neighbors.
num_workers (int, optional, default=1) – Number of workers for the KNN search.
- Returns:
A tuple
(y, col, row, weights)where:y: interpolated features of shape \((N_y, D_x)\)col: indices intopos_x(source) of shape \((k \cdot N_y,)\)row: indices intopos_y(target) of shape \((k \cdot N_y,)\)weights: interpolation weights of shape \((k \cdot N_y, 1)\)
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- class physicsnemo.models.meshgraphnet.bsms_mgn.BiStrideMeshGraphNet(*args, **kwargs)[source]#
Bases:
MeshGraphNetBi-stride MeshGraphNet network architecture.
Bi-stride MGN augments vanilla MGN with a U-Net-like multi-scale message passing that alternates between coarsening and refining the mesh. This improves modeling of long-range interactions.
- Parameters:
input_dim_nodes (int) – Number of node features.
input_dim_edges (int) – Number of edge features.
output_dim (int) – Number of outputs.
processor_size (int, optional, default=15) – Number of message passing blocks.
mlp_activation_fn (str, optional, default="relu") – Activation function to use.
num_layers_node_processor (int, optional, default=2) – Number of MLP layers for processing nodes in each message passing block.
num_layers_edge_processor (int, optional, default=2) – Number of MLP layers for processing edge features in each message passing block.
num_mesh_levels (int, optional, default=2) – Number of mesh levels used by the bi-stride U-Net (multi-scale) processor.
bistride_pos_dim (int, optional, default=3) – Dimensionality of node positions stored in
graph.pos(required by bi-stride).num_layers_bistride (int, optional, default=2) – Number of layers within each bi-stride message passing block.
bistride_unet_levels (int, optional, default=1) – Number of times to apply the bi-stride U-Net (depth of repeat).
hidden_dim_processor (int, optional, default=128) – Hidden layer size for the message passing blocks.
hidden_dim_node_encoder (int, optional, default=128) – Hidden layer size for the node feature encoder.
num_layers_node_encoder (Union[int, None], optional, default=2) – Number of MLP layers for the node feature encoder. If
Noneis provided, the MLP collapses to an identity function (no node encoder).hidden_dim_edge_encoder (int, optional, default=128) – Hidden layer size for the edge feature encoder.
num_layers_edge_encoder (Union[int, None], optional, default=2) – Number of MLP layers for the edge feature encoder. If
Noneis provided, the MLP collapses to an identity function (no edge encoder).hidden_dim_node_decoder (int, optional, default=128) – Hidden layer size for the node feature decoder.
num_layers_node_decoder (Union[int, None], optional, default=2) – Number of MLP layers for the node feature decoder. If
Noneis provided, the MLP collapses to an identity function (no decoder).aggregation (Literal["sum", "mean"], optional, default="sum") – Message aggregation type. Allowed values are
"sum"and"mean".do_concat_trick (bool, optional, default=False) – Whether to replace concat+MLP with MLP+idx+sum.
num_processor_checkpoint_segments (int, optional, default=0) – Number of processor segments for gradient checkpointing (checkpointing disabled if 0). The number of segments should be a factor of \(2\times\text{processor\_size}\). For example, if
processor_sizeis 15, thennum_processor_checkpoint_segmentscan be 10 since it’s a factor of \(15 \times 2 = 30\). Start with fewer segments if memory is tight, as each segment affects training speed.recompute_activation (bool, optional, default=False) – Whether to recompute activations during backward to reduce memory usage.
- Forward:
node_features (torch.Tensor) – Input node features of shape \((N_{nodes}, D_{in}^{node})\).
edge_features (torch.Tensor) – Input edge features of shape \((N_{edges}, D_{in}^{edge})\).
graph (
GraphType) – Graph connectivity/topology container (PyG). Connectivity/topology only. Do not duplicate node or edge features on the graph; pass them vianode_featuresandedge_features. If present on the graph, they will be ignored by the model.node_features.shape[0]must equal the number of nodes in the graphgraph.num_nodes.edge_features.shape[0]must equal the number of edges in the graphgraph.num_edges. The currentGraphTyperesolves to PyTorch Geometric objects (torch_geometric.data.Dataortorch_geometric.data.HeteroData). Seephysicsnemo.nn.module.gnn_layers.graph_typesfor the exact alias and requirements. Requiresgraph.poswith shape \((N_{nodes}, \text{bistride\_pos\_dim})\) for bi-stride.ms_edges (Iterable[torch.Tensor], optional) – Multi-scale edge lists; each is typically an integer tensor of shape \((2, E_l)\).
ms_ids (Iterable[torch.Tensor], optional) – Multi-scale node id tensors per level; typically shape \((N_l,)\).
- Outputs:
torch.Tensor – Output node features of shape \((N_{nodes}, D_{out})\).
Examples
>>> import torch >>> from torch_geometric.data import Data >>> from physicsnemo.models.meshgraphnet.bsms_mgn import BiStrideMeshGraphNet >>> >>> # Choose a device and create the model on it >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') >>> >>> # Create a simple graph >>> num_nodes = 8 >>> src = torch.arange(num_nodes, device=device) >>> dst = (src + 1) % num_nodes >>> edge_index = torch.stack([src, dst], dim=0) # (2, E) >>> graph = Data(edge_index=edge_index, num_nodes=num_nodes).to(device) >>> graph.pos = torch.randn(num_nodes, 3, device=device) # position needed by bi-stride >>> >>> # Features >>> node_features = torch.randn(num_nodes, 10, device=device) >>> edge_features = torch.randn(edge_index.shape[1], 4, device=device) >>> >>> # Multi-scale inputs (one level for simplicity) >>> ms_edges = [edge_index, edge_index] # list of (2, E_l) tensors >>> ms_ids = [torch.arange(num_nodes, device=device), torch.arange(num_nodes, device=device)] # list of (N_l,) tensors >>> >>> # Model >>> model = BiStrideMeshGraphNet( ... input_dim_nodes=10, ... input_dim_edges=4, ... output_dim=4, ... processor_size=2, ... hidden_dim_processor=32, ... hidden_dim_node_encoder=16, ... hidden_dim_edge_encoder=16, ... num_layers_bistride=1, ... num_mesh_levels=1, ... ).to(device) >>> >>> out = model(node_features, edge_features, graph, ms_edges, ms_ids) >>> out.size() torch.Size([8, 4])
Note
Reference: Efficient Learning of Mesh-Based Physical Simulation with Bi-Stride Multi-Scale Graph Neural Network <https://arxiv.org/pdf/2210.02573>.