deeplearning/modulus/modulus-core-v030/_modules/modulus/models/sfno/sfnonet.html

Source code for modulus.models.sfno.sfnonet

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from torch.cuda import amp
from dataclasses import dataclass
from typing import Any, Tuple

# import contractions
from modulus.models.sfno.factorizations import get_contract_fun, _contract_dense

# helpers
from modulus.models.sfno.layers import (
    trunc_normal_,
    DropPath,
    MLP,
    EncoderDecoder,
)

# import global convolution and non-linear spectral layers
from modulus.models.sfno.layers import (
    SpectralConv2d,
    SpectralAttention2d,
    SpectralAttentionS2,
)

from modulus.models.sfno.s2convolutions import SpectralConvS2

# get spectral transforms from torch_harmonics
import torch_harmonics as th
import torch_harmonics.distributed as thd

# wrap fft, to unify interface to spectral transforms
from modulus.models.sfno.layers import RealFFT2, InverseRealFFT2
from modulus.utils.sfno.distributed.layers import (
    DistributedRealFFT2,
    DistributedInverseRealFFT2,
    DistributedMLP,
    DistributedEncoderDecoder,
)

# more distributed stuff
from modulus.utils.sfno.distributed import comm

# layer normalization
from apex.normalization import FusedLayerNorm
from modulus.utils.sfno.distributed.layer_norm import DistributedInstanceNorm2d

from modulus.models.module import Module
from modulus.models.meta import ModelMetaData
from modulus.models.layers import get_activation


[docs]@dataclass class MetaData(ModelMetaData): name: str = "SFNO" # Optimization jit: bool = False cuda_graphs: bool = True amp_cpu: bool = True amp_gpu: bool = True torch_fx: bool = False # Inference onnx: bool = False # Physics informed func_torch: bool = False auto_grad: bool = False
[docs]class SpectralFilterLayer(nn.Module): """Spectral filter layer""" def __init__( self, forward_transform, inverse_transform, embed_dim, filter_type="linear", operator_type="diagonal", sparsity_threshold=0.0, use_complex_kernels=True, hidden_size_factor=1, rank=1.0, factorization=None, separable=False, complex_network=True, complex_activation="real", spectral_layers=1, drop_rate=0.0, ): # pragma: no cover super(SpectralFilterLayer, self).__init__() if filter_type == "non-linear" and ( isinstance(forward_transform, th.RealSHT) or isinstance(forward_transform, thd.DistributedRealSHT) ): self.filter = SpectralAttentionS2( forward_transform, inverse_transform, embed_dim, sparsity_threshold=sparsity_threshold, hidden_size_factor=hidden_size_factor, complex_activation=complex_activation, spectral_layers=spectral_layers, drop_rate=drop_rate, bias=False, ) elif filter_type == "non-linear" and ( isinstance(forward_transform, RealFFT2) or isinstance(forward_transform, DistributedRealFFT2) ): self.filter = SpectralAttention2d( forward_transform, inverse_transform, embed_dim, sparsity_threshold=sparsity_threshold, hidden_size_factor=hidden_size_factor, complex_activation=complex_activation, spectral_layers=spectral_layers, drop_rate=drop_rate, bias=False, ) # spectral transform is passed to the module elif filter_type == "linear": self.filter = SpectralConvS2( forward_transform, inverse_transform, embed_dim, embed_dim, operator_type=operator_type, rank=rank, factorization=factorization, separable=separable, bias=False, use_tensorly=False if factorization is None else True, ) else: raise (NotImplementedError)
[docs] def forward(self, x): # pragma: no cover return self.filter(x)
[docs]class FourierNeuralOperatorBlock(nn.Module): """Fourier Neural Operator Block""" def __init__( self, forward_transform, inverse_transform, embed_dim, filter_type="linear", operator_type="diagonal", mlp_ratio=2.0, drop_rate=0.0, drop_path=0.0, act_layer="gelu", norm_layer=(nn.LayerNorm, nn.LayerNorm), sparsity_threshold=0.0, use_complex_kernels=True, rank=1.0, factorization=None, separable=False, inner_skip="linear", outer_skip=None, # None, nn.linear or nn.Identity use_mlp=False, comm_feature_inp_name=None, comm_feature_hidden_name=None, complex_network=True, complex_activation="real", spectral_layers=1, checkpointing=0, ): # pragma: no cover super(FourierNeuralOperatorBlock, self).__init__() if (comm.get_size("h") > 1) or (comm.get_size("w") > 1): self.input_shape_loc = ( forward_transform.nlat_local, forward_transform.nlon_local, ) self.output_shape_loc = ( inverse_transform.nlat_local, inverse_transform.nlon_local, ) else: self.input_shape_loc = (forward_transform.nlat, forward_transform.nlon) self.output_shape_loc = (inverse_transform.nlat, inverse_transform.nlon) # norm layer self.norm0 = norm_layer[0]() # convolution layer self.filter = SpectralFilterLayer( forward_transform, inverse_transform, embed_dim, filter_type, operator_type, sparsity_threshold, use_complex_kernels=use_complex_kernels, hidden_size_factor=mlp_ratio, rank=rank, factorization=factorization, separable=separable, complex_network=complex_network, complex_activation=complex_activation, spectral_layers=spectral_layers, drop_rate=drop_rate, ) if inner_skip == "linear": self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1) elif inner_skip == "identity": self.inner_skip = nn.Identity() if filter_type == "linear" or filter_type == "real linear": self.act_layer = get_activation(act_layer) # dropout self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # norm layer self.norm1 = norm_layer[1]() if use_mlp == True: MLPH = DistributedMLP if (comm.get_size("matmul") > 1) else MLP mlp_hidden_dim = int(embed_dim * mlp_ratio) self.mlp = MLPH( in_features=embed_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_rate=drop_rate, comm_inp_name=comm_feature_inp_name, comm_hidden_name=comm_feature_hidden_name, checkpointing=checkpointing, ) if outer_skip == "linear": self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1) elif outer_skip == "identity": self.outer_skip = nn.Identity()
[docs] def forward(self, x): # pragma: no cover x_norm = torch.zeros_like(x) x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = self.norm0( x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] ) x, residual = self.filter(x_norm) if hasattr(self, "inner_skip"): x = x + self.inner_skip(residual) if hasattr(self, "act_layer"): x = self.act_layer(x) x_norm = torch.zeros_like(x) x_norm[ ..., : self.output_shape_loc[0], : self.output_shape_loc[1] ] = self.norm1(x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]]) x = x_norm if hasattr(self, "mlp"): x = self.mlp(x) x = self.drop_path(x) if hasattr(self, "outer_skip"): x = x + self.outer_skip(residual) return x
[docs]class SphericalFourierNeuralOperatorNet(Module): """ Spherical Fourier Neural Operator Network Parameters ---------- params : dict Dictionary of parameters spectral_transform : str, optional Type of spectral transformation to use, by default "sht" grid : str, optional Type of grid to use, by default "legendre-gauss" filter_type : str, optional Type of filter to use ('linear', 'non-linear'), by default "non-linear" operator_type : str, optional Type of operator to use ('diaginal', 'dhconv'), by default "diagonal" inp_shape : tuple, optional Shape of the input channels, by default (721, 1440) scale_factor : int, optional Scale factor to use, by default 16 in_chans : int, optional Number of input channels, by default 2 out_chans : int, optional Number of output channels, by default 2 embed_dim : int, optional Dimension of the embeddings, by default 256 num_layers : int, optional Number of layers in the network, by default 12 repeat_layers : int, optional Number of times to repeat the layers, by default 1 use_mlp : int, optional Whether to use MLP, by default True mlp_ratio : int, optional Ratio of MLP to use, by default 2.0 activation_function : str, optional Activation function to use, by default "gelu" encoder_layers : int, optional Number of layers in the encoder, by default 1 pos_embed : str, optional Type of positional embedding to use, by default "direct" drop_rate : float, optional Dropout rate, by default 0.0 drop_path_rate : float, optional Dropout path rate, by default 0.0 sparsity_threshold : float, optional Threshold for sparsity, by default 0.0 normalization_layer : str, optional Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm" max_modes : Any, optional Maximum modes to use, by default None hard_thresholding_fraction : float, optional Fraction of hard thresholding to apply, by default 1.0 use_complex_kernels : bool, optional Whether to use complex kernels, by default True big_skip : bool, optional Whether to use big skip connections, by default True rank : float, optional Rank of the approximation, by default 1.0 factorization : Any, optional Type of factorization to use, by default None separable : bool, optional Whether to use separable convolutions, by default False complex_network : bool, optional Whether to use a complex network architecture, by default True complex_activation : str, optional Type of complex activation function to use, by default "real" spectral_layers : int, optional Number of spectral layers, by default 3 output_transform : bool, optional Whether to use an output transform, by default False checkpointing : int, optional Number of checkpointing segments, by default 0 Example: -------- >>> from modulus.models.sfno.sfnonet import SphericalFourierNeuralOperatorNet as SFNO >>> model = SFNO( ... params={}, ... inp_shape=(8, 16), ... scale_factor=4, ... in_chans=2, ... out_chans=2, ... embed_dim=16, ... num_layers=2, ... encoder_layers=1, ... spectral_layers=2, ... use_mlp=True,) >>> model(torch.randn(1, 2, 8, 16)).shape torch.Size([1, 2, 8, 16]) """ def __init__( self, params: dict, spectral_transform: str = "sht", grid="legendre-gauss", filter_type: str = "non-linear", operator_type: str = "diagonal", inp_shape: Tuple[int] = (721, 1440), scale_factor: int = 16, in_chans: int = 2, out_chans: int = 2, embed_dim: int = 256, num_layers: int = 12, repeat_layers=1, use_mlp: int = True, mlp_ratio: int = 2.0, activation_function: str = "gelu", encoder_layers: int = 1, pos_embed: str = "direct", drop_rate: float = 0.0, drop_path_rate: float = 0.0, sparsity_threshold: float = 0.0, normalization_layer: str = "instance_norm", max_modes: Any = None, hard_thresholding_fraction: float = 1.0, use_complex_kernels: bool = True, big_skip: bool = True, rank: float = 1.0, factorization: Any = None, separable: bool = False, complex_network: bool = True, complex_activation: str = "real", spectral_layers: int = 3, output_transform: bool = False, checkpointing: int = 0, ): # pragma: no cover super(SphericalFourierNeuralOperatorNet, self).__init__(meta=MetaData()) self.params = params self.spectral_transform = ( params.spectral_transform if hasattr(params, "spectral_transform") else spectral_transform ) self.grid = params.grid if hasattr(params, "grid") else grid self.filter_type = ( params.filter_type if hasattr(params, "filter_type") else filter_type ) self.operator_type = ( params.operator_type if hasattr(params, "operator_type") else operator_type ) self.inp_shape = ( (params.img_shape_x, params.img_shape_y) if hasattr(params, "img_shape_x") and hasattr(params, "img_shape_y") else inp_shape ) self.out_shape = ( (params.out_shape_x, params.out_shape_y) if hasattr(params, "out_shape_x") and hasattr(params, "out_shape_y") else self.inp_shape ) self.scale_factor = ( params.scale_factor if hasattr(params, "scale_factor") else scale_factor ) self.in_chans = ( params.N_in_channels if hasattr(params, "N_in_channels") else in_chans ) self.out_chans = ( params.N_out_channels if hasattr(params, "N_out_channels") else out_chans ) self.embed_dim = self.num_features = ( params.embed_dim if hasattr(params, "embed_dim") else embed_dim ) self.num_layers = ( params.num_layers if hasattr(params, "num_layers") else num_layers ) self.repeat_layers = ( params.repeat_layers if hasattr(params, "repeat_layers") else repeat_layers ) self.max_modes = ( (params.lmax, params.mmax) if hasattr(params, "lmax") and hasattr(params, "mmax") else max_modes ) self.hard_thresholding_fraction = ( params.hard_thresholding_fraction if hasattr(params, "hard_thresholding_fraction") else hard_thresholding_fraction ) self.normalization_layer = ( params.normalization_layer if hasattr(params, "normalization_layer") else normalization_layer ) self.use_mlp = params.use_mlp if hasattr(params, "use_mlp") else use_mlp self.mlp_ratio = params.mlp_ratio if hasattr(params, "mlp_ratio") else mlp_ratio self.activation_function = ( params.activation_function if hasattr(params, "activation_function") else activation_function ) self.encoder_layers = ( params.encoder_layers if hasattr(params, "encoder_layers") else encoder_layers ) self.pos_embed = params.pos_embed if hasattr(params, "pos_embed") else pos_embed self.big_skip = params.big_skip if hasattr(params, "big_skip") else big_skip self.rank = params.rank if hasattr(params, "rank") else rank self.factorization = ( params.factorization if hasattr(params, "factorization") else factorization ) self.separable = params.separable if hasattr(params, "separable") else separable self.complex_network = ( params.complex_network if hasattr(params, "complex_network") else complex_network ) self.complex_activation = ( params.complex_activation if hasattr(params, "complex_activation") else complex_activation ) self.spectral_layers = ( params.spectral_layers if hasattr(params, "spectral_layers") else spectral_layers ) self.output_transform = ( params.output_transform if hasattr(params, "output_transform") else output_transform ) self.checkpointing = ( params.checkpointing if hasattr(params, "checkpointing") else checkpointing ) # self.pretrain_encoding = params.pretrain_encoding if hasattr(params, "pretrain_encoding") else False # compute the downscaled image size self.h = int(self.inp_shape[0] // self.scale_factor) self.w = int(self.inp_shape[1] // self.scale_factor) # Compute the maximum frequencies in h and in w if self.max_modes is not None: modes_lat, modes_lon = self.max_modes else: modes_lat = int(self.h * self.hard_thresholding_fraction) modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction) # prepare the spectral transforms if self.spectral_transform == "sht": sht_handle = th.RealSHT isht_handle = th.InverseRealSHT # parallelism if (comm.get_size("h") > 1) or (comm.get_size("w") > 1): polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") azimuth_group = ( None if (comm.get_size("w") == 1) else comm.get_group("w") ) thd.init(polar_group, azimuth_group) sht_handle = thd.DistributedRealSHT isht_handle = thd.DistributedInverseRealSHT # set up self.trans_down = sht_handle( *self.inp_shape, lmax=modes_lat, mmax=modes_lon, grid="equiangular" ).float() self.itrans_up = isht_handle( *self.out_shape, lmax=modes_lat, mmax=modes_lon, grid="equiangular" ).float() self.trans = sht_handle( self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=self.grid ).float() self.itrans = isht_handle( self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=self.grid ).float() elif self.spectral_transform == "fft": fft_handle = th.RealFFT2 ifft_handle = th.InverseRealFFT2 # determine the global padding inp_dist_h = ( (self.inp_shape[0] + comm.get_size("h")) - 1 ) // comm.get_size("h") inp_dist_w = ( (self.inp_shape[1] + comm.get_size("w")) - 1 ) // comm.get_size("w") self.inp_padding = ( inp_dist_h * comm.get_size("h") - self.inp_shape[0], inp_dist_w * comm.get_size("w") - self.inp_shape[1], ) out_dist_h = ( (self.out_shape[0] + comm.get_size("h")) - 1 ) // comm.get_size("h") out_dist_w = ( (self.out_shape[1] + comm.get_size("w")) - 1 ) // comm.get_size("w") self.out_padding = ( out_dist_h * comm.get_size("h") - self.out_shape[0], out_dist_w * comm.get_size("w") - self.out_shape[1], ) # effective image size: self.inp_shape_eff = [ self.inp_shape[0] + self.inp_padding[0], self.inp_shape[1] + self.inp_padding[1], ] self.inp_shape_loc = [ self.inp_shape_eff[0] // comm.get_size("h"), self.inp_shape_eff[1] // comm.get_size("w"), ] self.out_shape_eff = [ self.out_shape[0] + self.out_padding[0], self.out_shape[1] + self.out_padding[1], ] self.out_shape_loc = [ self.out_shape_eff[0] // comm.get_size("h"), self.out_shape_eff[1] // comm.get_size("w"), ] if (comm.get_size("h") > 1) or (comm.get_size("w") > 1): fft_handle = DistributedRealFFT2 ifft_handle = DistributedInverseRealFFT2 self.trans_down = fft_handle( *self.inp_shape_eff, lmax=modes_lat, mmax=modes_lon ).float() self.itrans_up = ifft_handle( *self.out_shape_eff, lmax=modes_lat, mmax=modes_lon ).float() self.trans = fft_handle( self.h, self.w, lmax=modes_lat, mmax=modes_lon ).float() self.itrans = ifft_handle( self.h, self.w, lmax=modes_lat, mmax=modes_lon ).float() else: raise (ValueError("Unknown spectral transform")) # use the SHT/FFT to compute the local, downscaled grid dimensions if (comm.get_size("h") > 1) or (comm.get_size("w") > 1): self.inp_shape_loc = ( self.trans_down.nlat_local, self.trans_down.nlon_local, ) self.inp_shape_eff = [ self.trans_down.nlat_local + self.trans_down.nlatpad_local, self.trans_down.nlon_local + self.trans_down.nlonpad_local, ] self.h_loc = self.itrans.nlat_local self.w_loc = self.itrans.nlon_local else: self.inp_shape_loc = (self.trans_down.nlat, self.trans_down.nlon) self.inp_shape_eff = [self.trans_down.nlat, self.trans_down.nlon] self.h_loc = self.itrans.nlat self.w_loc = self.itrans.nlon # encoder if comm.get_size("matmul") > 1: self.encoder = DistributedEncoderDecoder( num_layers=self.encoder_layers, input_dim=self.in_chans, output_dim=self.embed_dim, hidden_dim=int(1 * self.embed_dim), act=self.activation_function, comm_inp_name="fin", comm_out_name="fout", ) fblock_mlp_inp_name = self.encoder.comm_out_name fblock_mlp_hidden_name = ( "fout" if (self.encoder.comm_out_name == "fin") else "fin" ) else: self.encoder = EncoderDecoder( num_layers=self.encoder_layers, input_dim=self.in_chans, output_dim=self.embed_dim, hidden_dim=int(1 * self.embed_dim), act=self.activation_function, ) fblock_mlp_inp_name = "fin" fblock_mlp_hidden_name = "fout" # dropout self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity() dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)] # pick norm layer if self.normalization_layer == "layer_norm": norm_layer_inp = partial( nn.LayerNorm, normalized_shape=(self.inp_shape_loc[0], self.inp_shape_loc[1]), eps=1e-6, ) norm_layer_mid = partial( nn.LayerNorm, normalized_shape=(self.h_loc, self.w_loc), eps=1e-6 ) norm_layer_out = partial( nn.LayerNorm, normalized_shape=(self.out_shape_loc[0], self.out_shape_loc[1]), eps=1e-6, ) elif self.normalization_layer == "instance_norm": if comm.get_size("spatial") > 1: norm_layer_inp = partial( DistributedInstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, ) else: norm_layer_inp = partial( nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False, ) norm_layer_out = norm_layer_mid = norm_layer_inp elif self.normalization_layer == "none": norm_layer_out = norm_layer_mid = norm_layer_inp = nn.Identity else: raise NotImplementedError( f"Error, normalization {self.normalization_layer} not implemented." ) # FNO blocks self.blocks = nn.ModuleList([]) for i in range(self.num_layers): first_layer = i == 0 last_layer = i == self.num_layers - 1 forward_transform = self.trans_down if first_layer else self.trans inverse_transform = self.itrans_up if last_layer else self.itrans inner_skip = "linear" outer_skip = "identity" if first_layer and last_layer: norm_layer = (norm_layer_inp, norm_layer_out) elif first_layer: norm_layer = (norm_layer_inp, norm_layer_mid) elif last_layer: norm_layer = (norm_layer_mid, norm_layer_out) else: norm_layer = (norm_layer_mid, norm_layer_mid) filter_type = self.filter_type operator_type = self.operator_type block = FourierNeuralOperatorBlock( forward_transform, inverse_transform, self.embed_dim, filter_type=filter_type, operator_type=operator_type, mlp_ratio=self.mlp_ratio, drop_rate=drop_rate, drop_path=dpr[i], act_layer=self.activation_function, norm_layer=norm_layer, sparsity_threshold=sparsity_threshold, use_complex_kernels=use_complex_kernels, inner_skip=inner_skip, outer_skip=outer_skip, use_mlp=self.use_mlp, comm_feature_inp_name=fblock_mlp_inp_name, comm_feature_hidden_name=fblock_mlp_hidden_name, rank=self.rank, factorization=self.factorization, separable=self.separable, complex_network=self.complex_network, complex_activation=self.complex_activation, spectral_layers=self.spectral_layers, checkpointing=self.checkpointing, ) self.blocks.append(block) # decoder if comm.get_size("matmul") > 1: comm_inp_name = fblock_mlp_inp_name comm_out_name = fblock_mlp_hidden_name self.decoder = DistributedEncoderDecoder( num_layers=self.encoder_layers, input_dim=self.embed_dim, output_dim=self.out_chans, hidden_dim=int(1 * self.embed_dim), act=self.activation_function, comm_inp_name=comm_inp_name, comm_out_name=comm_out_name, ) else: self.decoder = EncoderDecoder( num_layers=self.encoder_layers, input_dim=self.embed_dim + self.big_skip * self.out_chans, output_dim=self.out_chans, hidden_dim=int(1 * self.embed_dim), act=self.activation_function, ) # output transform if self.big_skip: self.residual_transform = nn.Conv2d( self.out_chans, self.out_chans, 1, bias=False ) # learned position embedding if self.pos_embed == "direct": # currently using deliberately a differently shape position embedding self.pos_embed = nn.Parameter( torch.zeros( 1, self.embed_dim, self.inp_shape_loc[0], self.inp_shape_loc[1] ) ) self.pos_embed.is_shared_mp = ["matmul"] self.pos_embed.sharded_dims_mp = [None, None, "h", "w"] self.pos_embed.type = "direct" trunc_normal_(self.pos_embed, std=0.02) elif self.pos_embed == "frequency": if (comm.get_size("h") > 1) or (comm.get_size("w") > 1): lmax_loc = self.itrans_up.lmax_local mmax_loc = self.itrans_up.mmax_local else: lmax_loc = self.itrans_up.lmax mmax_loc = self.itrans_up.mmax rcoeffs = nn.Parameter( torch.tril( torch.randn(1, self.embed_dim, lmax_loc, mmax_loc), diagonal=0 ) ) ccoeffs = nn.Parameter( torch.tril( torch.randn(1, self.embed_dim, lmax_loc, mmax_loc - 1), diagonal=-1 ) ) trunc_normal_(rcoeffs, std=0.02) trunc_normal_(ccoeffs, std=0.02) self.pos_embed = nn.ParameterList([rcoeffs, ccoeffs]) self.pos_embed.type = "frequency" elif self.pos_embed == "none" or self.pos_embed == "None": delattr(self, "pos_embed") else: raise ValueError("Unknown position embedding type") if self.output_transform: minmax_channels = [] for o, c in enumerate(params.out_channels): if params.channel_names[c][0] == "r": minmax_channels.append(o) self.register_buffer( "minmax_channels", torch.Tensor(minmax_channels).to(torch.long), persistent=False, ) self.apply(self._init_weights) def _init_weights(self, m): # pragma: no cover """Helper routine for weight initialization""" if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
[docs] @torch.jit.ignore def no_weight_decay(self): # pragma: no cover """Helper""" return {"pos_embed", "cls_token"}

def _forward_features(self, x): # pragma: no cover for r in range(self.repeat_layers): for blk in self.blocks: if self.checkpointing >= 3: x = checkpoint(blk, x) else: x = blk(x) return x

[docs] def forward(self, x): # pragma: no cover if comm.get_size("fin") > 1: x = scatter_to_parallel_region(x, "fin", 1) # save big skip if self.big_skip: # if output shape differs, use the spectral transforms to change resolution if self.out_shape != self.inp_shape: xtype = x.dtype # only take the predicted channels as residual residual = x[..., : self.out_chans, :, :].to(torch.float32) with amp.autocast(enabled=False): residual = self.trans_down(residual) residual = residual.contiguous() # residual = self.inverse_transform(residual) residual = self.itrans_up(residual) residual = residual.to(dtype=xtype) else: # only take the predicted channels residual = x[..., : self.out_chans, :, :] if self.checkpointing >= 1: x = checkpoint(self.encoder, x) else: x = self.encoder(x) if hasattr(self, "pos_embed"): if self.pos_embed.type == "frequency": pos_embed = torch.stack( [ self.pos_embed[0], nn.functional.pad(self.pos_embed[1], (1, 0), "constant", 0), ], dim=-1, ) with amp.autocast(enabled=False): pos_embed = self.itrans_up(torch.view_as_complex(pos_embed)) else: pos_embed = self.pos_embed # old way of treating unequally shaped weights if ( self.pos_embed.type == "direct" and self.inp_shape_loc != self.inp_shape_eff ): xp = torch.zeros_like(x) xp[..., : self.inp_shape_loc[0], : self.inp_shape_loc[1]] = ( x[..., : self.inp_shape_loc[0], : self.inp_shape_loc[1]] + pos_embed ) x = xp else: x = x + pos_embed # maybe clean the padding jsut in case x = self.pos_drop(x) # do the feature extraction x = self._forward_features(x) if self.big_skip: x = torch.cat((x, residual), dim=1) if self.checkpointing >= 1: x = checkpoint(self.decoder, x) else: x = self.decoder(x) if hasattr(self.decoder, "comm_out_name") and ( comm.get_size(self.decoder.comm_out_name) > 1 ): x = gather_from_parallel_region(x, self.decoder.comm_out_name, 1) if self.big_skip: x = x + self.residual_transform(residual) if self.output_transform: x[:, self.minmax_channels] = torch.sigmoid(x[:, self.minmax_channels]) return x
© Copyright 2023, NVIDIA Modulus Team. Last updated on Jan 25, 2024.