deeplearning/modulus/modulus-core/_modules/modulus/models/sfno/s2convolutions.html

Source code for modulus.models.sfno.s2convolutions

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.cuda import amp

# import FactorizedTensor from tensorly for tensorized operations
import tensorly as tl

tl.set_backend("pytorch")
# from tensorly.plugins import use_opt_einsum
# use_opt_einsum('optimal')
from tltorch.factorized_tensors.core import FactorizedTensor

# import convenience functions for factorized tensors
from modulus.models.sfno.activations import ComplexReLU
from modulus.models.sfno.contractions import compl_muladd2d_fwd, compl_mul2d_fwd
from modulus.models.sfno.contractions import _contract_localconv_fwd
from modulus.models.sfno.contractions import (
    _contract_blockconv_fwd,
    _contractadd_blockconv_fwd,
)
from modulus.models.sfno.factorizations import get_contract_fun

# for the experimental module
from modulus.models.sfno.contractions import (
    compl_exp_muladd2d_fwd,
    compl_exp_mul2d_fwd,
    real_mul2d_fwd,
    real_muladd2d_fwd,
)

import torch_harmonics as th
import torch_harmonics.distributed as thd


[docs]class SpectralConvS2(nn.Module): """ Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2 using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic domain via the RealFFT2 and InverseRealFFT2 wrappers. """ def __init__( self, forward_transform, inverse_transform, in_channels, out_channels, scale="auto", operator_type="diagonal", rank=0.2, factorization=None, separable=False, decomposition_kwargs=dict(), bias=False, use_tensorly=True, ): # pragma: no cover super(SpectralConvS2, self).__init__() if scale == "auto": # heuristic scale = 2 / (in_channels + out_channels) self.forward_transform = forward_transform self.inverse_transform = inverse_transform self.modes_lat = self.inverse_transform.lmax self.modes_lon = self.inverse_transform.mmax self.scale_residual = ( self.forward_transform.nlat != self.inverse_transform.nlat ) or (self.forward_transform.nlon != self.inverse_transform.nlon) if hasattr(self.forward_transform, "grid"): self.scale_residual = self.scale_residual or ( self.forward_transform.grid != self.inverse_transform.grid ) # Make sure we are using a Complex Factorized Tensor if factorization is None: factorization = "ComplexDense" # No factorization complex_weight = factorization[:7].lower() == "complex" # remember factorization details self.operator_type = operator_type self.rank = rank self.factorization = factorization self.separable = separable assert self.inverse_transform.lmax == self.modes_lat assert self.inverse_transform.mmax == self.modes_lon weight_shape = [in_channels] if not self.separable: weight_shape += [out_channels] if isinstance(self.inverse_transform, thd.DistributedInverseRealSHT): self.modes_lat_local = self.inverse_transform.lmax_local self.modes_lon_local = self.inverse_transform.mmax_local self.lpad_local = self.inverse_transform.lpad_local self.mpad_local = self.inverse_transform.mpad_local else: self.modes_lat_local = self.modes_lat self.modes_lon_local = self.modes_lon self.lpad = 0 self.mpad = 0 # unpadded weights if self.operator_type == "diagonal": weight_shape += [self.modes_lat_local, self.modes_lon_local] elif self.operator_type == "dhconv": weight_shape += [self.modes_lat_local] else: raise ValueError(f"Unsupported operator type f{self.operator_type}") if use_tensorly: # form weight tensors self.weight = FactorizedTensor.new( weight_shape, rank=self.rank, factorization=factorization, fixed_rank_modes=False, **decomposition_kwargs, ) # initialization of weights self.weight.normal_(0, scale) else: if complex_weight: init = scale * torch.randn(*weight_shape, 2) self.weight = nn.Parameter(init) else: init = scale * torch.randn(*weight_shape) self.weight = nn.Parameter(init) if self.operator_type == "dhconv": self.weight.is_shared_mp = ["matmul", "w"] self.weight.sharded_dims_mp = [None for _ in weight_shape] self.weight.sharded_dims_mp[-1] = "h" else: self.weight.is_shared_mp = ["matmul"] self.weight.sharded_dims_mp = [None for _ in weight_shape] self.weight.sharded_dims_mp[-1] = "w" self.weight.sharded_dims_mp[-2] = "h" # get the contraction handle self._contract = get_contract_fun( self.weight, implementation="factorized", separable=separable, complex=complex_weight, operator_type=operator_type, ) if bias: self.bias = nn.Parameter(scale * torch.zeros(1, out_channels, 1, 1))
[docs] def forward(self, x): # pragma: no cover dtype = x.dtype residual = x x = x.float() B, C, H, W = x.shape with amp.autocast(enabled=False): x = self.forward_transform(x) if self.scale_residual: x = x.contiguous() residual = self.inverse_transform(x) residual = residual.to(dtype) # approach with unpadded weights xp = torch.zeros_like(x) xp[..., : self.modes_lat_local, : self.modes_lon_local] = self._contract( x[..., : self.modes_lat_local, : self.modes_lon_local], self.weight, separable=self.separable, operator_type=self.operator_type, ) x = xp.contiguous() with amp.autocast(enabled=False): x = self.inverse_transform(x) if hasattr(self, "bias"): x = x + self.bias x = x.type(dtype) return x, residual
[docs]class SpectralAttentionS2(nn.Module): """ Spherical non-linear FNO layer """ def __init__( self, forward_transform, inverse_transform, embed_dim, operator_type="diagonal", sparsity_threshold=0.0, hidden_size_factor=2, complex_activation="real", scale="auto", bias=False, spectral_layers=1, drop_rate=0.0, ): # pragma: no cover super(SpectralAttentionS2, self).__init__() self.embed_dim = embed_dim self.sparsity_threshold = sparsity_threshold self.operator_type = operator_type self.spectral_layers = spectral_layers if scale == "auto": self.scale = 1 / (embed_dim * embed_dim) self.modes_lat = forward_transform.lmax self.modes_lon = forward_transform.mmax # only storing the forward handle to be able to call it self.forward_transform = forward_transform self.inverse_transform = inverse_transform self.scale_residual = ( (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon) or (self.forward_transform.grid != self.inverse_transform.grid) ) assert inverse_transform.lmax == self.modes_lat assert inverse_transform.mmax == self.modes_lon hidden_size = int(hidden_size_factor * self.embed_dim) if operator_type == "diagonal": self.mul_add_handle = compl_muladd2d_fwd self.mul_handle = compl_mul2d_fwd # weights w = [self.scale * torch.randn(self.embed_dim, hidden_size, 2)] for l in range(1, self.spectral_layers): w.append(self.scale * torch.randn(hidden_size, hidden_size, 2)) self.w = nn.ParameterList(w) self.wout = nn.Parameter( self.scale * torch.randn(hidden_size, self.embed_dim, 2) ) if bias: self.b = nn.ParameterList( [ self.scale * torch.randn(hidden_size, 1, 1, 2) for _ in range(self.spectral_layers) ] ) self.activations = nn.ModuleList([]) for l in range(0, self.spectral_layers): self.activations.append( ComplexReLU( mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=self.scale, ) ) elif operator_type == "l-dependant": self.mul_add_handle = compl_exp_muladd2d_fwd self.mul_handle = compl_exp_mul2d_fwd # weights w = [ self.scale * torch.randn(self.modes_lat, self.embed_dim, hidden_size, 2) ] for l in range(1, self.spectral_layers): w.append( self.scale * torch.randn(self.modes_lat, hidden_size, hidden_size, 2) ) self.w = nn.ParameterList(w) if bias: self.b = nn.ParameterList( [ self.scale * torch.randn(hidden_size, 1, 1, 2) for _ in range(self.spectral_layers) ] ) self.wout = nn.Parameter( self.scale * torch.randn(self.modes_lat, hidden_size, self.embed_dim, 2) ) self.activations = nn.ModuleList([]) for l in range(0, self.spectral_layers): self.activations.append( ComplexReLU( mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=self.scale, ) ) else: raise ValueError("Unknown operator type") self.drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity()
[docs] def forward_mlp(self, x): # pragma: no cover """forward pass of the MLP""" B, C, H, W = x.shape xr = torch.view_as_real(x) for l in range(self.spectral_layers): if hasattr(self, "b"): xr = self.mul_add_handle(xr, self.w[l], self.b[l]) else: xr = self.mul_handle(xr, self.w[l]) xr = torch.view_as_complex(xr) xr = self.activations[l](xr) xr = self.drop(xr) xr = torch.view_as_real(xr) # final MLP x = self.mul_handle(xr, self.wout) x = torch.view_as_complex(x) return x
[docs] def forward(self, x): # pragma: no cover dtype = x.dtype residual = x x = x.to(torch.float32) # FWD transform with amp.autocast(enabled=False): x = self.forward_transform(x) if self.scale_residual: x = x.contiguous() residual = self.inverse_transform(x) residual = residual.to(dtype) # MLP x = self.forward_mlp(x) # BWD transform x = x.contiguous() with amp.autocast(enabled=False): x = self.inverse_transform(x) # cast back to initial precision x = x.to(dtype) return x, residual
© Copyright 2023, NVIDIA Modulus Team. Last updated on Sep 22, 2023.