# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import torch
from jaxtyping import Float
from torch import nn
from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module
from physicsnemo.nn import (
ConvBlock,
CubeEmbedding,
SwinTransformer,
)
@dataclass
class MetaData(ModelMetaData):
# Optimization
jit: bool = False # ONNX Ops Conflict
cuda_graphs: bool = True
amp: bool = True
# Inference
onnx_cpu: bool = False # No FFT op on CPU
onnx_gpu: bool = True
onnx_runtime: bool = True
# Physics informed
var_dim: int = 1
func_torch: bool = False
auto_grad: bool = False
[docs]
class SwinRNN(Module):
r"""
SwinRNN weather forecasting model.
This implementation follows `SwinRNN
<https://arxiv.org/abs/2205.13158>`_.
Parameters
----------
img_size : tuple[int, int, int], optional, default=(2, 721, 1440)
Input size as :math:`(T, H, W)`, where :math:`T` is the number of
input timesteps.
patch_size : tuple[int, int, int], optional, default=(2, 4, 4)
Patch size as :math:`(p_t, p_h, p_w)` for cube embedding.
in_chans : int, optional, default=70
Number of input channels.
out_chans : int, optional, default=70
Number of output channels.
embed_dim : int, optional, default=1536
Embedding channel size used by Swin blocks.
num_groups : int, optional, default=32
Number of channel groups for convolutional blocks.
num_heads : int, optional, default=8
Number of attention heads.
window_size : int, optional, default=7
Local window size of Swin transformer blocks.
Forward
-------
x : torch.Tensor
Input tensor of shape :math:`(B, C_{in}, T, H, W)`.
Outputs
-------
torch.Tensor
Predicted tensor of shape :math:`(B, C_{out}, H, W)`.
"""
def __init__(
self,
img_size: tuple[int, int, int] = (2, 721, 1440),
patch_size: tuple[int, int, int] = (2, 4, 4),
in_chans: int = 70,
out_chans: int = 70,
embed_dim: int = 1536,
num_groups: int = 32,
num_heads: int = 8,
window_size: int = 7,
) -> None:
super().__init__(meta=MetaData())
self.img_size = tuple(img_size)
self.patch_size = tuple(patch_size)
self.in_chans = in_chans
self.out_chans = out_chans
self.embed_dim = embed_dim
input_resolution = img_size[1:]
self.cube_embedding = CubeEmbedding(img_size, patch_size, in_chans, embed_dim)
self.swin_block1 = SwinTransformer(
embed_dim, input_resolution, num_heads, window_size, depth=2
)
self.down1 = ConvBlock(embed_dim, embed_dim, num_groups, upsample=-1)
self.down1x = ConvBlock(in_chans, in_chans, in_chans, upsample=-1)
self.lin_proj1 = nn.Linear(embed_dim + in_chans, embed_dim)
self.swin_decoder1 = SwinTransformer(
embed_dim, input_resolution, num_heads, window_size, depth=12
)
input_resolution = (input_resolution[0] // 2, input_resolution[1] // 2)
self.swin_block2 = SwinTransformer(
embed_dim, input_resolution, num_heads, window_size, depth=2
)
self.down2 = ConvBlock(embed_dim, embed_dim, num_groups, upsample=-1)
self.down2x = ConvBlock(in_chans, in_chans, in_chans, upsample=-1)
self.lin_proj2 = nn.Linear(embed_dim + in_chans, embed_dim)
self.swin_decoder2 = SwinTransformer(
embed_dim, input_resolution, num_heads, window_size, depth=12
)
input_resolution = (input_resolution[0] // 2, input_resolution[1] // 2)
self.swin_block3 = SwinTransformer(
embed_dim, input_resolution, num_heads, window_size, depth=2
)
self.down3 = ConvBlock(embed_dim, embed_dim, num_groups, upsample=-1)
self.down3x = ConvBlock(in_chans, in_chans, in_chans, upsample=-1)
self.lin_proj3 = nn.Linear(embed_dim + in_chans, embed_dim)
self.swin_decoder3 = SwinTransformer(
embed_dim, input_resolution, num_heads, window_size, depth=12
)
input_resolution = (input_resolution[0] // 2, input_resolution[1] // 2)
self.swin_block4 = SwinTransformer(
embed_dim, input_resolution, num_heads, window_size, depth=2
)
self.lin_proj4 = nn.Linear(embed_dim + in_chans, embed_dim)
self.swin_decoder4 = SwinTransformer(
embed_dim, input_resolution, num_heads, window_size, depth=12
)
self.up3x = ConvBlock(embed_dim, embed_dim, num_groups, upsample=1)
self.up2x = ConvBlock(embed_dim * 2, embed_dim, num_groups, upsample=1)
self.up1x = ConvBlock(embed_dim * 2, embed_dim, num_groups, upsample=1)
self.pred = ConvBlock(embed_dim * 2, out_chans, out_chans, upsample=0)
self.input_resolution = input_resolution
[docs]
def forward(
self,
x: Float[torch.Tensor, "batch in_chans time lat lon"],
) -> Float[torch.Tensor, "batch out_chans lat lon"]:
r"""
Run SwinRNN forward prediction.
Parameters
----------
x : torch.Tensor
Input tensor of shape :math:`(B, C_{in}, T, H, W)`.
Returns
-------
torch.Tensor
Prediction tensor of shape :math:`(B, C_{out}, H, W)`.
"""
if not torch.compiler.is_compiling():
if x.ndim != 5:
raise ValueError(
f"Expected 'x' to be a 5D tensor, got {x.ndim}D tensor with shape {tuple(x.shape)}"
)
if x.shape[1] != self.in_chans:
raise ValueError(
f"Expected 'x' to have {self.in_chans} channels, got tensor with shape {tuple(x.shape)}"
)
if x.shape[2] != self.img_size[0]:
raise ValueError(
f"Expected 'x' time dimension {self.img_size[0]}, got tensor with shape {tuple(x.shape)}"
)
if x.shape[3:] != self.img_size[1:]:
raise ValueError(
f"Expected 'x' spatial shape {self.img_size[1:]}, got tensor with shape {tuple(x.shape)}"
)
xT = x[:, :, -1, :, :]
x = self.cube_embedding(x).squeeze(2) # B C Lat Lon
h1 = self.swin_block1(x)
x = self.down1(h1)
h2 = self.swin_block2(x)
x = self.down2(h2)
h3 = self.swin_block3(x)
x = self.down3(h3)
h4 = self.swin_block4(x)
B, Cin, H, W = xT.shape
h1_d = torch.cat(
[xT.reshape(B, Cin, -1), h1.reshape(B, self.embed_dim, -1)], dim=1
).transpose(1, 2)
h1_d = self.lin_proj1(h1_d).transpose(1, 2).reshape(B, self.embed_dim, H, W)
h1_d = self.swin_decoder1(h1_d)
h1 = h1 + h1_d
x2T = self.down1x(xT)
B, Cin, H, W = x2T.shape
h2_d = torch.cat(
[x2T.reshape(B, Cin, -1), h2.reshape(B, self.embed_dim, -1)], dim=1
).transpose(1, 2)
h2_d = self.lin_proj2(h2_d).transpose(1, 2).reshape(B, self.embed_dim, H, W)
h2_d = self.swin_decoder2(h2_d)
h2 = h2 + h2_d
x3T = self.down2x(x2T)
B, Cin, H, W = x3T.shape
h3_d = torch.cat(
[x3T.reshape(B, Cin, -1), h3.reshape(B, self.embed_dim, -1)], dim=1
).transpose(1, 2)
h3_d = self.lin_proj3(h3_d).transpose(1, 2).reshape(B, self.embed_dim, H, W)
h3_d = self.swin_decoder3(h3_d)
h3 = h3 + h3_d
x4T = self.down3x(x3T)
B, Cin, H, W = x4T.shape
h4_d = torch.cat(
[x4T.reshape(B, Cin, -1), h4.reshape(B, self.embed_dim, -1)], dim=1
).transpose(1, 2)
h4_d = self.lin_proj4(h4_d).transpose(1, 2).reshape(B, self.embed_dim, H, W)
h4_d = self.swin_decoder4(h4_d)
h4 = h4 + h4_d
h4_up = self.up3x(h4)
h3_up = self.up2x(torch.cat([h3, h4_up], dim=1))
h2_up = self.up1x(torch.cat([h2, h3_up], dim=1))
h1_up = self.pred(torch.cat([h1, h2_up], dim=1))
x_h1 = xT + h1_up
return x_h1