deeplearning/physicsnemo/physicsnemo-core/_modules/physicsnemo/models/diffusion/unet.html
Source code for physicsnemo.models.diffusion.unet
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Tuple, Union
import torch
from physicsnemo.models.meta import ModelMetaData
from physicsnemo.models.module import Module
network_module = importlib.import_module("physicsnemo.models.diffusion")
[docs]@dataclass
class MetaData(ModelMetaData):
name: str = "UNet"
# Optimization
jit: bool = False
cuda_graphs: bool = False
amp_cpu: bool = False
amp_gpu: bool = True
torch_fx: bool = False
# Data type
bf16: bool = True
# Inference
onnx: bool = False
# Physics informed
func_torch: bool = False
auto_grad: bool = False
[docs]class UNet(Module): # TODO a lot of redundancy, need to clean up
"""
U-Net Wrapper for CorrDiff deterministic regression model.
Parameters
-----------
img_resolution : Union[int, Tuple[int, int]]
The resolution of the input/output image. If a single int is provided,
then the image is assumed to be square.
img_in_channels : int
Number of channels in the input image.
img_out_channels : int
Number of channels in the output image.
use_fp16: bool, optional
Execute the underlying model at FP16 precision, by default False.
model_type: str, optional
Class name of the underlying model. Must be one of the following:
'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'.
Defaults to 'SongUNetPosEmbd'.
**model_kwargs : dict
Keyword arguments passed to the underlying model `__init__` method.
See Also
--------
For information on model types and their usage:
:class:`~physicsnemo.models.diffusion.SongUNet`: Basic U-Net for diffusion models
:class:`~physicsnemo.models.diffusion.SongUNetPosEmbd`: U-Net with positional embeddings
:class:`~physicsnemo.models.diffusion.SongUNetPosLtEmbd`: U-Net with positional and lead-time embeddings
Please refer to the documentation of these classes for details on how to call
and use these models directly.
References
----------
Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y.,
Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023.
Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling.
arXiv preprint arXiv:2309.15214.
"""
__model_checkpoint_version__ = "0.2.0"
__supported_model_checkpoint_version__ = {
"0.1.0": "Loading UNet checkpoint from older version 0.1.0 (current version is 0.2.0). This version is still supported, but consider re-saving the model to upgrade to version 0.2.0 and remove this warning."
}
@classmethod
def _backward_compat_arg_mapper(
cls, version: str, args: Dict[str, Any]
) -> Dict[str, Any]:
"""Map arguments from older versions to current version format.
Parameters
----------
version : str
Version of the checkpoint being loaded
args : Dict[str, Any]
Arguments dictionary from the checkpoint
Returns
-------
Dict[str, Any]
Updated arguments dictionary compatible with current version
"""
# Call parent class method first
args = super()._backward_compat_arg_mapper(version, args)
if version == "0.1.0":
# In version 0.1.0, img_channels was unused
if "img_channels" in args:
_ = args.pop("img_channels")
# Sigma parameters are also unused
if "sigma_min" in args:
_ = args.pop("sigma_min")
if "sigma_max" in args:
_ = args.pop("sigma_max")
if "sigma_data" in args:
_ = args.pop("sigma_data")
return args
def __init__(
self,
img_resolution: Union[int, Tuple[int, int]],
img_in_channels: int,
img_out_channels: int,
use_fp16: bool = False,
model_type: Literal[
"SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet"
] = "SongUNetPosEmbd",
**model_kwargs: dict,
):
super().__init__(meta=MetaData)
# for compatibility with older versions that took only 1 dimension
if isinstance(img_resolution, int):
self.img_shape_x = self.img_shape_y = img_resolution
else:
self.img_shape_y = img_resolution[0]
self.img_shape_x = img_resolution[1]
self.img_in_channels = img_in_channels
self.img_out_channels = img_out_channels
self.use_fp16 = use_fp16
model_class = getattr(network_module, model_type)
self.model = model_class(
img_resolution=img_resolution,
in_channels=img_in_channels + img_out_channels,
out_channels=img_out_channels,
**model_kwargs,
)
[docs] def forward(
self,
x: torch.Tensor,
img_lr: torch.Tensor,
force_fp32: bool = False,
**model_kwargs: dict,
) -> torch.Tensor:
"""
Forward pass of the UNet wrapper model.
This method concatenates the input tensor with the low-resolution conditioning tensor
and passes the result through the underlying model.
Parameters
----------
x : torch.Tensor
The input tensor, typically zero-filled, of shape (B, C_hr, H, W).
img_lr : torch.Tensor
Low-resolution conditioning image of shape (B, C_lr, H, W).
force_fp32 : bool, optional
Whether to force FP32 precision regardless of the `use_fp16` attribute,
by default False.
**model_kwargs : dict
Additional keyword arguments to pass to the underlying model
`self.model` forward method.
Returns
-------
torch.Tensor
Output tensor (prediction) of shape (B, C_hr, H, W).
Raises
------
ValueError
If the model output dtype doesn't match the expected dtype.
"""
# SR: concatenate input channels
if img_lr is not None:
x = torch.cat((x, img_lr), dim=1)
dtype = (
torch.float16
if (self.use_fp16 and not force_fp32 and x.device.type == "cuda")
else torch.float32
)
F_x = self.model(
x.to(dtype), # (c_in * x).to(dtype),
torch.zeros(x.shape[0], dtype=dtype, device=x.device), # c_noise.flatten()
class_labels=None,
**model_kwargs,
)
if (F_x.dtype != dtype) and not torch.is_autocast_enabled():
raise ValueError(
f"Expected the dtype to be {dtype}, " f"but got {F_x.dtype} instead."
)
# skip connection
D_x = F_x.to(torch.float32)
return D_x
[docs] def round_sigma(self, sigma: Union[float, List, torch.Tensor]) -> torch.Tensor:
"""
Convert a given sigma value(s) to a tensor representation.
Parameters
----------
sigma : Union[float, List, torch.Tensor]
The sigma value(s) to convert.
Returns
-------
torch.Tensor
The tensor representation of the provided sigma value(s).
"""
return torch.as_tensor(sigma)@property
def amp_mode(self):
"""
Return the *amp_mode* flag of the underlying model if present.
"""
return getattr(self.model, "amp_mode", None)
@amp_mode.setter
def amp_mode(self, value: bool):
"""
Update *amp_mode* on the wrapped model and its sub-modules.
"""
if not isinstance(value, bool):
raise TypeError("amp_mode must be a boolean value.")
if hasattr(self.model, "amp_mode"):
self.model.amp_mode = value
# Recursively update sub-modules that define *amp_mode*.
for sub_module in self.model.modules():
if hasattr(sub_module, "amp_mode"):
sub_module.amp_mode = value
# TODO: implement amp_mode property for StormCastUNet (same as UNet)
[docs]class StormCastUNet(Module):
"""
U-Net wrapper for StormCast; used so the same Song U-Net network can be re-used for this model.
Parameters
-----------
img_resolution : int or List[int]
The resolution of the input/output image.
img_channels : int
Number of color channels.
img_in_channels : int
Number of input color channels.
img_out_channels : int
Number of output color channels.
use_fp16: bool, optional
Execute the underlying model at FP16 precision?, by default False.
sigma_min: float, optional
Minimum supported noise level, by default 0.
sigma_max: float, optional
Maximum supported noise level, by default float('inf').
sigma_data: float, optional
Expected standard deviation of the training data, by default 0.5.
model_type: str, optional
Class name of the underlying model, by default 'SongUNet'.
**model_kwargs : dict
Keyword arguments for the underlying model.
"""
def __init__(
self,
img_resolution,
img_in_channels,
img_out_channels,
use_fp16=False,
sigma_min=0,
sigma_max=float("inf"),
sigma_data=0.5,
model_type="SongUNet",
**model_kwargs,
):
super().__init__(meta=MetaData("StormCastUNet"))
if isinstance(img_resolution, int):
self.img_shape_x = self.img_shape_y = img_resolution
else:
self.img_shape_x = img_resolution[0]
self.img_shape_y = img_resolution[1]
self.img_in_channels = img_in_channels
self.img_out_channels = img_out_channels
self.use_fp16 = use_fp16
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_data = sigma_data
model_class = getattr(network_module, model_type)
self.model = model_class(
img_resolution=img_resolution,
in_channels=img_in_channels,
out_channels=img_out_channels,
**model_kwargs,
)
[docs] def forward(self, x, force_fp32=False, **model_kwargs):
"""Run a forward pass of the StormCast regression U-Net.
Args:
x (torch.Tensor): input to the U-Net
force_fp32 (bool, optional): force casting to fp_32 if True. Defaults to False.
Raises:
ValueError: If input data type is a mismatch with provided options
Returns:
D_x (torch.Tensor): Output (prediction) of the U-Net
"""
x = x.to(torch.float32)
dtype = (
torch.float16
if (self.use_fp16 and not force_fp32 and x.device.type == "cuda")
else torch.float32
)
F_x = self.model(
x.to(dtype),
torch.zeros(x.shape[0], dtype=x.dtype, device=x.device),
class_labels=None,
**model_kwargs,
)
if (F_x.dtype != dtype) and not torch.is_autocast_enabled():
raise ValueError(
f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead."
)
D_x = F_x.to(torch.float32)
return D_x