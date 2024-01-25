# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn import torch.nn.functional as F from modulus.utils.sfno.distributed import comm from modulus.utils.sfno.distributed.mappings import ( reduce_from_parallel_region , copy_to_parallel_region , ) [docs] class Preprocessor2D ( nn . Module ): """ Preprocessing methods to flatten image history, add static features, and convert the data format from NCHW to NHWC. """ def __init__ ( self , params ): # pragma: no cover super ( Preprocessor2D , self ) . __init__ () self . n_history = params . n_history self . transform_to_nhwc = params . enable_nhwc self . history_normalization_mode = params . history_normalization_mode if self . history_normalization_mode == "exponential" : self . history_normalization_decay = params . history_normalization_decay # inverse ordering, since first element is oldest history_normalization_weights = torch . exp ( ( - self . history_normalization_decay ) * torch . arange ( start = self . n_history , end =- 1 , step =- 1 , dtype = torch . float32 ) ) history_normalization_weights = history_normalization_weights / torch . sum ( history_normalization_weights ) history_normalization_weights = torch . reshape ( history_normalization_weights , ( 1 , - 1 , 1 , 1 , 1 ) ) elif self . history_normalization_mode == "mean" : history_normalization_weights = torch . Tensor ( 1.0 / float ( self . n_history + 1 ), dtype = torch . float32 ) history_normalization_weights = torch . reshape ( history_normalization_weights , ( 1 , - 1 , 1 , 1 , 1 ) ) else : history_normalization_weights = torch . ones ( self . n_history + 1 , dtype = torch . float32 ) self . register_buffer ( "history_normalization_weights" , history_normalization_weights , persistent = False , ) self . history_mean = None self . history_std = None self . history_diff_mean = None self . history_diff_var = None self . history_eps = 1e-6 self . img_shape = [ params . img_shape_x , params . img_shape_y ] # unpredicted input channels: self . unpredicted_inp_train = None self . unpredicted_tar_train = None self . unpredicted_inp_eval = None self . unpredicted_tar_eval = None # process static features static_features = None # needed for sharding start_x = params . img_local_offset_x end_x = min ( start_x + params . img_local_shape_x , params . img_shape_x ) pad_x = params . img_local_shape_x - ( end_x - start_x ) start_y = params . img_local_offset_y end_y = min ( start_y + params . img_local_shape_y , params . img_shape_y ) pad_y = params . img_local_shape_y - ( end_y - start_y ) # set up grid if params . add_grid : with torch . no_grad (): tx = torch . linspace ( 0 , 1 , params . img_shape_x + 1 , dtype = torch . float32 )[ 0 : - 1 ] ty = torch . linspace ( 0 , 1 , params . img_shape_y + 1 , dtype = torch . float32 )[ 0 : - 1 ] x_grid , y_grid = torch . meshgrid ( tx , ty , indexing = "ij" ) x_grid , y_grid = x_grid . unsqueeze ( 0 ) . unsqueeze ( 0 ), y_grid . unsqueeze ( 0 ) . unsqueeze ( 0 ) grid = torch . cat ([ x_grid , y_grid ], dim = 1 ) # shard spatially: grid = grid [:, :, start_x : end_x , start_y : end_y ] # pad if needed grid = F . pad ( grid , [ 0 , pad_y , 0 , pad_x ]) # transform if requested if params . gridtype == "sinusoidal" : num_freq = 1 if hasattr ( params , "grid_num_frequencies" ): num_freq = int ( params . grid_num_frequencies ) singrid = None for freq in range ( 1 , num_freq + 1 ): if singrid is None : singrid = torch . sin ( grid ) else : singrid = torch . cat ( [ singrid , torch . sin ( freq * grid )], dim = 1 ) static_features = singrid else : static_features = grid if params . add_orography : from utils.conditioning_inputs import get_orography oro = torch . tensor ( get_orography ( params . orography_path ), dtype = torch . float32 ) oro = torch . reshape ( oro , ( 1 , 1 , oro . shape [ 0 ], oro . shape [ 1 ])) # shard oro = oro [:, :, start_x : end_x , start_y : end_y ] # pad if needed oro = F . pad ( oro , [ 0 , pad_y , 0 , pad_x ]) if static_features is None : static_features = oro else : static_features = torch . cat ([ static_features , oro ], dim = 1 ) if params . add_landmask : from utils.conditioning_inputs import get_land_mask lsm = torch . tensor ( get_land_mask ( params . landmask_path ), dtype = torch . long ) # one hot encode and move channels to front: lsm = torch . permute ( torch . nn . functional . one_hot ( lsm ), ( 2 , 0 , 1 )) . to ( torch . float32 ) lsm = torch . reshape ( lsm , ( 1 , lsm . shape [ 0 ], lsm . shape [ 1 ], lsm . shape [ 2 ])) # shard lsm = lsm [:, :, start_x : end_x , start_y : end_y ] # pad if needed lsm = F . pad ( lsm , [ 0 , pad_y , 0 , pad_x ]) if static_features is None : static_features = lsm else : static_features = torch . cat ([ static_features , lsm ], dim = 1 ) self . do_add_static_features = False if static_features is not None : self . do_add_static_features = True self . register_buffer ( "static_features" , static_features , persistent = False ) [docs] def flatten_history ( self , x ): # pragma: no cover """Flatten input so that history is included as part of channels""" if x . dim () == 5 : b_ , t_ , c_ , h_ , w_ = x . shape x = torch . reshape ( x , ( b_ , t_ * c_ , h_ , w_ )) return x [docs] def expand_history ( self , x , nhist ): # pragma: no cover """Expand history from flattened data""" if x . dim () == 4 : b_ , ct_ , h_ , w_ = x . shape x = torch . reshape ( x , ( b_ , nhist , ct_ // nhist , h_ , w_ )) return x [docs] def add_static_features ( self , x ): # pragma: no cover """Adds static features to the input""" if self . do_add_static_features : # we need to replicate the grid for each batch: static = torch . tile ( self . static_features , dims = ( x . shape [ 0 ], 1 , 1 , 1 )) x = torch . cat ([ x , static ], dim = 1 ) return x [docs] def remove_static_features ( self , x ): # pragma: no cover """ Removes static features from the input only remove if something was added in the first place """ if self . do_add_static_features : nfeat = self . static_features . shape [ 1 ] x = x [:, : x . shape [ 1 ] - nfeat , :, :] return x [docs] def append_history ( self , x1 , x2 , step ): # pragma: no cover """ Appends history to the main input. Without history, just returns the second tensor (x2). """ # take care of unpredicted features first # this is necessary in order to copy the targets unpredicted features # (such as zenith angle) into the inputs unpredicted features, # such that they can be forward in the next autoregressive step # update the unpredicted input if self . training : if ( self . unpredicted_tar_train is not None ) and ( step < self . unpredicted_tar_train . shape [ 1 ] ): utar = self . unpredicted_tar_train [:, step : ( step + 1 ), :, :, :] if self . n_history == 0 : self . unpredicted_inp_train . copy_ ( utar ) else : self . unpredicted_inp_train . copy_ ( torch . cat ( [ self . unpredicted_inp_train [:, 1 :, :, :, :], utar ], dim = 1 ) ) else : if ( self . unpredicted_tar_eval is not None ) and ( step < self . unpredicted_tar_eval . shape [ 1 ] ): utar = self . unpredicted_tar_eval [:, step : ( step + 1 ), :, :, :] if self . n_history == 0 : self . unpredicted_inp_eval . copy_ ( utar ) else : self . unpredicted_inp_eval . copy_ ( torch . cat ( [ self . unpredicted_inp_eval [:, 1 :, :, :, :], utar ], dim = 1 ) ) # without history, just return the second tensor if self . n_history > 0 : # this is more complicated x1 = self . expand_history ( x1 , nhist = self . n_history + 1 ) x2 = self . expand_history ( x2 , nhist = 1 ) # append res = torch . cat ([ x1 [:, 1 :, :, :, :], x2 ], dim = 1 ) # flatten again res = self . flatten_history ( res ) else : res = x2 return res [docs] def append_channels ( self , x , xc ): # pragma: no cover """Appends channels""" xdim = x . dim () x = self . expand_history ( x , self . n_history + 1 ) xc = self . expand_history ( xc , self . n_history + 1 ) # concatenate xo = torch . cat ([ x , xc ], dim = 2 ) # flatten if requested if xdim == 4 : xo = self . flatten_history ( xo ) return xo [docs] def history_compute_stats ( self , x ): # pragma: no cover """Compute stats from history timesteps""" if self . history_normalization_mode == "none" : self . history_mean = torch . zeros ( ( 1 , 1 , 1 , 1 ), dtype = torch . float32 , device = x . device ) self . history_std = torch . ones ( ( 1 , 1 , 1 , 1 ), dtype = torch . float32 , device = x . device ) elif self . history_normalization_mode == "timediff" : # reshaping xdim = x . dim () if xdim == 4 : b_ , c_ , h_ , w_ = x . shape xr = torch . reshape ( x , ( b_ , ( self . n_history + 1 ), c_ // ( self . n_history + 1 ), h_ , w_ ) ) else : xshape = x . shape xr = x # time difference mean: self . history_diff_mean = torch . mean ( torch . sum ( xr [:, 1 :, ... ] - xr [:, 0 : - 1 , ... ], dim = ( 4 , 5 )), dim = ( 1 , 2 ) ) # reduce across gpus if comm . get_size ( "spatial" ) > 1 : self . history_diff_mean = reduce_from_parallel_region ( self . history_diff_mean , "spatial" ) self . history_diff_mean = self . history_diff_mean / float ( self . img_shape [ 0 ] * self . img_shape [ 1 ] ) # time difference std self . history_diff_var = torch . mean ( torch . sum ( torch . square ( ( xr [:, 1 :, ... ] - xr [:, 0 : - 1 , ... ]) - self . history_diff_mean ), dim = ( 4 , 5 ), ), dim = ( 1 , 2 ), ) # reduce across gpus if comm . get_size ( "spatial" ) > 1 : self . history_diff_var = reduce_from_parallel_region ( self . history_diff_var , "spatial" ) self . history_diff_var = self . history_diff_var / float ( self . img_shape [ 0 ] * self . img_shape [ 1 ] ) # time difference stds self . history_diff_mean = copy_to_parallel_region ( self . history_diff_mean , "spatial" ) self . history_diff_var = copy_to_parallel_region ( self . history_diff_var , "spatial" ) else : xdim = x . dim () if xdim == 4 : b_ , c_ , h_ , w_ = x . shape xr = torch . reshape ( x , ( b_ , ( self . n_history + 1 ), c_ // ( self . n_history + 1 ), h_ , w_ ) ) else : xshape = x . shape xr = x # mean # compute weighted mean over dim 1, but sum over dim=3,4 self . history_mean = torch . sum ( xr * self . history_normalization_weights , dim = ( 1 , 3 , 4 ), keepdim = True ) # reduce across gpus if comm . get_size ( "spatial" ) > 1 : self . history_mean = reduce_from_parallel_region ( self . history_mean , "spatial" ) self . history_mean = self . history_mean / float ( self . img_shape [ 0 ] * self . img_shape [ 1 ] ) # compute std self . history_std = torch . sum ( torch . square ( xr - self . history_mean ) * self . history_normalization_weights , dim = ( 1 , 3 , 4 ), keepdim = True , ) # reduce across gpus if comm . get_size ( "spatial" ) > 1 : self . history_std = reduce_from_parallel_region ( self . history_std , "spatial" ) self . history_std = torch . sqrt ( self . history_std / float ( self . img_shape [ 0 ] * self . img_shape [ 1 ]) ) # squeeze self . history_mean = torch . squeeze ( self . history_mean , dim = 1 ) self . history_std = torch . squeeze ( self . history_std , dim = 1 ) # copy to parallel region self . history_mean = copy_to_parallel_region ( self . history_mean , "spatial" ) self . history_std = copy_to_parallel_region ( self . history_std , "spatial" ) return [docs] def history_normalize ( self , x , target = False ): # pragma: no cover """Normalize history""" if self . history_normalization_mode in [ "none" , "timediff" ]: return x xdim = x . dim () if xdim == 4 : b_ , c_ , h_ , w_ = x . shape xr = torch . reshape ( x , ( b_ , ( self . n_history + 1 ), c_ // ( self . n_history + 1 ), h_ , w_ ) ) else : xshape = x . shape xr = x x = self . flatten_history ( x ) # normalize if target : # strip off the unpredicted channels xn = ( x - self . history_mean [:, : x . shape [ 1 ], :, :]) / self . history_std [ :, : x . shape [ 1 ], :, : ] else : # tile to include history hm = torch . tile ( self . history_mean , ( 1 , self . n_history + 1 , 1 , 1 )) hs = torch . tile ( self . history_std , ( 1 , self . n_history + 1 , 1 , 1 )) xn = ( x - hm ) / hs if xdim == 5 : xn = torch . reshape ( xn , xshape ) return xn [docs] def history_denormalize ( self , xn , target = False ): # pragma: no cover """Denormalize history""" if self . history_normalization_mode in [ "none" , "timediff" ]: return xn assert self . history_mean is not None assert self . history_std is not None xndim = xn . dim () if xndim == 5 : xnshape = xn . shape xn = self . flatten_history ( xn ) # de-normalize if target : # strip off the unpredicted channels x = ( xn * self . history_std [:, : xn . shape [ 1 ], :, :] + self . history_mean [:, : xn . shape [ 1 ], :, :] ) else : # tile to include history hm = torch . tile ( self . history_mean , ( 1 , self . n_history + 1 , 1 , 1 )) hs = torch . tile ( self . history_std , ( 1 , self . n_history + 1 , 1 , 1 )) x = xn * hs + hm if xndim == 5 : x = torch . reshape ( x , xnshape ) return x [docs] def cache_unpredicted_features ( self , x , y = None , xz = None , yz = None ): # pragma: no cover """Caches features not predicted by the model (such as zenith angle)""" if self . training : if ( self . unpredicted_inp_train is not None ) and ( xz is not None ): self . unpredicted_inp_train . copy_ ( xz ) else : self . unpredicted_inp_train = xz if ( self . unpredicted_tar_train is not None ) and ( yz is not None ): self . unpredicted_tar_train . copy_ ( yz ) else : self . unpredicted_tar_train = yz else : if ( self . unpredicted_inp_eval is not None ) and ( xz is not None ): self . unpredicted_inp_eval . copy_ ( xz ) else : self . unpredicted_inp_eval = xz if ( self . unpredicted_tar_eval is not None ) and ( yz is not None ): self . unpredicted_tar_eval . copy_ ( yz ) else : self . unpredicted_tar_eval = yz return x , y [docs] def append_unpredicted_features ( self , inp ): # pragma: no cover """Appends features not predicted by the model (such as zenith angle) from the input""" if self . training : if self . unpredicted_inp_train is not None : inp = self . append_channels ( inp , self . unpredicted_inp_train ) else : if self . unpredicted_inp_eval is not None : inp = self . append_channels ( inp , self . unpredicted_inp_eval ) return inp [docs] def remove_unpredicted_features ( self , inp ): # pragma: no cover """Removes features not predicted by the model (such as zenith angle) from the input""" if self . training : if self . unpredicted_inp_train is not None : inpf = self . expand_history ( inp , nhist = self . n_history + 1 ) inpc = inpf [ :, :, : inpf . shape [ 2 ] - self . unpredicted_inp_train . shape [ 2 ], :, : ] inp = self . flatten_history ( inpc ) else : if self . unpredicted_inp_eval is not None : inpf = self . expand_history ( inp , nhist = self . n_history + 1 ) inpc = inpf [ :, :, : inpf . shape [ 2 ] - self . unpredicted_inp_eval . shape [ 2 ], :, : ] inp = self . flatten_history ( inpc ) return inp [docs] def get_preprocessor ( params ): # pragma: no cover """Returns the preprocessor module""" return Preprocessor2D ( params )