# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from torch.cuda import amp
from dataclasses import dataclass
from typing import Any, Tuple

# import contractions
from modulus.models.sfno.factorizations import get_contract_fun, _contract_dense

# helpers
from modulus.models.sfno.layers import (
    trunc_normal_,
    DropPath,
    MLP,
    EncoderDecoder,
)

# import global convolution and non-linear spectral layers
from modulus.models.sfno.layers import (
    SpectralConv2d,
    SpectralAttention2d,
    SpectralAttentionS2,
)

from modulus.models.sfno.s2convolutions import SpectralConvS2

# get spectral transforms from torch_harmonics
import torch_harmonics as th
import torch_harmonics.distributed as thd

# wrap fft, to unify interface to spectral transforms
from modulus.models.sfno.layers import RealFFT2, InverseRealFFT2
from modulus.utils.sfno.distributed.layers import (
    DistributedRealFFT2,
    DistributedInverseRealFFT2,
    DistributedMLP,
    DistributedEncoderDecoder,
)

# more distributed stuff
from modulus.utils.sfno.distributed import comm

# layer normalization
from apex.normalization import FusedLayerNorm
from modulus.utils.sfno.distributed.layer_norm import DistributedInstanceNorm2d

from modulus.models.module import Module
from modulus.models.meta import ModelMetaData



[docs]@dataclass
class MetaData(ModelMetaData):
    name: str = "SFNO"
    # Optimization
    jit: bool = False
    cuda_graphs: bool = True
    amp_cpu: bool = True
    amp_gpu: bool = True
    torch_fx: bool = False
    # Inference
    onnx: bool = False
    # Physics informed
    func_torch: bool = False
    auto_grad: bool = False




[docs]class SpectralFilterLayer(nn.Module):
    """Spectral filter layer"""

    def __init__(
        self,
        forward_transform,
        inverse_transform,
        embed_dim,
        filter_type="linear",
        operator_type="diagonal",
        sparsity_threshold=0.0,
        use_complex_kernels=True,
        hidden_size_factor=1,
        rank=1.0,
        factorization=None,
        separable=False,
        complex_network=True,
        complex_activation="real",
        spectral_layers=1,
        drop_rate=0.0,
    ):  # pragma: no cover
        super(SpectralFilterLayer, self).__init__()

        if filter_type == "non-linear" and (
            isinstance(forward_transform, th.RealSHT)
            or isinstance(forward_transform, thd.DistributedRealSHT)
        ):
            self.filter = SpectralAttentionS2(
                forward_transform,
                inverse_transform,
                embed_dim,
                sparsity_threshold=sparsity_threshold,
                hidden_size_factor=hidden_size_factor,
                complex_activation=complex_activation,
                spectral_layers=spectral_layers,
                drop_rate=drop_rate,
                bias=False,
            )

        elif filter_type == "non-linear" and (
            isinstance(forward_transform, RealFFT2)
            or isinstance(forward_transform, DistributedRealFFT2)
        ):
            self.filter = SpectralAttention2d(
                forward_transform,
                inverse_transform,
                embed_dim,
                sparsity_threshold=sparsity_threshold,
                hidden_size_factor=hidden_size_factor,
                complex_activation=complex_activation,
                spectral_layers=spectral_layers,
                drop_rate=drop_rate,
                bias=False,
            )

        # spectral transform is passed to the module
        elif filter_type == "linear":
            self.filter = SpectralConvS2(
                forward_transform,
                inverse_transform,
                embed_dim,
                embed_dim,
                operator_type=operator_type,
                rank=rank,
                factorization=factorization,
                separable=separable,
                bias=False,
                use_tensorly=False if factorization is None else True,
            )

        else:
            raise (NotImplementedError)


[docs]    def forward(self, x):  # pragma: no cover
        return self.filter(x)




[docs]class FourierNeuralOperatorBlock(nn.Module):
    """Fourier Neural Operator Block"""

    def __init__(
        self,
        forward_transform,
        inverse_transform,
        embed_dim,
        filter_type="linear",
        operator_type="diagonal",
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=(nn.LayerNorm, nn.LayerNorm),
        sparsity_threshold=0.0,
        use_complex_kernels=True,
        rank=1.0,
        factorization=None,
        separable=False,
        inner_skip="linear",
        outer_skip=None,  # None, nn.linear or nn.Identity
        use_mlp=False,
        comm_feature_inp_name=None,
        comm_feature_hidden_name=None,
        complex_network=True,
        complex_activation="real",
        spectral_layers=1,
        checkpointing=0,
    ):  # pragma: no cover
        super(FourierNeuralOperatorBlock, self).__init__()

        if (comm.get_size("h") > 1) or (comm.get_size("w") > 1):
            self.input_shape_loc = (
                forward_transform.nlat_local,
                forward_transform.nlon_local,
            )
            self.output_shape_loc = (
                inverse_transform.nlat_local,
                inverse_transform.nlon_local,
            )
        else:
            self.input_shape_loc = (forward_transform.nlat, forward_transform.nlon)
            self.output_shape_loc = (inverse_transform.nlat, inverse_transform.nlon)

        # norm layer
        self.norm0 = norm_layer[0]()

        # convolution layer
        self.filter = SpectralFilterLayer(
            forward_transform,
            inverse_transform,
            embed_dim,
            filter_type,
            operator_type,
            sparsity_threshold,
            use_complex_kernels=use_complex_kernels,
            hidden_size_factor=mlp_ratio,
            rank=rank,
            factorization=factorization,
            separable=separable,
            complex_network=complex_network,
            complex_activation=complex_activation,
            spectral_layers=spectral_layers,
            drop_rate=drop_rate,
        )

        if inner_skip == "linear":
            self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
        elif inner_skip == "identity":
            self.inner_skip = nn.Identity()

        if filter_type == "linear" or filter_type == "real linear":
            self.act_layer = act_layer()

        # dropout
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        # norm layer
        self.norm1 = norm_layer[1]()

        if use_mlp == True:
            MLPH = DistributedMLP if (comm.get_size("matmul") > 1) else MLP
            mlp_hidden_dim = int(embed_dim * mlp_ratio)
            self.mlp = MLPH(
                in_features=embed_dim,
                hidden_features=mlp_hidden_dim,
                act_layer=act_layer,
                drop_rate=drop_rate,
                comm_inp_name=comm_feature_inp_name,
                comm_hidden_name=comm_feature_hidden_name,
                checkpointing=checkpointing,
            )

        if outer_skip == "linear":
            self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
        elif outer_skip == "identity":
            self.outer_skip = nn.Identity()


[docs]    def forward(self, x):  # pragma: no cover
        x_norm = torch.zeros_like(x)
        x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = self.norm0(
            x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]]
        )
        x, residual = self.filter(x_norm)

        if hasattr(self, "inner_skip"):
            x = x + self.inner_skip(residual)

        if hasattr(self, "act_layer"):
            x = self.act_layer(x)

        x_norm = torch.zeros_like(x)
        x_norm[
            ..., : self.output_shape_loc[0], : self.output_shape_loc[1]
        ] = self.norm1(x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]])
        x = x_norm

        if hasattr(self, "mlp"):
            x = self.mlp(x)

        x = self.drop_path(x)

        if hasattr(self, "outer_skip"):
            x = x + self.outer_skip(residual)

        return x




[docs]class SphericalFourierNeuralOperatorNet(Module):
    """
    Spherical Fourier Neural Operator Network

    Parameters
    ----------
    params : dict
        Dictionary of parameters
    spectral_transform : str, optional
        Type of spectral transformation to use, by default "sht"
    grid : str, optional
        Type of grid to use, by default "legendre-gauss"
    filter_type : str, optional
        Type of filter to use ('linear', 'non-linear'), by default "non-linear"
    operator_type : str, optional
        Type of operator to use ('diaginal', 'dhconv'), by default "diagonal"
    inp_shape : tuple, optional
        Shape of the input channels, by default (721, 1440)
    scale_factor : int, optional
        Scale factor to use, by default 16
    in_chans : int, optional
        Number of input channels, by default 2
    out_chans : int, optional
        Number of output channels, by default 2
    embed_dim : int, optional
        Dimension of the embeddings, by default 256
    num_layers : int, optional
        Number of layers in the network, by default 12
    repeat_layers : int, optional
        Number of times to repeat the layers, by default 1
    use_mlp : int, optional
        Whether to use MLP, by default True
    mlp_ratio : int, optional
        Ratio of MLP to use, by default 2.0
    activation_function : str, optional
        Activation function to use, by default "gelu"
    encoder_layers : int, optional
        Number of layers in the encoder, by default 1
    pos_embed : str, optional
        Type of positional embedding to use, by default "direct"
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
        Dropout path rate, by default 0.0
    sparsity_threshold : float, optional
        Threshold for sparsity, by default 0.0
    normalization_layer : str, optional
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
    max_modes : Any, optional
        Maximum modes to use, by default None
    hard_thresholding_fraction : float, optional
        Fraction of hard thresholding to apply, by default 1.0
    use_complex_kernels : bool, optional
        Whether to use complex kernels, by default True
    big_skip : bool, optional
        Whether to use big skip connections, by default True
    rank : float, optional
        Rank of the approximation, by default 1.0
    factorization : Any, optional
        Type of factorization to use, by default None
    separable : bool, optional
        Whether to use separable convolutions, by default False
    complex_network : bool, optional
        Whether to use a complex network architecture, by default True
    complex_activation : str, optional
        Type of complex activation function to use, by default "real"
    spectral_layers : int, optional
        Number of spectral layers, by default 3
    output_transform : bool, optional
        Whether to use an output transform, by default False
    checkpointing : int, optional
        Number of checkpointing segments, by default 0

    Example:
    --------
    >>> from modulus.models.sfno.sfnonet import SphericalFourierNeuralOperatorNet as SFNO
    >>> model = SFNO(
    ...         params={},
    ...         inp_shape=(8, 16),
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
    ...         num_layers=2,
    ...         encoder_layers=1,
    ...         spectral_layers=2,
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 8, 16)).shape
    torch.Size([1, 2, 8, 16])
    """

    def __init__(
        self,
        params: dict,
        spectral_transform: str = "sht",
        grid="legendre-gauss",
        filter_type: str = "non-linear",
        operator_type: str = "diagonal",
        inp_shape: Tuple[int] = (721, 1440),
        scale_factor: int = 16,
        in_chans: int = 2,
        out_chans: int = 2,
        embed_dim: int = 256,
        num_layers: int = 12,
        repeat_layers=1,
        use_mlp: int = True,
        mlp_ratio: int = 2.0,
        activation_function: str = "gelu",
        encoder_layers: int = 1,
        pos_embed: str = "direct",
        drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        sparsity_threshold: float = 0.0,
        normalization_layer: str = "instance_norm",
        max_modes: Any = None,
        hard_thresholding_fraction: float = 1.0,
        use_complex_kernels: bool = True,
        big_skip: bool = True,
        rank: float = 1.0,
        factorization: Any = None,
        separable: bool = False,
        complex_network: bool = True,
        complex_activation: str = "real",
        spectral_layers: int = 3,
        output_transform: bool = False,
        checkpointing: int = 0,
    ):  # pragma: no cover
        super(SphericalFourierNeuralOperatorNet, self).__init__(meta=MetaData())

        self.params = params
        self.spectral_transform = (
            params.spectral_transform
            if hasattr(params, "spectral_transform")
            else spectral_transform
        )
        self.grid = params.grid if hasattr(params, "grid") else grid
        self.filter_type = (
            params.filter_type if hasattr(params, "filter_type") else filter_type
        )
        self.operator_type = (
            params.operator_type if hasattr(params, "operator_type") else operator_type
        )
        self.inp_shape = (
            (params.img_shape_x, params.img_shape_y)
            if hasattr(params, "img_shape_x") and hasattr(params, "img_shape_y")
            else inp_shape
        )
        self.out_shape = (
            (params.out_shape_x, params.out_shape_y)
            if hasattr(params, "out_shape_x") and hasattr(params, "out_shape_y")
            else self.inp_shape
        )
        self.scale_factor = (
            params.scale_factor if hasattr(params, "scale_factor") else scale_factor
        )
        self.in_chans = (
            params.N_in_channels if hasattr(params, "N_in_channels") else in_chans
        )
        self.out_chans = (
            params.N_out_channels if hasattr(params, "N_out_channels") else out_chans
        )
        self.embed_dim = self.num_features = (
            params.embed_dim if hasattr(params, "embed_dim") else embed_dim
        )
        self.num_layers = (
            params.num_layers if hasattr(params, "num_layers") else num_layers
        )
        self.repeat_layers = (
            params.repeat_layers if hasattr(params, "repeat_layers") else repeat_layers
        )
        self.max_modes = (
            (params.lmax, params.mmax)
            if hasattr(params, "lmax") and hasattr(params, "mmax")
            else max_modes
        )

        self.hard_thresholding_fraction = (
            params.hard_thresholding_fraction
            if hasattr(params, "hard_thresholding_fraction")
            else hard_thresholding_fraction
        )
        self.normalization_layer = (
            params.normalization_layer
            if hasattr(params, "normalization_layer")
            else normalization_layer
        )
        self.use_mlp = params.use_mlp if hasattr(params, "use_mlp") else use_mlp
        self.mlp_ratio = params.mlp_ratio if hasattr(params, "mlp_ratio") else mlp_ratio
        self.activation_function = (
            params.activation_function
            if hasattr(params, "activation_function")
            else activation_function
        )
        self.encoder_layers = (
            params.encoder_layers
            if hasattr(params, "encoder_layers")
            else encoder_layers
        )
        self.pos_embed = params.pos_embed if hasattr(params, "pos_embed") else pos_embed
        self.big_skip = params.big_skip if hasattr(params, "big_skip") else big_skip
        self.rank = params.rank if hasattr(params, "rank") else rank
        self.factorization = (
            params.factorization if hasattr(params, "factorization") else factorization
        )
        self.separable = params.separable if hasattr(params, "separable") else separable
        self.complex_network = (
            params.complex_network
            if hasattr(params, "complex_network")
            else complex_network
        )
        self.complex_activation = (
            params.complex_activation
            if hasattr(params, "complex_activation")
            else complex_activation
        )
        self.spectral_layers = (
            params.spectral_layers
            if hasattr(params, "spectral_layers")
            else spectral_layers
        )
        self.output_transform = (
            params.output_transform
            if hasattr(params, "output_transform")
            else output_transform
        )
        self.checkpointing = (
            params.checkpointing if hasattr(params, "checkpointing") else checkpointing
        )
        # self.pretrain_encoding = params.pretrain_encoding if hasattr(params, "pretrain_encoding") else False

        # compute the downscaled image size
        self.h = int(self.inp_shape[0] // self.scale_factor)
        self.w = int(self.inp_shape[1] // self.scale_factor)

        # Compute the maximum frequencies in h and in w
        if self.max_modes is not None:
            modes_lat, modes_lon = self.max_modes
        else:
            modes_lat = int(self.h * self.hard_thresholding_fraction)
            modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)

        # prepare the spectral transforms
        if self.spectral_transform == "sht":
            sht_handle = th.RealSHT
            isht_handle = th.InverseRealSHT

            # parallelism
            if (comm.get_size("h") > 1) or (comm.get_size("w") > 1):
                polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h")
                azimuth_group = (
                    None if (comm.get_size("w") == 1) else comm.get_group("w")
                )
                thd.init(polar_group, azimuth_group)
                sht_handle = thd.DistributedRealSHT
                isht_handle = thd.DistributedInverseRealSHT

            # set up
            self.trans_down = sht_handle(
                *self.inp_shape, lmax=modes_lat, mmax=modes_lon, grid="equiangular"
            ).float()
            self.itrans_up = isht_handle(
                *self.out_shape, lmax=modes_lat, mmax=modes_lon, grid="equiangular"
            ).float()
            self.trans = sht_handle(
                self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=self.grid
            ).float()
            self.itrans = isht_handle(
                self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=self.grid
            ).float()

        elif self.spectral_transform == "fft":
            fft_handle = th.RealFFT2
            ifft_handle = th.InverseRealFFT2

            # determine the global padding
            inp_dist_h = (
                (self.inp_shape[0] + comm.get_size("h")) - 1
            ) // comm.get_size("h")
            inp_dist_w = (
                (self.inp_shape[1] + comm.get_size("w")) - 1
            ) // comm.get_size("w")
            self.inp_padding = (
                inp_dist_h * comm.get_size("h") - self.inp_shape[0],
                inp_dist_w * comm.get_size("w") - self.inp_shape[1],
            )
            out_dist_h = (
                (self.out_shape[0] + comm.get_size("h")) - 1
            ) // comm.get_size("h")
            out_dist_w = (
                (self.out_shape[1] + comm.get_size("w")) - 1
            ) // comm.get_size("w")
            self.out_padding = (
                out_dist_h * comm.get_size("h") - self.out_shape[0],
                out_dist_w * comm.get_size("w") - self.out_shape[1],
            )
            # effective image size:
            self.inp_shape_eff = [
                self.inp_shape[0] + self.inp_padding[0],
                self.inp_shape[1] + self.inp_padding[1],
            ]
            self.inp_shape_loc = [
                self.inp_shape_eff[0] // comm.get_size("h"),
                self.inp_shape_eff[1] // comm.get_size("w"),
            ]
            self.out_shape_eff = [
                self.out_shape[0] + self.out_padding[0],
                self.out_shape[1] + self.out_padding[1],
            ]
            self.out_shape_loc = [
                self.out_shape_eff[0] // comm.get_size("h"),
                self.out_shape_eff[1] // comm.get_size("w"),
            ]

            if (comm.get_size("h") > 1) or (comm.get_size("w") > 1):
                fft_handle = DistributedRealFFT2
                ifft_handle = DistributedInverseRealFFT2

            self.trans_down = fft_handle(
                *self.inp_shape_eff, lmax=modes_lat, mmax=modes_lon
            ).float()
            self.itrans_up = ifft_handle(
                *self.out_shape_eff, lmax=modes_lat, mmax=modes_lon
            ).float()
            self.trans = fft_handle(
                self.h, self.w, lmax=modes_lat, mmax=modes_lon
            ).float()
            self.itrans = ifft_handle(
                self.h, self.w, lmax=modes_lat, mmax=modes_lon
            ).float()
        else:
            raise (ValueError("Unknown spectral transform"))

        # use the SHT/FFT to compute the local, downscaled grid dimensions
        if (comm.get_size("h") > 1) or (comm.get_size("w") > 1):
            self.inp_shape_loc = (
                self.trans_down.nlat_local,
                self.trans_down.nlon_local,
            )
            self.inp_shape_eff = [
                self.trans_down.nlat_local + self.trans_down.nlatpad_local,
                self.trans_down.nlon_local + self.trans_down.nlonpad_local,
            ]
            self.h_loc = self.itrans.nlat_local
            self.w_loc = self.itrans.nlon_local
        else:
            self.inp_shape_loc = (self.trans_down.nlat, self.trans_down.nlon)
            self.inp_shape_eff = [self.trans_down.nlat, self.trans_down.nlon]
            self.h_loc = self.itrans.nlat
            self.w_loc = self.itrans.nlon

        # determine activation function
        if self.activation_function == "relu":
            self.activation_function = nn.ReLU
        elif self.activation_function == "gelu":
            self.activation_function = nn.GELU
        elif self.activation_function == "silu":
            self.activation_function = nn.SiLU
        else:
            raise ValueError(f"Unknown activation function {self.activation_function}")

        # encoder
        if comm.get_size("matmul") > 1:
            self.encoder = DistributedEncoderDecoder(
                num_layers=self.encoder_layers,
                input_dim=self.in_chans,
                output_dim=self.embed_dim,
                hidden_dim=int(1 * self.embed_dim),
                act=self.activation_function,
                comm_inp_name="fin",
                comm_out_name="fout",
            )
            fblock_mlp_inp_name = self.encoder.comm_out_name
            fblock_mlp_hidden_name = (
                "fout" if (self.encoder.comm_out_name == "fin") else "fin"
            )
        else:
            self.encoder = EncoderDecoder(
                num_layers=self.encoder_layers,
                input_dim=self.in_chans,
                output_dim=self.embed_dim,
                hidden_dim=int(1 * self.embed_dim),
                act=self.activation_function,
            )
            fblock_mlp_inp_name = "fin"
            fblock_mlp_hidden_name = "fout"

        # dropout
        self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]

        # pick norm layer
        if self.normalization_layer == "layer_norm":
            norm_layer_inp = partial(
                nn.LayerNorm,
                normalized_shape=(self.inp_shape_loc[0], self.inp_shape_loc[1]),
                eps=1e-6,
            )
            norm_layer_mid = partial(
                nn.LayerNorm, normalized_shape=(self.h_loc, self.w_loc), eps=1e-6
            )
            norm_layer_out = partial(
                nn.LayerNorm,
                normalized_shape=(self.out_shape_loc[0], self.out_shape_loc[1]),
                eps=1e-6,
            )
        elif self.normalization_layer == "instance_norm":
            if comm.get_size("spatial") > 1:
                norm_layer_inp = partial(
                    DistributedInstanceNorm2d,
                    num_features=self.embed_dim,
                    eps=1e-6,
                    affine=True,
                )
            else:
                norm_layer_inp = partial(
                    nn.InstanceNorm2d,
                    num_features=self.embed_dim,
                    eps=1e-6,
                    affine=True,
                    track_running_stats=False,
                )
            norm_layer_out = norm_layer_mid = norm_layer_inp
        elif self.normalization_layer == "none":
            norm_layer_out = norm_layer_mid = norm_layer_inp = nn.Identity
        else:
            raise NotImplementedError(
                f"Error, normalization {self.normalization_layer} not implemented."
            )

        # FNO blocks
        self.blocks = nn.ModuleList([])
        for i in range(self.num_layers):

            first_layer = i == 0
            last_layer = i == self.num_layers - 1

            forward_transform = self.trans_down if first_layer else self.trans
            inverse_transform = self.itrans_up if last_layer else self.itrans

            inner_skip = "linear"
            outer_skip = "identity"

            if first_layer and last_layer:
                norm_layer = (norm_layer_inp, norm_layer_out)
            elif first_layer:
                norm_layer = (norm_layer_inp, norm_layer_mid)
            elif last_layer:
                norm_layer = (norm_layer_mid, norm_layer_out)
            else:
                norm_layer = (norm_layer_mid, norm_layer_mid)

            filter_type = self.filter_type

            operator_type = self.operator_type

            block = FourierNeuralOperatorBlock(
                forward_transform,
                inverse_transform,
                self.embed_dim,
                filter_type=filter_type,
                operator_type=operator_type,
                mlp_ratio=self.mlp_ratio,
                drop_rate=drop_rate,
                drop_path=dpr[i],
                act_layer=self.activation_function,
                norm_layer=norm_layer,
                sparsity_threshold=sparsity_threshold,
                use_complex_kernels=use_complex_kernels,
                inner_skip=inner_skip,
                outer_skip=outer_skip,
                use_mlp=self.use_mlp,
                comm_feature_inp_name=fblock_mlp_inp_name,
                comm_feature_hidden_name=fblock_mlp_hidden_name,
                rank=self.rank,
                factorization=self.factorization,
                separable=self.separable,
                complex_network=self.complex_network,
                complex_activation=self.complex_activation,
                spectral_layers=self.spectral_layers,
                checkpointing=self.checkpointing,
            )

            self.blocks.append(block)

        # decoder
        if comm.get_size("matmul") > 1:
            comm_inp_name = fblock_mlp_inp_name
            comm_out_name = fblock_mlp_hidden_name
            self.decoder = DistributedEncoderDecoder(
                num_layers=self.encoder_layers,
                input_dim=self.embed_dim,
                output_dim=self.out_chans,
                hidden_dim=int(1 * self.embed_dim),
                act=self.activation_function,
                comm_inp_name=comm_inp_name,
                comm_out_name=comm_out_name,
            )
        else:
            self.decoder = EncoderDecoder(
                num_layers=self.encoder_layers,
                input_dim=self.embed_dim + self.big_skip * self.out_chans,
                output_dim=self.out_chans,
                hidden_dim=int(1 * self.embed_dim),
                act=self.activation_function,
            )

        # output transform
        if self.big_skip:
            self.residual_transform = nn.Conv2d(
                self.out_chans, self.out_chans, 1, bias=False
            )

        # learned position embedding
        if self.pos_embed == "direct":
            # currently using deliberately a differently shape position embedding
            self.pos_embed = nn.Parameter(
                torch.zeros(
                    1, self.embed_dim, self.inp_shape_loc[0], self.inp_shape_loc[1]
                )
            )
            self.pos_embed.is_shared_mp = ["matmul"]
            self.pos_embed.sharded_dims_mp = [None, None, "h", "w"]
            self.pos_embed.type = "direct"
            trunc_normal_(self.pos_embed, std=0.02)
        elif self.pos_embed == "frequency":
            if (comm.get_size("h") > 1) or (comm.get_size("w") > 1):
                lmax_loc = self.itrans_up.lmax_local
                mmax_loc = self.itrans_up.mmax_local
            else:
                lmax_loc = self.itrans_up.lmax
                mmax_loc = self.itrans_up.mmax

            rcoeffs = nn.Parameter(
                torch.tril(
                    torch.randn(1, self.embed_dim, lmax_loc, mmax_loc), diagonal=0
                )
            )
            ccoeffs = nn.Parameter(
                torch.tril(
                    torch.randn(1, self.embed_dim, lmax_loc, mmax_loc - 1), diagonal=-1
                )
            )
            trunc_normal_(rcoeffs, std=0.02)
            trunc_normal_(ccoeffs, std=0.02)
            self.pos_embed = nn.ParameterList([rcoeffs, ccoeffs])
            self.pos_embed.type = "frequency"

        elif self.pos_embed == "none" or self.pos_embed == "None":
            delattr(self, "pos_embed")
        else:
            raise ValueError("Unknown position embedding type")

        if self.output_transform:
            minmax_channels = []
            for o, c in enumerate(params.out_channels):
                if params.channel_names[c][0] == "r":
                    minmax_channels.append(o)
            self.register_buffer(
                "minmax_channels",
                torch.Tensor(minmax_channels).to(torch.long),
                persistent=False,
            )

        self.apply(self._init_weights)

    def _init_weights(self, m):  # pragma: no cover
        """Helper routine for weight initialization"""
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)


[docs]    @torch.jit.ignore
    def no_weight_decay(self):  # pragma: no cover
        """Helper"""
        return {"pos_embed", "cls_token"}
def _forward_features(self, x):  # pragma: no cover

        for r in range(self.repeat_layers):
            for blk in self.blocks:
                if self.checkpointing >= 3:
                    x = checkpoint(blk, x)
                else:
                    x = blk(x)

        return x




    
[docs]    def forward(self, x):  # pragma: no cover

        if comm.get_size("fin") > 1:
            x = scatter_to_parallel_region(x, "fin", 1)

        # save big skip
        if self.big_skip:
            # if output shape differs, use the spectral transforms to change resolution
            if self.out_shape != self.inp_shape:
                xtype = x.dtype
                # only take the predicted channels as residual
                residual = x[..., : self.out_chans, :, :].to(torch.float32)
                with amp.autocast(enabled=False):
                    residual = self.trans_down(residual)
                    residual = residual.contiguous()
                    # residual = self.inverse_transform(residual)
                    residual = self.itrans_up(residual)
                    residual = residual.to(dtype=xtype)
            else:
                # only take the predicted channels
                residual = x[..., : self.out_chans, :, :]

        if self.checkpointing >= 1:
            x = checkpoint(self.encoder, x)
        else:
            x = self.encoder(x)

        if hasattr(self, "pos_embed"):

            if self.pos_embed.type == "frequency":

                pos_embed = torch.stack(
                    [
                        self.pos_embed[0],
                        nn.functional.pad(self.pos_embed[1], (1, 0), "constant", 0),
                    ],
                    dim=-1,
                )
                with amp.autocast(enabled=False):
                    pos_embed = self.itrans_up(torch.view_as_complex(pos_embed))
            else:
                pos_embed = self.pos_embed

            # old way of treating unequally shaped weights
            if (
                self.pos_embed.type == "direct"
                and self.inp_shape_loc != self.inp_shape_eff
            ):
                xp = torch.zeros_like(x)
                xp[..., : self.inp_shape_loc[0], : self.inp_shape_loc[1]] = (
                    x[..., : self.inp_shape_loc[0], : self.inp_shape_loc[1]] + pos_embed
                )
                x = xp
            else:
                x = x + pos_embed

        # maybe clean the padding jsut in case
        x = self.pos_drop(x)

        # do the feature extraction
        x = self._forward_features(x)

        if self.big_skip:
            x = torch.cat((x, residual), dim=1)

        if self.checkpointing >= 1:
            x = checkpoint(self.decoder, x)
        else:
            x = self.decoder(x)

        if hasattr(self.decoder, "comm_out_name") and (
            comm.get_size(self.decoder.comm_out_name) > 1
        ):
            x = gather_from_parallel_region(x, self.decoder.comm_out_name, 1)

        if self.big_skip:
            x = x + self.residual_transform(residual)

        if self.output_transform:
            x[:, self.minmax_channels] = torch.sigmoid(x[:, self.minmax_channels])

        return x
