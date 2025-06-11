# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable , Optional import torch from torch import Tensor from physicsnemo.utils.patching import GridPatching2D [docs] def stochastic_sampler ( net : torch . nn . Module , latents : Tensor , img_lr : Tensor , class_labels : Optional [ Tensor ] = None , randn_like : Callable [[ Tensor ], Tensor ] = torch . randn_like , patching : Optional [ GridPatching2D ] = None , mean_hr : Optional [ Tensor ] = None , lead_time_label : Optional [ Tensor ] = None , num_steps : int = 18 , sigma_min : float = 0.002 , sigma_max : float = 800 , rho : float = 7 , S_churn : float = 0 , S_min : float = 0 , S_max : float = float ( "inf" ), S_noise : float = 1 , ) -> Tensor : """ Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution and patch-based diffusion. Parameters ---------- net : torch.nn.Module The neural network model that generates denoised images from noisy inputs. Expected signature: `net(x, x_lr, t_hat, class_labels, lead_time_label=lead_time_label, embedding_selector=embedding_selector)`, where: x (torch.Tensor): Noisy input of shape (batch_size, C_out, H, W) x_lr (torch.Tensor): Conditioning input of shape (batch_size, C_cond, H, W) t_hat (torch.Tensor): Noise level of shape (batch_size, 1, 1, 1) or scalar class_labels (torch.Tensor, optional): Optional class labels lead_time_label (torch.Tensor, optional): Optional lead time labels embedding_selector (callable, optional): Function to select positional embeddings. Used for patch-based diffusion. Returns: torch.Tensor: Denoised prediction of shape (batch_size, C_out, H, W) Required attributes: sigma_min (float): Minimum supported noise level for the model sigma_max (float): Maximum supported noise level for the model round_sigma (callable): Method to convert sigma values to tensor representation latents : Tensor The latent variables (e.g., noise) used as the initial input for the sampler. Has shape (batch_size, C_out, img_shape_y, img_shape_x). img_lr : Tensor Low-resolution input image for conditioning the super-resolution process. Must have shape (batch_size, C_lr, img_lr_ shape_y, img_lr_shape_x). class_labels : Optional[Tensor], optional Class labels for conditional generation, if required by the model. By default None. randn_like : Callable[[Tensor], Tensor] Function to generate random noise with the same shape as the input tensor. By default torch.randn_like. patching : Optional[GridPatching2D], optional A patching utility for patch-based diffusion. Implements methods to extract patches from an image and batch the patches along `dim=0`. Should also implement a `fuse` method to reconstruct the original image from a batch of patches. See :class:`physicsnemo.utils.patching.GridPatching2D` for details. By default None, in which case non-patched diffusion is used. mean_hr : Optional[Tensor], optional Optional tensor containing mean high-resolution images for conditioning. Must have same height and width as `img_lr`, with shape (B_hr, C_hr, img_lr_shape_y, img_lr_shape_x) where the batch dimension B_hr can be either 1, either equal to batch_size, or can be omitted. If B_hr = 1 or is omitted, `mean_hr` will be expanded to match the shape of `img_lr`. By default None. lead_time_label : Optional[Tensor], optional Optional lead time labels. By default None. num_steps : int Number of time steps for the sampler. By default 18. sigma_min : float Minimum noise level. By default 0.002. sigma_max : float Maximum noise level. By default 800. rho : float Exponent used in the time step discretization. By default 7. S_churn : float Churn parameter controlling the level of noise added in each step. By default 0. S_min : float Minimum time step for applying churn. By default 0. S_max : float Maximum time step for applying churn. By default float("inf"). S_noise : float Noise scaling factor applied during the churn step. By default 1. Returns ------- Tensor The final denoised image produced by the sampler. Same shape as `latents`: (batch_size, C_out, img_shape_y, img_shape_x). See Also -------- :class:`physicsnemo.models.diffusion.EDMPrecondSuperResolution`: A model wrapper that provides preconditioning for super-resolution diffusion models and implements the required interface for this sampler. """ # Adjust noise levels based on what's supported by the network. # Proposed EDM sampler (Algorithm 2) with minor changes to enable # super-resolution/ sigma_min = max ( sigma_min , net . sigma_min ) sigma_max = min ( sigma_max , net . sigma_max ) # Safety check on type of patching if patching is not None and not isinstance ( patching , GridPatching2D ): raise ValueError ( "patching must be an instance of GridPatching2D." ) # Safety check: if patching is used then img_lr and latents must have same # height and width, otherwise there is mismatch in the number # of patches extracted to form the final batch_size. if patching : if img_lr . shape [ - 2 :] != latents . shape [ - 2 :]: raise ValueError ( f "img_lr and latents must have the same height and width, " f "but found { img_lr . shape [ - 2 :] } vs { latents . shape [ - 2 :] } . " ) # img_lr and latents must also have the same batch_size, otherwise mismatch # when processed by the network if img_lr . shape [ 0 ] != latents . shape [ 0 ]: raise ValueError ( f "img_lr and latents must have the same batch size, but found " f " { img_lr . shape [ 0 ] } vs { latents . shape [ 0 ] } ." ) # Time step discretization. step_indices = torch . arange ( num_steps , dtype = torch . float64 , device = latents . device ) t_steps = ( sigma_max ** ( 1 / rho ) + step_indices / ( num_steps - 1 ) * ( sigma_min ** ( 1 / rho ) - sigma_max ** ( 1 / rho )) ) ** rho t_steps = torch . cat ( [ net . round_sigma ( t_steps ), torch . zeros_like ( t_steps [: 1 ])] ) # t_N = 0 batch_size = img_lr . shape [ 0 ] # conditioning = [mean_hr, img_lr, global_lr, pos_embd] x_lr = img_lr if mean_hr is not None : if mean_hr . shape [ - 2 :] != img_lr . shape [ - 2 :]: raise ValueError ( f "mean_hr and img_lr must have the same height and width, " f "but found { mean_hr . shape [ - 2 :] } vs { img_lr . shape [ - 2 :] } ." ) x_lr = torch . cat (( mean_hr . expand ( x_lr . shape [ 0 ], - 1 , - 1 , - 1 ), x_lr ), dim = 1 ) # input and position padding + patching if patching : # Patched conditioning [x_lr, mean_hr] # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x) x_lr = patching . apply ( input = x_lr , additional_input = img_lr ) # Function to select the correct positional embedding for each patch def patch_embedding_selector ( emb ): # emb: (N_pe, image_shape_y, image_shape_x) # return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x) return patching . apply ( emb [ None ] . expand ( batch_size , - 1 , - 1 , - 1 )) else : patch_embedding_selector = None # Main sampling loop. x_next = latents . to ( torch . float64 ) * t_steps [ 0 ] for i , ( t_cur , t_next ) in enumerate ( zip ( t_steps [: - 1 ], t_steps [ 1 :])): # 0, ..., N-1 x_cur = x_next # Increase noise temporarily. gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 t_hat = net . round_sigma ( t_cur + gamma * t_cur ) x_hat = x_cur + ( t_hat ** 2 - t_cur ** 2 ) . sqrt () * S_noise * randn_like ( x_cur ) # Euler step. Perform patching operation on score tensor if patch-based # generation is used denoised = net(x_hat, t_hat, # class_labels,lead_time_label=lead_time_label).to(torch.float64) x_hat_batch = ( patching . apply ( input = x_hat ) if patching else x_hat ) . to ( latents . device ) x_lr = x_lr . to ( latents . device ) if lead_time_label is not None : denoised = net ( x_hat_batch , x_lr , t_hat , class_labels , lead_time_label = lead_time_label , embedding_selector = patch_embedding_selector , ) . to ( torch . float64 ) else : denoised = net ( x_hat_batch , x_lr , t_hat , class_labels , embedding_selector = patch_embedding_selector , ) . to ( torch . float64 ) if patching : # Un-patch the denoised image # (batch_size, C_out, img_shape_y, img_shape_x) denoised = patching . fuse ( input = denoised , batch_size = batch_size ) d_cur = ( x_hat - denoised ) / t_hat x_next = x_hat + ( t_next - t_hat ) * d_cur # Apply 2nd order correction. if i < num_steps - 1 : # Patched input # (batch_size * patch_num, C_out, patch_shape_y, patch_shape_x) x_next_batch = ( patching . apply ( input = x_next ) if patching else x_next ) . to ( latents . device ) if lead_time_label is not None : denoised = net ( x_next_batch , x_lr , t_next , class_labels , lead_time_label = lead_time_label , embedding_selector = patch_embedding_selector , ) . to ( torch . float64 ) else : denoised = net ( x_next_batch , x_lr , t_next , class_labels , embedding_selector = patch_embedding_selector , ) . to ( torch . float64 ) if patching : # Un-patch the denoised image # (batch_size, C_out, img_shape_y, img_shape_x) denoised = patching . fuse ( input = denoised , batch_size = batch_size ) d_prime = ( x_next - denoised ) / t_next x_next = x_hat + ( t_next - t_hat ) * ( 0.5 * d_cur + 0.5 * d_prime ) return x_next