Source code for physicsnemo.models.mesh_reduced.mesh_reduced

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Literal, Tuple

import torch
from torch import Tensor

from physicsnemo.core.version_check import OptionalImport, require_version_spec
from physicsnemo.models.meshgraphnet.meshgraphnet import MeshGraphNet
from physicsnemo.nn.module.gnn_layers.graph_types import GraphType  # noqa

# Optional imports for GNN dependencies (lazy, cached, with helpful errors)
_torch_geometric = OptionalImport("torch_geometric")
_torch_cluster = OptionalImport("torch_cluster")
_torch_scatter = OptionalImport("torch_scatter")


[docs] class Mesh_Reduced(torch.nn.Module): r"""PbGMR-GMUS architecture. A mesh-reduced architecture that combines encoding and decoding processors for physics prediction in reduced mesh space. Parameters ---------- input_dim_nodes : int Number of node features. input_dim_edges : int Number of edge features. output_decode_dim : int Number of decoding outputs (per node). output_encode_dim : int, optional, default=3 Number of encoding outputs (per pivotal position). processor_size : int, optional, default=15 Number of message passing blocks. num_layers_node_processor : int, optional, default=2 Number of MLP layers for processing nodes in each message passing block. num_layers_edge_processor : int, optional, default=2 Number of MLP layers for processing edge features in each message passing block. hidden_dim_processor : int, optional, default=128 Hidden layer size for the message passing blocks. hidden_dim_node_encoder : int, optional, default=128 Hidden layer size for the node feature encoder. num_layers_node_encoder : int, optional, default=2 Number of MLP layers for the node feature encoder. hidden_dim_edge_encoder : int, optional, default=128 Hidden layer size for the edge feature encoder. num_layers_edge_encoder : int, optional, default=2 Number of MLP layers for the edge feature encoder. hidden_dim_node_decoder : int, optional, default=128 Hidden layer size for the node feature decoder. num_layers_node_decoder : int, optional, default=2 Number of MLP layers for the node feature decoder. k : int, optional, default=3 Number of nearest neighbors for interpolation. aggregation : Literal["sum", "mean"], optional, default="mean" Message aggregation type. Allowed values are ``"sum"`` and ``"mean"``. Forward ------- node_features : torch.Tensor Input node features of shape :math:`(N_{nodes}^{batch}, D_{in}^{node})`. edge_features : torch.Tensor Input edge features of shape :math:`(N_{edges}^{batch}, D_{in}^{edge})`. graph : :class:`~physicsnemo.nn.gnn_layers.utils.GraphType` Graph connectivity/topology container (PyG). Connectivity/topology only. Do not duplicate node or edge features on the graph; pass them via ``node_features`` and ``edge_features``. If present on the graph, they will be ignored by the model. ``node_features.shape[0]`` must equal the number of nodes in the graph ``graph.num_nodes``. ``edge_features.shape[0]`` must equal the number of edges in the graph ``graph.num_edges``. The current :class:`~physicsnemo.nn.gnn_layers.graph_types.GraphType` resolves to PyTorch Geometric objects (``torch_geometric.data.Data`` or ``torch_geometric.data.HeteroData``). See :mod:`physicsnemo.nn.gnn_layers.graph_types` for the exact alias and requirements. position_mesh : torch.Tensor Per-graph reference mesh positions of shape :math:`(N_{mesh}, D_{pos})`. These positions are repeated internally across the batch. position_pivotal : torch.Tensor Per-graph pivotal positions of shape :math:`(N_{pivotal}, D_{pos})`. These positions are repeated internally across the batch. Returns ------- torch.Tensor Decoded node features of shape :math:`(N_{nodes}^{batch}, D_{out}^{decode})`. Examples -------- >>> import torch >>> from torch_geometric.data import Data >>> from physicsnemo.models.mesh_reduced.mesh_reduced import Mesh_Reduced >>> >>> # Choose a consistent device >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') >>> >>> # Instantiate model >>> model = Mesh_Reduced( ... input_dim_nodes=4, ... input_dim_edges=3, ... output_decode_dim=2, ... ).to(device) >>> >>> # Build a simple PyG graph >>> # Note: num_nodes must match len(position_mesh) for batch alignment >>> num_mesh = 20 >>> num_nodes, num_edges = num_mesh, 30 >>> edge_index = torch.randint(0, num_nodes, (2, num_edges)) >>> graph = Data(edge_index=edge_index, num_nodes=num_nodes).to(device) >>> # For a single graph, set a batch vector of zeros >>> graph.batch = torch.zeros(num_nodes, dtype=torch.long, device=device) >>> >>> # Node/edge features >>> node_features = torch.randn(num_nodes, 4, device=device) >>> edge_features = torch.randn(num_edges, 3, device=device) >>> >>> # Per-graph positions (repeated internally across the batch) >>> position_mesh = torch.randn(num_mesh, 3, device=device) # (N_mesh, D_pos) >>> position_pivotal = torch.randn(5, 3, device=device) # (N_pivotal, D_pos) >>> >>> # Encode to pivotal space, then decode back to mesh space >>> enc = model.encode(node_features, edge_features, graph, position_mesh, position_pivotal) >>> out = model.decode(enc, edge_features, graph, position_mesh, position_pivotal) >>> out.size() torch.Size([20, 2]) Notes ----- Reference: `Predicting physics in mesh-reduced space with temporal attention <https://arxiv.org/pdf/2201.09113>`. """ def __init__( self, input_dim_nodes: int, input_dim_edges: int, output_decode_dim: int, output_encode_dim: int = 3, processor_size: int = 15, num_layers_node_processor: int = 2, num_layers_edge_processor: int = 2, hidden_dim_processor: int = 128, hidden_dim_node_encoder: int = 128, num_layers_node_encoder: int = 2, hidden_dim_edge_encoder: int = 128, num_layers_edge_encoder: int = 2, hidden_dim_node_decoder: int = 128, num_layers_node_decoder: int = 2, k: int = 3, aggregation: Literal["sum", "mean"] = "mean", ): super().__init__() self.knn_encoder_already = False self.knn_decoder_already = False self.encoder_processor = MeshGraphNet( input_dim_nodes, input_dim_edges, output_encode_dim, processor_size, "relu", num_layers_node_processor, num_layers_edge_processor, hidden_dim_processor, hidden_dim_node_encoder, num_layers_node_encoder, hidden_dim_edge_encoder, num_layers_edge_encoder, hidden_dim_node_decoder, num_layers_node_decoder, aggregation, ) self.decoder_processor = MeshGraphNet( output_encode_dim, input_dim_edges, output_decode_dim, processor_size, "relu", num_layers_node_processor, num_layers_edge_processor, hidden_dim_processor, hidden_dim_node_encoder, num_layers_node_encoder, hidden_dim_edge_encoder, num_layers_edge_encoder, hidden_dim_node_decoder, num_layers_node_decoder, aggregation, ) self.k = k self.PivotalNorm = torch.nn.LayerNorm(output_encode_dim) # Public constructor attributes for validation/serialization self.input_dim_nodes = input_dim_nodes self.input_dim_edges = input_dim_edges self.output_encode_dim = output_encode_dim self.output_decode_dim = output_decode_dim
[docs] @require_version_spec("torch_cluster") @require_version_spec("torch_scatter") def knn_interpolate( self, x: Tensor, pos_x: Tensor, pos_y: Tensor, batch_x: Tensor | None = None, batch_y: Tensor | None = None, k: int = 3, num_workers: int = 1, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: r"""Perform k-nearest neighbor interpolation from ``pos_x`` to ``pos_y``. Parameters ---------- x : torch.Tensor Source features of shape :math:`(N_x, D_x)`. pos_x : torch.Tensor Source positions of shape :math:`(N_x, D_{pos})`. pos_y : torch.Tensor Target positions of shape :math:`(N_y, D_{pos})`. batch_x : torch.Tensor, optional Batch indices for ``pos_x`` of shape :math:`(N_x,)`. If provided, neighbors are computed per-graph. Default is ``None``. batch_y : torch.Tensor, optional Batch indices for ``pos_y`` of shape :math:`(N_y,)`. If provided, neighbors are computed per-graph. Default is ``None``. k : int, optional, default=3 Number of nearest neighbors. num_workers : int, optional, default=1 Number of workers for the KNN search. Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] A tuple ``(y, col, row, weights)`` where: - ``y``: interpolated features of shape :math:`(N_y, D_x)` - ``col``: indices into ``pos_x`` (source) of shape :math:`(k \cdot N_y,)` - ``row``: indices into ``pos_y`` (target) of shape :math:`(k \cdot N_y,)` - ``weights``: interpolation weights of shape :math:`(k \cdot N_y, 1)` """ with torch.no_grad(): row, col = _torch_cluster.knn( pos_x, pos_y, k, batch_x=batch_x, batch_y=batch_y, num_workers=num_workers, ) # row: indices in pos_y, col: indices in pos_x diff = pos_x[col] - pos_y[row] squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) y = _torch_scatter.scatter( x[col] * weights, row, dim=0, dim_size=pos_y.size(0), reduce="sum" ) y = y / _torch_scatter.scatter( weights, row, dim=0, dim_size=pos_y.size(0), reduce="sum" ) return y.float(), col, row, weights
[docs] @require_version_spec("torch_geometric") def encode( self, x: Tensor, edge_features: Tensor, graph: GraphType, position_mesh: Tensor, position_pivotal: Tensor, ) -> Tensor: r"""Encode mesh features to pivotal space. Parameters ---------- x : torch.Tensor Input node features of shape :math:`(N_{nodes}^{batch}, D_{in}^{node})`. edge_features : torch.Tensor Edge features of shape :math:`(N_{edges}^{batch}, D_{in}^{edge})`. graph : :class:`~physicsnemo.nn.gnn_layers.utils.GraphType` PyG graph container with batch information. position_mesh : torch.Tensor Per-graph reference mesh positions of shape :math:`(N_{mesh}, D_{pos})`. position_pivotal : torch.Tensor Per-graph pivotal positions of shape :math:`(N_{pivotal}, D_{pos})`. Returns ------- torch.Tensor Encoded pivotal features of shape :math:`(N_{pivotal}^{batch}, D_{enc})`. """ if not torch.compiler.is_compiling(): if x.ndim != 2: raise ValueError( f"Expected 2D node features (N_nodes, D_in) but got shape {tuple(x.shape)}" ) if edge_features.ndim != 2: raise ValueError( f"Expected 2D edge features (N_edges, D_in) but got shape {tuple(edge_features.shape)}" ) if position_mesh.ndim != 2 or position_pivotal.ndim != 2: raise ValueError( f"Expected position tensors to be 2D, got {tuple(position_mesh.shape)} and {tuple(position_pivotal.shape)}" ) x = self.encoder_processor(x, edge_features, graph) x = self.PivotalNorm(x) if isinstance(graph, _torch_geometric.data.Data): batch_mesh = graph.batch batch_size = ( int(batch_mesh.max().item()) + 1 if batch_mesh.numel() > 0 else 1 ) else: raise ValueError(f"Unsupported graph type: {type(graph)}") nodes_index = torch.arange(batch_size, device=x.device) position_mesh_batch = position_mesh.repeat(batch_size, 1) position_pivotal_batch = position_pivotal.repeat(batch_size, 1) batch_pivotal = nodes_index.repeat_interleave( torch.tensor([len(position_pivotal)] * batch_size, device=x.device) ) x, _, _, _ = self.knn_interpolate( x=x, pos_x=position_mesh_batch, pos_y=position_pivotal_batch, batch_x=batch_mesh, batch_y=batch_pivotal, k=self.k, ) return x
[docs] @require_version_spec("torch_geometric") def decode( self, x: Tensor, edge_features: Tensor, graph: GraphType, position_mesh: Tensor, position_pivotal: Tensor, ) -> Tensor: r"""Decode pivotal features back to mesh space. Parameters ---------- x : torch.Tensor Input features in pivotal space of shape :math:`(N_{pivotal}^{batch}, D_{enc})`. edge_features : torch.Tensor Edge features of shape :math:`(N_{edges}^{batch}, D_{in}^{edge})`. graph : :class:`~physicsnemo.nn.gnn_layers.utils.GraphType` Graph connectivity/topology container (PyG). Connectivity/topology only. Do not duplicate node or edge features on the graph; pass them via ``node_features`` and ``edge_features``. If present on the graph, they will be ignored by the model. ``node_features.shape[0]`` must equal the number of nodes in the graph ``graph.num_nodes``. ``edge_features.shape[0]`` must equal the number of edges in the graph ``graph.num_edges``. The current :class:`~physicsnemo.nn.gnn_layers.graph_types.GraphType` resolves to PyTorch Geometric objects (``torch_geometric.data.Data`` or ``torch_geometric.data.HeteroData``). See :mod:`physicsnemo.nn.gnn_layers.graph_types` for the exact alias and requirements. position_mesh : torch.Tensor Per-graph mesh positions of shape :math:`(N_{mesh}, D_{pos})`. position_pivotal : torch.Tensor Per-graph pivotal positions of shape :math:`(N_{pivotal}, D_{pos})`. Returns ------- torch.Tensor Decoded features in mesh space of shape :math:`(N_{nodes}^{batch}, D_{out}^{decode})`. """ if not torch.compiler.is_compiling(): if ( edge_features.ndim != 2 or edge_features.shape[1] != self.input_dim_edges ): raise ValueError( f"Expected tensor of shape (N_edges, {self.input_dim_edges}) but got tensor of shape {tuple(edge_features.shape)}" ) if position_mesh.ndim != 2 or position_pivotal.ndim != 2: raise ValueError( f"Expected position tensors to be 2D, got shapes {tuple(position_mesh.shape)} and {tuple(position_pivotal.shape)}" ) if isinstance(graph, _torch_geometric.data.Data): batch_mesh = graph.batch batch_size = ( int(batch_mesh.max().item()) + 1 if batch_mesh.numel() > 0 else 1 ) else: raise ValueError(f"Unsupported graph type: {type(graph)}") nodes_index = torch.arange(batch_size, device=x.device) position_mesh_batch = position_mesh.repeat(batch_size, 1) position_pivotal_batch = position_pivotal.repeat(batch_size, 1) batch_pivotal = nodes_index.repeat_interleave( torch.tensor([len(position_pivotal)] * batch_size, device=x.device) ) x, _, _, _ = self.knn_interpolate( x=x, pos_x=position_pivotal_batch, pos_y=position_mesh_batch, batch_x=batch_pivotal, batch_y=batch_mesh, k=self.k, ) x = self.decoder_processor(x, edge_features, graph) return x