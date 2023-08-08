# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dgl from dgl import DGLGraph import torch from torch import Tensor , testing import numpy as np from torch.nn import functional as F from typing import List , Tuple , Union [docs] def create_graph ( src : List , dst : List , to_bidirected : bool = True , add_self_loop : bool = False , dtype : torch . dtype = torch . int32 , ) -> DGLGraph : """ Creates a DGL graph from an adj matrix in COO format. Parameters ---------- src : List List of source nodes dst : List List of destination nodes to_bidirected : bool, optional Whether to make the graph bidirectional, by default True add_self_loop : bool, optional Whether to add self loop to the graph, by default False dtype : torch.dtype, optional Graph index data type, by default torch.int32 Returns ------- DGLGraph The dgl Graph. """ graph = dgl . graph (( src , dst ), idtype = dtype ) if to_bidirected : graph = dgl . to_bidirected ( graph ) if add_self_loop : graph = dgl . add_self_loop ( graph ) return graph [docs] def create_heterograph ( src : List , dst : List , labels : str , dtype : torch . dtype = torch . int32 ) -> DGLGraph : """Creates a heterogeneous DGL graph from an adj matrix in COO format. Parameters ---------- src : List List of source nodes dst : List List of destination nodes labels : str Label of the edge type dtype : torch.dtype, optional Graph index data type, by default torch.int32 Returns ------- DGLGraph The dgl Graph. """ graph = dgl . heterograph ({ labels : ( "coo" , ( src , dst ))}, idtype = dtype ) return graph [docs] def add_edge_features ( graph : DGLGraph , pos : Tensor , normalize : bool = True ) -> DGLGraph : """Adds edge features to the graph. Parameters ---------- graph : DGLGraph The graph to add edge features to. pos : Tensor The node positions. normalize : bool, optional Whether to normalize the edge features, by default True Returns ------- DGLGraph The graph with edge features. """ if isinstance ( pos , tuple ): src_pos , dst_pos = pos else : src_pos = dst_pos = pos src , dst = graph . edges () src_pos , dst_pos = src_pos [ src . long ()], dst_pos [ dst . long ()] dst_latlon = xyz2latlon ( dst_pos , unit = "rad" ) dst_lat , dst_lon = dst_latlon [:, 0 ], dst_latlon [:, 1 ] # azimuthal & polar rotation theta_azimuthal = azimuthal_angle ( dst_lon ) theta_polar = polar_angle ( dst_lat ) src_pos = geospatial_rotation ( src_pos , theta = theta_azimuthal , axis = "z" , unit = "rad" ) dst_pos = geospatial_rotation ( dst_pos , theta = theta_azimuthal , axis = "z" , unit = "rad" ) # y values should be zero try : testing . assert_close ( dst_pos [:, 1 ], torch . zeros_like ( dst_pos [:, 1 ])) except : raise ValueError ( "Invalid projection of edge nodes to local ccordinate system" ) src_pos = geospatial_rotation ( src_pos , theta = theta_polar , axis = "y" , unit = "rad" ) dst_pos = geospatial_rotation ( dst_pos , theta = theta_polar , axis = "y" , unit = "rad" ) # x values should be one, y & z values should be zero try : testing . assert_close ( dst_pos [:, 0 ], torch . ones_like ( dst_pos [:, 0 ])) testing . assert_close ( dst_pos [:, 1 ], torch . zeros_like ( dst_pos [:, 1 ])) testing . assert_close ( dst_pos [:, 2 ], torch . zeros_like ( dst_pos [:, 2 ])) except : raise ValueError ( "Invalid projection of edge nodes to local ccordinate system" ) # prepare edge features disp = src_pos - dst_pos disp_norm = torch . linalg . norm ( disp , dim =- 1 , keepdim = True ) # normalize using the longest edge if normalize : max_disp_norm = torch . max ( disp_norm ) graph . edata [ "x" ] = torch . cat ( ( disp / max_disp_norm , disp_norm / max_disp_norm ), dim =- 1 ) else : graph . edata [ "x" ] = torch . cat (( disp , disp_norm ), dim =- 1 ) return graph [docs] def add_node_features ( graph : DGLGraph , pos : Tensor ) -> DGLGraph : """Adds cosine of latitude, sine and cosine of longitude as the node features to the graph. Parameters ---------- graph : DGLGraph The graph to add node features to. pos : Tensor The node positions. Returns ------- graph : DGLGraph The graph with node features. """ latlon = xyz2latlon ( pos ) lat , lon = latlon [:, 0 ], latlon [:, 1 ] graph . ndata [ "x" ] = torch . stack ( ( torch . cos ( lat ), torch . sin ( lon ), torch . cos ( lon )), dim =- 1 ) return graph [docs] def latlon2xyz ( latlon : Tensor , radius : float = 1 , unit : str = "deg" ) -> Tensor : """ Converts latlon in degrees to xyz Based on: https://stackoverflow.com/questions/1185408 - The x-axis goes through long,lat (0,0); - The y-axis goes through (0,90); - The z-axis goes through the poles. Parameters ---------- latlon : Tensor Tensor of shape (N, 2) containing latitudes and longitudes radius : float, optional Radius of the sphere, by default 1 unit : str, optional Unit of the latlon, by default "deg" Returns ------- Tensor Tensor of shape (N, 3) containing x, y, z coordinates """ if unit == "deg" : latlon = deg2rad ( latlon ) elif unit == "rad" : pass else : raise ValueError ( "Not a valid unit" ) lat , lon = latlon [:, 0 ], latlon [:, 1 ] x = radius * torch . cos ( lat ) * torch . cos ( lon ) y = radius * torch . cos ( lat ) * torch . sin ( lon ) z = radius * torch . sin ( lat ) return torch . stack (( x , y , z ), dim = 1 ) [docs] def xyz2latlon ( xyz : Tensor , radius : float = 1 , unit : str = "deg" ) -> Tensor : """ Converts xyz to latlon in degrees Based on: https://stackoverflow.com/questions/1185408 - The x-axis goes through long,lat (0,0); - The y-axis goes through (0,90); - The z-axis goes through the poles. Parameters ---------- xyz : Tensor Tensor of shape (N, 3) containing x, y, z coordinates radius : float, optional Radius of the sphere, by default 1 unit : str, optional Unit of the latlon, by default "deg" Returns ------- Tensor Tensor of shape (N, 2) containing latitudes and longitudes """ lat = torch . arcsin ( xyz [:, 2 ] / radius ) lon = torch . arctan2 ( xyz [:, 1 ], xyz [:, 0 ]) if unit == "deg" : return torch . stack (( rad2deg ( lat ), rad2deg ( lon )), dim = 1 ) elif unit == "rad" : return torch . stack (( lat , lon ), dim = 1 ) else : raise ValueError ( "Not a valid unit" ) [docs] def geospatial_rotation ( invar : Tensor , theta : Tensor , axis : str , unit : str = "rad" ) -> Tensor : """Rotation using right hand rule Parameters ---------- invar : Tensor Tensor of shape (N, 3) containing x, y, z coordinates theta : Tensor Tensor of shape (N, ) containing the rotation angle axis : str Axis of rotation unit : str, optional Unit of the theta, by default "rad" Returns ------- Tensor Tensor of shape (N, 3) containing the rotated x, y, z coordinates """ # get the right unit if unit == "deg" : invar = rad2deg ( invar ) elif unit == "rad" : pass else : raise ValueError ( "Not a valid unit" ) invar = torch . unsqueeze ( invar , - 1 ) rotation = torch . zeros (( theta . size ( 0 ), 3 , 3 )) cos = torch . cos ( theta ) sin = torch . sin ( theta ) if axis == "x" : rotation [:, 0 , 0 ] += 1.0 rotation [:, 1 , 1 ] += cos rotation [:, 1 , 2 ] -= sin rotation [:, 2 , 1 ] += sin rotation [:, 2 , 2 ] += cos elif axis == "y" : rotation [:, 0 , 0 ] += cos rotation [:, 0 , 2 ] += sin rotation [:, 1 , 1 ] += 1.0 rotation [:, 2 , 0 ] -= sin rotation [:, 2 , 2 ] += cos elif axis == "z" : rotation [:, 0 , 0 ] += cos rotation [:, 0 , 1 ] -= sin rotation [:, 1 , 0 ] += sin rotation [:, 1 , 1 ] += cos rotation [:, 2 , 2 ] += 1.0 else : raise ValueError ( "Invalid axis" ) outvar = torch . matmul ( rotation , invar ) outvar = outvar . squeeze () return outvar [docs] def azimuthal_angle ( lon : Tensor ) -> Tensor : """ Gives the azimuthal angle of a point on the sphere Parameters ---------- lon : Tensor Tensor of shape (N, ) containing the longitude of the point Returns ------- Tensor Tensor of shape (N, ) containing the azimuthal angle """ angle = torch . where ( lon >= 0.0 , 2 * np . pi - lon , - lon ) return angle [docs] def polar_angle ( lat : Tensor ) -> Tensor : """ Gives the polar angle of a point on the sphere Parameters ---------- lat : Tensor Tensor of shape (N, ) containing the latitude of the point Returns ------- Tensor Tensor of shape (N, ) containing the polar angle """ angle = torch . where ( lat >= 0.0 , lat , 2 * np . pi + lat ) return angle [docs] def deg2rad ( deg : Tensor ) -> Tensor : """Converts degrees to radians Parameters ---------- deg : Tensor of shape (N, ) containing the degrees Returns ------- Tensor Tensor of shape (N, ) containing the radians """ return deg * np . pi / 180 [docs] def rad2deg ( rad ): """Converts radians to degrees Parameters ---------- rad : Tensor of shape (N, ) containing the radians Returns ------- Tensor Tensor of shape (N, ) containing the degrees """ return rad * 180 / np . pi [docs] def get_edge_len ( edge_src : Tensor , edge_dst : Tensor , axis : int = 1 ): """returns the length of the edge Parameters ---------- edge_src : Tensor Tensor of shape (N, 3) containing the source of the edge edge_dst : Tensor Tensor of shape (N, 3) containing the destination of the edge axis : int, optional Axis along which the norm is computed, by default 1 Returns ------- Tensor Tensor of shape (N, ) containing the length of the edge """ return np . linalg . norm ( edge_src - edge_dst , axis = axis ) [docs] def cell_to_adj ( cells : List [ List [ int ]]): """creates adjancy matrix in COO format from mesh cells Parameters ---------- cells : List[List[int]] List of cells, each cell is a list of 3 vertices Returns ------- src, dst : List[int], List[int] List of source and destination vertices """ num_cells = np . shape ( cells )[ 0 ] src = [ cells [ i ][ indx ] for i in range ( num_cells ) for indx in [ 0 , 1 , 2 ]] dst = [ cells [ i ][ indx ] for i in range ( num_cells ) for indx in [ 1 , 2 , 0 ]] return src , dst