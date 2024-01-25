# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor
import torch.nn as nn
import dgl
try:
from dgl import DGLGraph
except:
raise ImportError(
"Mesh Graph Net requires the DGL library. Install the "
+ "desired CUDA version at: \n https://www.dgl.ai/pages/start.html"
)
from typing import Callable, Tuple, List, Union
from dataclasses import dataclass
import modulus
from modulus.models.meta import ModelMetaData
from modulus.models.module import Module
from modulus.models.gnn_layers.utils import set_checkpoint_fn, CuGraphCSC
from modulus.models.gnn_layers.mesh_graph_mlp import MeshGraphMLP
from modulus.models.gnn_layers.mesh_edge_block import MeshEdgeBlock
from modulus.models.gnn_layers.mesh_node_block import MeshNodeBlock
[docs]class MeshGraphNet(Module):
"""MeshGraphNet network architecture
Parameters
----------
input_dim_nodes : int
Number of node features
input_dim_edges : int
Number of edge features
output_dim : int
Number of outputs
processor_size : int, optional
Number of message passing blocks, by default 15
num_layers_node_processor : int, optional
Number of MLP layers for processing nodes in each message passing block, by default 2
num_layers_edge_processor : int, optional
Number of MLP layers for processing edge features in each message passing block, by default 2
hidden_dim_processor : int, optional
Hidden layer size for the message passing blocks, by default 128
hidden_dim_node_encoder : int, optional
Hidden layer size for the node feature encoder, by default 128
num_layers_node_encoder : int, optional
Number of MLP layers for the node feature encoder, by default 2
hidden_dim_edge_encoder : int, optional
Hidden layer size for the edge feature encoder, by default 128
num_layers_edge_encoder : int, optional
Number of MLP layers for the edge feature encoder, by default 2
hidden_dim_node_decoder : int, optional
Hidden layer size for the node feature decoder, by default 128
num_layers_node_decoder : int, optional
Number of MLP layers for the node feature decoder, by default 2
aggregation: str, optional
Message aggregation type, by default "sum"
do_conat_trick: : bool, default=False
Whether to replace concat+MLP with MLP+idx+sum
num_processor_checkpoint_segments: int, optional
Number of processor segments for gradient checkpointing, by default 0 (checkpointing disabled)
Example
-------
>>> model = modulus.models.meshgraphnet.MeshGraphNet(
... input_dim_nodes=4,
... input_dim_edges=3,
... output_dim=2,
... )
>>> graph = dgl.rand_graph(10, 5)
>>> node_features = torch.randn(10, 4)
>>> edge_features = torch.randn(5, 3)
>>> output = model(node_features, edge_features, graph)
>>> output.size()
torch.Size([10, 2])
Note
----
Reference: Pfaff, Tobias, et al. "Learning mesh-based simulation with graph networks."
arXiv preprint arXiv:2010.03409 (2020).
"""
def __init__(
self,
input_dim_nodes: int,
input_dim_edges: int,
output_dim: int,
processor_size: int = 15,
num_layers_node_processor: int = 2,
num_layers_edge_processor: int = 2,
hidden_dim_processor: int = 128,
hidden_dim_node_encoder: int = 128,
num_layers_node_encoder: int = 2,
hidden_dim_edge_encoder: int = 128,
num_layers_edge_encoder: int = 2,
hidden_dim_node_decoder: int = 128,
num_layers_node_decoder: int = 2,
aggregation: str = "sum",
do_concat_trick: bool = False,
num_processor_checkpoint_segments: int = 0,
):
super().__init__(meta=MetaData())
self.edge_encoder = MeshGraphMLP(
input_dim_edges,
output_dim=hidden_dim_processor,
hidden_dim=hidden_dim_edge_encoder,
hidden_layers=num_layers_edge_encoder,
activation_fn=nn.ReLU(),
norm_type="LayerNorm",
recompute_activation=False,
)
self.node_encoder = MeshGraphMLP(
input_dim_nodes,
output_dim=hidden_dim_processor,
hidden_dim=hidden_dim_node_encoder,
hidden_layers=num_layers_node_encoder,
activation_fn=nn.ReLU(),
norm_type="LayerNorm",
recompute_activation=False,
)
self.node_decoder = MeshGraphMLP(
hidden_dim_processor,
output_dim=output_dim,
hidden_dim=hidden_dim_node_decoder,
hidden_layers=num_layers_node_decoder,
activation_fn=nn.ReLU(),
norm_type=None,
recompute_activation=False,
)
self.processor = MeshGraphNetProcessor(
processor_size=processor_size,
input_dim_node=hidden_dim_processor,
input_dim_edge=hidden_dim_processor,
num_layers_node=num_layers_node_processor,
num_layers_edge=num_layers_edge_processor,
aggregation=aggregation,
norm_type="LayerNorm",
activation_fn=nn.ReLU(),
do_concat_trick=do_concat_trick,
num_processor_checkpoint_segments=num_processor_checkpoint_segments,
)
[docs] def forward(
self,
node_features: Tensor,
edge_features: Tensor,
graph: Union[DGLGraph, List[DGLGraph]],
) -> Tensor:
edge_features = self.edge_encoder(edge_features)
node_features = self.node_encoder(node_features)
x = self.processor(node_features, edge_features, graph)
x = self.node_decoder(x)
return x [docs]class MeshGraphNetProcessor(nn.Module):
"""MeshGraphNet processor block"""
def __init__(
self,
processor_size: int = 15,
input_dim_node: int = 128,
input_dim_edge: int = 128,
num_layers_node: int = 2,
num_layers_edge: int = 2,
aggregation: str = "sum",
norm_type: str = "LayerNorm",
activation_fn: nn.Module = nn.ReLU(),
do_concat_trick: bool = False,
num_processor_checkpoint_segments: int = 0,
):
super().__init__()
self.processor_size = processor_size
self.num_processor_checkpoint_segments = num_processor_checkpoint_segments
edge_block_invars = (
input_dim_node,
input_dim_edge,
input_dim_edge,
input_dim_edge,
num_layers_edge,
activation_fn,
norm_type,
do_concat_trick,
False,
)
node_block_invars = (
aggregation,
input_dim_node,
input_dim_edge,
input_dim_edge,
input_dim_edge,
num_layers_node,
activation_fn,
norm_type,
False,
)
edge_blocks = []
node_blocks = []
layers = []
for _ in range(self.processor_size):
edge_blocks.append(MeshEdgeBlock(*edge_block_invars))
for _ in range(self.processor_size):
node_blocks.append(MeshNodeBlock(*node_block_invars))
for i in range(self.processor_size):
layers.append(edge_blocks[i])
layers.append(node_blocks[i])
self.processor_layers = nn.ModuleList(layers)
self.num_processor_layers = len(self.processor_layers)
self.set_checkpoint_segments(self.num_processor_checkpoint_segments)
[docs] def set_checkpoint_segments(self, checkpoint_segments: int):
"""
Set the number of checkpoint segments
Parameters
----------
checkpoint_segments : int
number of checkpoint segments
Raises
------
ValueError
if the number of processor layers is not a multiple of the number of
checkpoint segments
"""
if checkpoint_segments > 0:
if self.num_processor_layers % checkpoint_segments != 0:
raise ValueError(
"Processor layers must be a multiple of checkpoint_segments"
)
segment_size = self.num_processor_layers // checkpoint_segments
self.checkpoint_segments = []
for i in range(0, self.num_processor_layers, segment_size):
self.checkpoint_segments.append((i, i + segment_size))
self.checkpoint_fn = set_checkpoint_fn(True)
else:
self.checkpoint_fn = set_checkpoint_fn(False)
self.checkpoint_segments = [(0, self.num_processor_layers)]
[docs] def run_function(
self, segment_start: int, segment_end: int
) -> Callable[
[Tensor, Tensor, Union[DGLGraph, List[DGLGraph]]], Tuple[Tensor, Tensor]
]:
"""Custom forward for gradient checkpointing
Parameters
----------
segment_start : int
Layer index as start of the segment
segment_end : int
Layer index as end of the segment
Returns
-------
Callable
Custom forward function
"""
segment = self.processor_layers[segment_start:segment_end]
def custom_forward(
node_features: Tensor,
edge_features: Tensor,
graph: Union[DGLGraph, List[DGLGraph]],
) -> Tuple[Tensor, Tensor]:
"""Custom forward function"""
for module in segment:
edge_features, node_features = module(
edge_features, node_features, graph
)
return edge_features, node_features
return custom_forward
[docs] @torch.jit.unused
def forward(
self,
node_features: Tensor,
edge_features: Tensor,
graph: Union[DGLGraph, List[DGLGraph], CuGraphCSC],
) -> Tensor:
for segment_start, segment_end in self.checkpoint_segments:
edge_features, node_features = self.checkpoint_fn(
self.run_function(segment_start, segment_end),
node_features,
edge_features,
graph,
use_reentrant=False,
preserve_rng_state=False,
)
return node_features