deeplearning/physicsnemo/physicsnemo-core/_modules/physicsnemo/utils/corrdiff/utils.html
Source code for physicsnemo.utils.corrdiff.utils
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from typing import Optional
import cftime
import nvtx
import torch
import tqdm
from physicsnemo.utils.generative import StackedRandomGenerator, time_range
############################################################################
# CorrDiff Generation Utilities #
############################################################################
[docs]def regression_step(
net: torch.nn.Module,
img_lr: torch.Tensor,
latents_shape: torch.Size,
lead_time_label: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Perform a regression step to produce ensemble mean prediction.
This function takes a low-resolution input and performs a regression step to produce
an ensemble mean prediction. It processes a single instance and then replicates
the results across the batch dimension if needed.
Parameters
----------
net : torch.nn.Module
U-Net model for regression.
img_lr : torch.Tensor
Low-resolution input to the network with shape (1, channels, height, width).
Must have a batch dimension of 1.
latents_shape : torch.Size
Shape of the latent representation with format
(batch_size, out_channels, image_shape_y, image_shape_x).
lead_time_label : Optional[torch.Tensor], optional
Lead time label tensor for lead time conditioning,
with shape (1, lead_time_dims). Default is None.
Returns
-------
torch.Tensor
Predicted ensemble mean at the next time step with shape matching latents_shape.
Raises
------
ValueError
If img_lr has a batch size greater than 1.
"""
# Create a tensor of zeros with the given shape and move it to the appropriate device
x_hat = torch.zeros(latents_shape, dtype=torch.float64, device=net.device)
# Safety check: avoid silently ignoring batch elements in img_lr
if img_lr.shape[0] > 1:
raise ValueError(
f"Expected img_lr to have a batch size of 1, "
f"but found {img_lr.shape[0]}."
)
# Perform regression on a single batch element
with torch.inference_mode():
if lead_time_label is not None:
x = net(x=x_hat[0:1], img_lr=img_lr, lead_time_label=lead_time_label)
else:
x = net(x=x_hat[0:1], img_lr=img_lr)
# If the batch size is greater than 1, repeat the prediction
if x_hat.shape[0] > 1:
x = x.repeat([d if i == 0 else 1 for i, d in enumerate(x_hat.shape)])
return x
[docs]def diffusion_step(
net: torch.nn.Module,
sampler_fn: callable,
img_shape: tuple,
img_out_channels: int,
rank_batches: list,
img_lr: torch.Tensor,
rank: int,
device: torch.device,
mean_hr: torch.Tensor = None,
lead_time_label: torch.Tensor = None,
) -> torch.Tensor:
"""
Generate images using diffusion techniques as described in the relevant paper.
This function applies a diffusion model to generate high-resolution images based on
low-resolution inputs. It supports optional conditioning on high-resolution mean
predictions and lead time labels.
For each low-resolution sample in `img_lr`, the function generates multiple
high-resolution samples, with different random seeds, specified in `rank_batches`.
The function then concatenates these high-resolution samples across the batch dimension.
Parameters
----------
net : torch.nn.Module
The diffusion model network.
sampler_fn : callable
Function used to sample images from the diffusion model.
img_shape : tuple
Shape of the images, (height, width).
img_out_channels : int
Number of output channels for the image.
rank_batches : list
List of batches of seeds to process.
img_lr : torch.Tensor
Low-resolution input image with shape (seed_batch_size, channels_lr, height, width).
rank : int, optional
Rank of the current process for distributed processing.
device : torch.device, optional
Device to perform computations.
mean_hr : torch.Tensor, optional
High-resolution mean tensor to be used as an additional input,
with shape (1, channels_hr, height, width). Default is None.
lead_time_label : torch.Tensor, optional
Lead time label tensor for temporal conditioning,
with shape (batch_size, lead_time_dims). Default is None.
Returns
-------
torch.Tensor
Generated images concatenated across batches with shape
(seed_batch_size * len(rank_batches), out_channels, height, width).
"""
# Check img_lr dimensions match expected shape
if img_lr.shape[2:] != img_shape:
raise ValueError(
f"img_lr shape {img_lr.shape[2:]} does not match expected shape img_shape {img_shape}"
)
# Check mean_hr dimensions if provided
if mean_hr is not None:
if mean_hr.shape[2:] != img_shape:
raise ValueError(
f"mean_hr shape {mean_hr.shape[2:]} does not match expected shape img_shape {img_shape}"
)
if mean_hr.shape[0] != 1:
raise ValueError(f"mean_hr must have batch size 1, got {mean_hr.shape[0]}")
img_lr = img_lr.to(memory_format=torch.channels_last)
# Handling of the high-res mean
additional_args = {}
if mean_hr is not None:
additional_args["mean_hr"] = mean_hr
if lead_time_label is not None:
additional_args["lead_time_label"] = lead_time_label
# Loop over batches
all_images = []
for batch_seeds in tqdm.tqdm(rank_batches, unit="batch", disable=(rank != 0)):
with nvtx.annotate(f"generate {len(all_images)}", color="rapids"):
batch_size = len(batch_seeds)
if batch_size == 0:
continue
# Initialize random generator, and generate latents
rnd = StackedRandomGenerator(device, batch_seeds)
latents = rnd.randn(
[
img_lr.shape[0],
img_out_channels,
img_shape[0],
img_shape[1],
],
device=device,
).to(memory_format=torch.channels_last)
with torch.inference_mode():
images = sampler_fn(
net, latents, img_lr, randn_like=rnd.randn_like, **additional_args
)
all_images.append(images)
return torch.cat(all_images)############################################################################
# CorrDiff writer utilities #
############################################################################
[docs]class NetCDFWriter:
"""NetCDF Writer"""
def __init__(
self, f, lat, lon, input_channels, output_channels, has_lead_time=False
):
self._f = f
self.has_lead_time = has_lead_time
# create unlimited dimensions
f.createDimension("time")
f.createDimension("ensemble")
if lat.shape != lon.shape:
raise ValueError("lat and lon must have the same shape")
ny, nx = lat.shape
# create lat/lon grid
f.createDimension("x", nx)
f.createDimension("y", ny)
v = f.createVariable("lat", "f", dimensions=("y", "x"))
# NOTE rethink this for datasets whose samples don't have constant lat-lon.
v[:] = lat
v.standard_name = "latitude"
v.units = "degrees_north"
v = f.createVariable("lon", "f", dimensions=("y", "x"))
v[:] = lon
v.standard_name = "longitude"
v.units = "degrees_east"
# create time dimension
if has_lead_time:
v = f.createVariable("time", "str", ("time"))
else:
v = f.createVariable("time", "i8", ("time"))
v.calendar = "standard"
v.units = "hours since 1990-01-01 00:00:00"
self.truth_group = f.createGroup("truth")
self.prediction_group = f.createGroup("prediction")
self.input_group = f.createGroup("input")
for variable in output_channels:
name = variable.name + variable.level
self.truth_group.createVariable(name, "f", dimensions=("time", "y", "x"))
self.prediction_group.createVariable(
name, "f", dimensions=("ensemble", "time", "y", "x")
)
# setup input data in netCDF
for variable in input_channels:
name = variable.name + variable.level
self.input_group.createVariable(name, "f", dimensions=("time", "y", "x"))
[docs] def write_input(self, channel_name, time_index, val):
"""Write input data to NetCDF file."""
self.input_group[channel_name][time_index] = val
[docs] def write_truth(self, channel_name, time_index, val):
"""Write ground truth data to NetCDF file."""
self.truth_group[channel_name][time_index] = val
[docs] def write_prediction(self, channel_name, time_index, ensemble_index, val):
"""Write prediction data to NetCDF file."""
self.prediction_group[channel_name][ensemble_index, time_index] = val
[docs] def write_time(self, time_index, time):
"""Write time information to NetCDF file."""
if self.has_lead_time:
self._f["time"][time_index] = time
else:
time_v = self._f["time"]
self._f["time"][time_index] = cftime.date2num(
time, time_v.units, time_v.calendar
)############################################################################
# CorrDiff time utilities #
############################################################################
[docs]def get_time_from_range(times_range, time_format="%Y-%m-%dT%H:%M:%S"):
"""Generates a list of times within a given range.
Args:
times_range: A list containing start time, end time, and optional interval (hours).
time_format: The format of the input times (default: "%Y-%m-%dT%H:%M:%S").
Returns:
A list of times within the specified range.
"""
start_time = datetime.datetime.strptime(times_range[0], time_format)
end_time = datetime.datetime.strptime(times_range[1], time_format)
interval = (
datetime.timedelta(hours=times_range[2])
if len(times_range) > 2
else datetime.timedelta(hours=1)
)
times = [
t.strftime(time_format)
for t in time_range(start_time, end_time, interval, inclusive=True)
]
return times