What can I help you with?
NVIDIA PhysicsNeMo Core (Latest Release)

deeplearning/physicsnemo/physicsnemo-core/_modules/physicsnemo/utils/corrdiff/utils.html

Source code for physicsnemo.utils.corrdiff.utils

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
from typing import Optional

import cftime
import nvtx
import torch
import tqdm

from physicsnemo.utils.generative import StackedRandomGenerator, time_range

############################################################################
#                     CorrDiff Generation Utilities                        #
############################################################################


[docs]def regression_step( net: torch.nn.Module, img_lr: torch.Tensor, latents_shape: torch.Size, lead_time_label: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Perform a regression step to produce ensemble mean prediction. This function takes a low-resolution input and performs a regression step to produce an ensemble mean prediction. It processes a single instance and then replicates the results across the batch dimension if needed. Parameters ---------- net : torch.nn.Module U-Net model for regression. img_lr : torch.Tensor Low-resolution input to the network with shape (1, channels, height, width). Must have a batch dimension of 1. latents_shape : torch.Size Shape of the latent representation with format (batch_size, out_channels, image_shape_y, image_shape_x). lead_time_label : Optional[torch.Tensor], optional Lead time label tensor for lead time conditioning, with shape (1, lead_time_dims). Default is None. Returns ------- torch.Tensor Predicted ensemble mean at the next time step with shape matching latents_shape. Raises ------ ValueError If img_lr has a batch size greater than 1. """ # Create a tensor of zeros with the given shape and move it to the appropriate device x_hat = torch.zeros(latents_shape, dtype=torch.float64, device=net.device) # Safety check: avoid silently ignoring batch elements in img_lr if img_lr.shape[0] > 1: raise ValueError( f"Expected img_lr to have a batch size of 1, " f"but found {img_lr.shape[0]}." ) # Perform regression on a single batch element with torch.inference_mode(): if lead_time_label is not None: x = net(x=x_hat[0:1], img_lr=img_lr, lead_time_label=lead_time_label) else: x = net(x=x_hat[0:1], img_lr=img_lr) # If the batch size is greater than 1, repeat the prediction if x_hat.shape[0] > 1: x = x.repeat([d if i == 0 else 1 for i, d in enumerate(x_hat.shape)]) return x
[docs]def diffusion_step( net: torch.nn.Module, sampler_fn: callable, img_shape: tuple, img_out_channels: int, rank_batches: list, img_lr: torch.Tensor, rank: int, device: torch.device, mean_hr: torch.Tensor = None, lead_time_label: torch.Tensor = None, ) -> torch.Tensor: """ Generate images using diffusion techniques as described in the relevant paper. This function applies a diffusion model to generate high-resolution images based on low-resolution inputs. It supports optional conditioning on high-resolution mean predictions and lead time labels. For each low-resolution sample in `img_lr`, the function generates multiple high-resolution samples, with different random seeds, specified in `rank_batches`. The function then concatenates these high-resolution samples across the batch dimension. Parameters ---------- net : torch.nn.Module The diffusion model network. sampler_fn : callable Function used to sample images from the diffusion model. img_shape : tuple Shape of the images, (height, width). img_out_channels : int Number of output channels for the image. rank_batches : list List of batches of seeds to process. img_lr : torch.Tensor Low-resolution input image with shape (seed_batch_size, channels_lr, height, width). rank : int, optional Rank of the current process for distributed processing. device : torch.device, optional Device to perform computations. mean_hr : torch.Tensor, optional High-resolution mean tensor to be used as an additional input, with shape (1, channels_hr, height, width). Default is None. lead_time_label : torch.Tensor, optional Lead time label tensor for temporal conditioning, with shape (batch_size, lead_time_dims). Default is None. Returns ------- torch.Tensor Generated images concatenated across batches with shape (seed_batch_size * len(rank_batches), out_channels, height, width). """ # Check img_lr dimensions match expected shape if img_lr.shape[2:] != img_shape: raise ValueError( f"img_lr shape {img_lr.shape[2:]} does not match expected shape img_shape {img_shape}" ) # Check mean_hr dimensions if provided if mean_hr is not None: if mean_hr.shape[2:] != img_shape: raise ValueError( f"mean_hr shape {mean_hr.shape[2:]} does not match expected shape img_shape {img_shape}" ) if mean_hr.shape[0] != 1: raise ValueError(f"mean_hr must have batch size 1, got {mean_hr.shape[0]}") img_lr = img_lr.to(memory_format=torch.channels_last) # Handling of the high-res mean additional_args = {} if mean_hr is not None: additional_args["mean_hr"] = mean_hr if lead_time_label is not None: additional_args["lead_time_label"] = lead_time_label # Loop over batches all_images = [] for batch_seeds in tqdm.tqdm(rank_batches, unit="batch", disable=(rank != 0)): with nvtx.annotate(f"generate {len(all_images)}", color="rapids"): batch_size = len(batch_seeds) if batch_size == 0: continue # Initialize random generator, and generate latents rnd = StackedRandomGenerator(device, batch_seeds) latents = rnd.randn( [ img_lr.shape[0], img_out_channels, img_shape[0], img_shape[1], ], device=device, ).to(memory_format=torch.channels_last) with torch.inference_mode(): images = sampler_fn( net, latents, img_lr, randn_like=rnd.randn_like, **additional_args ) all_images.append(images) return torch.cat(all_images)

############################################################################ # CorrDiff writer utilities # ############################################################################

[docs]class NetCDFWriter: """NetCDF Writer""" def __init__( self, f, lat, lon, input_channels, output_channels, has_lead_time=False ): self._f = f self.has_lead_time = has_lead_time # create unlimited dimensions f.createDimension("time") f.createDimension("ensemble") if lat.shape != lon.shape: raise ValueError("lat and lon must have the same shape") ny, nx = lat.shape # create lat/lon grid f.createDimension("x", nx) f.createDimension("y", ny) v = f.createVariable("lat", "f", dimensions=("y", "x")) # NOTE rethink this for datasets whose samples don't have constant lat-lon. v[:] = lat v.standard_name = "latitude" v.units = "degrees_north" v = f.createVariable("lon", "f", dimensions=("y", "x")) v[:] = lon v.standard_name = "longitude" v.units = "degrees_east" # create time dimension if has_lead_time: v = f.createVariable("time", "str", ("time")) else: v = f.createVariable("time", "i8", ("time")) v.calendar = "standard" v.units = "hours since 1990-01-01 00:00:00" self.truth_group = f.createGroup("truth") self.prediction_group = f.createGroup("prediction") self.input_group = f.createGroup("input") for variable in output_channels: name = variable.name + variable.level self.truth_group.createVariable(name, "f", dimensions=("time", "y", "x")) self.prediction_group.createVariable( name, "f", dimensions=("ensemble", "time", "y", "x") ) # setup input data in netCDF for variable in input_channels: name = variable.name + variable.level self.input_group.createVariable(name, "f", dimensions=("time", "y", "x"))
[docs] def write_input(self, channel_name, time_index, val): """Write input data to NetCDF file.""" self.input_group[channel_name][time_index] = val
[docs] def write_truth(self, channel_name, time_index, val): """Write ground truth data to NetCDF file.""" self.truth_group[channel_name][time_index] = val
[docs] def write_prediction(self, channel_name, time_index, ensemble_index, val): """Write prediction data to NetCDF file.""" self.prediction_group[channel_name][ensemble_index, time_index] = val
[docs] def write_time(self, time_index, time): """Write time information to NetCDF file.""" if self.has_lead_time: self._f["time"][time_index] = time else: time_v = self._f["time"] self._f["time"][time_index] = cftime.date2num( time, time_v.units, time_v.calendar )

############################################################################ # CorrDiff time utilities # ############################################################################

[docs]def get_time_from_range(times_range, time_format="%Y-%m-%dT%H:%M:%S"): """Generates a list of times within a given range. Args: times_range: A list containing start time, end time, and optional interval (hours). time_format: The format of the input times (default: "%Y-%m-%dT%H:%M:%S"). Returns: A list of times within the specified range. """ start_time = datetime.datetime.strptime(times_range[0], time_format) end_time = datetime.datetime.strptime(times_range[1], time_format) interval = ( datetime.timedelta(hours=times_range[2]) if len(times_range) > 2 else datetime.timedelta(hours=1) ) times = [ t.strftime(time_format) for t in time_range(start_time, end_time, interval, inclusive=True) ] return times
© Copyright 2023, NVIDIA PhysicsNeMo Team. Last updated on Jun 11, 2025.