NVIDIA PhysicsNeMo Core (Latest Release)

deeplearning/physicsnemo/physicsnemo-core/_modules/physicsnemo/models/diffusion/unet.html

Source code for physicsnemo.models.diffusion.unet

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Tuple, Union

import torch

from physicsnemo.models.meta import ModelMetaData
from physicsnemo.models.module import Module

network_module = importlib.import_module("physicsnemo.models.diffusion")


[docs]@dataclass class MetaData(ModelMetaData): name: str = "UNet" # Optimization jit: bool = False cuda_graphs: bool = False amp_cpu: bool = False amp_gpu: bool = True torch_fx: bool = False # Data type bf16: bool = True # Inference onnx: bool = False # Physics informed func_torch: bool = False auto_grad: bool = False
[docs]class UNet(Module): # TODO a lot of redundancy, need to clean up """ U-Net Wrapper for CorrDiff deterministic regression model. Parameters ----------- img_resolution : Union[int, Tuple[int, int]] The resolution of the input/output image. If a single int is provided, then the image is assumed to be square. img_in_channels : int Number of channels in the input image. img_out_channels : int Number of channels in the output image. use_fp16: bool, optional Execute the underlying model at FP16 precision, by default False. model_type: str, optional Class name of the underlying model. Must be one of the following: 'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'. Defaults to 'SongUNetPosEmbd'. **model_kwargs : dict Keyword arguments passed to the underlying model `__init__` method. See Also -------- For information on model types and their usage: :class:`~physicsnemo.models.diffusion.SongUNet`: Basic U-Net for diffusion models :class:`~physicsnemo.models.diffusion.SongUNetPosEmbd`: U-Net with positional embeddings :class:`~physicsnemo.models.diffusion.SongUNetPosLtEmbd`: U-Net with positional and lead-time embeddings Please refer to the documentation of these classes for details on how to call and use these models directly. References ---------- Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. arXiv preprint arXiv:2309.15214. """ __model_checkpoint_version__ = "0.2.0" __supported_model_checkpoint_version__ = { "0.1.0": "Loading UNet checkpoint from older version 0.1.0 (current version is 0.2.0). This version is still supported, but consider re-saving the model to upgrade to version 0.2.0 and remove this warning." } @classmethod def _backward_compat_arg_mapper( cls, version: str, args: Dict[str, Any] ) -> Dict[str, Any]: """Map arguments from older versions to current version format. Parameters ---------- version : str Version of the checkpoint being loaded args : Dict[str, Any] Arguments dictionary from the checkpoint Returns ------- Dict[str, Any] Updated arguments dictionary compatible with current version """ # Call parent class method first args = super()._backward_compat_arg_mapper(version, args) if version == "0.1.0": # In version 0.1.0, img_channels was unused if "img_channels" in args: _ = args.pop("img_channels") # Sigma parameters are also unused if "sigma_min" in args: _ = args.pop("sigma_min") if "sigma_max" in args: _ = args.pop("sigma_max") if "sigma_data" in args: _ = args.pop("sigma_data") return args def __init__( self, img_resolution: Union[int, Tuple[int, int]], img_in_channels: int, img_out_channels: int, use_fp16: bool = False, model_type: Literal[ "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet" ] = "SongUNetPosEmbd", **model_kwargs: dict, ): super().__init__(meta=MetaData) # for compatibility with older versions that took only 1 dimension if isinstance(img_resolution, int): self.img_shape_x = self.img_shape_y = img_resolution else: self.img_shape_y = img_resolution[0] self.img_shape_x = img_resolution[1] self.img_in_channels = img_in_channels self.img_out_channels = img_out_channels self.use_fp16 = use_fp16 model_class = getattr(network_module, model_type) self.model = model_class( img_resolution=img_resolution, in_channels=img_in_channels + img_out_channels, out_channels=img_out_channels, **model_kwargs, )
[docs] def forward( self, x: torch.Tensor, img_lr: torch.Tensor, force_fp32: bool = False, **model_kwargs: dict, ) -> torch.Tensor: """ Forward pass of the UNet wrapper model. This method concatenates the input tensor with the low-resolution conditioning tensor and passes the result through the underlying model. Parameters ---------- x : torch.Tensor The input tensor, typically zero-filled, of shape (B, C_hr, H, W). img_lr : torch.Tensor Low-resolution conditioning image of shape (B, C_lr, H, W). force_fp32 : bool, optional Whether to force FP32 precision regardless of the `use_fp16` attribute, by default False. **model_kwargs : dict Additional keyword arguments to pass to the underlying model `self.model` forward method. Returns ------- torch.Tensor Output tensor (prediction) of shape (B, C_hr, H, W). Raises ------ ValueError If the model output dtype doesn't match the expected dtype. """ # SR: concatenate input channels if img_lr is not None: x = torch.cat((x, img_lr), dim=1) dtype = ( torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32 ) F_x = self.model( x.to(dtype), # (c_in * x).to(dtype), torch.zeros(x.shape[0], dtype=dtype, device=x.device), # c_noise.flatten() class_labels=None, **model_kwargs, ) if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): raise ValueError( f"Expected the dtype to be {dtype}, " f"but got {F_x.dtype} instead." ) # skip connection D_x = F_x.to(torch.float32) return D_x
[docs] def round_sigma(self, sigma: Union[float, List, torch.Tensor]) -> torch.Tensor: """ Convert a given sigma value(s) to a tensor representation. Parameters ---------- sigma : Union[float, List, torch.Tensor] The sigma value(s) to convert. Returns ------- torch.Tensor The tensor representation of the provided sigma value(s). """ return torch.as_tensor(sigma)

@property def amp_mode(self): """ Return the *amp_mode* flag of the underlying model if present. """ return getattr(self.model, "amp_mode", None) @amp_mode.setter def amp_mode(self, value: bool): """ Update *amp_mode* on the wrapped model and its sub-modules. """ if not isinstance(value, bool): raise TypeError("amp_mode must be a boolean value.") if hasattr(self.model, "amp_mode"): self.model.amp_mode = value # Recursively update sub-modules that define *amp_mode*. for sub_module in self.model.modules(): if hasattr(sub_module, "amp_mode"): sub_module.amp_mode = value

# TODO: implement amp_mode property for StormCastUNet (same as UNet)

[docs]class StormCastUNet(Module): """ U-Net wrapper for StormCast; used so the same Song U-Net network can be re-used for this model. Parameters ----------- img_resolution : int or List[int] The resolution of the input/output image. img_channels : int Number of color channels. img_in_channels : int Number of input color channels. img_out_channels : int Number of output color channels. use_fp16: bool, optional Execute the underlying model at FP16 precision?, by default False. sigma_min: float, optional Minimum supported noise level, by default 0. sigma_max: float, optional Maximum supported noise level, by default float('inf'). sigma_data: float, optional Expected standard deviation of the training data, by default 0.5. model_type: str, optional Class name of the underlying model, by default 'SongUNet'. **model_kwargs : dict Keyword arguments for the underlying model. """ def __init__( self, img_resolution, img_in_channels, img_out_channels, use_fp16=False, sigma_min=0, sigma_max=float("inf"), sigma_data=0.5, model_type="SongUNet", **model_kwargs, ): super().__init__(meta=MetaData("StormCastUNet")) if isinstance(img_resolution, int): self.img_shape_x = self.img_shape_y = img_resolution else: self.img_shape_x = img_resolution[0] self.img_shape_y = img_resolution[1] self.img_in_channels = img_in_channels self.img_out_channels = img_out_channels self.use_fp16 = use_fp16 self.sigma_min = sigma_min self.sigma_max = sigma_max self.sigma_data = sigma_data model_class = getattr(network_module, model_type) self.model = model_class( img_resolution=img_resolution, in_channels=img_in_channels, out_channels=img_out_channels, **model_kwargs, )
[docs] def forward(self, x, force_fp32=False, **model_kwargs): """Run a forward pass of the StormCast regression U-Net. Args: x (torch.Tensor): input to the U-Net force_fp32 (bool, optional): force casting to fp_32 if True. Defaults to False. Raises: ValueError: If input data type is a mismatch with provided options Returns: D_x (torch.Tensor): Output (prediction) of the U-Net """ x = x.to(torch.float32) dtype = ( torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32 ) F_x = self.model( x.to(dtype), torch.zeros(x.shape[0], dtype=x.dtype, device=x.device), class_labels=None, **model_kwargs, ) if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): raise ValueError( f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." ) D_x = F_x.to(torch.float32) return D_x
© Copyright 2023, NVIDIA PhysicsNeMo Team. Last updated on Jun 11, 2025.