# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch import json import numpy as np from torch import Tensor from sklearn.neighbors import NearestNeighbors import logging from .graph_utils import ( cell_to_adj , create_graph , create_heterograph , add_edge_features , add_node_features , latlon2xyz , get_edge_len , ) logger = logging . getLogger ( __name__ ) [docs] class Graph : """Graph class for creating the graph2mesh, multimesh, and mesh2graph graphs. Parameters ---------- icospheres_path : str Path to the icospheres json file. If the file does not exist, it will try to generate it using PyMesh. lat_lon_grid : Tensor Tensor with shape (lat, lon, 2) that includes the latitudes and longitudes meshgrid. dtype : torch.dtype, optional Data type of the graph, by default torch.float """ def __init__ ( self , icospheres_path : str , lat_lon_grid : Tensor , dtype = torch . float ) -> None : self . dtype = dtype # Get or generate the icospheres try : with open ( icospheres_path , "r" ) as f : loaded_dict = json . load ( f ) icospheres = { key : ( np . array ( value ) if isinstance ( value , list ) else value ) for key , value in loaded_dict . items () } logger . info ( f "Opened pre-computed graph at { icospheres_path } ." ) except : from modulus.utils.graphcast.icospheres import ( generate_and_save_icospheres , ) # requires PyMesh logger . info ( f "Could not open { icospheres_path } ...generating mesh from scratch." ) generate_and_save_icospheres () self . icospheres = icospheres self . max_order = ( len ([ key for key in self . icospheres . keys () if "faces" in key ]) - 2 ) # flatten lat/lon gird self . lat_lon_grid_flat = lat_lon_grid . permute ( 2 , 0 , 1 ) . view ( 2 , - 1 ) . permute ( 1 , 0 ) [docs] def create_mesh_graph ( self , verbose : bool = True ) -> Tensor : """Create the multimesh graph. Parameters ---------- verbose : bool, optional verbosity, by default True Returns ------- DGLGraph Multimesh graph. """ # create the bi-directional mesh graph multimesh_faces = self . icospheres [ "order_0_faces" ] for i in range ( 1 , self . max_order + 1 ): multimesh_faces = np . concatenate ( ( multimesh_faces , self . icospheres [ "order_" + str ( i ) + "_faces" ]) ) src , dst = cell_to_adj ( multimesh_faces ) mesh_graph = create_graph ( src , dst , to_bidirected = True , add_self_loop = False , dtype = torch . int32 ) mesh_pos = torch . tensor ( self . icospheres [ "order_" + str ( self . max_order ) + "_vertices" ], dtype = torch . float32 , ) mesh_graph = add_edge_features ( mesh_graph , mesh_pos ) mesh_graph = add_node_features ( mesh_graph , mesh_pos ) # ensure fields set to dtype to avoid later conversions mesh_graph . ndata [ "x" ] = mesh_graph . ndata [ "x" ] . to ( dtype = self . dtype ) mesh_graph . edata [ "x" ] = mesh_graph . edata [ "x" ] . to ( dtype = self . dtype ) if verbose : print ( "mesh graph:" , mesh_graph ) return mesh_graph [docs] def create_g2m_graph ( self , verbose : bool = True ) -> Tensor : """Create the graph2mesh graph. Parameters ---------- verbose : bool, optional verbosity, by default True Returns ------- DGLGraph Graph2mesh graph. """ # get the max edge length of icosphere with max order edge_src = self . icospheres [ "order_" + str ( self . max_order ) + "_vertices" ][ self . icospheres [ "order_" + str ( self . max_order ) + "_faces" ][:, 0 ] ] edge_dst = self . icospheres [ "order_" + str ( self . max_order ) + "_vertices" ][ self . icospheres [ "order_" + str ( self . max_order ) + "_faces" ][:, 1 ] ] edge_len_1 = np . max ( get_edge_len ( edge_src , edge_dst )) edge_src = self . icospheres [ "order_" + str ( self . max_order ) + "_vertices" ][ self . icospheres [ "order_" + str ( self . max_order ) + "_faces" ][:, 0 ] ] edge_dst = self . icospheres [ "order_" + str ( self . max_order ) + "_vertices" ][ self . icospheres [ "order_" + str ( self . max_order ) + "_faces" ][:, 2 ] ] edge_len_2 = np . max ( get_edge_len ( edge_src , edge_dst )) edge_src = self . icospheres [ "order_" + str ( self . max_order ) + "_vertices" ][ self . icospheres [ "order_" + str ( self . max_order ) + "_faces" ][:, 1 ] ] edge_dst = self . icospheres [ "order_" + str ( self . max_order ) + "_vertices" ][ self . icospheres [ "order_" + str ( self . max_order ) + "_faces" ][:, 2 ] ] edge_len_3 = np . max ( get_edge_len ( edge_src , edge_dst )) edge_len = max ([ edge_len_1 , edge_len_2 , edge_len_3 ]) # create the grid2mesh bipartite graph cartesian_grid = latlon2xyz ( self . lat_lon_grid_flat ) n_nbrs = 4 neighbors = NearestNeighbors ( n_neighbors = n_nbrs ) . fit ( self . icospheres [ "order_" + str ( self . max_order ) + "_vertices" ] ) distances , indices = neighbors . kneighbors ( cartesian_grid ) src , dst = [], [] for i in range ( len ( cartesian_grid )): for j in range ( n_nbrs ): if distances [ i ][ j ] <= 0.6 * edge_len : src . append ( i ) dst . append ( indices [ i ][ j ]) # NOTE this gives 1,624,344 edges, in the paper it is 1,618,746 # this number is very sensitive to the chosen edge_len, not clear # in the paper what they use. g2m_graph = create_heterograph ( src , dst , ( "grid" , "g2m" , "mesh" ), dtype = torch . int32 ) # number of edges is 3,114,720, exactly matches with the paper g2m_graph . srcdata [ "pos" ] = cartesian_grid . to ( torch . float32 ) g2m_graph . dstdata [ "pos" ] = torch . tensor ( self . icospheres [ "order_" + str ( self . max_order ) + "_vertices" ], dtype = torch . float32 , ) g2m_graph = add_edge_features ( g2m_graph , ( g2m_graph . srcdata [ "pos" ], g2m_graph . dstdata [ "pos" ]) ) # avoid potential conversions at later points g2m_graph . srcdata [ "pos" ] = g2m_graph . srcdata [ "pos" ] . to ( dtype = self . dtype ) g2m_graph . dstdata [ "pos" ] = g2m_graph . dstdata [ "pos" ] . to ( dtype = self . dtype ) g2m_graph . ndata [ "pos" ][ "grid" ] = g2m_graph . ndata [ "pos" ][ "grid" ] . to ( dtype = self . dtype ) g2m_graph . ndata [ "pos" ][ "mesh" ] = g2m_graph . ndata [ "pos" ][ "mesh" ] . to ( dtype = self . dtype ) g2m_graph . edata [ "x" ] = g2m_graph . edata [ "x" ] . to ( dtype = self . dtype ) if verbose : print ( "g2m graph:" , g2m_graph ) return g2m_graph [docs] def create_m2g_graph ( self , verbose : bool = True ) -> Tensor : """Create the mesh2grid graph. Parameters ---------- verbose : bool, optional verbosity, by default True Returns ------- DGLGraph Mesh2grid graph. """ # create the mesh2grid bipartite graph cartesian_grid = latlon2xyz ( self . lat_lon_grid_flat ) n_nbrs = 1 neighbors = NearestNeighbors ( n_neighbors = n_nbrs ) . fit ( self . icospheres [ "order_" + str ( self . max_order ) + "_face_centroid" ] ) _ , indices = neighbors . kneighbors ( cartesian_grid ) indices = indices . flatten () src = [ p for i in indices for p in self . icospheres [ "order_" + str ( self . max_order ) + "_faces" ][ i ] ] dst = [ i for i in range ( len ( cartesian_grid )) for _ in range ( 3 )] m2g_graph = create_heterograph ( src , dst , ( "mesh" , "m2g" , "grid" ), dtype = torch . int32 ) # number of edges is 3,114,720, exactly matches with the paper m2g_graph . srcdata [ "pos" ] = torch . tensor ( self . icospheres [ "order_" + str ( self . max_order ) + "_vertices" ], dtype = torch . float32 , ) m2g_graph . dstdata [ "pos" ] = cartesian_grid . to ( dtype = torch . float32 ) m2g_graph = add_edge_features ( m2g_graph , ( m2g_graph . srcdata [ "pos" ], m2g_graph . dstdata [ "pos" ]) ) # avoid potential conversions at later points m2g_graph . srcdata [ "pos" ] = m2g_graph . srcdata [ "pos" ] . to ( dtype = self . dtype ) m2g_graph . dstdata [ "pos" ] = m2g_graph . dstdata [ "pos" ] . to ( dtype = self . dtype ) m2g_graph . ndata [ "pos" ][ "grid" ] = m2g_graph . ndata [ "pos" ][ "grid" ] . to ( dtype = self . dtype ) m2g_graph . ndata [ "pos" ][ "mesh" ] = m2g_graph . ndata [ "pos" ][ "mesh" ] . to ( dtype = self . dtype ) m2g_graph . edata [ "x" ] = m2g_graph . edata [ "x" ] . to ( dtype = self . dtype ) if verbose : print ( "m2g graph:" , m2g_graph ) return m2g_graph