Weather / Climate Models#
- class physicsnemo.models.dlwp.dlwp.DLWP(*args, **kwargs)[source]#
Bases:
ModuleConvolutional U-Net for Deep Learning Weather Prediction on cubed-sphere grids.
This model operates on cubed-sphere data with six faces and applies face-aware padding so that convolutions respect cubed-sphere connectivity.
Based on Weyn et al. (2021).
- Parameters:
nr_input_channels (int) – Number of input channels \(C_{in}\).
nr_output_channels (int) – Number of output channels \(C_{out}\).
nr_initial_channels (int, optional, default=64) – Number of channels in the first convolution block \(C_{init}\). Defaults to 64.
activation_fn (str, optional, default="leaky_relu") – Activation name resolved with
get_activation(). Defaults to “leaky_relu”.depth (int, optional, default=2) – Depth of the U-Net encoder/decoder stacks. Defaults to 2.
clamp_activation (Tuple[float | int | None, float | int | None], optional, default=(None, 10.0)) – Minimum and maximum bounds applied via
torch.clampafter activation. Defaults to (None, 10.0).
- Forward:
cubed_sphere_input (torch.Tensor) – Input tensor of shape \((B, C_{in}, F, H, W)\) with \(F=6\) faces.
- Outputs:
torch.Tensor – Output tensor of shape \((B, C_{out}, F, H, W)\).
Examples
>>> import torch >>> from physicsnemo.models import DLWP >>> model = DLWP(nr_input_channels=2, nr_output_channels=4) >>> inputs = torch.randn(4, 2, 6, 64, 64) >>> outputs = model(inputs) >>> outputs.shape torch.Size([4, 4, 6, 64, 64])
- activation(
- x: Float[Tensor, 'batch channels height width'],
Apply activation and optional clamping to a face tensor.
- Parameters:
x (torch.Tensor) – Input face tensor of shape \((B, C, H, W)\).
- Returns:
Activated face tensor of shape \((B, C, H, W)\).
- Return type:
torch.Tensor
- class physicsnemo.models.dlwp_healpix.HEALPixRecUNet.HEALPixRecUNet(*args, **kwargs)[source]#
Bases:
ModuleDeep Learning Weather Prediction (DLWP) recurrent UNet on the HEALPix mesh.
- Parameters:
encoder (DictConfig) – Instantiable configuration for the U-Net encoder block.
decoder (DictConfig) – Instantiable configuration for the U-Net decoder block.
input_channels (int) – Number of prognostic input channels per time step.
output_channels (int) – Number of prognostic output channels per time step.
n_constants (int) – Number of constant channels provided for all faces.
decoder_input_channels (int) – Number of prescribed decoder input channels per time step.
input_time_dim (int) – Number of input time steps \(T_{in}\).
output_time_dim (int) – Number of output time steps \(T_{out}\).
delta_time (str, optional) – Time difference between samples, e.g.,
\"6h\". Defaults to\"6h\".reset_cycle (str, optional) – Period for recurrent state reset, e.g.,
\"24h\". Defaults to\"24h\".presteps (int, optional) – Number of warm-up steps used to initialize recurrent states.
enable_nhwc (bool, optional) – If
True, use channels-last tensors.enable_healpixpad (bool, optional) – Enable CUDA HEALPix padding when available.
couplings (list, optional) – Optional coupling specifications appended to the input feature channels.
- Forward:
inputs (Sequence[torch.Tensor]) – Inputs shaped \((B, F, T_{in}, C_{in}, H, W)\) plus decoder inputs, constants, and optional coupling tensors.
output_only_last (bool, optional) – If
True, return only the final forecast step.
- Outputs:
torch.Tensor – Predictions shaped \((B, F, T_{out}, C_{out}, H, W)\).
- forward(
- inputs: Sequence,
- output_only_last: bool = False,
Forward pass of the recurrent HEALPix UNet.
- Parameters:
inputs (Sequence) – List
[prognostics, decoder_inputs, constants]or[prognostics, decoder_inputs, constants, couplings]with shapes consistent with \((B, F, T, C, H, W)\).output_only_last (bool, optional) – If
True, return only the final forecast step.
- Returns:
Model outputs shaped \((B, F, T_{out}, C_{out}, H, W)\).
- Return type:
torch.Tensor
- property integration_steps#
Number of implicit forward integration steps.
- Returns:
Integration horizon \(T_{out} / T_{in}\) (minimum 1).
- Return type:
int
- class physicsnemo.models.graphcast.graph_cast_net.GraphCastNet(*args, **kwargs)[source]#
Bases:
ModuleGraphCast network architecture for global weather forecasting on an icosahedral mesh graph.
- Parameters:
mesh_level (int, optional, default=6) – Level of the latent mesh used to build the graph.
multimesh (bool, optional, default=True) – If
True, the latent mesh includes nodes from all mesh levels up to and includingmesh_level.input_res (Tuple[int, int], optional, default=(721, 1440)) – Resolution of the latitude-longitude grid
(H, W).input_dim_grid_nodes (int, optional, default=474) – Input dimensionality of the grid node features, by default 474
input_dim_mesh_nodes (int, optional, default=3) – Input dimensionality of the mesh node features, by default 3
input_dim_edges (int, optional, default=4) – Input dimensionality of the edge features, by default 4
output_dim_grid_nodes (int, optional, default=227) – Output dimensionality of the grid node features, by default 227
processor_type (Literal["MessagePassing", "GraphTransformer"], optional, default="MessagePassing") – Processor type used for the latent mesh.
"GraphTransformer"usesGraphCastProcessorGraphTransformer.khop_neighbors (int, optional, default=32) – Number of k-hop neighbors used in the graph transformer processor. Ignored when
processor_type="MessagePassing". Defaults to 32num_attention_heads (int, optional, default=4) – Number of attention heads for the graph transformer processor. Defaults to 4
processor_layers (int, optional, default=16) – Number of processor layers. Defaults to 16
hidden_layers (int, optional, default=1) – Number of hidden layers in MLP blocks. Defaults to 1
hidden_dim (int, optional, default=512) – Hidden dimension for node and edge embeddings. Defaults to 512
aggregation (Literal["sum", "mean"], optional, default="sum") – Message passing aggregation method. Defaults to “sum”
activation_fn (str, optional, default="silu") – Activation function name passed to
get_activation(). Defaults to “silu”norm_type (Literal["TELayerNorm", "LayerNorm"], optional, default="LayerNorm") – Normalization type.
"TELayerNorm"is recommended when supported. Defaults to “LayerNorm”use_cugraphops_encoder (bool, optional, default=False) – Deprecated flag for cugraphops encoder kernels (not supported). Defaults to False
use_cugraphops_processor (bool, optional, default=False) – Deprecated flag for cugraphops processor kernels (not supported). Defaults to False
use_cugraphops_decoder (bool, optional, default=False) – Deprecated flag for cugraphops decoder kernels (not supported). Defaults to False
do_concat_trick (bool, optional, default=False) – Whether to replace concat+MLP with MLP+index+sum. Defaults to False
recompute_activation (bool, optional, default=False) – Whether to recompute activations during backward to save memory. Defaults to False
partition_size (int, optional, default=1) – Number of process groups across which graphs are distributed. If
1, the model runs in single-GPU mode. Defaults to 1partition_group_name (str | None, optional, default=None) – Name of the process group across which graphs are distributed. Defaults to None
use_lat_lon_partitioning (bool, optional, default=False) – If
True, graph partitions are based on lat-lon coordinates instead of IDs. Defaults to Falseexpect_partitioned_input (bool, optional, default=False) – If
True, the input is already partitioned. Defaults to Falseglobal_features_on_rank_0 (bool, optional, default=False) – If
True, global input features are only provided on rank 0 and are scattered. Defaults to Falseproduce_aggregated_output (bool, optional, default=True) – Whether to gather outputs to a global tensor. Defaults to True
produce_aggregated_output_on_all_ranks (bool, optional, default=True) – If
produce_aggregated_outputisTrue, gather on all ranks or only rank 0. Defaults to Truegraph_backend (Literal["dgl", "pyg"], optional, default="pyg") – Legacy argument to select the backend used to build the graphs. Defaults to “pyg”; “dgl” option is deprecated.
- Forward:
grid_nfeat (torch.Tensor) – Input grid features of shape \((B, C_{in}, H, W)\) where \(B=1\), \(C_{in} =\)
input_dim_grid_nodes, and \((H, W) =\)input_res.- Outputs:
torch.Tensor – Output grid features of shape \((B, C_{out}, H, W)\) where \(C_{out} =\)
output_dim_grid_nodes.
Notes
This implementation follows the GraphCast and GenCast architectures; see GraphCast and GenCast. The graph transformer processor requires
transformer-engineto be installed.Examples
>>> import torch >>> from physicsnemo.models.graphcast.graph_cast_net import GraphCastNet >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> model = GraphCastNet( ... mesh_level=1, ... input_res=(10, 20), ... input_dim_grid_nodes=2, ... input_dim_mesh_nodes=3, ... input_dim_edges=4, ... output_dim_grid_nodes=2, ... processor_layers=3, ... hidden_dim=4, ... do_concat_trick=True, ... ).to(device) >>> x = torch.randn(1, 2, 10, 20, device=device) >>> y = model(x) >>> y.shape torch.Size([1, 2, 10, 20])
- custom_forward(
- grid_nfeat: Float[Tensor, 'grid_nodes grid_features'],
GraphCast forward method with gradient checkpointing support.
- Parameters:
grid_nfeat (torch.Tensor) – Grid node features of shape \((N_{grid}, C_{in})\).
- Returns:
Output grid node features of shape \((N_{grid}, C_{out})\).
- Return type:
torch.Tensor
- decoder_forward(
- mesh_efeat_processed: Float[Tensor, 'mesh_edges hidden_dim'] | None,
- mesh_nfeat_processed: Float[Tensor, 'mesh_nodes hidden_dim'],
- grid_nfeat_encoded: Float[Tensor, 'grid_nodes hidden_dim'],
Run the final processor stage, decoder, and output MLP.
- Parameters:
mesh_efeat_processed (torch.Tensor | None) – Processed mesh edge features of shape \((N_{mesh\_edge}, C_{hid})\), or
Nonewhen using the graph transformer processor.mesh_nfeat_processed (torch.Tensor) – Processed mesh node features of shape \((N_{mesh}, C_{hid})\).
grid_nfeat_encoded (torch.Tensor) – Encoded grid node features of shape \((N_{grid}, C_{hid})\).
- Returns:
Final grid node features of shape \((N_{grid}, C_{out})\).
- Return type:
torch.Tensor
- encoder_forward(
- grid_nfeat: Float[Tensor, 'grid_nodes grid_features'],
Run the embedder, encoder, and the first processor stage.
- Parameters:
grid_nfeat (torch.Tensor) – Grid node features of shape \((N_{grid}, C_{in})\).
- Returns:
mesh_efeat_processed (torch.Tensor | None) – Processed mesh edge features of shape \((N_{mesh\_edge}, C_{hid})\), or
Nonewhen using the graph transformer processor.mesh_nfeat_processed (torch.Tensor) – Processed mesh node features of shape \((N_{mesh}, C_{hid})\).
grid_nfeat_encoded (torch.Tensor) – Encoded grid node features of shape \((N_{grid}, C_{hid})\).
- forward(
- grid_nfeat: Float[Tensor, 'batch grid_features height width'],
Run the GraphCast forward pass.
- Parameters:
grid_nfeat (torch.Tensor) – Input grid features of shape \((B, C_{in}, H, W)\) with \(B=1\) when
expect_partitioned_inputisFalse.- Returns:
Output grid features of shape \((B, C_{out}, H, W)\).
- Return type:
torch.Tensor
- prepare_input(
- invar: Float[Tensor, 'batch grid_features height width'] | Float[Tensor, 'grid_nodes grid_features'],
- expect_partitioned_input: bool,
- global_features_on_rank_0: bool,
Prepare model input in the required grid-node layout.
- Parameters:
invar (torch.Tensor) – Input grid features of shape \((B, C_{in}, H, W)\) or partitioned features of shape \((N_{grid}, C_{in})\).
expect_partitioned_input (bool) – Whether
invaris already partitioned.global_features_on_rank_0 (bool) – Whether global features are only provided on rank 0 and should be scattered.
- Returns:
Grid-node features of shape \((N_{grid}, C_{in})\).
- Return type:
torch.Tensor
- prepare_output(
- outvar: Float[Tensor, 'grid_nodes out_features'],
- produce_aggregated_output: bool,
- produce_aggregated_output_on_all_ranks: bool = True,
Prepare model output in the required layout.
- Parameters:
outvar (torch.Tensor) – Output node features of shape \((N_{grid}, C_{out})\).
produce_aggregated_output (bool) – Whether to gather outputs to a global tensor.
produce_aggregated_output_on_all_ranks (bool, optional, default=True) – Whether to gather outputs on all ranks or only rank 0. Defaults to True
- Returns:
Output features in either global grid format \((B, C_{out}, H, W)\) or distributed node format \((N_{grid}, C_{out})\).
- Return type:
torch.Tensor
- set_checkpoint_decoder(checkpoint_flag: bool)[source]#
Set checkpointing for the decoder path.
- Parameters:
checkpoint_flag (bool) – Whether to enable checkpointing for the decoder path.
- Returns:
This method updates decoder checkpointing settings in-place.
- Return type:
None
- set_checkpoint_encoder(checkpoint_flag: bool)[source]#
Set checkpointing for the encoder path.
- Parameters:
checkpoint_flag (bool) – Whether to enable checkpointing for the encoder path.
- Returns:
This method updates encoder checkpointing settings in-place.
- Return type:
None
- set_checkpoint_model(checkpoint_flag: bool)[source]#
Set checkpointing for the entire model.
- Parameters:
checkpoint_flag (bool) – Whether to enable checkpointing using
torch.utils.checkpoint.- Returns:
This method updates internal checkpointing settings in-place.
- Return type:
None
- set_checkpoint_processor(checkpoint_segments: int)[source]#
Set checkpointing for the processor interior layers.
- Parameters:
checkpoint_segments (int) – Number of checkpoint segments. A positive value enables checkpointing.
- Returns:
This method updates processor checkpointing settings in-place.
- Return type:
None
- to(
- *args: Any,
- **kwargs: Any,
Move the model and its graph buffers to a device or dtype.
- Parameters:
*args (Any) – Positional arguments passed to
torch._C._nn._parse_to.**kwargs (Any) – Keyword arguments passed to
torch._C._nn._parse_to.
- Returns:
The updated model instance.
- Return type:
- physicsnemo.models.graphcast.graph_cast_net.get_lat_lon_partition_separators(
- partition_size: int,
Compute separation intervals for lat-lon grid partitioning.
- Parameters:
partition_size (int) – Size of the graph partition.
- Returns:
The
(min_seps, max_seps)coordinate separators for each partition.- Return type:
Tuple[list[list[float | None]], list[list[float | None]]]
GraphCast Utils#
Utilities for working with the GraphCast model, including graph construction and mesh handling. These are used when implementing and training GraphCast-based weather prediction models.
- class physicsnemo.models.graphcast.utils.graph.Graph(
- lat_lon_grid: Tensor,
- mesh_level: int = 6,
- multimesh: bool = True,
- khop_neighbors: int = 0,
- dtype=torch.float32,
- backend: str = 'pyg',
Bases:
objectGraph class for creating the graph2mesh, latent mesh, and mesh2graph graphs.
- Parameters:
lat_lon_grid (Tensor) – Tensor with shape (lat, lon, 2) that includes the latitudes and longitudes meshgrid.
mesh_level (int, optional) – Level of the latent mesh, by default 6
multimesh (bool, optional) – If the latent mesh is a multimesh, by default True If True, the latent mesh includes the nodes corresponding to the specified mesh_level`and incorporates the edges from all mesh levels ranging from level 0 up to and including `mesh_level.
khop_neighbors (int, optional) – This option is used to retrieve a list of indices for the k-hop neighbors of all mesh nodes. It is applicable when a graph transformer is used as the processor. If set to 0, this list is not computed. If a message passing processor is used, it is forced to 0. By default 0.
dtype (torch.dtype, optional) – Data type of the graph, by default torch.float
- create_g2m_graph(verbose: bool = True) None[source]#
Create the graph2mesh graph.
- Parameters:
verbose (bool, optional) – verbosity, by default True
- Returns:
Graph2mesh graph.
- Return type:
GraphType
- physicsnemo.models.graphcast.utils.graph_utils.azimuthal_angle(lon: Tensor) Tensor[source]#
Gives the azimuthal angle of a point on the sphere
- Parameters:
lon (Tensor) – Tensor of shape (N, ) containing the longitude of the point
- Returns:
Tensor of shape (N, ) containing the azimuthal angle
- Return type:
Tensor
- physicsnemo.models.graphcast.utils.graph_utils.cell_to_adj(cells: List[List[int]])[source]#
creates adjancy matrix in COO format from mesh cells
- Parameters:
cells (List[List[int]]) – List of cells, each cell is a list of 3 vertices
- Returns:
src, dst – List of source and destination vertices
- Return type:
List[int], List[int]
- physicsnemo.models.graphcast.utils.graph_utils.deg2rad(deg: Tensor) Tensor[source]#
Converts degrees to radians
- Parameters:
deg – Tensor of shape (N, ) containing the degrees
- Returns:
Tensor of shape (N, ) containing the radians
- Return type:
Tensor
- physicsnemo.models.graphcast.utils.graph_utils.geospatial_rotation(
- invar: Tensor,
- theta: Tensor,
- axis: str,
- unit: str = 'rad',
Rotation using right hand rule
- Parameters:
invar (Tensor) – Tensor of shape (N, 3) containing x, y, z coordinates
theta (Tensor) – Tensor of shape (N, ) containing the rotation angle
axis (str) – Axis of rotation
unit (str, optional) – Unit of the theta, by default “rad”
- Returns:
Tensor of shape (N, 3) containing the rotated x, y, z coordinates
- Return type:
Tensor
- physicsnemo.models.graphcast.utils.graph_utils.get_face_centroids(
- vertices: List[Tuple[float, float, float]],
- faces: List[List[int]],
Compute the centroids of triangular faces in a graph.
Parameters: vertices (List[Tuple[float, float, float]]): A list of tuples representing the coordinates of the vertices. faces (List[List[int]]): A list of lists, where each inner list contains three indices representing a triangular face.
Returns: List[Tuple[float, float, float]]: A list of tuples representing the centroids of the faces.
- physicsnemo.models.graphcast.utils.graph_utils.latlon2xyz(
- latlon: Tensor,
- radius: float = 1,
- unit: str = 'deg',
Converts latlon in degrees to xyz Based on: https://stackoverflow.com/questions/1185408 - The x-axis goes through long,lat (0,0); - The y-axis goes through (0,90); - The z-axis goes through the poles.
- Parameters:
latlon (Tensor) – Tensor of shape (N, 2) containing latitudes and longitudes
radius (float, optional) – Radius of the sphere, by default 1
unit (str, optional) – Unit of the latlon, by default “deg”
- Returns:
Tensor of shape (N, 3) containing x, y, z coordinates
- Return type:
Tensor
- physicsnemo.models.graphcast.utils.graph_utils.max_edge_length(
- vertices: List[List[float]],
- source_nodes: List[int],
- destination_nodes: List[int],
Compute the maximum edge length in a graph.
Parameters: vertices (List[List[float]]): A list of tuples representing the coordinates of the vertices. source_nodes (List[int]): A list of indices representing the source nodes of the edges. destination_nodes (List[int]): A list of indices representing the destination nodes of the edges.
Returns: The maximum edge length in the graph (float).
- physicsnemo.models.graphcast.utils.graph_utils.polar_angle(lat: Tensor) Tensor[source]#
Gives the polar angle of a point on the sphere
- Parameters:
lat (Tensor) – Tensor of shape (N, ) containing the latitude of the point
- Returns:
Tensor of shape (N, ) containing the polar angle
- Return type:
Tensor
- physicsnemo.models.graphcast.utils.graph_utils.rad2deg(rad)[source]#
Converts radians to degrees
- Parameters:
rad – Tensor of shape (N, ) containing the radians
- Returns:
Tensor of shape (N, ) containing the degrees
- Return type:
Tensor
- physicsnemo.models.graphcast.utils.graph_utils.xyz2latlon(
- xyz: Tensor,
- radius: float = 1,
- unit: str = 'deg',
Converts xyz to latlon in degrees Based on: https://stackoverflow.com/questions/1185408 - The x-axis goes through long,lat (0,0); - The y-axis goes through (0,90); - The z-axis goes through the poles.
- Parameters:
xyz (Tensor) – Tensor of shape (N, 3) containing x, y, z coordinates
radius (float, optional) – Radius of the sphere, by default 1
unit (str, optional) – Unit of the latlon, by default “deg”
- Returns:
Tensor of shape (N, 2) containing latitudes and longitudes
- Return type:
Tensor
- class physicsnemo.models.fengwu.fengwu.Fengwu(*args, **kwargs)[source]#
Bases:
ModuleFengWu weather forecasting model.
This implementation follows FengWu: Pushing the Skillful Global Medium-range Weather Forecast beyond 10 Days Lead.
- Parameters:
img_size (tuple[int, int], optional, default=(721, 1440)) – Spatial resolution \((H, W)\) of all input and output fields.
pressure_level (int, optional, default=37) – Number of pressure levels \(L\).
embed_dim (int, optional, default=192) – Embedding channel size used in encoder/decoder/fuser blocks.
patch_size (tuple[int, int], optional, default=(4, 4)) – Patch size \((p_h, p_w)\) used by the hierarchical encoder/decoder.
num_heads (tuple[int, int, int, int], optional, default=(6, 12, 12, 6)) – Number of attention heads used at each stage.
window_size (tuple[int, int, int], optional, default=(2, 6, 12)) – Window size used by the transformer blocks.
- Forward:
x (torch.Tensor) – Input tensor of shape \((B, C_{in}, H, W)\) with \(C_{in} = 4 + 5L\).
- Outputs:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] – Tuple
(surface, z, r, u, v, t)where:surfacehas shape \((B, 4, H, W)\).z, r, u, v, teach have shape \((B, L, H, W)\).
- forward(
- x: Float[Tensor, 'batch channels lat lon'],
Run Fengwu forward prediction.
- Parameters:
x (torch.Tensor) – Concatenated input tensor of shape \((B, 4 + 5L, H, W)\).
- Returns:
Output tuple
(surface, z, r, u, v, t)wheresurfacehas shape \((B, 4, H, W)\) and the other outputs have shape \((B, L, H, W)\).- Return type:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- prepare_input(
- surface: Float[Tensor, 'batch c_surface lat lon'],
- z: Float[Tensor, 'batch c_pressure lat lon'],
- r: Float[Tensor, 'batch c_pressure lat lon'],
- u: Float[Tensor, 'batch c_pressure lat lon'],
- v: Float[Tensor, 'batch c_pressure lat lon'],
- t: Float[Tensor, 'batch c_pressure lat lon'],
Prepare input fields by concatenating all variables along channels.
- Parameters:
surface (torch.Tensor) – Surface tensor of shape \((B, 4, H, W)\).
z (torch.Tensor) – Geopotential tensor of shape \((B, L, H, W)\).
r (torch.Tensor) – Relative humidity tensor of shape \((B, L, H, W)\).
u (torch.Tensor) – U-wind tensor of shape \((B, L, H, W)\).
v (torch.Tensor) – V-wind tensor of shape \((B, L, H, W)\).
t (torch.Tensor) – Temperature tensor of shape \((B, L, H, W)\).
- Returns:
Concatenated tensor of shape \((B, 4 + 5L, H, W)\).
- Return type:
torch.Tensor
- class physicsnemo.models.pangu.pangu.Pangu(*args, **kwargs)[source]#
Bases:
ModulePangu weather forecasting model.
This implementation follows Pangu-Weather: A 3D High-Resolution Model for Fast and Accurate Global Weather Forecast.
- Parameters:
img_size (tuple[int, int], optional, default=(721, 1440)) – Spatial resolution \((H, W)\) of the latitude-longitude grid.
patch_size (tuple[int, int, int], optional, default=(2, 4, 4)) – Patch size \((p_l, p_h, p_w)\) for pressure-level and spatial axes.
embed_dim (int, optional, default=192) – Embedding channel size used throughout the transformer hierarchy.
num_heads (tuple[int, int, int, int], optional, default=(6, 12, 12, 6)) – Number of attention heads used at each stage.
window_size (tuple[int, int, int], optional, default=(2, 6, 12)) – Window size used by the transformer blocks.
- Forward:
x (torch.Tensor) – Input tensor of shape \((B, 72, H, W)\) where channels are arranged as
surface(7) + upper_air(5*13).- Outputs:
tuple[torch.Tensor, torch.Tensor] – Tuple
(surface, upper_air)wheresurfacehas shape \((B, 4, H, W)\) andupper_airhas shape \((B, 5, 13, H, W)\).
- forward(
- x: Float[Tensor, 'batch channels lat lon'],
Run Pangu forward prediction.
- Parameters:
x (torch.Tensor) – Concatenated input tensor of shape \((B, 72, H, W)\).
- Returns:
Output tuple
(surface, upper_air)with shapes \((B, 4, H, W)\) and \((B, 5, 13, H, W)\).- Return type:
tuple[torch.Tensor, torch.Tensor]
- prepare_input(
- surface: Float[Tensor, 'batch c_surface lat lon'],
- surface_mask: Float[Tensor, 'c_mask lat lon'] | Float[Tensor, 'batch c_mask lat lon'],
- upper_air: Float[Tensor, 'batch c_upper levels lat lon'],
Prepare input by combining surface, static masks, and upper-air fields.
- Parameters:
surface (torch.Tensor) – Surface tensor of shape \((B, 4, H, W)\).
surface_mask (torch.Tensor) – Static mask tensor of shape \((3, H, W)\) or \((B, 3, H, W)\).
upper_air (torch.Tensor) – Upper-air tensor of shape \((B, 5, 13, H, W)\).
- Returns:
Concatenated tensor of shape \((B, 72, H, W)\).
- Return type:
torch.Tensor
- class physicsnemo.models.swinvrnn.swinvrnn.SwinRNN(*args, **kwargs)[source]#
Bases:
ModuleSwinRNN weather forecasting model.
This implementation follows SwinRNN.
- Parameters:
img_size (tuple[int, int, int], optional, default=(2, 721, 1440)) – Input size as \((T, H, W)\), where \(T\) is the number of input timesteps.
patch_size (tuple[int, int, int], optional, default=(2, 4, 4)) – Patch size as \((p_t, p_h, p_w)\) for cube embedding.
in_chans (int, optional, default=70) – Number of input channels.
out_chans (int, optional, default=70) – Number of output channels.
embed_dim (int, optional, default=1536) – Embedding channel size used by Swin blocks.
num_groups (int, optional, default=32) – Number of channel groups for convolutional blocks.
num_heads (int, optional, default=8) – Number of attention heads.
window_size (int, optional, default=7) – Local window size of Swin transformer blocks.
- Forward:
x (torch.Tensor) – Input tensor of shape \((B, C_{in}, T, H, W)\).
- Outputs:
torch.Tensor – Predicted tensor of shape \((B, C_{out}, H, W)\).
- forward(
- x: Float[Tensor, 'batch in_chans time lat lon'],
Run SwinRNN forward prediction.
- Parameters:
x (torch.Tensor) – Input tensor of shape \((B, C_{in}, T, H, W)\).
- Returns:
Prediction tensor of shape \((B, C_{out}, H, W)\).
- Return type:
torch.Tensor