Source code for physicsnemo.models.dlwp.dlwp

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from dataclasses import dataclass
from typing import Callable, Sequence, Tuple, Union

import torch
import torch.nn as nn
from jaxtyping import Float

import physicsnemo  # noqa: F401 for docs
from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module
from physicsnemo.nn import get_activation


def _get_same_padding(x: int, k: int, s: int) -> int:
    r"""
    Compute the "same" padding size for a 1D convolution dimension.

    Parameters
    ----------
    x : int
        Input size.
    k : int
        Kernel size.
    s : int
        Stride size.

    Returns
    -------
    int
        Padding size for "same" output resolution.
    """
    return max(s * math.ceil(x / s) - s - x + k, 0)


def _pad_periodically_equatorial(
    main_face: Float[torch.Tensor, "batch channels height width"],
    left_face: Float[torch.Tensor, "batch channels height width"],
    right_face: Float[torch.Tensor, "batch channels height width"],
    top_face: Float[torch.Tensor, "batch channels height width"],
    bottom_face: Float[torch.Tensor, "batch channels height width"],
    nr_rot: int,
    size: int = 2,
) -> Float[torch.Tensor, "batch channels height_out width_out"]:
    r"""
    Periodically pad a cubed-sphere equatorial face using adjacent faces.

    Parameters
    ----------
    main_face : torch.Tensor
        Equatorial face tensor of shape :math:`(B, C, H, W)`.
    left_face : torch.Tensor
        Left neighbor face tensor of shape :math:`(B, C, H, W)`.
    right_face : torch.Tensor
        Right neighbor face tensor of shape :math:`(B, C, H, W)`.
    top_face : torch.Tensor
        Top neighbor face tensor of shape :math:`(B, C, H, W)`.
    bottom_face : torch.Tensor
        Bottom neighbor face tensor of shape :math:`(B, C, H, W)`.
    nr_rot : int
        Number of 90-degree rotations applied to the polar faces.
    size : int, optional, default=2
        Padding size applied along spatial dimensions.

    Returns
    -------
    torch.Tensor
        Padded face tensor of shape :math:`(B, C, H + 2p, W + 2p)`.
    """
    if nr_rot != 0:
        top_face = torch.rot90(top_face, k=nr_rot, dims=(-2, -1))
        bottom_face = torch.rot90(bottom_face, k=nr_rot, dims=(-1, -2))
    padded_data_temp = torch.cat(
        (left_face[..., :, -size:], main_face, right_face[..., :, :size]), dim=-1
    )
    top_pad = torch.cat(
        (top_face[..., :, :size], top_face, top_face[..., :, -size:]), dim=-1
    )  # hacky - extend on the left and right side
    bottom_pad = torch.cat(
        (bottom_face[..., :, :size], bottom_face, bottom_face[..., :, -size:]), dim=-1
    )  # hacky - extend on the left and right side
    padded_data = torch.cat(
        (bottom_pad[..., -size:, :], padded_data_temp, top_pad[..., :size, :]), dim=-2
    )
    return padded_data


def _pad_periodically_polar(
    main_face: Float[torch.Tensor, "batch channels height width"],
    left_face: Float[torch.Tensor, "batch channels height width"],
    right_face: Float[torch.Tensor, "batch channels height width"],
    top_face: Float[torch.Tensor, "batch channels height width"],
    bottom_face: Float[torch.Tensor, "batch channels height width"],
    rot_axis_left: tuple[int, int],
    rot_axis_right: tuple[int, int],
    size: int = 2,
) -> Float[torch.Tensor, "batch channels height_out width_out"]:
    r"""
    Periodically pad a cubed-sphere polar face using adjacent faces.

    Parameters
    ----------
    main_face : torch.Tensor
        Polar face tensor of shape :math:`(B, C, H, W)`.
    left_face : torch.Tensor
        Left neighbor face tensor of shape :math:`(B, C, H, W)`.
    right_face : torch.Tensor
        Right neighbor face tensor of shape :math:`(B, C, H, W)`.
    top_face : torch.Tensor
        Top neighbor face tensor of shape :math:`(B, C, H, W)`.
    bottom_face : torch.Tensor
        Bottom neighbor face tensor of shape :math:`(B, C, H, W)`.
    rot_axis_left : tuple[int, int]
        Rotation axes for the left neighbor face.
    rot_axis_right : tuple[int, int]
        Rotation axes for the right neighbor face.
    size : int, optional, default=2
        Padding size applied along spatial dimensions.

    Returns
    -------
    torch.Tensor
        Padded face tensor of shape :math:`(B, C, H + 2p, W + 2p)`.
    """
    left_face = torch.rot90(left_face, dims=rot_axis_left)
    right_face = torch.rot90(right_face, dims=rot_axis_right)
    padded_data_temp = torch.cat(
        (bottom_face[..., -size:, :], main_face, top_face[..., :size, :]), dim=-2
    )
    left_pad = torch.cat(
        (left_face[..., :size, :], left_face, left_face[..., -size:, :]), dim=-2
    )  # hacky - extend the left and right
    right_pad = torch.cat(
        (right_face[..., :size, :], right_face, right_face[..., -size:, :]), dim=-2
    )  # hacky - extend the left and right
    padded_data = torch.cat(
        (left_pad[..., :, -size:], padded_data_temp, right_pad[..., :, :size]), dim=-1
    )
    return padded_data


def _cubed_conv_wrapper(
    faces: Sequence[Float[torch.Tensor, "batch channels height width"]],
    equator_conv: nn.Conv2d,
    polar_conv: nn.Conv2d,
) -> list[Float[torch.Tensor, "batch channels height_out width_out"]]:
    r"""
    Apply face-wise convolution with cubed-sphere padding.

    Parameters
    ----------
    faces : Sequence[torch.Tensor]
        Sequence of six faces, each of shape :math:`(B, C, H, W)`.
    equator_conv : torch.nn.Conv2d
        Convolution applied to equatorial faces (indices 0-3).
    polar_conv : torch.nn.Conv2d
        Convolution applied to polar faces (indices 4-5).

    Returns
    -------
    list[torch.Tensor]
        List of six convolved faces, each of shape :math:`(B, C', H', W')`.
    """
    # compute the required padding
    padding_size = _get_same_padding(
        x=faces[0].size(-1), k=equator_conv.kernel_size[0], s=equator_conv.stride[0]
    )
    padding_size = padding_size // 2
    output = []
    if padding_size != 0:
        for i in range(6):
            if i == 0:
                x = _pad_periodically_equatorial(
                    faces[0],
                    faces[3],
                    faces[1],
                    faces[5],
                    faces[4],
                    nr_rot=0,
                    size=padding_size,
                )
                output.append(equator_conv(x))
            elif i == 1:
                x = _pad_periodically_equatorial(
                    faces[1],
                    faces[0],
                    faces[2],
                    faces[5],
                    faces[4],
                    nr_rot=1,
                    size=padding_size,
                )
                output.append(equator_conv(x))
            elif i == 2:
                x = _pad_periodically_equatorial(
                    faces[2],
                    faces[1],
                    faces[3],
                    faces[5],
                    faces[4],
                    nr_rot=2,
                    size=padding_size,
                )
                output.append(equator_conv(x))
            elif i == 3:
                x = _pad_periodically_equatorial(
                    faces[3],
                    faces[2],
                    faces[0],
                    faces[5],
                    faces[4],
                    nr_rot=3,
                    size=padding_size,
                )
                output.append(equator_conv(x))
            elif i == 4:
                x = _pad_periodically_polar(
                    faces[4],
                    faces[3],
                    faces[1],
                    faces[0],
                    faces[5],
                    rot_axis_left=(-1, -2),
                    rot_axis_right=(-2, -1),
                    size=padding_size,
                )
                output.append(polar_conv(x))
            else:  # i=5
                x = _pad_periodically_polar(
                    faces[5],
                    faces[3],
                    faces[1],
                    faces[4],
                    faces[0],
                    rot_axis_left=(-2, -1),
                    rot_axis_right=(-1, -2),
                    size=padding_size,
                )
                x = torch.flip(x, [-1])
                x = polar_conv(x)
                output.append(torch.flip(x, [-1]))
    else:
        for i in range(6):
            if i in [0, 1, 2, 3]:
                output.append(equator_conv(faces[i]))
            elif i == 4:
                output.append(polar_conv(faces[i]))
            else:  # i=5
                x = torch.flip(faces[i], [-1])
                x = polar_conv(x)
                output.append(torch.flip(x, [-1]))

    return output


def _cubed_non_conv_wrapper(
    faces: Sequence[Float[torch.Tensor, "batch channels height width"]],
    layer: Callable[
        [Float[torch.Tensor, "batch channels height width"]],
        Float[torch.Tensor, "batch channels height width"],
    ],
) -> list[Float[torch.Tensor, "batch channels height width"]]:
    r"""
    Apply a non-convolutional layer to each cubed-sphere face.

    Parameters
    ----------
    faces : Sequence[torch.Tensor]
        Sequence of six faces, each of shape :math:`(B, C, H, W)`.
    layer : Callable[[torch.Tensor], torch.Tensor]
        Callable applied independently to each face tensor.

    Returns
    -------
    list[torch.Tensor]
        List of transformed faces, each of shape :math:`(B, C', H', W')`.
    """
    return [layer(face) for face in faces]


@dataclass
class MetaData(ModelMetaData):
    # Optimization
    jit: bool = False
    cuda_graphs: bool = True
    amp_cpu: bool = True
    amp_gpu: bool = True
    # Inference
    onnx: bool = False
    # Physics informed
    var_dim: int = 1
    func_torch: bool = False
    auto_grad: bool = False


[docs] class DLWP(Module): r""" Convolutional U-Net for Deep Learning Weather Prediction on cubed-sphere grids. This model operates on cubed-sphere data with six faces and applies face-aware padding so that convolutions respect cubed-sphere connectivity. Based on `Weyn et al. (2021) <https://agupubs.onlinelibrary.wiley.com/doi/10.1029/2021MS002502>`_. Parameters ---------- nr_input_channels : int Number of input channels :math:`C_{in}`. nr_output_channels : int Number of output channels :math:`C_{out}`. nr_initial_channels : int, optional, default=64 Number of channels in the first convolution block :math:`C_{init}`. Defaults to 64. activation_fn : str, optional, default="leaky_relu" Activation name resolved with :func:`~physicsnemo.nn.get_activation`. Defaults to "leaky_relu". depth : int, optional, default=2 Depth of the U-Net encoder/decoder stacks. Defaults to 2. clamp_activation : Tuple[float | int | None, float | int | None], optional, default=(None, 10.0) Minimum and maximum bounds applied via ``torch.clamp`` after activation. Defaults to (None, 10.0). Forward ------- cubed_sphere_input : torch.Tensor Input tensor of shape :math:`(B, C_{in}, F, H, W)` with :math:`F=6` faces. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{out}, F, H, W)`. Examples -------- >>> import torch >>> from physicsnemo.models import DLWP >>> model = DLWP(nr_input_channels=2, nr_output_channels=4) >>> inputs = torch.randn(4, 2, 6, 64, 64) >>> outputs = model(inputs) >>> outputs.shape torch.Size([4, 4, 6, 64, 64]) """ def __init__( self, nr_input_channels: int, nr_output_channels: int, nr_initial_channels: int = 64, activation_fn: str = "leaky_relu", depth: int = 2, clamp_activation: Tuple[Union[float, int, None], Union[float, int, None]] = ( None, 10.0, ), ) -> None: r"""Initialize the DLWP model.""" super().__init__(meta=MetaData()) self.nr_input_channels = nr_input_channels self.nr_output_channels = nr_output_channels self.nr_initial_channels = nr_initial_channels self.activation_fn = get_activation(activation_fn) self.depth = depth self.clamp_activation = clamp_activation # define layers # define non-convolutional layers self.avg_pool = nn.AvgPool2d(2) self.upsample_layer = nn.Upsample(scale_factor=2) # define layers self.equatorial_downsample = [] self.equatorial_upsample = [] self.equatorial_mid_layers = [] self.polar_downsample = [] self.polar_upsample = [] self.polar_mid_layers = [] for i in range(depth): if i == 0: ins = self.nr_input_channels else: ins = self.nr_initial_channels * (2 ** (i - 1)) outs = self.nr_initial_channels * (2 ** (i)) self.equatorial_downsample.append(nn.Conv2d(ins, outs, kernel_size=3)) self.polar_downsample.append(nn.Conv2d(ins, outs, kernel_size=3)) self.equatorial_downsample.append(nn.Conv2d(outs, outs, kernel_size=3)) self.polar_downsample.append(nn.Conv2d(outs, outs, kernel_size=3)) for i in range(2): if i == 0: ins = outs outs = ins * 2 else: ins = outs outs = ins // 2 self.equatorial_mid_layers.append(nn.Conv2d(ins, outs, kernel_size=3)) self.polar_mid_layers.append(nn.Conv2d(ins, outs, kernel_size=3)) for i in range(depth - 1, -1, -1): if i == 0: outs = self.nr_initial_channels outs_final = outs else: outs = self.nr_initial_channels * (2 ** (i)) outs_final = outs // 2 ins = outs * 2 self.equatorial_upsample.append(nn.Conv2d(ins, outs, kernel_size=3)) self.polar_upsample.append(nn.Conv2d(ins, outs, kernel_size=3)) self.equatorial_upsample.append(nn.Conv2d(outs, outs_final, kernel_size=3)) self.polar_upsample.append(nn.Conv2d(outs, outs_final, kernel_size=3)) self.equatorial_downsample = nn.ModuleList(self.equatorial_downsample) self.polar_downsample = nn.ModuleList(self.polar_downsample) self.equatorial_mid_layers = nn.ModuleList(self.equatorial_mid_layers) self.polar_mid_layers = nn.ModuleList(self.polar_mid_layers) self.equatorial_upsample = nn.ModuleList(self.equatorial_upsample) self.polar_upsample = nn.ModuleList(self.polar_upsample) self.equatorial_last = nn.Conv2d(outs, self.nr_output_channels, kernel_size=1) self.polar_last = nn.Conv2d(outs, self.nr_output_channels, kernel_size=1) # define activation layers
[docs] def activation( self, x: Float[torch.Tensor, "batch channels height width"] ) -> Float[torch.Tensor, "batch channels height width"]: r""" Apply activation and optional clamping to a face tensor. Parameters ---------- x : torch.Tensor Input face tensor of shape :math:`(B, C, H, W)`. Returns ------- torch.Tensor Activated face tensor of shape :math:`(B, C, H, W)`. """ x = self.activation_fn(x) if any(isinstance(c, (float, int)) for c in self.clamp_activation): x = torch.clamp( x, min=self.clamp_activation[0], max=self.clamp_activation[1] ) return x
[docs] def forward( self, cubed_sphere_input: Float[torch.Tensor, "batch channels faces height width"], ) -> Float[torch.Tensor, "batch channels_out faces height width"]: r"""Apply the DLWP forward pass to cubed-sphere input data.""" # Input validation (skip under torch.compile) if not torch.compiler.is_compiling(): if cubed_sphere_input.ndim != 5: raise ValueError( "Expected input tensor of shape (B, C, F, H, W) but got tensor of " f"shape {tuple(cubed_sphere_input.shape)}" ) batch, channels, faces_count, height, width = cubed_sphere_input.shape if channels != self.nr_input_channels: raise ValueError( f"Expected input tensor with {self.nr_input_channels} channels but " f"got {channels} channels" ) if faces_count != 6: raise ValueError( "Expected input tensor of shape (B, C, 6, H, W) but got tensor of " f"shape {tuple(cubed_sphere_input.shape)}" ) if height != width: raise ValueError( "Expected input tensor of shape (B, C, F, H, H) but got tensor of " f"shape {tuple(cubed_sphere_input.shape)}" ) # Split cubed-sphere input into individual faces faces = torch.split( cubed_sphere_input, split_size_or_sections=1, dim=2 ) # (B, C, 1, H, W) faces = [torch.squeeze(face, dim=2) for face in faces] # (B, C, H, W) encoder_states = [] # Encoder: per-face convolutions with downsampling for i, (equatorial_layer, polar_layer) in enumerate( zip(self.equatorial_downsample, self.polar_downsample) ): faces = _cubed_conv_wrapper(faces, equatorial_layer, polar_layer) faces = _cubed_non_conv_wrapper(faces, self.activation) if i % 2 != 0: encoder_states.append(faces) faces = _cubed_non_conv_wrapper(faces, self.avg_pool) # Bottleneck convolutions for i, (equatorial_layer, polar_layer) in enumerate( zip(self.equatorial_mid_layers, self.polar_mid_layers) ): faces = _cubed_conv_wrapper(faces, equatorial_layer, polar_layer) faces = _cubed_non_conv_wrapper(faces, self.activation) j = 0 # Decoder: upsample, concatenate skip connections, and convolve for i, (equatorial_layer, polar_layer) in enumerate( zip(self.equatorial_upsample, self.polar_upsample) ): if i % 2 == 0: encoder_faces = encoder_states[len(encoder_states) - j - 1] faces = _cubed_non_conv_wrapper(faces, self.upsample_layer) faces = [ torch.cat((face_1, face_2), dim=1) # (B, 2*C, H, W) for face_1, face_2 in zip(faces, encoder_faces) ] j += 1 faces = _cubed_conv_wrapper(faces, equatorial_layer, polar_layer) faces = _cubed_non_conv_wrapper(faces, self.activation) # Final face-wise projection and reassembly faces = _cubed_conv_wrapper(faces, self.equatorial_last, self.polar_last) output = torch.stack(faces, dim=2) return output