# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from torch.cuda import amp from dataclasses import dataclass from typing import Any , Tuple # import contractions from modulus.models.sfno.factorizations import get_contract_fun , _contract_dense # helpers from modulus.models.sfno.layers import ( trunc_normal_ , DropPath , MLP , EncoderDecoder , ) # import global convolution and non-linear spectral layers from modulus.models.sfno.layers import ( SpectralConv2d , SpectralAttention2d , SpectralAttentionS2 , ) from modulus.models.sfno.s2convolutions import SpectralConvS2 # get spectral transforms from torch_harmonics import torch_harmonics as th import torch_harmonics.distributed as thd # wrap fft, to unify interface to spectral transforms from modulus.models.sfno.layers import RealFFT2 , InverseRealFFT2 from modulus.utils.sfno.distributed.layers import ( DistributedRealFFT2 , DistributedInverseRealFFT2 , DistributedMLP , DistributedEncoderDecoder , ) # more distributed stuff from modulus.utils.sfno.distributed import comm # layer normalization from apex.normalization import FusedLayerNorm from modulus.utils.sfno.distributed.layer_norm import DistributedInstanceNorm2d from modulus.models.module import Module from modulus.models.meta import ModelMetaData from modulus.models.layers import get_activation [docs] @dataclass class MetaData ( ModelMetaData ): name : str = "SFNO" # Optimization jit : bool = False cuda_graphs : bool = True amp_cpu : bool = True amp_gpu : bool = True torch_fx : bool = False # Inference onnx : bool = False # Physics informed func_torch : bool = False auto_grad : bool = False [docs] class SpectralFilterLayer ( nn . Module ): """Spectral filter layer""" def __init__ ( self , forward_transform , inverse_transform , embed_dim , filter_type = "linear" , operator_type = "diagonal" , sparsity_threshold = 0.0 , use_complex_kernels = True , hidden_size_factor = 1 , rank = 1.0 , factorization = None , separable = False , complex_network = True , complex_activation = "real" , spectral_layers = 1 , drop_rate = 0.0 , ): # pragma: no cover super ( SpectralFilterLayer , self ) . __init__ () if filter_type == "non-linear" and ( isinstance ( forward_transform , th . RealSHT ) or isinstance ( forward_transform , thd . DistributedRealSHT ) ): self . filter = SpectralAttentionS2 ( forward_transform , inverse_transform , embed_dim , sparsity_threshold = sparsity_threshold , hidden_size_factor = hidden_size_factor , complex_activation = complex_activation , spectral_layers = spectral_layers , drop_rate = drop_rate , bias = False , ) elif filter_type == "non-linear" and ( isinstance ( forward_transform , RealFFT2 ) or isinstance ( forward_transform , DistributedRealFFT2 ) ): self . filter = SpectralAttention2d ( forward_transform , inverse_transform , embed_dim , sparsity_threshold = sparsity_threshold , hidden_size_factor = hidden_size_factor , complex_activation = complex_activation , spectral_layers = spectral_layers , drop_rate = drop_rate , bias = False , ) # spectral transform is passed to the module elif filter_type == "linear" : self . filter = SpectralConvS2 ( forward_transform , inverse_transform , embed_dim , embed_dim , operator_type = operator_type , rank = rank , factorization = factorization , separable = separable , bias = False , use_tensorly = False if factorization is None else True , ) else : raise ( NotImplementedError ) [docs] def forward ( self , x ): # pragma: no cover return self . filter ( x ) [docs] class FourierNeuralOperatorBlock ( nn . Module ): """Fourier Neural Operator Block""" def __init__ ( self , forward_transform , inverse_transform , embed_dim , filter_type = "linear" , operator_type = "diagonal" , mlp_ratio = 2.0 , drop_rate = 0.0 , drop_path = 0.0 , act_layer = "gelu" , norm_layer = ( nn . LayerNorm , nn . LayerNorm ), sparsity_threshold = 0.0 , use_complex_kernels = True , rank = 1.0 , factorization = None , separable = False , inner_skip = "linear" , outer_skip = None , # None, nn.linear or nn.Identity use_mlp = False , comm_feature_inp_name = None , comm_feature_hidden_name = None , complex_network = True , complex_activation = "real" , spectral_layers = 1 , checkpointing = 0 , ): # pragma: no cover super ( FourierNeuralOperatorBlock , self ) . __init__ () if ( comm . get_size ( "h" ) > 1 ) or ( comm . get_size ( "w" ) > 1 ): self . input_shape_loc = ( forward_transform . nlat_local , forward_transform . nlon_local , ) self . output_shape_loc = ( inverse_transform . nlat_local , inverse_transform . nlon_local , ) else : self . input_shape_loc = ( forward_transform . nlat , forward_transform . nlon ) self . output_shape_loc = ( inverse_transform . nlat , inverse_transform . nlon ) # norm layer self . norm0 = norm_layer [ 0 ]() # convolution layer self . filter = SpectralFilterLayer ( forward_transform , inverse_transform , embed_dim , filter_type , operator_type , sparsity_threshold , use_complex_kernels = use_complex_kernels , hidden_size_factor = mlp_ratio , rank = rank , factorization = factorization , separable = separable , complex_network = complex_network , complex_activation = complex_activation , spectral_layers = spectral_layers , drop_rate = drop_rate , ) if inner_skip == "linear" : self . inner_skip = nn . Conv2d ( embed_dim , embed_dim , 1 , 1 ) elif inner_skip == "identity" : self . inner_skip = nn . Identity () if filter_type == "linear" or filter_type == "real linear" : self . act_layer = get_activation ( act_layer ) # dropout self . drop_path = DropPath ( drop_path ) if drop_path > 0.0 else nn . Identity () # norm layer self . norm1 = norm_layer [ 1 ]() if use_mlp == True : MLPH = DistributedMLP if ( comm . get_size ( "matmul" ) > 1 ) else MLP mlp_hidden_dim = int ( embed_dim * mlp_ratio ) self . mlp = MLPH ( in_features = embed_dim , hidden_features = mlp_hidden_dim , act_layer = act_layer , drop_rate = drop_rate , comm_inp_name = comm_feature_inp_name , comm_hidden_name = comm_feature_hidden_name , checkpointing = checkpointing , ) if outer_skip == "linear" : self . outer_skip = nn . Conv2d ( embed_dim , embed_dim , 1 , 1 ) elif outer_skip == "identity" : self . outer_skip = nn . Identity () [docs] def forward ( self , x ): # pragma: no cover x_norm = torch . zeros_like ( x ) x_norm [ ... , : self . input_shape_loc [ 0 ], : self . input_shape_loc [ 1 ]] = self . norm0 ( x [ ... , : self . input_shape_loc [ 0 ], : self . input_shape_loc [ 1 ]] ) x , residual = self . filter ( x_norm ) if hasattr ( self , "inner_skip" ): x = x + self . inner_skip ( residual ) if hasattr ( self , "act_layer" ): x = self . act_layer ( x ) x_norm = torch . zeros_like ( x ) x_norm [ ... , : self . output_shape_loc [ 0 ], : self . output_shape_loc [ 1 ] ] = self . norm1 ( x [ ... , : self . output_shape_loc [ 0 ], : self . output_shape_loc [ 1 ]]) x = x_norm if hasattr ( self , "mlp" ): x = self . mlp ( x ) x = self . drop_path ( x ) if hasattr ( self , "outer_skip" ): x = x + self . outer_skip ( residual ) return x [docs] class SphericalFourierNeuralOperatorNet ( Module ): """ Spherical Fourier Neural Operator Network Parameters ---------- params : dict Dictionary of parameters spectral_transform : str, optional Type of spectral transformation to use, by default "sht" grid : str, optional Type of grid to use, by default "legendre-gauss" filter_type : str, optional Type of filter to use ('linear', 'non-linear'), by default "non-linear" operator_type : str, optional Type of operator to use ('diaginal', 'dhconv'), by default "diagonal" inp_shape : tuple, optional Shape of the input channels, by default (721, 1440) scale_factor : int, optional Scale factor to use, by default 16 in_chans : int, optional Number of input channels, by default 2 out_chans : int, optional Number of output channels, by default 2 embed_dim : int, optional Dimension of the embeddings, by default 256 num_layers : int, optional Number of layers in the network, by default 12 repeat_layers : int, optional Number of times to repeat the layers, by default 1 use_mlp : int, optional Whether to use MLP, by default True mlp_ratio : int, optional Ratio of MLP to use, by default 2.0 activation_function : str, optional Activation function to use, by default "gelu" encoder_layers : int, optional Number of layers in the encoder, by default 1 pos_embed : str, optional Type of positional embedding to use, by default "direct" drop_rate : float, optional Dropout rate, by default 0.0 drop_path_rate : float, optional Dropout path rate, by default 0.0 sparsity_threshold : float, optional Threshold for sparsity, by default 0.0 normalization_layer : str, optional Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm" max_modes : Any, optional Maximum modes to use, by default None hard_thresholding_fraction : float, optional Fraction of hard thresholding to apply, by default 1.0 use_complex_kernels : bool, optional Whether to use complex kernels, by default True big_skip : bool, optional Whether to use big skip connections, by default True rank : float, optional Rank of the approximation, by default 1.0 factorization : Any, optional Type of factorization to use, by default None separable : bool, optional Whether to use separable convolutions, by default False complex_network : bool, optional Whether to use a complex network architecture, by default True complex_activation : str, optional Type of complex activation function to use, by default "real" spectral_layers : int, optional Number of spectral layers, by default 3 output_transform : bool, optional Whether to use an output transform, by default False checkpointing : int, optional Number of checkpointing segments, by default 0 Example: -------- >>> from modulus.models.sfno.sfnonet import SphericalFourierNeuralOperatorNet as SFNO >>> model = SFNO( ... params={}, ... inp_shape=(8, 16), ... scale_factor=4, ... in_chans=2, ... out_chans=2, ... embed_dim=16, ... num_layers=2, ... encoder_layers=1, ... spectral_layers=2, ... use_mlp=True,) >>> model(torch.randn(1, 2, 8, 16)).shape torch.Size([1, 2, 8, 16]) """ def __init__ ( self , params : dict , spectral_transform : str = "sht" , grid = "legendre-gauss" , filter_type : str = "non-linear" , operator_type : str = "diagonal" , inp_shape : Tuple [ int ] = ( 721 , 1440 ), scale_factor : int = 16 , in_chans : int = 2 , out_chans : int = 2 , embed_dim : int = 256 , num_layers : int = 12 , repeat_layers = 1 , use_mlp : int = True , mlp_ratio : int = 2.0 , activation_function : str = "gelu" , encoder_layers : int = 1 , pos_embed : str = "direct" , drop_rate : float = 0.0 , drop_path_rate : float = 0.0 , sparsity_threshold : float = 0.0 , normalization_layer : str = "instance_norm" , max_modes : Any = None , hard_thresholding_fraction : float = 1.0 , use_complex_kernels : bool = True , big_skip : bool = True , rank : float = 1.0 , factorization : Any = None , separable : bool = False , complex_network : bool = True , complex_activation : str = "real" , spectral_layers : int = 3 , output_transform : bool = False , checkpointing : int = 0 , ): # pragma: no cover super ( SphericalFourierNeuralOperatorNet , self ) . __init__ ( meta = MetaData ()) self . params = params self . spectral_transform = ( params . spectral_transform if hasattr ( params , "spectral_transform" ) else spectral_transform ) self . grid = params . grid if hasattr ( params , "grid" ) else grid self . filter_type = ( params . filter_type if hasattr ( params , "filter_type" ) else filter_type ) self . operator_type = ( params . operator_type if hasattr ( params , "operator_type" ) else operator_type ) self . inp_shape = ( ( params . img_shape_x , params . img_shape_y ) if hasattr ( params , "img_shape_x" ) and hasattr ( params , "img_shape_y" ) else inp_shape ) self . out_shape = ( ( params . out_shape_x , params . out_shape_y ) if hasattr ( params , "out_shape_x" ) and hasattr ( params , "out_shape_y" ) else self . inp_shape ) self . scale_factor = ( params . scale_factor if hasattr ( params , "scale_factor" ) else scale_factor ) self . in_chans = ( params . N_in_channels if hasattr ( params , "N_in_channels" ) else in_chans ) self . out_chans = ( params . N_out_channels if hasattr ( params , "N_out_channels" ) else out_chans ) self . embed_dim = self . num_features = ( params . embed_dim if hasattr ( params , "embed_dim" ) else embed_dim ) self . num_layers = ( params . num_layers if hasattr ( params , "num_layers" ) else num_layers ) self . repeat_layers = ( params . repeat_layers if hasattr ( params , "repeat_layers" ) else repeat_layers ) self . max_modes = ( ( params . lmax , params . mmax ) if hasattr ( params , "lmax" ) and hasattr ( params , "mmax" ) else max_modes ) self . hard_thresholding_fraction = ( params . hard_thresholding_fraction if hasattr ( params , "hard_thresholding_fraction" ) else hard_thresholding_fraction ) self . normalization_layer = ( params . normalization_layer if hasattr ( params , "normalization_layer" ) else normalization_layer ) self . use_mlp = params . use_mlp if hasattr ( params , "use_mlp" ) else use_mlp self . mlp_ratio = params . mlp_ratio if hasattr ( params , "mlp_ratio" ) else mlp_ratio self . activation_function = ( params . activation_function if hasattr ( params , "activation_function" ) else activation_function ) self . encoder_layers = ( params . encoder_layers if hasattr ( params , "encoder_layers" ) else encoder_layers ) self . pos_embed = params . pos_embed if hasattr ( params , "pos_embed" ) else pos_embed self . big_skip = params . big_skip if hasattr ( params , "big_skip" ) else big_skip self . rank = params . rank if hasattr ( params , "rank" ) else rank self . factorization = ( params . factorization if hasattr ( params , "factorization" ) else factorization ) self . separable = params . separable if hasattr ( params , "separable" ) else separable self . complex_network = ( params . complex_network if hasattr ( params , "complex_network" ) else complex_network ) self . complex_activation = ( params . complex_activation if hasattr ( params , "complex_activation" ) else complex_activation ) self . spectral_layers = ( params . spectral_layers if hasattr ( params , "spectral_layers" ) else spectral_layers ) self . output_transform = ( params . output_transform if hasattr ( params , "output_transform" ) else output_transform ) self . checkpointing = ( params . checkpointing if hasattr ( params , "checkpointing" ) else checkpointing ) # self.pretrain_encoding = params.pretrain_encoding if hasattr(params, "pretrain_encoding") else False # compute the downscaled image size self . h = int ( self . inp_shape [ 0 ] // self . scale_factor ) self . w = int ( self . inp_shape [ 1 ] // self . scale_factor ) # Compute the maximum frequencies in h and in w if self . max_modes is not None : modes_lat , modes_lon = self . max_modes else : modes_lat = int ( self . h * self . hard_thresholding_fraction ) modes_lon = int (( self . w // 2 + 1 ) * self . hard_thresholding_fraction ) # prepare the spectral transforms if self . spectral_transform == "sht" : sht_handle = th . RealSHT isht_handle = th . InverseRealSHT # parallelism if ( comm . get_size ( "h" ) > 1 ) or ( comm . get_size ( "w" ) > 1 ): polar_group = None if ( comm . get_size ( "h" ) == 1 ) else comm . get_group ( "h" ) azimuth_group = ( None if ( comm . get_size ( "w" ) == 1 ) else comm . get_group ( "w" ) ) thd . init ( polar_group , azimuth_group ) sht_handle = thd . DistributedRealSHT isht_handle = thd . DistributedInverseRealSHT # set up self . trans_down = sht_handle ( * self . inp_shape , lmax = modes_lat , mmax = modes_lon , grid = "equiangular" ) . float () self . itrans_up = isht_handle ( * self . out_shape , lmax = modes_lat , mmax = modes_lon , grid = "equiangular" ) . float () self . trans = sht_handle ( self . h , self . w , lmax = modes_lat , mmax = modes_lon , grid = self . grid ) . float () self . itrans = isht_handle ( self . h , self . w , lmax = modes_lat , mmax = modes_lon , grid = self . grid ) . float () elif self . spectral_transform == "fft" : fft_handle = th . RealFFT2 ifft_handle = th . InverseRealFFT2 # determine the global padding inp_dist_h = ( ( self . inp_shape [ 0 ] + comm . get_size ( "h" )) - 1 ) // comm . get_size ( "h" ) inp_dist_w = ( ( self . inp_shape [ 1 ] + comm . get_size ( "w" )) - 1 ) // comm . get_size ( "w" ) self . inp_padding = ( inp_dist_h * comm . get_size ( "h" ) - self . inp_shape [ 0 ], inp_dist_w * comm . get_size ( "w" ) - self . inp_shape [ 1 ], ) out_dist_h = ( ( self . out_shape [ 0 ] + comm . get_size ( "h" )) - 1 ) // comm . get_size ( "h" ) out_dist_w = ( ( self . out_shape [ 1 ] + comm . get_size ( "w" )) - 1 ) // comm . get_size ( "w" ) self . out_padding = ( out_dist_h * comm . get_size ( "h" ) - self . out_shape [ 0 ], out_dist_w * comm . get_size ( "w" ) - self . out_shape [ 1 ], ) # effective image size: self . inp_shape_eff = [ self . inp_shape [ 0 ] + self . inp_padding [ 0 ], self . inp_shape [ 1 ] + self . inp_padding [ 1 ], ] self . inp_shape_loc = [ self . inp_shape_eff [ 0 ] // comm . get_size ( "h" ), self . inp_shape_eff [ 1 ] // comm . get_size ( "w" ), ] self . out_shape_eff = [ self . out_shape [ 0 ] + self . out_padding [ 0 ], self . out_shape [ 1 ] + self . out_padding [ 1 ], ] self . out_shape_loc = [ self . out_shape_eff [ 0 ] // comm . get_size ( "h" ), self . out_shape_eff [ 1 ] // comm . get_size ( "w" ), ] if ( comm . get_size ( "h" ) > 1 ) or ( comm . get_size ( "w" ) > 1 ): fft_handle = DistributedRealFFT2 ifft_handle = DistributedInverseRealFFT2 self . trans_down = fft_handle ( * self . inp_shape_eff , lmax = modes_lat , mmax = modes_lon ) . float () self . itrans_up = ifft_handle ( * self . out_shape_eff , lmax = modes_lat , mmax = modes_lon ) . float () self . trans = fft_handle ( self . h , self . w , lmax = modes_lat , mmax = modes_lon ) . float () self . itrans = ifft_handle ( self . h , self . w , lmax = modes_lat , mmax = modes_lon ) . float () else : raise ( ValueError ( "Unknown spectral transform" )) # use the SHT/FFT to compute the local, downscaled grid dimensions if ( comm . get_size ( "h" ) > 1 ) or ( comm . get_size ( "w" ) > 1 ): self . inp_shape_loc = ( self . trans_down . nlat_local , self . trans_down . nlon_local , ) self . inp_shape_eff = [ self . trans_down . nlat_local + self . trans_down . nlatpad_local , self . trans_down . nlon_local + self . trans_down . nlonpad_local , ] self . h_loc = self . itrans . nlat_local self . w_loc = self . itrans . nlon_local else : self . inp_shape_loc = ( self . trans_down . nlat , self . trans_down . nlon ) self . inp_shape_eff = [ self . trans_down . nlat , self . trans_down . nlon ] self . h_loc = self . itrans . nlat self . w_loc = self . itrans . nlon # encoder if comm . get_size ( "matmul" ) > 1 : self . encoder = DistributedEncoderDecoder ( num_layers = self . encoder_layers , input_dim = self . in_chans , output_dim = self . embed_dim , hidden_dim = int ( 1 * self . embed_dim ), act = self . activation_function , comm_inp_name = "fin" , comm_out_name = "fout" , ) fblock_mlp_inp_name = self . encoder . comm_out_name fblock_mlp_hidden_name = ( "fout" if ( self . encoder . comm_out_name == "fin" ) else "fin" ) else : self . encoder = EncoderDecoder ( num_layers = self . encoder_layers , input_dim = self . in_chans , output_dim = self . embed_dim , hidden_dim = int ( 1 * self . embed_dim ), act = self . activation_function , ) fblock_mlp_inp_name = "fin" fblock_mlp_hidden_name = "fout" # dropout self . pos_drop = nn . Dropout ( p = drop_rate ) if drop_rate > 0.0 else nn . Identity () dpr = [ x . item () for x in torch . linspace ( 0 , drop_path_rate , self . num_layers )] # pick norm layer if self . normalization_layer == "layer_norm" : norm_layer_inp = partial ( nn . LayerNorm , normalized_shape = ( self . inp_shape_loc [ 0 ], self . inp_shape_loc [ 1 ]), eps = 1e-6 , ) norm_layer_mid = partial ( nn . LayerNorm , normalized_shape = ( self . h_loc , self . w_loc ), eps = 1e-6 ) norm_layer_out = partial ( nn . LayerNorm , normalized_shape = ( self . out_shape_loc [ 0 ], self . out_shape_loc [ 1 ]), eps = 1e-6 , ) elif self . normalization_layer == "instance_norm" : if comm . get_size ( "spatial" ) > 1 : norm_layer_inp = partial ( DistributedInstanceNorm2d , num_features = self . embed_dim , eps = 1e-6 , affine = True , ) else : norm_layer_inp = partial ( nn . InstanceNorm2d , num_features = self . embed_dim , eps = 1e-6 , affine = True , track_running_stats = False , ) norm_layer_out = norm_layer_mid = norm_layer_inp elif self . normalization_layer == "none" : norm_layer_out = norm_layer_mid = norm_layer_inp = nn . Identity else : raise NotImplementedError ( f "Error, normalization { self . normalization_layer } not implemented." ) # FNO blocks self . blocks = nn . ModuleList ([]) for i in range ( self . num_layers ): first_layer = i == 0 last_layer = i == self . num_layers - 1 forward_transform = self . trans_down if first_layer else self . trans inverse_transform = self . itrans_up if last_layer else self . itrans inner_skip = "linear" outer_skip = "identity" if first_layer and last_layer : norm_layer = ( norm_layer_inp , norm_layer_out ) elif first_layer : norm_layer = ( norm_layer_inp , norm_layer_mid ) elif last_layer : norm_layer = ( norm_layer_mid , norm_layer_out ) else : norm_layer = ( norm_layer_mid , norm_layer_mid ) filter_type = self . filter_type operator_type = self . operator_type block = FourierNeuralOperatorBlock ( forward_transform , inverse_transform , self . embed_dim , filter_type = filter_type , operator_type = operator_type , mlp_ratio = self . mlp_ratio , drop_rate = drop_rate , drop_path = dpr [ i ], act_layer = self . activation_function , norm_layer = norm_layer , sparsity_threshold = sparsity_threshold , use_complex_kernels = use_complex_kernels , inner_skip = inner_skip , outer_skip = outer_skip , use_mlp = self . use_mlp , comm_feature_inp_name = fblock_mlp_inp_name , comm_feature_hidden_name = fblock_mlp_hidden_name , rank = self . rank , factorization = self . factorization , separable = self . separable , complex_network = self . complex_network , complex_activation = self . complex_activation , spectral_layers = self . spectral_layers , checkpointing = self . checkpointing , ) self . blocks . append ( block ) # decoder if comm . get_size ( "matmul" ) > 1 : comm_inp_name = fblock_mlp_inp_name comm_out_name = fblock_mlp_hidden_name self . decoder = DistributedEncoderDecoder ( num_layers = self . encoder_layers , input_dim = self . embed_dim , output_dim = self . out_chans , hidden_dim = int ( 1 * self . embed_dim ), act = self . activation_function , comm_inp_name = comm_inp_name , comm_out_name = comm_out_name , ) else : self . decoder = EncoderDecoder ( num_layers = self . encoder_layers , input_dim = self . embed_dim + self . big_skip * self . out_chans , output_dim = self . out_chans , hidden_dim = int ( 1 * self . embed_dim ), act = self . activation_function , ) # output transform if self . big_skip : self . residual_transform = nn . Conv2d ( self . out_chans , self . out_chans , 1 , bias = False ) # learned position embedding if self . pos_embed == "direct" : # currently using deliberately a differently shape position embedding self . pos_embed = nn . Parameter ( torch . zeros ( 1 , self . embed_dim , self . inp_shape_loc [ 0 ], self . inp_shape_loc [ 1 ] ) ) self . pos_embed . is_shared_mp = [ "matmul" ] self . pos_embed . sharded_dims_mp = [ None , None , "h" , "w" ] self . pos_embed . type = "direct" trunc_normal_ ( self . pos_embed , std = 0.02 ) elif self . pos_embed == "frequency" : if ( comm . get_size ( "h" ) > 1 ) or ( comm . get_size ( "w" ) > 1 ): lmax_loc = self . itrans_up . lmax_local mmax_loc = self . itrans_up . mmax_local else : lmax_loc = self . itrans_up . lmax mmax_loc = self . itrans_up . mmax rcoeffs = nn . Parameter ( torch . tril ( torch . randn ( 1 , self . embed_dim , lmax_loc , mmax_loc ), diagonal = 0 ) ) ccoeffs = nn . Parameter ( torch . tril ( torch . randn ( 1 , self . embed_dim , lmax_loc , mmax_loc - 1 ), diagonal =- 1 ) ) trunc_normal_ ( rcoeffs , std = 0.02 ) trunc_normal_ ( ccoeffs , std = 0.02 ) self . pos_embed = nn . ParameterList ([ rcoeffs , ccoeffs ]) self . pos_embed . type = "frequency" elif self . pos_embed == "none" or self . pos_embed == "None" : delattr ( self , "pos_embed" ) else : raise ValueError ( "Unknown position embedding type" ) if self . output_transform : minmax_channels = [] for o , c in enumerate ( params . out_channels ): if params . channel_names [ c ][ 0 ] == "r" : minmax_channels . append ( o ) self . register_buffer ( "minmax_channels" , torch . Tensor ( minmax_channels ) . to ( torch . long ), persistent = False , ) self . apply ( self . _init_weights ) def _init_weights ( self , m ): # pragma: no cover """Helper routine for weight initialization""" if isinstance ( m , nn . Linear ) or isinstance ( m , nn . Conv2d ): trunc_normal_ ( m . weight , std = 0.02 ) if m . bias is not None : nn . init . constant_ ( m . bias , 0 ) elif isinstance ( m , nn . LayerNorm ) or isinstance ( m , FusedLayerNorm ): nn . init . constant_ ( m . bias , 0 ) nn . init . constant_ ( m . weight , 1.0 ) [docs] @torch . jit . ignore def no_weight_decay ( self ): # pragma: no cover """Helper""" return { "pos_embed" , "cls_token" } def _forward_features(self, x): # pragma: no cover for r in range(self.repeat_layers): for blk in self.blocks: if self.checkpointing >= 3: x = checkpoint(blk, x) else: x = blk(x) return x [docs] def forward ( self , x ): # pragma: no cover if comm . get_size ( "fin" ) > 1 : x = scatter_to_parallel_region ( x , "fin" , 1 ) # save big skip if self . big_skip : # if output shape differs, use the spectral transforms to change resolution if self . out_shape != self . inp_shape : xtype = x . dtype # only take the predicted channels as residual residual = x [ ... , : self . out_chans , :, :] . to ( torch . float32 ) with amp . autocast ( enabled = False ): residual = self . trans_down ( residual ) residual = residual . contiguous () # residual = self.inverse_transform(residual) residual = self . itrans_up ( residual ) residual = residual . to ( dtype = xtype ) else : # only take the predicted channels residual = x [ ... , : self . out_chans , :, :] if self . checkpointing >= 1 : x = checkpoint ( self . encoder , x ) else : x = self . encoder ( x ) if hasattr ( self , "pos_embed" ): if self . pos_embed . type == "frequency" : pos_embed = torch . stack ( [ self . pos_embed [ 0 ], nn . functional . pad ( self . pos_embed [ 1 ], ( 1 , 0 ), "constant" , 0 ), ], dim =- 1 , ) with amp . autocast ( enabled = False ): pos_embed = self . itrans_up ( torch . view_as_complex ( pos_embed )) else : pos_embed = self . pos_embed # old way of treating unequally shaped weights if ( self . pos_embed . type == "direct" and self . inp_shape_loc != self . inp_shape_eff ): xp = torch . zeros_like ( x ) xp [ ... , : self . inp_shape_loc [ 0 ], : self . inp_shape_loc [ 1 ]] = ( x [ ... , : self . inp_shape_loc [ 0 ], : self . inp_shape_loc [ 1 ]] + pos_embed ) x = xp else : x = x + pos_embed # maybe clean the padding jsut in case x = self . pos_drop ( x ) # do the feature extraction x = self . _forward_features ( x ) if self . big_skip : x = torch . cat (( x , residual ), dim = 1 ) if self . checkpointing >= 1 : x = checkpoint ( self . decoder , x ) else : x = self . decoder ( x ) if hasattr ( self . decoder , "comm_out_name" ) and ( comm . get_size ( self . decoder . comm_out_name ) > 1 ): x = gather_from_parallel_region ( x , self . decoder . comm_out_name , 1 ) if self . big_skip : x = x + self . residual_transform ( residual ) if self . output_transform : x [:, self . minmax_channels ] = torch . sigmoid ( x [:, self . minmax_channels ]) return x