# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import Any, Dict, Sequence
import pandas as pd
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig
from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module
from physicsnemo.nn.module.hpx import HEALPixFoldFaces, HEALPixUnfoldFaces
from .layers import _legacy_hydra_targets_warning, _remap_obj
logger = logging.getLogger(__name__)
@dataclass
class MetaData(ModelMetaData):
r"""Metadata for the DLWP HEALPix recurrent model."""
# Optimization
jit: bool = False
cuda_graphs: bool = True
amp_cpu: bool = True
amp_gpu: bool = True
# Inference
onnx: bool = False
# Physics informed
var_dim: int = 1
func_torch: bool = False
auto_grad: bool = False
[docs]
class HEALPixRecUNet(Module):
r"""
Deep Learning Weather Prediction (DLWP) recurrent UNet on the HEALPix mesh.
Parameters
----------
encoder : DictConfig
Instantiable configuration for the U-Net encoder block.
decoder : DictConfig
Instantiable configuration for the U-Net decoder block.
input_channels : int
Number of prognostic input channels per time step.
output_channels : int
Number of prognostic output channels per time step.
n_constants : int
Number of constant channels provided for all faces.
decoder_input_channels : int
Number of prescribed decoder input channels per time step.
input_time_dim : int
Number of input time steps :math:`T_{in}`.
output_time_dim : int
Number of output time steps :math:`T_{out}`.
delta_time : str, optional
Time difference between samples, e.g., ``\"6h\"``. Defaults to ``\"6h\"``.
reset_cycle : str, optional
Period for recurrent state reset, e.g., ``\"24h\"``. Defaults to ``\"24h\"``.
presteps : int, optional
Number of warm-up steps used to initialize recurrent states.
enable_nhwc : bool, optional
If ``True``, use channels-last tensors.
enable_healpixpad : bool, optional
Enable CUDA HEALPix padding when available.
couplings : list, optional
Optional coupling specifications appended to the input feature channels.
Forward
-------
inputs : Sequence[torch.Tensor]
Inputs shaped :math:`(B, F, T_{in}, C_{in}, H, W)` plus decoder inputs,
constants, and optional coupling tensors.
output_only_last : bool, optional
If ``True``, return only the final forecast step.
Outputs
-------
torch.Tensor
Predictions shaped :math:`(B, F, T_{out}, C_{out}, H, W)`.
"""
__model_checkpoint_version__ = "0.2.0"
__supported_model_checkpoint_version__ = {
"0.1.0": _legacy_hydra_targets_warning,
}
@classmethod
def _backward_compat_arg_mapper(
cls, version: str, args: Dict[str, Any]
) -> Dict[str, Any]:
r"""
Map arguments from older checkpoints to the current format.
Parameters
----------
version : str
Version of the checkpoint being loaded.
args : Dict[str, Any]
Arguments dictionary from the checkpoint.
Returns
-------
Dict[str, Any]
Updated arguments dictionary compatible with the current version.
"""
args = super()._backward_compat_arg_mapper(version, args)
if version != "0.1.0":
return args
return _remap_obj(args)
def __init__(
self,
encoder: DictConfig,
decoder: DictConfig,
input_channels: int,
output_channels: int,
n_constants: int,
decoder_input_channels: int,
input_time_dim: int,
output_time_dim: int,
delta_time: str = "6h",
reset_cycle: str = "24h",
presteps: int = 1,
enable_nhwc: bool = False,
enable_healpixpad: bool = False,
couplings: list = [],
):
r"""Initialize the recurrent DLWP HEALPix UNet."""
super().__init__(meta=MetaData())
self.channel_dim = 2 # Now 2 with [B, F, T*C, H, W]. Was 1 in old data format with [B, T*C, F, H, W]
self.input_channels = input_channels
if n_constants == 0 and decoder_input_channels == 0:
raise NotImplementedError(
"support for models with no constant fields and no decoder inputs (TOA insolation) is not available at this time."
)
if len(couplings) > 0:
if n_constants == 0:
raise NotImplementedError(
"support for coupled models with no constant fields is not available at this time."
)
if decoder_input_channels == 0:
raise NotImplementedError(
"support for coupled models with no decoder inputs (TOA insolation) is not available at this time."
)
# add coupled fields to input channels for model initialization
self.coupled_channels = self._compute_coupled_channels(couplings)
self.couplings = couplings
self.train_couplers = None
self.output_channels = output_channels
self.n_constants = n_constants
self.decoder_input_channels = decoder_input_channels
self.input_time_dim = input_time_dim
self.output_time_dim = output_time_dim
self.delta_t = int(pd.Timedelta(delta_time).total_seconds() // 3600)
self.reset_cycle = int(pd.Timedelta(reset_cycle).total_seconds() // 3600)
self.presteps = presteps
self.enable_nhwc = enable_nhwc
self.enable_healpixpad = enable_healpixpad
# Number of passes through the model, or a diagnostic model with only one output time
self.is_diagnostic = self.output_time_dim == 1 and self.input_time_dim > 1
if not self.is_diagnostic and (self.output_time_dim % self.input_time_dim != 0):
raise ValueError(
f"'output_time_dim' must be a multiple of 'input_time_dim' (got "
f"{self.output_time_dim} and {self.input_time_dim})"
)
# Build the model layers
self.fold = HEALPixFoldFaces()
self.unfold = HEALPixUnfoldFaces(num_faces=12)
self.encoder = instantiate(
config=encoder,
input_channels=self._compute_input_channels(),
enable_nhwc=self.enable_nhwc,
enable_healpixpad=self.enable_healpixpad,
)
self.encoder_depth = len(self.encoder.n_channels)
self.decoder = instantiate(
config=decoder,
output_channels=self._compute_output_channels(),
enable_nhwc=self.enable_nhwc,
enable_healpixpad=self.enable_healpixpad,
)
@property
def integration_steps(self):
r"""
Number of implicit forward integration steps.
Returns
-------
int
Integration horizon :math:`T_{out} / T_{in}` (minimum 1).
"""
return max(self.output_time_dim // self.input_time_dim, 1)
def _compute_input_channels(self) -> int:
r"""
Calculate total number of input channels.
Returns
-------
int
Total channel count including couplings and constants.
"""
return (
self.input_time_dim * (self.input_channels + self.decoder_input_channels)
+ self.n_constants
+ self.coupled_channels
)
def _compute_coupled_channels(self, couplings):
r"""
Get the number of coupled channels.
Parameters
----------
couplings : list
Coupling configuration dictionaries.
Returns
-------
int
The number of coupled channels.
"""
return sum(
len(c["params"]["variables"]) * len(c["params"]["input_times"])
for c in couplings
)
def _compute_output_channels(self) -> int:
r"""
Compute the total number of output channels in the model.
Returns
-------
int
Output channel count for each integration step.
"""
return (1 if self.is_diagnostic else self.input_time_dim) * self.output_channels
def _reshape_inputs(self, inputs: Sequence, step: int = 0) -> torch.Tensor:
r"""
Concatenate prognostic, decoder, constant, and coupling inputs for the encoder.
Parameters
----------
inputs : Sequence
Tensors arranged as ``[prognostics, decoder_inputs, constants]`` with
optional couplings.
step : int, optional
Integration step index.
Returns
-------
torch.Tensor
Folded encoder input shaped :math:`(B \cdot F, C, H, W)`.
"""
if len(self.couplings) > 0:
result = [
inputs[0].flatten(
start_dim=self.channel_dim, end_dim=self.channel_dim + 1
),
inputs[1][
:,
:,
slice(step * self.input_time_dim, (step + 1) * self.input_time_dim),
...,
].flatten(
start_dim=self.channel_dim, end_dim=self.channel_dim + 1
), # DI
inputs[2].expand(
*tuple([inputs[0].shape[0]] + len(inputs[2].shape) * [-1])
), # constants
inputs[3].permute(0, 2, 1, 3, 4), # coupled inputs
]
res = torch.cat(result, dim=self.channel_dim)
else:
if self.n_constants == 0:
result = [
inputs[0].flatten(
start_dim=self.channel_dim, end_dim=self.channel_dim + 1
),
inputs[1][
:,
:,
slice(
step * self.input_time_dim, (step + 1) * self.input_time_dim
),
...,
].flatten(
start_dim=self.channel_dim, end_dim=self.channel_dim + 1
), # DI
]
res = torch.cat(result, dim=self.channel_dim)
# fold faces into batch dim
res = self.fold(res)
return res
if self.decoder_input_channels == 0:
result = [
inputs[0].flatten(
start_dim=self.channel_dim, end_dim=self.channel_dim + 1
),
inputs[1].expand(
*tuple([inputs[0].shape[0]] + len(inputs[1].shape) * [-1])
), # constants
]
res = torch.cat(result, dim=self.channel_dim)
# fold faces into batch dim
res = self.fold(res)
return res
result = [
inputs[0].flatten(
start_dim=self.channel_dim, end_dim=self.channel_dim + 1
),
inputs[1][
:,
:,
slice(step * self.input_time_dim, (step + 1) * self.input_time_dim),
...,
].flatten(
start_dim=self.channel_dim, end_dim=self.channel_dim + 1
), # DI
inputs[2].expand(
*tuple([inputs[0].shape[0]] + len(inputs[2].shape) * [-1])
), # constants
]
res = torch.cat(result, dim=self.channel_dim)
# fold faces into batch dim
res = self.fold(res)
return res
def _reshape_outputs(self, outputs: torch.Tensor) -> torch.Tensor:
r"""
Reshape decoder output back to explicit time and channel dimensions.
Parameters
----------
outputs : torch.Tensor
Decoder output shaped :math:`(B \cdot F, C, H, W)`.
Returns
-------
torch.Tensor
Unfolded tensor shaped :math:`(B, F, T_{out}, C_{out}, H, W)`.
"""
# unfold:
outputs = self.unfold(outputs)
# extract shape and reshape
shape = tuple(outputs.shape)
res = torch.reshape(
outputs,
shape=(
shape[0],
shape[1],
1 if self.is_diagnostic else self.input_time_dim,
-1,
*shape[3:],
),
)
return res
def _initialize_hidden(
self, inputs: Sequence, outputs: Sequence, step: int
) -> None:
r"""
Initialize the recurrent hidden states.
Parameters
----------
inputs : Sequence
Input tensors used for warm-up.
outputs : Sequence
Outputs accumulated so far.
step : int
Current integration step index.
Returns
-------
None
"""
self.reset()
for prestep in range(self.presteps):
if step < self.presteps:
s = step + prestep
if len(self.couplings) > 0:
input_tensor = self._reshape_inputs(
inputs=[
inputs[0][
:,
:,
s * self.input_time_dim : (s + 1) * self.input_time_dim,
]
]
+ list(inputs[1:3])
+ [inputs[3][prestep]],
step=step + prestep,
)
else:
input_tensor = self._reshape_inputs(
inputs=[
inputs[0][
:,
:,
s * self.input_time_dim : (s + 1) * self.input_time_dim,
]
]
+ list(inputs[1:]),
step=step + prestep,
)
else:
s = step - self.presteps + prestep
if len(self.couplings) > 0:
input_tensor = self._reshape_inputs(
inputs=[outputs[s - 1]]
+ list(inputs[1:3])
+ [inputs[3][step - (prestep - self.presteps)]],
step=s + 1,
)
else:
input_tensor = self._reshape_inputs(
inputs=[outputs[s - 1]] + list(inputs[1:]), step=s + 1
)
# Forward the data through the model to initialize hidden states
self.decoder(self.encoder(input_tensor))
[docs]
def forward(self, inputs: Sequence, output_only_last: bool = False) -> torch.Tensor:
r"""
Forward pass of the recurrent HEALPix UNet.
Parameters
----------
inputs : Sequence
List ``[prognostics, decoder_inputs, constants]`` or
``[prognostics, decoder_inputs, constants, couplings]`` with shapes
consistent with :math:`(B, F, T, C, H, W)`.
output_only_last : bool, optional
If ``True``, return only the final forecast step.
Returns
-------
torch.Tensor
Model outputs shaped :math:`(B, F, T_{out}, C_{out}, H, W)`.
"""
if not torch.compiler.is_compiling():
if inputs[0].ndim != 6:
raise ValueError(
"HEALPixRecUNet.forward expects prognostics shaped "
"(B, F, T, C, H, W)"
)
self.reset()
outputs = []
for step in range(self.integration_steps):
# (Re-)initialize recurrent hidden states
if (step * (self.delta_t * self.input_time_dim)) % self.reset_cycle == 0:
self._initialize_hidden(inputs=inputs, outputs=outputs, step=step)
# Construct concatenated input: [prognostics|TISR|constants]
if step == 0:
s = self.presteps
if len(self.couplings) > 0:
input_tensor = self._reshape_inputs(
inputs=[
inputs[0][
:,
:,
s * self.input_time_dim : (s + 1) * self.input_time_dim,
]
]
+ list(inputs[1:3])
+ [inputs[3][s]],
step=s,
)
else:
input_tensor = self._reshape_inputs(
inputs=[
inputs[0][
:,
:,
s * self.input_time_dim : (s + 1) * self.input_time_dim,
]
]
+ list(inputs[1:]),
step=s,
)
else:
if len(self.couplings) > 0:
input_tensor = self._reshape_inputs(
inputs=[outputs[-1]]
+ list(inputs[1:3])
+ [inputs[3][self.presteps + step]],
step=step + self.presteps,
)
else:
input_tensor = self._reshape_inputs(
inputs=[outputs[-1]] + list(inputs[1:]),
step=step + self.presteps,
)
# Forward through model
encodings = self.encoder(input_tensor)
decodings = self.decoder(encodings)
# Residual prediction
reshaped = self._reshape_outputs(
input_tensor[:, : self.input_channels * self.input_time_dim] + decodings
)
outputs.append(reshaped)
if output_only_last:
return outputs[-1]
return torch.cat(outputs, dim=self.channel_dim)
[docs]
def reset(self):
r"""Reset the state of the encoder and decoder recurrent blocks."""
self.encoder.reset()
self.decoder.reset()