Weather / Climate Models#

class physicsnemo.models.dlwp.dlwp.DLWP(*args, **kwargs)[source]#

Bases: Module

Convolutional U-Net for Deep Learning Weather Prediction on cubed-sphere grids.

This model operates on cubed-sphere data with six faces and applies face-aware padding so that convolutions respect cubed-sphere connectivity.

Based on Weyn et al. (2021).

Parameters:
  • nr_input_channels (int) – Number of input channels \(C_{in}\).

  • nr_output_channels (int) – Number of output channels \(C_{out}\).

  • nr_initial_channels (int, optional, default=64) – Number of channels in the first convolution block \(C_{init}\). Defaults to 64.

  • activation_fn (str, optional, default="leaky_relu") – Activation name resolved with get_activation(). Defaults to “leaky_relu”.

  • depth (int, optional, default=2) – Depth of the U-Net encoder/decoder stacks. Defaults to 2.

  • clamp_activation (Tuple[float | int | None, float | int | None], optional, default=(None, 10.0)) – Minimum and maximum bounds applied via torch.clamp after activation. Defaults to (None, 10.0).

Forward:

cubed_sphere_input (torch.Tensor) – Input tensor of shape \((B, C_{in}, F, H, W)\) with \(F=6\) faces.

Outputs:

torch.Tensor – Output tensor of shape \((B, C_{out}, F, H, W)\).

Examples

>>> import torch
>>> from physicsnemo.models import DLWP
>>> model = DLWP(nr_input_channels=2, nr_output_channels=4)
>>> inputs = torch.randn(4, 2, 6, 64, 64)
>>> outputs = model(inputs)
>>> outputs.shape
torch.Size([4, 4, 6, 64, 64])
activation(
x: Float[Tensor, 'batch channels height width'],
) Float[Tensor, 'batch channels height width'][source]#

Apply activation and optional clamping to a face tensor.

Parameters:

x (torch.Tensor) – Input face tensor of shape \((B, C, H, W)\).

Returns:

Activated face tensor of shape \((B, C, H, W)\).

Return type:

torch.Tensor

forward(
cubed_sphere_input: Float[Tensor, 'batch channels faces height width'],
) Float[Tensor, 'batch channels_out faces height width'][source]#

Apply the DLWP forward pass to cubed-sphere input data.

class physicsnemo.models.dlwp_healpix.HEALPixRecUNet.HEALPixRecUNet(*args, **kwargs)[source]#

Bases: Module

Deep Learning Weather Prediction (DLWP) recurrent UNet on the HEALPix mesh.

Parameters:
  • encoder (DictConfig) – Instantiable configuration for the U-Net encoder block.

  • decoder (DictConfig) – Instantiable configuration for the U-Net decoder block.

  • input_channels (int) – Number of prognostic input channels per time step.

  • output_channels (int) – Number of prognostic output channels per time step.

  • n_constants (int) – Number of constant channels provided for all faces.

  • decoder_input_channels (int) – Number of prescribed decoder input channels per time step.

  • input_time_dim (int) – Number of input time steps \(T_{in}\).

  • output_time_dim (int) – Number of output time steps \(T_{out}\).

  • delta_time (str, optional) – Time difference between samples, e.g., \"6h\". Defaults to \"6h\".

  • reset_cycle (str, optional) – Period for recurrent state reset, e.g., \"24h\". Defaults to \"24h\".

  • presteps (int, optional) – Number of warm-up steps used to initialize recurrent states.

  • enable_nhwc (bool, optional) – If True, use channels-last tensors.

  • enable_healpixpad (bool, optional) – Enable CUDA HEALPix padding when available.

  • couplings (list, optional) – Optional coupling specifications appended to the input feature channels.

Forward:
  • inputs (Sequence[torch.Tensor]) – Inputs shaped \((B, F, T_{in}, C_{in}, H, W)\) plus decoder inputs, constants, and optional coupling tensors.

  • output_only_last (bool, optional) – If True, return only the final forecast step.

Outputs:

torch.Tensor – Predictions shaped \((B, F, T_{out}, C_{out}, H, W)\).

forward(
inputs: Sequence,
output_only_last: bool = False,
) Tensor[source]#

Forward pass of the recurrent HEALPix UNet.

Parameters:
  • inputs (Sequence) – List [prognostics, decoder_inputs, constants] or [prognostics, decoder_inputs, constants, couplings] with shapes consistent with \((B, F, T, C, H, W)\).

  • output_only_last (bool, optional) – If True, return only the final forecast step.

Returns:

Model outputs shaped \((B, F, T_{out}, C_{out}, H, W)\).

Return type:

torch.Tensor

property integration_steps#

Number of implicit forward integration steps.

Returns:

Integration horizon \(T_{out} / T_{in}\) (minimum 1).

Return type:

int

reset()[source]#

Reset the state of the encoder and decoder recurrent blocks.

class physicsnemo.models.graphcast.graph_cast_net.GraphCastNet(*args, **kwargs)[source]#

Bases: Module

GraphCast network architecture for global weather forecasting on an icosahedral mesh graph.

Parameters:
  • mesh_level (int, optional, default=6) – Level of the latent mesh used to build the graph.

  • multimesh (bool, optional, default=True) – If True, the latent mesh includes nodes from all mesh levels up to and including mesh_level.

  • input_res (Tuple[int, int], optional, default=(721, 1440)) – Resolution of the latitude-longitude grid (H, W).

  • input_dim_grid_nodes (int, optional, default=474) – Input dimensionality of the grid node features, by default 474

  • input_dim_mesh_nodes (int, optional, default=3) – Input dimensionality of the mesh node features, by default 3

  • input_dim_edges (int, optional, default=4) – Input dimensionality of the edge features, by default 4

  • output_dim_grid_nodes (int, optional, default=227) – Output dimensionality of the grid node features, by default 227

  • processor_type (Literal["MessagePassing", "GraphTransformer"], optional, default="MessagePassing") – Processor type used for the latent mesh. "GraphTransformer" uses GraphCastProcessorGraphTransformer.

  • khop_neighbors (int, optional, default=32) – Number of k-hop neighbors used in the graph transformer processor. Ignored when processor_type="MessagePassing". Defaults to 32

  • num_attention_heads (int, optional, default=4) – Number of attention heads for the graph transformer processor. Defaults to 4

  • processor_layers (int, optional, default=16) – Number of processor layers. Defaults to 16

  • hidden_layers (int, optional, default=1) – Number of hidden layers in MLP blocks. Defaults to 1

  • hidden_dim (int, optional, default=512) – Hidden dimension for node and edge embeddings. Defaults to 512

  • aggregation (Literal["sum", "mean"], optional, default="sum") – Message passing aggregation method. Defaults to “sum”

  • activation_fn (str, optional, default="silu") – Activation function name passed to get_activation(). Defaults to “silu”

  • norm_type (Literal["TELayerNorm", "LayerNorm"], optional, default="LayerNorm") – Normalization type. "TELayerNorm" is recommended when supported. Defaults to “LayerNorm”

  • use_cugraphops_encoder (bool, optional, default=False) – Deprecated flag for cugraphops encoder kernels (not supported). Defaults to False

  • use_cugraphops_processor (bool, optional, default=False) – Deprecated flag for cugraphops processor kernels (not supported). Defaults to False

  • use_cugraphops_decoder (bool, optional, default=False) – Deprecated flag for cugraphops decoder kernels (not supported). Defaults to False

  • do_concat_trick (bool, optional, default=False) – Whether to replace concat+MLP with MLP+index+sum. Defaults to False

  • recompute_activation (bool, optional, default=False) – Whether to recompute activations during backward to save memory. Defaults to False

  • partition_size (int, optional, default=1) – Number of process groups across which graphs are distributed. If 1, the model runs in single-GPU mode. Defaults to 1

  • partition_group_name (str | None, optional, default=None) – Name of the process group across which graphs are distributed. Defaults to None

  • use_lat_lon_partitioning (bool, optional, default=False) – If True, graph partitions are based on lat-lon coordinates instead of IDs. Defaults to False

  • expect_partitioned_input (bool, optional, default=False) – If True, the input is already partitioned. Defaults to False

  • global_features_on_rank_0 (bool, optional, default=False) – If True, global input features are only provided on rank 0 and are scattered. Defaults to False

  • produce_aggregated_output (bool, optional, default=True) – Whether to gather outputs to a global tensor. Defaults to True

  • produce_aggregated_output_on_all_ranks (bool, optional, default=True) – If produce_aggregated_output is True, gather on all ranks or only rank 0. Defaults to True

  • graph_backend (Literal["dgl", "pyg"], optional, default="pyg") – Legacy argument to select the backend used to build the graphs. Defaults to “pyg”; “dgl” option is deprecated.

Forward:

grid_nfeat (torch.Tensor) – Input grid features of shape \((B, C_{in}, H, W)\) where \(B=1\), \(C_{in} =\) input_dim_grid_nodes, and \((H, W) =\) input_res.

Outputs:

torch.Tensor – Output grid features of shape \((B, C_{out}, H, W)\) where \(C_{out} =\) output_dim_grid_nodes.

Notes

This implementation follows the GraphCast and GenCast architectures; see GraphCast and GenCast. The graph transformer processor requires transformer-engine to be installed.

Examples

>>> import torch
>>> from physicsnemo.models.graphcast.graph_cast_net import GraphCastNet
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> model = GraphCastNet(
...     mesh_level=1,
...     input_res=(10, 20),
...     input_dim_grid_nodes=2,
...     input_dim_mesh_nodes=3,
...     input_dim_edges=4,
...     output_dim_grid_nodes=2,
...     processor_layers=3,
...     hidden_dim=4,
...     do_concat_trick=True,
... ).to(device)
>>> x = torch.randn(1, 2, 10, 20, device=device)
>>> y = model(x)
>>> y.shape
torch.Size([1, 2, 10, 20])
custom_forward(
grid_nfeat: Float[Tensor, 'grid_nodes grid_features'],
) Float[Tensor, 'grid_nodes out_features'][source]#

GraphCast forward method with gradient checkpointing support.

Parameters:

grid_nfeat (torch.Tensor) – Grid node features of shape \((N_{grid}, C_{in})\).

Returns:

Output grid node features of shape \((N_{grid}, C_{out})\).

Return type:

torch.Tensor

decoder_forward(
mesh_efeat_processed: Float[Tensor, 'mesh_edges hidden_dim'] | None,
mesh_nfeat_processed: Float[Tensor, 'mesh_nodes hidden_dim'],
grid_nfeat_encoded: Float[Tensor, 'grid_nodes hidden_dim'],
) Float[Tensor, 'grid_nodes out_features'][source]#

Run the final processor stage, decoder, and output MLP.

Parameters:
  • mesh_efeat_processed (torch.Tensor | None) – Processed mesh edge features of shape \((N_{mesh\_edge}, C_{hid})\), or None when using the graph transformer processor.

  • mesh_nfeat_processed (torch.Tensor) – Processed mesh node features of shape \((N_{mesh}, C_{hid})\).

  • grid_nfeat_encoded (torch.Tensor) – Encoded grid node features of shape \((N_{grid}, C_{hid})\).

Returns:

Final grid node features of shape \((N_{grid}, C_{out})\).

Return type:

torch.Tensor

encoder_forward(
grid_nfeat: Float[Tensor, 'grid_nodes grid_features'],
) Tuple[Float[Tensor, 'mesh_edges hidden_dim'] | None, Float[Tensor, 'mesh_nodes hidden_dim'], Float[Tensor, 'grid_nodes hidden_dim']][source]#

Run the embedder, encoder, and the first processor stage.

Parameters:

grid_nfeat (torch.Tensor) – Grid node features of shape \((N_{grid}, C_{in})\).

Returns:

  • mesh_efeat_processed (torch.Tensor | None) – Processed mesh edge features of shape \((N_{mesh\_edge}, C_{hid})\), or None when using the graph transformer processor.

  • mesh_nfeat_processed (torch.Tensor) – Processed mesh node features of shape \((N_{mesh}, C_{hid})\).

  • grid_nfeat_encoded (torch.Tensor) – Encoded grid node features of shape \((N_{grid}, C_{hid})\).

forward(
grid_nfeat: Float[Tensor, 'batch grid_features height width'],
) Float[Tensor, 'batch out_features height width'][source]#

Run the GraphCast forward pass.

Parameters:

grid_nfeat (torch.Tensor) – Input grid features of shape \((B, C_{in}, H, W)\) with \(B=1\) when expect_partitioned_input is False.

Returns:

Output grid features of shape \((B, C_{out}, H, W)\).

Return type:

torch.Tensor

prepare_input(
invar: Float[Tensor, 'batch grid_features height width'] | Float[Tensor, 'grid_nodes grid_features'],
expect_partitioned_input: bool,
global_features_on_rank_0: bool,
) Float[Tensor, 'grid_nodes grid_features'][source]#

Prepare model input in the required grid-node layout.

Parameters:
  • invar (torch.Tensor) – Input grid features of shape \((B, C_{in}, H, W)\) or partitioned features of shape \((N_{grid}, C_{in})\).

  • expect_partitioned_input (bool) – Whether invar is already partitioned.

  • global_features_on_rank_0 (bool) – Whether global features are only provided on rank 0 and should be scattered.

Returns:

Grid-node features of shape \((N_{grid}, C_{in})\).

Return type:

torch.Tensor

prepare_output(
outvar: Float[Tensor, 'grid_nodes out_features'],
produce_aggregated_output: bool,
produce_aggregated_output_on_all_ranks: bool = True,
) Float[Tensor, 'batch out_features height width'] | Float[Tensor, 'grid_nodes out_features'][source]#

Prepare model output in the required layout.

Parameters:
  • outvar (torch.Tensor) – Output node features of shape \((N_{grid}, C_{out})\).

  • produce_aggregated_output (bool) – Whether to gather outputs to a global tensor.

  • produce_aggregated_output_on_all_ranks (bool, optional, default=True) – Whether to gather outputs on all ranks or only rank 0. Defaults to True

Returns:

Output features in either global grid format \((B, C_{out}, H, W)\) or distributed node format \((N_{grid}, C_{out})\).

Return type:

torch.Tensor

set_checkpoint_decoder(checkpoint_flag: bool)[source]#

Set checkpointing for the decoder path.

Parameters:

checkpoint_flag (bool) – Whether to enable checkpointing for the decoder path.

Returns:

This method updates decoder checkpointing settings in-place.

Return type:

None

set_checkpoint_encoder(checkpoint_flag: bool)[source]#

Set checkpointing for the encoder path.

Parameters:

checkpoint_flag (bool) – Whether to enable checkpointing for the encoder path.

Returns:

This method updates encoder checkpointing settings in-place.

Return type:

None

set_checkpoint_model(checkpoint_flag: bool)[source]#

Set checkpointing for the entire model.

Parameters:

checkpoint_flag (bool) – Whether to enable checkpointing using torch.utils.checkpoint.

Returns:

This method updates internal checkpointing settings in-place.

Return type:

None

set_checkpoint_processor(checkpoint_segments: int)[source]#

Set checkpointing for the processor interior layers.

Parameters:

checkpoint_segments (int) – Number of checkpoint segments. A positive value enables checkpointing.

Returns:

This method updates processor checkpointing settings in-place.

Return type:

None

to(
*args: Any,
**kwargs: Any,
) Self[source]#

Move the model and its graph buffers to a device or dtype.

Parameters:
  • *args (Any) – Positional arguments passed to torch._C._nn._parse_to.

  • **kwargs (Any) – Keyword arguments passed to torch._C._nn._parse_to.

Returns:

The updated model instance.

Return type:

GraphCastNet

physicsnemo.models.graphcast.graph_cast_net.get_lat_lon_partition_separators(
partition_size: int,
) Tuple[list[list[float | None]], list[list[float | None]]][source]#

Compute separation intervals for lat-lon grid partitioning.

Parameters:

partition_size (int) – Size of the graph partition.

Returns:

The (min_seps, max_seps) coordinate separators for each partition.

Return type:

Tuple[list[list[float | None]], list[list[float | None]]]

GraphCast Utils#

Utilities for working with the GraphCast model, including graph construction and mesh handling. These are used when implementing and training GraphCast-based weather prediction models.

class physicsnemo.models.graphcast.utils.graph.Graph(
lat_lon_grid: Tensor,
mesh_level: int = 6,
multimesh: bool = True,
khop_neighbors: int = 0,
dtype=torch.float32,
backend: str = 'pyg',
)[source]#

Bases: object

Graph class for creating the graph2mesh, latent mesh, and mesh2graph graphs.

Parameters:
  • lat_lon_grid (Tensor) – Tensor with shape (lat, lon, 2) that includes the latitudes and longitudes meshgrid.

  • mesh_level (int, optional) – Level of the latent mesh, by default 6

  • multimesh (bool, optional) – If the latent mesh is a multimesh, by default True If True, the latent mesh includes the nodes corresponding to the specified mesh_level`and incorporates the edges from all mesh levels ranging from level 0 up to and including `mesh_level.

  • khop_neighbors (int, optional) – This option is used to retrieve a list of indices for the k-hop neighbors of all mesh nodes. It is applicable when a graph transformer is used as the processor. If set to 0, this list is not computed. If a message passing processor is used, it is forced to 0. By default 0.

  • dtype (torch.dtype, optional) – Data type of the graph, by default torch.float

create_g2m_graph(verbose: bool = True) None[source]#

Create the graph2mesh graph.

Parameters:

verbose (bool, optional) – verbosity, by default True

Returns:

Graph2mesh graph.

Return type:

GraphType

create_m2g_graph(verbose: bool = True) None[source]#

Create the mesh2grid graph.

Parameters:

verbose (bool, optional) – verbosity, by default True

Returns:

Mesh2grid graph.

Return type:

GraphType

create_mesh_graph(verbose: bool = True) None[source]#

Create the multimesh graph.

Parameters:

verbose (bool, optional) – verbosity, by default True

Returns:

Multimesh graph

Return type:

GraphType

physicsnemo.models.graphcast.utils.graph_utils.azimuthal_angle(lon: Tensor) Tensor[source]#

Gives the azimuthal angle of a point on the sphere

Parameters:

lon (Tensor) – Tensor of shape (N, ) containing the longitude of the point

Returns:

Tensor of shape (N, ) containing the azimuthal angle

Return type:

Tensor

physicsnemo.models.graphcast.utils.graph_utils.cell_to_adj(cells: List[List[int]])[source]#

creates adjancy matrix in COO format from mesh cells

Parameters:

cells (List[List[int]]) – List of cells, each cell is a list of 3 vertices

Returns:

src, dst – List of source and destination vertices

Return type:

List[int], List[int]

physicsnemo.models.graphcast.utils.graph_utils.deg2rad(deg: Tensor) Tensor[source]#

Converts degrees to radians

Parameters:

deg – Tensor of shape (N, ) containing the degrees

Returns:

Tensor of shape (N, ) containing the radians

Return type:

Tensor

physicsnemo.models.graphcast.utils.graph_utils.geospatial_rotation(
invar: Tensor,
theta: Tensor,
axis: str,
unit: str = 'rad',
) Tensor[source]#

Rotation using right hand rule

Parameters:
  • invar (Tensor) – Tensor of shape (N, 3) containing x, y, z coordinates

  • theta (Tensor) – Tensor of shape (N, ) containing the rotation angle

  • axis (str) – Axis of rotation

  • unit (str, optional) – Unit of the theta, by default “rad”

Returns:

Tensor of shape (N, 3) containing the rotated x, y, z coordinates

Return type:

Tensor

physicsnemo.models.graphcast.utils.graph_utils.get_face_centroids(
vertices: List[Tuple[float, float, float]],
faces: List[List[int]],
) List[Tuple[float, float, float]][source]#

Compute the centroids of triangular faces in a graph.

Parameters: vertices (List[Tuple[float, float, float]]): A list of tuples representing the coordinates of the vertices. faces (List[List[int]]): A list of lists, where each inner list contains three indices representing a triangular face.

Returns: List[Tuple[float, float, float]]: A list of tuples representing the centroids of the faces.

physicsnemo.models.graphcast.utils.graph_utils.latlon2xyz(
latlon: Tensor,
radius: float = 1,
unit: str = 'deg',
) Tensor[source]#

Converts latlon in degrees to xyz Based on: https://stackoverflow.com/questions/1185408 - The x-axis goes through long,lat (0,0); - The y-axis goes through (0,90); - The z-axis goes through the poles.

Parameters:
  • latlon (Tensor) – Tensor of shape (N, 2) containing latitudes and longitudes

  • radius (float, optional) – Radius of the sphere, by default 1

  • unit (str, optional) – Unit of the latlon, by default “deg”

Returns:

Tensor of shape (N, 3) containing x, y, z coordinates

Return type:

Tensor

physicsnemo.models.graphcast.utils.graph_utils.max_edge_length(
vertices: List[List[float]],
source_nodes: List[int],
destination_nodes: List[int],
) float[source]#

Compute the maximum edge length in a graph.

Parameters: vertices (List[List[float]]): A list of tuples representing the coordinates of the vertices. source_nodes (List[int]): A list of indices representing the source nodes of the edges. destination_nodes (List[int]): A list of indices representing the destination nodes of the edges.

Returns: The maximum edge length in the graph (float).

physicsnemo.models.graphcast.utils.graph_utils.polar_angle(lat: Tensor) Tensor[source]#

Gives the polar angle of a point on the sphere

Parameters:

lat (Tensor) – Tensor of shape (N, ) containing the latitude of the point

Returns:

Tensor of shape (N, ) containing the polar angle

Return type:

Tensor

physicsnemo.models.graphcast.utils.graph_utils.rad2deg(rad)[source]#

Converts radians to degrees

Parameters:

rad – Tensor of shape (N, ) containing the radians

Returns:

Tensor of shape (N, ) containing the degrees

Return type:

Tensor

physicsnemo.models.graphcast.utils.graph_utils.xyz2latlon(
xyz: Tensor,
radius: float = 1,
unit: str = 'deg',
) Tensor[source]#

Converts xyz to latlon in degrees Based on: https://stackoverflow.com/questions/1185408 - The x-axis goes through long,lat (0,0); - The y-axis goes through (0,90); - The z-axis goes through the poles.

Parameters:
  • xyz (Tensor) – Tensor of shape (N, 3) containing x, y, z coordinates

  • radius (float, optional) – Radius of the sphere, by default 1

  • unit (str, optional) – Unit of the latlon, by default “deg”

Returns:

Tensor of shape (N, 2) containing latitudes and longitudes

Return type:

Tensor

class physicsnemo.models.fengwu.fengwu.Fengwu(*args, **kwargs)[source]#

Bases: Module

FengWu weather forecasting model.

This implementation follows FengWu: Pushing the Skillful Global Medium-range Weather Forecast beyond 10 Days Lead.

Parameters:
  • img_size (tuple[int, int], optional, default=(721, 1440)) – Spatial resolution \((H, W)\) of all input and output fields.

  • pressure_level (int, optional, default=37) – Number of pressure levels \(L\).

  • embed_dim (int, optional, default=192) – Embedding channel size used in encoder/decoder/fuser blocks.

  • patch_size (tuple[int, int], optional, default=(4, 4)) – Patch size \((p_h, p_w)\) used by the hierarchical encoder/decoder.

  • num_heads (tuple[int, int, int, int], optional, default=(6, 12, 12, 6)) – Number of attention heads used at each stage.

  • window_size (tuple[int, int, int], optional, default=(2, 6, 12)) – Window size used by the transformer blocks.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, C_{in}, H, W)\) with \(C_{in} = 4 + 5L\).

Outputs:

tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] – Tuple (surface, z, r, u, v, t) where:

  • surface has shape \((B, 4, H, W)\).

  • z, r, u, v, t each have shape \((B, L, H, W)\).

forward(
x: Float[Tensor, 'batch channels lat lon'],
) tuple[Float[Tensor, 'batch c_surface lat lon'], Float[Tensor, 'batch c_pressure lat lon'], Float[Tensor, 'batch c_pressure lat lon'], Float[Tensor, 'batch c_pressure lat lon'], Float[Tensor, 'batch c_pressure lat lon'], Float[Tensor, 'batch c_pressure lat lon']][source]#

Run Fengwu forward prediction.

Parameters:

x (torch.Tensor) – Concatenated input tensor of shape \((B, 4 + 5L, H, W)\).

Returns:

Output tuple (surface, z, r, u, v, t) where surface has shape \((B, 4, H, W)\) and the other outputs have shape \((B, L, H, W)\).

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

prepare_input(
surface: Float[Tensor, 'batch c_surface lat lon'],
z: Float[Tensor, 'batch c_pressure lat lon'],
r: Float[Tensor, 'batch c_pressure lat lon'],
u: Float[Tensor, 'batch c_pressure lat lon'],
v: Float[Tensor, 'batch c_pressure lat lon'],
t: Float[Tensor, 'batch c_pressure lat lon'],
) Float[Tensor, 'batch channels lat lon'][source]#

Prepare input fields by concatenating all variables along channels.

Parameters:
  • surface (torch.Tensor) – Surface tensor of shape \((B, 4, H, W)\).

  • z (torch.Tensor) – Geopotential tensor of shape \((B, L, H, W)\).

  • r (torch.Tensor) – Relative humidity tensor of shape \((B, L, H, W)\).

  • u (torch.Tensor) – U-wind tensor of shape \((B, L, H, W)\).

  • v (torch.Tensor) – V-wind tensor of shape \((B, L, H, W)\).

  • t (torch.Tensor) – Temperature tensor of shape \((B, L, H, W)\).

Returns:

Concatenated tensor of shape \((B, 4 + 5L, H, W)\).

Return type:

torch.Tensor

class physicsnemo.models.pangu.pangu.Pangu(*args, **kwargs)[source]#

Bases: Module

Pangu weather forecasting model.

This implementation follows Pangu-Weather: A 3D High-Resolution Model for Fast and Accurate Global Weather Forecast.

Parameters:
  • img_size (tuple[int, int], optional, default=(721, 1440)) – Spatial resolution \((H, W)\) of the latitude-longitude grid.

  • patch_size (tuple[int, int, int], optional, default=(2, 4, 4)) – Patch size \((p_l, p_h, p_w)\) for pressure-level and spatial axes.

  • embed_dim (int, optional, default=192) – Embedding channel size used throughout the transformer hierarchy.

  • num_heads (tuple[int, int, int, int], optional, default=(6, 12, 12, 6)) – Number of attention heads used at each stage.

  • window_size (tuple[int, int, int], optional, default=(2, 6, 12)) – Window size used by the transformer blocks.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, 72, H, W)\) where channels are arranged as surface(7) + upper_air(5*13).

Outputs:

tuple[torch.Tensor, torch.Tensor] – Tuple (surface, upper_air) where surface has shape \((B, 4, H, W)\) and upper_air has shape \((B, 5, 13, H, W)\).

forward(
x: Float[Tensor, 'batch channels lat lon'],
) tuple[Float[Tensor, 'batch c_surface lat lon'], Float[Tensor, 'batch c_upper levels lat lon']][source]#

Run Pangu forward prediction.

Parameters:

x (torch.Tensor) – Concatenated input tensor of shape \((B, 72, H, W)\).

Returns:

Output tuple (surface, upper_air) with shapes \((B, 4, H, W)\) and \((B, 5, 13, H, W)\).

Return type:

tuple[torch.Tensor, torch.Tensor]

prepare_input(
surface: Float[Tensor, 'batch c_surface lat lon'],
surface_mask: Float[Tensor, 'c_mask lat lon'] | Float[Tensor, 'batch c_mask lat lon'],
upper_air: Float[Tensor, 'batch c_upper levels lat lon'],
) Float[Tensor, 'batch channels lat lon'][source]#

Prepare input by combining surface, static masks, and upper-air fields.

Parameters:
  • surface (torch.Tensor) – Surface tensor of shape \((B, 4, H, W)\).

  • surface_mask (torch.Tensor) – Static mask tensor of shape \((3, H, W)\) or \((B, 3, H, W)\).

  • upper_air (torch.Tensor) – Upper-air tensor of shape \((B, 5, 13, H, W)\).

Returns:

Concatenated tensor of shape \((B, 72, H, W)\).

Return type:

torch.Tensor

class physicsnemo.models.swinvrnn.swinvrnn.SwinRNN(*args, **kwargs)[source]#

Bases: Module

SwinRNN weather forecasting model.

This implementation follows SwinRNN.

Parameters:
  • img_size (tuple[int, int, int], optional, default=(2, 721, 1440)) – Input size as \((T, H, W)\), where \(T\) is the number of input timesteps.

  • patch_size (tuple[int, int, int], optional, default=(2, 4, 4)) – Patch size as \((p_t, p_h, p_w)\) for cube embedding.

  • in_chans (int, optional, default=70) – Number of input channels.

  • out_chans (int, optional, default=70) – Number of output channels.

  • embed_dim (int, optional, default=1536) – Embedding channel size used by Swin blocks.

  • num_groups (int, optional, default=32) – Number of channel groups for convolutional blocks.

  • num_heads (int, optional, default=8) – Number of attention heads.

  • window_size (int, optional, default=7) – Local window size of Swin transformer blocks.

Forward:

x (torch.Tensor) – Input tensor of shape \((B, C_{in}, T, H, W)\).

Outputs:

torch.Tensor – Predicted tensor of shape \((B, C_{out}, H, W)\).

forward(
x: Float[Tensor, 'batch in_chans time lat lon'],
) Float[Tensor, 'batch out_chans lat lon'][source]#

Run SwinRNN forward prediction.

Parameters:

x (torch.Tensor) – Input tensor of shape \((B, C_{in}, T, H, W)\).

Returns:

Prediction tensor of shape \((B, C_{out}, H, W)\).

Return type:

torch.Tensor