NVIDIA Modulus Core (Latest Release)
Core (Latest Release)

deeplearning/modulus/modulus-core/_modules/modulus/models/diffusion/dhariwal_unet.html

Source code for modulus.models.diffusion.dhariwal_unet

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Model architectures used in the paper "Elucidating the Design Space of 
Diffusion-Based Generative Models".
"""

from dataclasses import dataclass
from typing import List

import numpy as np
import torch
from torch.nn.functional import silu

from modulus.models.diffusion import (
    Conv2d,
    GroupNorm,
    Linear,
    PositionalEmbedding,
    UNetBlock,
)
from modulus.models.meta import ModelMetaData
from modulus.models.module import Module


[docs]@dataclass class MetaData(ModelMetaData): name: str = "DhariwalUNet" # Optimization jit: bool = False cuda_graphs: bool = False amp_cpu: bool = False amp_gpu: bool = True torch_fx: bool = False # Data type bf16: bool = True # Inference onnx: bool = False # Physics informed func_torch: bool = False auto_grad: bool = False
[docs]class DhariwalUNet(Module): """ Reimplementation of the ADM architecture, a U-Net variant, with optional self-attention. This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder type, embedding type, etc., making it flexible and adaptable to different tasks and configurations. Parameters ----------- img_resolution : int The resolution of the input/output image. in_channels : int Number of channels in the input image. out_channels : int Number of channels in the output image. label_dim : int, optional Number of class labels; 0 indicates an unconditional model. By default 0. augment_dim : int, optional Dimensionality of augmentation labels; 0 means no augmentation. By default 0. model_channels : int, optional Base multiplier for the number of channels across the network, by default 192. channel_mult : List[int], optional Per-resolution multipliers for the number of channels. By default [1,2,3,4]. channel_mult_emb : int, optional Multiplier for the dimensionality of the embedding vector. By default 4. num_blocks : int, optional Number of residual blocks per resolution. By default 3. attn_resolutions : List[int], optional Resolutions at which self-attention layers are applied. By default [32, 16, 8]. dropout : float, optional Dropout probability applied to intermediate activations. By default 0.10. label_dropout : float, optional Dropout probability of class labels for classifier-free guidance. By default 0.0. Reference ---------- Reference: Dhariwal, P. and Nichol, A., 2021. Diffusion models beat gans on image synthesis. Advances in neural information processing systems, 34, pp.8780-8794. Note ----- Equivalent to the original implementation by Dhariwal and Nichol, available at https://github.com/openai/guided-diffusion Example -------- >>> model = DhariwalUNet(img_resolution=16, in_channels=2, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels) """ def __init__( self, img_resolution: int, in_channels: int, out_channels: int, label_dim: int = 0, augment_dim: int = 0, model_channels: int = 192, channel_mult: List[int] = [1, 2, 3, 4], channel_mult_emb: int = 4, num_blocks: int = 3, attn_resolutions: List[int] = [32, 16, 8], dropout: float = 0.10, label_dropout: float = 0.0, ): super().__init__(meta=MetaData()) self.label_dropout = label_dropout emb_channels = model_channels * channel_mult_emb init = dict( init_mode="kaiming_uniform", init_weight=np.sqrt(1 / 3), init_bias=np.sqrt(1 / 3), ) init_zero = dict(init_mode="kaiming_uniform", init_weight=0, init_bias=0) block_kwargs = dict( emb_channels=emb_channels, channels_per_head=64, dropout=dropout, init=init, init_zero=init_zero, ) # Mapping. self.map_noise = PositionalEmbedding(num_channels=model_channels) self.map_augment = ( Linear( in_features=augment_dim, out_features=model_channels, bias=False, **init_zero, ) if augment_dim else None ) self.map_layer0 = Linear( in_features=model_channels, out_features=emb_channels, **init ) self.map_layer1 = Linear( in_features=emb_channels, out_features=emb_channels, **init ) self.map_label = ( Linear( in_features=label_dim, out_features=emb_channels, bias=False, init_mode="kaiming_normal", init_weight=np.sqrt(label_dim), ) if label_dim else None ) # Encoder. self.enc = torch.nn.ModuleDict() cout = in_channels for level, mult in enumerate(channel_mult): res = img_resolution >> level if level == 0: cin = cout cout = model_channels * mult self.enc[f"{res}x{res}_conv"] = Conv2d( in_channels=cin, out_channels=cout, kernel=3, **init ) else: self.enc[f"{res}x{res}_down"] = UNetBlock( in_channels=cout, out_channels=cout, down=True, **block_kwargs ) for idx in range(num_blocks): cin = cout cout = model_channels * mult self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( in_channels=cin, out_channels=cout, attention=(res in attn_resolutions), **block_kwargs, ) skips = [block.out_channels for block in self.enc.values()] # Decoder. self.dec = torch.nn.ModuleDict() for level, mult in reversed(list(enumerate(channel_mult))): res = img_resolution >> level if level == len(channel_mult) - 1: self.dec[f"{res}x{res}_in0"] = UNetBlock( in_channels=cout, out_channels=cout, attention=True, **block_kwargs ) self.dec[f"{res}x{res}_in1"] = UNetBlock( in_channels=cout, out_channels=cout, **block_kwargs ) else: self.dec[f"{res}x{res}_up"] = UNetBlock( in_channels=cout, out_channels=cout, up=True, **block_kwargs ) for idx in range(num_blocks + 1): cin = cout + skips.pop() cout = model_channels * mult self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( in_channels=cin, out_channels=cout, attention=(res in attn_resolutions), **block_kwargs, ) self.out_norm = GroupNorm(num_channels=cout) self.out_conv = Conv2d( in_channels=cout, out_channels=out_channels, kernel=3, **init_zero )
[docs] def forward(self, x, noise_labels, class_labels, augment_labels=None): # Mapping. emb = self.map_noise(noise_labels) if self.map_augment is not None and augment_labels is not None: emb = emb + self.map_augment(augment_labels) emb = silu(self.map_layer0(emb)) emb = self.map_layer1(emb) if self.map_label is not None: tmp = class_labels if self.training and self.label_dropout: tmp = tmp * ( torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout ).to(tmp.dtype) emb = emb + self.map_label(tmp) emb = silu(emb) # Encoder. skips = [] for block in self.enc.values(): x = block(x, emb) if isinstance(block, UNetBlock) else block(x) skips.append(x) # Decoder. for block in self.dec.values(): if x.shape[1] != block.in_channels: x = torch.cat([x, skips.pop()], dim=1) x = block(x, emb) x = self.out_conv(silu(self.out_norm(x))) return x
© Copyright 2023, NVIDIA Modulus Team. Last updated on Nov 27, 2024.