Source code for physicsnemo.models.diffusion.dhariwal_unet

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import List

import numpy as np
import torch
from torch.nn.functional import silu

from physicsnemo.models.diffusion import (
    Conv2d,
    GroupNorm,
    Linear,
    PositionalEmbedding,
    UNetBlock,
)
from physicsnemo.models.diffusion.utils import _recursive_property
from physicsnemo.models.meta import ModelMetaData
from physicsnemo.models.module import Module

# ------------------------------------------------------------------------------
# Backbone architectures
# ------------------------------------------------------------------------------


@dataclass
class MetaData(ModelMetaData):
    name: str = "DhariwalUNet"
    # Optimization
    jit: bool = False
    cuda_graphs: bool = False
    amp_cpu: bool = False
    amp_gpu: bool = True
    torch_fx: bool = False
    # Data type
    bf16: bool = True
    # Inference
    onnx: bool = False
    # Physics informed
    func_torch: bool = False
    auto_grad: bool = False


# NOTE: this module can actually be replicated as a special case of the
# SongUnet class (with very minior extension of the SongUnet class). We should
# consider inheriting the more general SongUnet class here.
[docs] class DhariwalUNet(Module): r""" This architecture is a diffusion backbone for 2D image generation. It reimplements the `ADM architecture <https://arxiv.org/abs/2105.05233>`_, a U-Net variant, with optional self-attention. It is highly similar to the U-Net backbone defined in :class:`~physicsnemo.models.diffusion.song_unet.SongUNet`, and only differs in a few aspects: • The embedding conditioning mechanism relies on adaptive scaling of the group normalization layers within the U-Net blocks. • The parameters initialization follows Kaiming uniform initialization. Parameters ----------- img_resolution :int The resolution :math:`H = W` of the input/output image. Assumes square images. *Note:* This parameter is only used as a convenience to build the network. In practice, the model can still be used with images of different resolutions. in_channels : int Number of channels :math:`C_{in}` in the input image. May include channels from both the latent state :math:`\mathbf{x}` and additional channels when conditioning on images. For an unconditional model, this should be equal to ``out_channels``. out_channels : int Number of channels :math:`C_{out}` in the output image. Should be equal to the number of channels :math:`C_{\mathbf{x}}` in the latent state. label_dim : int, optional, default=0 Dimension of the vector-valued ``class_labels`` conditioning; 0 indicates no conditioning on class labels. augment_dim : int, optional, default=0 Dimension of the vector-valued ``augment_labels`` conditioning; 0 means no conditioning on augmentation labels. model_channels : int, optional, default=128 Base multiplier for the number of channels accross the entire network. channel_mult : List[int], optional, default=[1,2,2,2] Multipliers for the number of channels at every level in the encoder and decoder. The length of ``channel_mult`` determines the number of levels in the U-Net. At level ``i``, the number of channel in the feature map is ``channel_mult[i] * model_channels``. channel_mult_emb : int, optional, default=4 Multiplier for the number of channels in the embedding vector. The embedding vector has ``model_channels * channel_mult_emb`` channels. num_blocks : int, optional, default=3 Number of U-Net blocks at each level. attn_resolutions : List[int], optional, default=[16] Resolutions of the levels at which self-attention layers are applied. Note that the feature map resolution must match exactly the value provided in ``attn_resolutions`` for the self-attention layers to be applied. dropout : float, optional, default=0.10 Dropout probability applied to intermediate activations within the U-Net blocks. label_dropout : float, optional, default=0.0 Dropout probability applied to the ``class_labels``. Typically used for classifier-free guidance. Forward ------- x : torch.Tensor The input tensor of shape :math:`(B, C_{in}, H_{in}, W_{in})`. In general ``x`` is the channel-wise concatenation of the latent state :math:`\mathbf{x}` and additional images used for conditioning. For an unconditional model, ``x`` is simply the latent state :math:`\mathbf{x}`. noise_labels : torch.Tensor The noise labels of shape :math:`(B,)`. Used for conditioning on the noise level. class_labels : torch.Tensor The class labels of shape :math:`(B, \text{label_dim})`. Used for conditioning on any vector-valued quantity. Can pass ``None`` when ``label_dim`` is 0. augment_labels : torch.Tensor, optional, default=None The augmentation labels of shape :math:`(B, \text{augment_dim})`. Used for conditioning on any additional vector-valued quantity. Can pass ``None`` when ``augment_dim`` is 0. Outputs ------- torch.Tensor: The denoised latent state of shape :math:`(B, C_{out}, H_{in}, W_{in})`. Examples -------- >>> model = DhariwalUNet(img_resolution=16, in_channels=2, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) # noqa: N806 >>> input_image = torch.ones([1, 2, 16, 16]) # noqa: N806 >>> output_image = model(input_image, noise_labels, class_labels) # noqa: N806 """ def __init__( self, img_resolution: int, in_channels: int, out_channels: int, label_dim: int = 0, augment_dim: int = 0, model_channels: int = 192, channel_mult: List[int] = [1, 2, 3, 4], channel_mult_emb: int = 4, num_blocks: int = 3, attn_resolutions: List[int] = [32, 16, 8], dropout: float = 0.10, label_dropout: float = 0.0, ): super().__init__(meta=MetaData()) self.label_dropout = label_dropout emb_channels = model_channels * channel_mult_emb init = dict( init_mode="kaiming_uniform", init_weight=np.sqrt(1 / 3), init_bias=np.sqrt(1 / 3), ) init_zero = dict(init_mode="kaiming_uniform", init_weight=0, init_bias=0) block_kwargs = dict( emb_channels=emb_channels, channels_per_head=64, dropout=dropout, init=init, init_zero=init_zero, ) # Mapping. self.map_noise = PositionalEmbedding(num_channels=model_channels) self.map_augment = ( Linear( in_features=augment_dim, out_features=model_channels, bias=False, **init_zero, ) if augment_dim else None ) self.map_layer0 = Linear( in_features=model_channels, out_features=emb_channels, **init ) self.map_layer1 = Linear( in_features=emb_channels, out_features=emb_channels, **init ) self.map_label = ( Linear( in_features=label_dim, out_features=emb_channels, bias=False, init_mode="kaiming_normal", init_weight=np.sqrt(label_dim), ) if label_dim else None ) # Encoder. self.enc = torch.nn.ModuleDict() cout = in_channels for level, mult in enumerate(channel_mult): res = img_resolution >> level if level == 0: cin = cout cout = model_channels * mult self.enc[f"{res}x{res}_conv"] = Conv2d( in_channels=cin, out_channels=cout, kernel=3, **init ) else: self.enc[f"{res}x{res}_down"] = UNetBlock( in_channels=cout, out_channels=cout, down=True, **block_kwargs ) for idx in range(num_blocks): cin = cout cout = model_channels * mult self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( in_channels=cin, out_channels=cout, attention=(res in attn_resolutions), **block_kwargs, ) skips = [block.out_channels for block in self.enc.values()] # Decoder. self.dec = torch.nn.ModuleDict() for level, mult in reversed(list(enumerate(channel_mult))): res = img_resolution >> level if level == len(channel_mult) - 1: self.dec[f"{res}x{res}_in0"] = UNetBlock( in_channels=cout, out_channels=cout, attention=True, **block_kwargs ) self.dec[f"{res}x{res}_in1"] = UNetBlock( in_channels=cout, out_channels=cout, **block_kwargs ) else: self.dec[f"{res}x{res}_up"] = UNetBlock( in_channels=cout, out_channels=cout, up=True, **block_kwargs ) for idx in range(num_blocks + 1): cin = cout + skips.pop() cout = model_channels * mult self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( in_channels=cin, out_channels=cout, attention=(res in attn_resolutions), **block_kwargs, ) self.out_norm = GroupNorm(num_channels=cout) self.out_conv = Conv2d( in_channels=cout, out_channels=out_channels, kernel=3, **init_zero ) # Properties that are recursively set on submodules profile_mode = _recursive_property( "profile_mode", bool, "Should be set to ``True`` to enable profiling." ) amp_mode = _recursive_property( "amp_mode", bool, "Should be set to ``True`` to enable automatic mixed precision.", ) def forward(self, x, noise_labels, class_labels, augment_labels=None): # Mapping. emb = self.map_noise(noise_labels) if self.map_augment is not None and augment_labels is not None: emb = emb + self.map_augment(augment_labels) emb = silu(self.map_layer0(emb)) emb = self.map_layer1(emb) if self.map_label is not None: tmp = class_labels if self.training and self.label_dropout: tmp = tmp * ( torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout ).to(tmp.dtype) emb = emb + self.map_label(tmp) emb = silu(emb) # Encoder. skips = [] for block in self.enc.values(): x = block(x, emb) if isinstance(block, UNetBlock) else block(x) skips.append(x) # Decoder. for block in self.dec.values(): if x.shape[1] != block.in_channels: x = torch.cat([x, skips.pop()], dim=1) x = block(x, emb) x = self.out_conv(silu(self.out_norm(x))) return x