What can I help you with?
NVIDIA PhysicsNeMo Core (Latest Release)

deeplearning/physicsnemo/physicsnemo-core/_modules/physicsnemo/utils/generative/stochastic_sampler.html

Source code for physicsnemo.utils.generative.stochastic_sampler

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Callable, Optional

import torch
from torch import Tensor

from physicsnemo.utils.patching import GridPatching2D


[docs]def stochastic_sampler( net: torch.nn.Module, latents: Tensor, img_lr: Tensor, class_labels: Optional[Tensor] = None, randn_like: Callable[[Tensor], Tensor] = torch.randn_like, patching: Optional[GridPatching2D] = None, mean_hr: Optional[Tensor] = None, lead_time_label: Optional[Tensor] = None, num_steps: int = 18, sigma_min: float = 0.002, sigma_max: float = 800, rho: float = 7, S_churn: float = 0, S_min: float = 0, S_max: float = float("inf"), S_noise: float = 1, ) -> Tensor: """ Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution and patch-based diffusion. Parameters ---------- net : torch.nn.Module The neural network model that generates denoised images from noisy inputs. Expected signature: `net(x, x_lr, t_hat, class_labels, lead_time_label=lead_time_label, embedding_selector=embedding_selector)`, where: x (torch.Tensor): Noisy input of shape (batch_size, C_out, H, W) x_lr (torch.Tensor): Conditioning input of shape (batch_size, C_cond, H, W) t_hat (torch.Tensor): Noise level of shape (batch_size, 1, 1, 1) or scalar class_labels (torch.Tensor, optional): Optional class labels lead_time_label (torch.Tensor, optional): Optional lead time labels embedding_selector (callable, optional): Function to select positional embeddings. Used for patch-based diffusion. Returns: torch.Tensor: Denoised prediction of shape (batch_size, C_out, H, W) Required attributes: sigma_min (float): Minimum supported noise level for the model sigma_max (float): Maximum supported noise level for the model round_sigma (callable): Method to convert sigma values to tensor representation latents : Tensor The latent variables (e.g., noise) used as the initial input for the sampler. Has shape (batch_size, C_out, img_shape_y, img_shape_x). img_lr : Tensor Low-resolution input image for conditioning the super-resolution process. Must have shape (batch_size, C_lr, img_lr_ shape_y, img_lr_shape_x). class_labels : Optional[Tensor], optional Class labels for conditional generation, if required by the model. By default None. randn_like : Callable[[Tensor], Tensor] Function to generate random noise with the same shape as the input tensor. By default torch.randn_like. patching : Optional[GridPatching2D], optional A patching utility for patch-based diffusion. Implements methods to extract patches from an image and batch the patches along `dim=0`. Should also implement a `fuse` method to reconstruct the original image from a batch of patches. See :class:`physicsnemo.utils.patching.GridPatching2D` for details. By default None, in which case non-patched diffusion is used. mean_hr : Optional[Tensor], optional Optional tensor containing mean high-resolution images for conditioning. Must have same height and width as `img_lr`, with shape (B_hr, C_hr, img_lr_shape_y, img_lr_shape_x) where the batch dimension B_hr can be either 1, either equal to batch_size, or can be omitted. If B_hr = 1 or is omitted, `mean_hr` will be expanded to match the shape of `img_lr`. By default None. lead_time_label : Optional[Tensor], optional Optional lead time labels. By default None. num_steps : int Number of time steps for the sampler. By default 18. sigma_min : float Minimum noise level. By default 0.002. sigma_max : float Maximum noise level. By default 800. rho : float Exponent used in the time step discretization. By default 7. S_churn : float Churn parameter controlling the level of noise added in each step. By default 0. S_min : float Minimum time step for applying churn. By default 0. S_max : float Maximum time step for applying churn. By default float("inf"). S_noise : float Noise scaling factor applied during the churn step. By default 1. Returns ------- Tensor The final denoised image produced by the sampler. Same shape as `latents`: (batch_size, C_out, img_shape_y, img_shape_x). See Also -------- :class:`physicsnemo.models.diffusion.EDMPrecondSuperResolution`: A model wrapper that provides preconditioning for super-resolution diffusion models and implements the required interface for this sampler. """ # Adjust noise levels based on what's supported by the network. # Proposed EDM sampler (Algorithm 2) with minor changes to enable # super-resolution/ sigma_min = max(sigma_min, net.sigma_min) sigma_max = min(sigma_max, net.sigma_max) # Safety check on type of patching if patching is not None and not isinstance(patching, GridPatching2D): raise ValueError("patching must be an instance of GridPatching2D.") # Safety check: if patching is used then img_lr and latents must have same # height and width, otherwise there is mismatch in the number # of patches extracted to form the final batch_size. if patching: if img_lr.shape[-2:] != latents.shape[-2:]: raise ValueError( f"img_lr and latents must have the same height and width, " f"but found {img_lr.shape[-2:]} vs {latents.shape[-2:]}. " ) # img_lr and latents must also have the same batch_size, otherwise mismatch # when processed by the network if img_lr.shape[0] != latents.shape[0]: raise ValueError( f"img_lr and latents must have the same batch size, but found " f"{img_lr.shape[0]} vs {latents.shape[0]}." ) # Time step discretization. step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) t_steps = ( sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho t_steps = torch.cat( [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] ) # t_N = 0 batch_size = img_lr.shape[0] # conditioning = [mean_hr, img_lr, global_lr, pos_embd] x_lr = img_lr if mean_hr is not None: if mean_hr.shape[-2:] != img_lr.shape[-2:]: raise ValueError( f"mean_hr and img_lr must have the same height and width, " f"but found {mean_hr.shape[-2:]} vs {img_lr.shape[-2:]}." ) x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1) # input and position padding + patching if patching: # Patched conditioning [x_lr, mean_hr] # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x) x_lr = patching.apply(input=x_lr, additional_input=img_lr) # Function to select the correct positional embedding for each patch def patch_embedding_selector(emb): # emb: (N_pe, image_shape_y, image_shape_x) # return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x) return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) else: patch_embedding_selector = None # Main sampling loop. x_next = latents.to(torch.float64) * t_steps[0] for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 x_cur = x_next # Increase noise temporarily. gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 t_hat = net.round_sigma(t_cur + gamma * t_cur) x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) # Euler step. Perform patching operation on score tensor if patch-based # generation is used denoised = net(x_hat, t_hat, # class_labels,lead_time_label=lead_time_label).to(torch.float64) x_hat_batch = (patching.apply(input=x_hat) if patching else x_hat).to( latents.device ) x_lr = x_lr.to(latents.device) if lead_time_label is not None: denoised = net( x_hat_batch, x_lr, t_hat, class_labels, lead_time_label=lead_time_label, embedding_selector=patch_embedding_selector, ).to(torch.float64) else: denoised = net( x_hat_batch, x_lr, t_hat, class_labels, embedding_selector=patch_embedding_selector, ).to(torch.float64) if patching: # Un-patch the denoised image # (batch_size, C_out, img_shape_y, img_shape_x) denoised = patching.fuse(input=denoised, batch_size=batch_size) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: # Patched input # (batch_size * patch_num, C_out, patch_shape_y, patch_shape_x) x_next_batch = (patching.apply(input=x_next) if patching else x_next).to( latents.device ) if lead_time_label is not None: denoised = net( x_next_batch, x_lr, t_next, class_labels, lead_time_label=lead_time_label, embedding_selector=patch_embedding_selector, ).to(torch.float64) else: denoised = net( x_next_batch, x_lr, t_next, class_labels, embedding_selector=patch_embedding_selector, ).to(torch.float64) if patching: # Un-patch the denoised image # (batch_size, C_out, img_shape_y, img_shape_x) denoised = patching.fuse(input=denoised, batch_size=batch_size) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) return x_next
© Copyright 2023, NVIDIA PhysicsNeMo Team. Last updated on Jun 11, 2025.