# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
from dataclasses import dataclass
import modulus
from modulus.models.meta import ModelMetaData
from modulus.models.module import Module
from typing import Tuple, Union
Tensor = torch.Tensor
def _get_same_padding(x: int, k: int, s: int) -> int:
"""Function to compute "same" padding. Inspired from:
https://github.com/huggingface/pytorch-image-models/blob/0.5.x/timm/models/layers/padding.py
"""
return max(s * math.ceil(x / s) - s - x + k, 0)
def _pad_periodically_equatorial(
main_face, left_face, right_face, top_face, bottom_face, nr_rot, size=2
):
if nr_rot != 0:
top_face = torch.rot90(top_face, k=nr_rot, dims=(-2, -1))
bottom_face = torch.rot90(bottom_face, k=nr_rot, dims=(-1, -2))
padded_data_temp = torch.cat(
(left_face[..., :, -size:], main_face, right_face[..., :, :size]), dim=-1
)
top_pad = torch.cat(
(top_face[..., :, :size], top_face, top_face[..., :, -size:]), dim=-1
) # hacky - extend on the left and right side
bottom_pad = torch.cat(
(bottom_face[..., :, :size], bottom_face, bottom_face[..., :, -size:]), dim=-1
) # hacky - extend on the left and right side
padded_data = torch.cat(
(bottom_pad[..., -size:, :], padded_data_temp, top_pad[..., :size, :]), dim=-2
)
return padded_data
def _pad_periodically_polar(
main_face,
left_face,
right_face,
top_face,
bottom_face,
rot_axis_left,
rot_axis_right,
size=2,
):
left_face = torch.rot90(left_face, dims=rot_axis_left)
right_face = torch.rot90(right_face, dims=rot_axis_right)
padded_data_temp = torch.cat(
(bottom_face[..., -size:, :], main_face, top_face[..., :size, :]), dim=-2
)
left_pad = torch.cat(
(left_face[..., :size, :], left_face, left_face[..., -size:, :]), dim=-2
) # hacky - extend the left and right
right_pad = torch.cat(
(right_face[..., :size, :], right_face, right_face[..., -size:, :]), dim=-2
) # hacky - extend the left and right
padded_data = torch.cat(
(left_pad[..., :, -size:], padded_data_temp, right_pad[..., :, :size]), dim=-1
)
return padded_data
def _cubed_conv_wrapper(faces, equator_conv, polar_conv):
# compute the required padding
padding_size = _get_same_padding(
x=faces[0].size(-1), k=equator_conv.kernel_size[0], s=equator_conv.stride[0]
)
padding_size = padding_size // 2
output = []
if padding_size != 0:
for i in range(6):
if i == 0:
x = _pad_periodically_equatorial(
faces[0],
faces[3],
faces[1],
faces[5],
faces[4],
nr_rot=0,
size=padding_size,
)
output.append(equator_conv(x))
elif i == 1:
x = _pad_periodically_equatorial(
faces[1],
faces[0],
faces[2],
faces[5],
faces[4],
nr_rot=1,
size=padding_size,
)
output.append(equator_conv(x))
elif i == 2:
x = _pad_periodically_equatorial(
faces[2],
faces[1],
faces[3],
faces[5],
faces[4],
nr_rot=2,
size=padding_size,
)
output.append(equator_conv(x))
elif i == 3:
x = _pad_periodically_equatorial(
faces[3],
faces[2],
faces[0],
faces[5],
faces[4],
nr_rot=3,
size=padding_size,
)
output.append(equator_conv(x))
elif i == 4:
x = _pad_periodically_polar(
faces[4],
faces[3],
faces[1],
faces[0],
faces[5],
rot_axis_left=(-1, -2),
rot_axis_right=(-2, -1),
size=padding_size,
)
output.append(polar_conv(x))
else: # i=5
x = _pad_periodically_polar(
faces[5],
faces[3],
faces[1],
faces[4],
faces[0],
rot_axis_left=(-2, -1),
rot_axis_right=(-1, -2),
size=padding_size,
)
x = torch.flip(x, [-1])
x = polar_conv(x)
output.append(torch.flip(x, [-1]))
else:
for i in range(6):
if i in [0, 1, 2, 3]:
output.append(equator_conv(faces[i]))
elif i == 4:
output.append(polar_conv(faces[i]))
else: # i=5
x = torch.flip(faces[i], [-1])
x = polar_conv(x)
output.append(torch.flip(x, [-1]))
return output
def _cubed_non_conv_wrapper(faces, layer):
output = []
for i in range(6):
output.append(layer(faces[i]))
return output
[docs]class DLWP(Module):
"""A Convolutional model for Deep Learning Weather Prediction that
works on Cubed-sphere grids.
This model expects the input to be of shape [N, C, 6, Res, Res]
Parameters
----------
nr_input_channels : int
Number of channels in the input
nr_output_channels : int
Number of channels in the output
nr_initial_channels : int
Number of channels in the initial convolution. This governs the overall channels
in the model.
activation_fn : nn.Module
Activation function for the convolutions
depth : int
Depth for the U-Net
clamp_activation : Tuple of ints, floats or None
The min and max value used for torch.clamp()
Example
-------
>>> model = modulus.models.dlwp.DLWP(
... nr_input_channels=2,
... nr_output_channels=4,
... )
>>> input = torch.randn(4, 2, 6, 64, 64) # [N, C, F, Res, Res]
>>> output = model(input)
>>> output.size()
torch.Size([4, 4, 6, 64, 64])
Note
----
Reference: Weyn, Jonathan A., et al. "Sub‐seasonal forecasting with a large ensemble
of deep‐learning weather prediction models." Journal of Advances in Modeling Earth
Systems 13.7 (2021): e2021MS002502.
"""
def __init__(
self,
nr_input_channels: int,
nr_output_channels: int,
nr_initial_channels: int = 64,
activation_fn: nn.Module = nn.LeakyReLU(0.1),
depth: int = 2,
clamp_activation: Tuple[Union[float, int, None], Union[float, int, None]] = (
None,
10.0,
),
):
super().__init__(meta=MetaData())
self.nr_input_channels = nr_input_channels
self.nr_output_channels = nr_output_channels
self.nr_initial_channels = nr_initial_channels
self.activation_fn = activation_fn
self.depth = depth
self.clamp_activation = clamp_activation
# define layers
# define non-convolutional layers
self.avg_pool = nn.AvgPool2d(2)
self.upsample_layer = nn.Upsample(scale_factor=2)
# define layers
self.equatorial_downsample = []
self.equatorial_upsample = []
self.equatorial_mid_layers = []
self.polar_downsample = []
self.polar_upsample = []
self.polar_mid_layers = []
for i in range(depth):
if i == 0:
ins = self.nr_input_channels
else:
ins = self.nr_initial_channels * (2 ** (i - 1))
outs = self.nr_initial_channels * (2 ** (i))
self.equatorial_downsample.append(nn.Conv2d(ins, outs, kernel_size=3))
self.polar_downsample.append(nn.Conv2d(ins, outs, kernel_size=3))
self.equatorial_downsample.append(nn.Conv2d(outs, outs, kernel_size=3))
self.polar_downsample.append(nn.Conv2d(outs, outs, kernel_size=3))
for i in range(2):
if i == 0:
ins = outs
outs = ins * 2
else:
ins = outs
outs = ins // 2
self.equatorial_mid_layers.append(nn.Conv2d(ins, outs, kernel_size=3))
self.polar_mid_layers.append(nn.Conv2d(ins, outs, kernel_size=3))
for i in range(depth - 1, -1, -1):
if i == 0:
outs = self.nr_initial_channels
outs_final = outs
else:
outs = self.nr_initial_channels * (2 ** (i))
outs_final = outs // 2
ins = outs * 2
self.equatorial_upsample.append(nn.Conv2d(ins, outs, kernel_size=3))
self.polar_upsample.append(nn.Conv2d(ins, outs, kernel_size=3))
self.equatorial_upsample.append(nn.Conv2d(outs, outs_final, kernel_size=3))
self.polar_upsample.append(nn.Conv2d(outs, outs_final, kernel_size=3))
self.equatorial_downsample = nn.ModuleList(self.equatorial_downsample)
self.polar_downsample = nn.ModuleList(self.polar_downsample)
self.equatorial_mid_layers = nn.ModuleList(self.equatorial_mid_layers)
self.polar_mid_layers = nn.ModuleList(self.polar_mid_layers)
self.equatorial_upsample = nn.ModuleList(self.equatorial_upsample)
self.polar_upsample = nn.ModuleList(self.polar_upsample)
self.equatorial_last = nn.Conv2d(outs, self.nr_output_channels, kernel_size=1)
self.polar_last = nn.Conv2d(outs, self.nr_output_channels, kernel_size=1)
# define activation layers
def activation(self, x: Tensor):
x = self.activation_fn(x)
if any(isinstance(c, (float, int)) for c in self.clamp_activation):
x = torch.clamp(
x, min=self.clamp_activation[0], max=self.clamp_activation[1]
)
return x
[docs] def forward(self, cubed_sphere_input):
# do some input checks
assert cubed_sphere_input.size(2) == 6, "The input must have 6 faces."
assert cubed_sphere_input.size(3) == cubed_sphere_input.size(
4
), "The input must have equal height and width"
# split the cubed_sphere_input into individual faces
faces = torch.split(
cubed_sphere_input, split_size_or_sections=1, dim=2
) # split along face dim
faces = [torch.squeeze(face, dim=2) for face in faces]
encoder_states = []
for i, (equatorial_layer, polar_layer) in enumerate(
zip(self.equatorial_downsample, self.polar_downsample)
):
faces = _cubed_conv_wrapper(faces, equatorial_layer, polar_layer)
faces = _cubed_non_conv_wrapper(faces, self.activation)
if i % 2 != 0:
encoder_states.append(faces)
faces = _cubed_non_conv_wrapper(faces, self.avg_pool)
for i, (equatorial_layer, polar_layer) in enumerate(
zip(self.equatorial_mid_layers, self.polar_mid_layers)
):
faces = _cubed_conv_wrapper(faces, equatorial_layer, polar_layer)
faces = _cubed_non_conv_wrapper(faces, self.activation)
j = 0
for i, (equatorial_layer, polar_layer) in enumerate(
zip(self.equatorial_upsample, self.polar_upsample)
):
if i % 2 == 0:
encoder_faces = encoder_states[len(encoder_states) - j - 1]
faces = _cubed_non_conv_wrapper(faces, self.upsample_layer)
faces = [
torch.cat((face_1, face_2), dim=1)
for face_1, face_2 in zip(faces, encoder_faces)
]
j += 1
faces = _cubed_conv_wrapper(faces, equatorial_layer, polar_layer)
faces = _cubed_non_conv_wrapper(faces, self.activation)
faces = _cubed_conv_wrapper(faces, self.equatorial_last, self.polar_last)
output = torch.stack(faces, dim=2)
return output