NVIDIA Modulus Core (Latest Release)
Core (Latest Release)

deeplearning/modulus/modulus-core/_modules/modulus/utils/graphcast/graph.html

Source code for modulus.utils.graphcast.graph

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from torch import Tensor

from .graph_utils import (
    add_edge_features,
    add_node_features,
    create_graph,
    create_heterograph,
    get_face_centroids,
    latlon2xyz,
    max_edge_length,
    xyz2latlon,
)
from .icosahedral_mesh import (
    faces_to_edges,
    get_hierarchy_of_triangular_meshes_for_sphere,
    merge_meshes,
)

logger = logging.getLogger(__name__)


[docs]class Graph: """Graph class for creating the graph2mesh, latent mesh, and mesh2graph graphs. Parameters ---------- lat_lon_grid : Tensor Tensor with shape (lat, lon, 2) that includes the latitudes and longitudes meshgrid. mesh_level: int, optional Level of the latent mesh, by default 6 multimesh: bool, optional If the latent mesh is a multimesh, by default True If True, the latent mesh includes the nodes corresponding to the specified `mesh_level`and incorporates the edges from all mesh levels ranging from level 0 up to and including `mesh_level`. khop_neighbors: int, optional This option is used to retrieve a list of indices for the k-hop neighbors of all mesh nodes. It is applicable when a graph transformer is used as the processor. If set to 0, this list is not computed. If a message passing processor is used, it is forced to 0. By default 0. dtype : torch.dtype, optional Data type of the graph, by default torch.float """ def __init__( self, lat_lon_grid: Tensor, mesh_level: int = 6, multimesh: bool = True, khop_neighbors: int = 0, dtype=torch.float, ) -> None: self.khop_neighbors = khop_neighbors self.dtype = dtype # flatten lat/lon gird self.lat_lon_grid_flat = lat_lon_grid.permute(2, 0, 1).view(2, -1).permute(1, 0) # create the multi-mesh _meshes = get_hierarchy_of_triangular_meshes_for_sphere(splits=mesh_level) finest_mesh = _meshes[-1] # get the last one in the list of meshes self.finest_mesh_src, self.finest_mesh_dst = faces_to_edges(finest_mesh.faces) self.finest_mesh_vertices = np.array(finest_mesh.vertices) if multimesh: mesh = merge_meshes(_meshes) self.mesh_src, self.mesh_dst = faces_to_edges(mesh.faces) self.mesh_vertices = np.array(mesh.vertices) else: mesh = finest_mesh self.mesh_src, self.mesh_dst = self.finest_mesh_src, self.finest_mesh_dst self.mesh_vertices = self.finest_mesh_vertices self.mesh_faces = mesh.faces @staticmethod def khop_adj_all_k(g, kmax): if not g.is_homogeneous: raise NotImplementedError("only homogeneous graph is supported") min_degree = g.in_degrees().min() with torch.no_grad(): adj = g.adj_external(transpose=True, scipy_fmt=None) adj_k = adj adj_all = adj.clone() for _ in range(2, kmax + 1): # scale with min-degree to avoid too large values # but >= 1.0 adj_k = (adj @ adj_k) / min_degree adj_all += adj_k return adj_all.to_dense().bool()
[docs] def create_mesh_graph(self, verbose: bool = True) -> Tensor: """Create the multimesh graph. Parameters ---------- verbose : bool, optional verbosity, by default True Returns ------- DGLGraph Multimesh graph """ mesh_graph = create_graph( self.mesh_src, self.mesh_dst, to_bidirected=True, add_self_loop=False, dtype=torch.int32, ) mesh_pos = torch.tensor( self.mesh_vertices, dtype=torch.float32, ) mesh_graph = add_edge_features(mesh_graph, mesh_pos) mesh_graph = add_node_features(mesh_graph, mesh_pos) mesh_graph.ndata["lat_lon"] = xyz2latlon(mesh_pos) # ensure fields set to dtype to avoid later conversions mesh_graph.ndata["x"] = mesh_graph.ndata["x"].to(dtype=self.dtype) mesh_graph.edata["x"] = mesh_graph.edata["x"].to(dtype=self.dtype) if self.khop_neighbors > 0: # Make a graph whose edges connect the k-hop neighbors of the original graph. khop_adj_bool = self.khop_adj_all_k(g=mesh_graph, kmax=self.khop_neighbors) mask = ~khop_adj_bool else: mask = None if verbose: print("mesh graph:", mesh_graph) return mesh_graph, mask
[docs] def create_g2m_graph(self, verbose: bool = True) -> Tensor: """Create the graph2mesh graph. Parameters ---------- verbose : bool, optional verbosity, by default True Returns ------- DGLGraph Graph2mesh graph. """ # get the max edge length of icosphere with max order max_edge_len = max_edge_length( self.finest_mesh_vertices, self.finest_mesh_src, self.finest_mesh_dst ) # create the grid2mesh bipartite graph cartesian_grid = latlon2xyz(self.lat_lon_grid_flat) n_nbrs = 4 neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(self.mesh_vertices) distances, indices = neighbors.kneighbors(cartesian_grid) src, dst = [], [] for i in range(len(cartesian_grid)): for j in range(n_nbrs): if distances[i][j] <= 0.6 * max_edge_len: src.append(i) dst.append(indices[i][j]) # NOTE this gives 1,618,820 edges, in the paper it is 1,618,746 g2m_graph = create_heterograph( src, dst, ("grid", "g2m", "mesh"), dtype=torch.int32 ) g2m_graph.srcdata["pos"] = cartesian_grid.to(torch.float32) g2m_graph.dstdata["pos"] = torch.tensor( self.mesh_vertices, dtype=torch.float32, ) g2m_graph.srcdata["lat_lon"] = self.lat_lon_grid_flat g2m_graph.dstdata["lat_lon"] = xyz2latlon(g2m_graph.dstdata["pos"]) g2m_graph = add_edge_features( g2m_graph, (g2m_graph.srcdata["pos"], g2m_graph.dstdata["pos"]) ) # avoid potential conversions at later points g2m_graph.srcdata["pos"] = g2m_graph.srcdata["pos"].to(dtype=self.dtype) g2m_graph.dstdata["pos"] = g2m_graph.dstdata["pos"].to(dtype=self.dtype) g2m_graph.ndata["pos"]["grid"] = g2m_graph.ndata["pos"]["grid"].to( dtype=self.dtype ) g2m_graph.ndata["pos"]["mesh"] = g2m_graph.ndata["pos"]["mesh"].to( dtype=self.dtype ) g2m_graph.edata["x"] = g2m_graph.edata["x"].to(dtype=self.dtype) if verbose: print("g2m graph:", g2m_graph) return g2m_graph
[docs] def create_m2g_graph(self, verbose: bool = True) -> Tensor: """Create the mesh2grid graph. Parameters ---------- verbose : bool, optional verbosity, by default True Returns ------- DGLGraph Mesh2grid graph. """ # create the mesh2grid bipartite graph cartesian_grid = latlon2xyz(self.lat_lon_grid_flat) face_centroids = get_face_centroids(self.mesh_vertices, self.mesh_faces) n_nbrs = 1 neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(face_centroids) _, indices = neighbors.kneighbors(cartesian_grid) indices = indices.flatten() src = [p for i in indices for p in self.mesh_faces[i]] dst = [i for i in range(len(cartesian_grid)) for _ in range(3)] m2g_graph = create_heterograph( src, dst, ("mesh", "m2g", "grid"), dtype=torch.int32 ) # number of edges is 3,114,720, exactly matches with the paper m2g_graph.srcdata["pos"] = torch.tensor( self.mesh_vertices, dtype=torch.float32, ) m2g_graph.dstdata["pos"] = cartesian_grid.to(dtype=torch.float32) m2g_graph.srcdata["lat_lon"] = xyz2latlon(m2g_graph.srcdata["pos"]) m2g_graph.dstdata["lat_lon"] = self.lat_lon_grid_flat m2g_graph = add_edge_features( m2g_graph, (m2g_graph.srcdata["pos"], m2g_graph.dstdata["pos"]) ) # avoid potential conversions at later points m2g_graph.srcdata["pos"] = m2g_graph.srcdata["pos"].to(dtype=self.dtype) m2g_graph.dstdata["pos"] = m2g_graph.dstdata["pos"].to(dtype=self.dtype) m2g_graph.ndata["pos"]["grid"] = m2g_graph.ndata["pos"]["grid"].to( dtype=self.dtype ) m2g_graph.ndata["pos"]["mesh"] = m2g_graph.ndata["pos"]["mesh"].to( dtype=self.dtype ) m2g_graph.edata["x"] = m2g_graph.edata["x"].to(dtype=self.dtype) if verbose: print("m2g graph:", m2g_graph) return m2g_graph
© Copyright 2023, NVIDIA Modulus Team. Last updated on Nov 27, 2024.