Diffusion Models#

PhysicsNeMo diffusion library provides three categories of models, that serve different purposes. All models are based on the Module class.

  • Model backbones:

    Those are highly configurable architectures that can be used as a building block for more complex models.

  • Specialized architectures:

    Those are models that usually inherit from the model backbones, with some specific additional functionalities.

  • Application-specific interfaces:

    These Modules are not truly architectures, but rather wrappers around the model backbones or specialized architectures. Their intent is to provide a more user-friendly interface for specific applications.

In addition of these model architectures, PhysicsNeMo provides diffusion preconditioners, which are essentially wrappers around model architectures, that rescale the inputs and outputs of diffusion models to improve their performance.

Architecture Backbones#

Diffusion model backbones are highly configurable architectures that can be used as a building block for more complex models. Backbones support both conditional and unconditional modeling. Currently, there are two provided backbones: the SongUNet, as implemented in the SongUNet class and the DhariwalUNet, as implemented in the DhariwalUNet class. These models were introduced in the papers Score-based generative modeling through stochastic differential equations, Song et al. and Diffusion models beat gans on image synthesis, Dhariwal et al.. The PhysicsNeMo implementation of these models follows closely that used in the paper Elucidating the Design Space of Diffusion-Based Generative Models, Karras et al.. The original implementation of these models can be found in the EDM repository.

Model backbones can be used as is, such as in in the StormCast example, but they can also be used as a base class for more complex models.

One of the most common diffusion backbones for image generation is the SongUNet class. Its latent state \(\mathbf{x}\) is a tensor of shape \((B, C, H, W)\), where \(B\) is the batch size, \(C\) is the number of channels, and \(H\) and \(W\) are the height and width of the feature map. The model is conditional on the noise level, and can additionally be conditioned on vector-valued class labels and/or images. The model is organized into levels, whose number is determined by len(channel_mult), and each level operates at half the resolution of the previous level (odd resolutions are rounded down). Each level is composed of a sequence of UNet blocks, that optionally contain self-attention layers, as controlled by the attn_resolutions parameter. The feature map resolution is halved at the first block of each level and then remains constant within the level.

Here we start by creating a SongUNet model with 3 levels, that applies self-attention at levels 1 and 2. The model is unconditional, i.e. it is not conditioned on any class labels or images (but is still conditional on the noise level, as it is standard practice for diffusion models).

import torch
from physicsnemo.models.diffusion import SongUNet

B, C_x, res = 3, 6, 40   # Batch size, channels, and resolution of the latent state

model = SongUNet(
    img_resolution=res,
    in_channels=C_x,
    out_channels=C_x,  # No conditioning on image: number of output channels is the same as the input channels
    label_dim=0,  # No conditioning on vector-valued class labels
    augment_dim=0,
    model_channels=64,
    channel_mult=[1, 2, 3],  # 3-levels UNet with 64, 128, and 192 channels at each level, respectively
    num_blocks=4,  # 4 UNet blocks at each level
    attn_resolutions=[20, 10],  # Attention is applied at level 1 (resolution 20x20) and level 2 (resolution 10x10)
)

x = torch.randn(B, C_x, res, res)  # Latent state
noise_labels = torch.randn(B)  # Noise level for each sample

# The feature map resolution is 40 at level 0, 20 at level 1, and 10 at level 2
out = model(x, noise_labels, None)
print(out.shape)  # Shape: (B, C_x, res, res), same as the latent state

# The same model can be used on images of different resolution
# Note: the attention is still applied at levels 1 and 2
x_32 = torch.randn(B, C_x, 32, 32)  # Lower resolution latent state
out_32 = model(x_32, noise_labels, None)  # None means no conditioning on class labels
print(out_32.shape)  # Shape: (B, C_x, 32, 32), same as the latent state

The unconditional SongUNet can be extended to be conditional on class labels and/or images. Conditioning on images is performed by channel-wise concatenation of the image to the latent state \(\mathbf{x}\) before passing it to the model. The model does not perform conditioning on images internally, and this operation is left to the user. For conditioning on class labels (or any vector-valued quantity whose dimension is label_dim), the model internally generates embeddings for the class labels and adds them to intermediate activations within the UNet blocks. Here we extend the previous example to be conditional on a 16-dimensional vector-valued class label and a 3-channel image.

import torch
from physicsnemo.models.diffusion import SongUNet

B, C_x, res = 3, 10, 40
C_cond = 3

model = SongUNet(
    img_resolution=res,
    in_channels=C_x + C_cond,  # Conditioning on an image with C_cond channels
    out_channels=C_x,  # Output channels: only those of the latent state
    label_dim=16,  # Conditioning on 16-dimensional vector-valued class labels
    augment_dim=0,
    model_channels=64,
    channel_mult=[1, 2, 2],
    num_blocks=4,
    attn_resolutions=[20, 10],
)

x = torch.randn(B, C_x, res, res)  # Latent state
cond = torch.randn(B, C_cond, res, res)  # Conditioning image
x_cond = torch.cat([x, cond], dim=1)  # Channel-wise concatenation of the conditioning image before passing to the model
noise_labels = torch.randn(B)
class_labels = torch.randn(B, 16)  # Conditioning on vector-valued class labels

out = model(x_cond, noise_labels, class_labels)
print(out.shape)  # Shape: (B, C_x, res, res), same as the latent state

Specialized Architectures#

Note that even though backbones can be used as is, some of the examples in PhysicsNeMo examples use specialized architectures. These specialized architectures typically inherit from the backbones and implement additional functionalities for specific applications. For example the CorrDiff example uses the specialized architectures SongUNetPosEmbd and SongUNetPosLtEmbd to implement the diffusion model.

Positional embeddings#

Multi-diffusion (also called patch-based diffusion) is a technique to scale diffusion models to large domains. The idea is to split the full domain into patches, and run a diffusion model on each patch in parallel. The generated patches are then fused back to form the final image. This technique is particularly useful for domains that are too large to fit into the memory of a single GPU. The CorrDiff example uses patch-based diffusion for weather downscaling on large domains. A key ingredient in the implementation of patch-based diffusion is the use of a global spatial grid, that is used to inform each patch with their respective position in the full domain. The SongUNetPosEmbd class implements this functionality by providing multiple methods to encode global spatial coordinates of the pixels into a global positional embedding grid. In addition of multi-diffusion, spatial positional embeddings have also been observed to improve the quality of the generated images, even for diffusion models that operate on the full domain.

The following example shows how to use the specialized architecture SongUNetPosEmbd to implement a multi-diffusion model. First, we create a SongUNetPosEmbd model similar to the one in the conditional SongUnet example with a global positional embedding grid of shape (C_pos_emb, res, res). We show that the model can be used with the entire latent state (full domain).

import torch
from physicsnemo.models.diffusion import SongUNetPosEmbd

B, C_x, res = 3, 10, 40
C_cond = 3
C_PE = 8  # Number of channels in the positional embedding grid

# Create a SongUNet with a global positional embedding grid of shape (C_PE, res, res)
model = SongUNetPosEmbd(
    img_resolution=res,  # Define the resolution of the global positional embedding grid
    in_channels=C_x + C_cond + C_PE,  # in_channels must include the number of channels in the positional embedding grid
    out_channels=C_x,
    label_dim=16,
    augment_dim=0,
    model_channels=64,
    channel_mult=[1, 2, 2],
    num_blocks=4,
    attn_resolutions=[20, 10],
    gridtype="learnable",  # Use a learnable grid of positional embeddings
    N_grid_channels=C_PE  # Number of channels in the positional embedding grid
)

# Can pass the entire latent state to the model
x_global = torch.randn(B, C_x, res, res)  # Entire latent state
cond = torch.randn(B, C_cond, res, res)  # Conditioning image
x_cond = torch.cat([x_global, cond], dim=1)  # Latent state with conditioning image
noise_labels = torch.randn(B)
class_labels = torch.randn(B, 16)

# The model internally concatenates the global positional embedding grid to the
# input x_cond before the first UNet block.
# Note: global_index=None means use the entire positional embedding grid
out = model(x_cond, noise_labels, class_labels, global_index=None)
print(out.shape)  # Shape: (B, C_x, res, res), same as the latent state

Now we show that the model can be used on local patches of the latent state (multi-diffusion approach). We manually extract 3 patches from the latent state. Patches are treated as individual samples, so they are concatenated along the batch dimension. We also create a global grid of indices grid that contains the indices of the pixels in the full domain, and we exctract the same 3 patches from the global grid and pass them to the global_index parameter. The model internally uses global_index to extract the corresponding patches from the positional embedding grid and concatenate them to the input x_cond_patches before the first UNet block. Note that conditional multi-diffusion still requires each patch to be conditioned on the entire conditioning image cond, which is why we interpolate the conditioning image to the patch resolution and concatenate it to each individual patch. In practice it is not necessary to manually extract the patches from the latent state and the global grid, as PhysicsNeMo provides utilities to help with the patching operations, in patching. For an example of how to use these utilities, see the CorrDiff example.

# Can pass local patches to the model
# Create batch of 3 patches from `x_global` with resolution 16x16
pres = 16  # Patch resolution
p1 = x_global[0:1, :, :pres, :pres]  # Patch 1
p2 = x_global[3:4, :, pres:2*pres, pres:2*pres]  # Patch 2
p3 = x_global[1:2, :, -pres:, pres:2*pres]  # Patch 3
patches = torch.cat([p1, p2, p3], dim=0)  # Batch of 3 patches

# Note: the conditioning image needs interpolation (or other operations) to
# match the patch resolution
cond1 = torch.nn.functional.interpolate(cond[0:1], size=(pres, pres), mode="bilinear")
cond2 = torch.nn.functional.interpolate(cond[3:4], size=(pres, pres), mode="bilinear")
cond3 = torch.nn.functional.interpolate(cond[1:2], size=(pres, pres), mode="bilinear")
cond_patches = torch.cat([cond1, cond2, cond3], dim=0)

# Concatenate the patches and the conditioning image
x_cond_patches = torch.cat([patches, cond_patches], dim=1)

# Create corresponding global indices for the patches
Ny, Nx = torch.arange(res).int(), torch.arange(res).int()
grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)
idx_patch1 = grid[:, :pres, :pres]  # Global indices for patch 1
idx_patch2 = grid[:, pres:2*pres, pres:2*pres]  # Global indices for patch 2
idx_patch3 = grid[:, -pres:, pres:2*pres]  # Global indices for patch 3
global_index = torch.stack([idx_patch1, idx_patch2, idx_patch3], dim=0)

# The model internally extracts the corresponding patches from the global
# positional embedding grid and concatenates them to the input x_cond_patches
# before the first UNet block.
out = model(x_cond_patches, noise_labels, class_labels, global_index=global_index)
print(out.shape)  # Shape: (3, C_x, pres, pres), same as the patches extracted from the latent state

Lead-time aware models#

In many diffusion applications, the latent state is time-dependent, and the diffusion process should account for the time-dependence of the latent state. For instance, a forecast model could provide latent states \(\mathbf{x}(T)\) (current time), \(\mathbf{x}(T + \Delta t)\) (one time step forward), …, up to \(\mathbf{x}(T + K \Delta t)\) (K time steps forward). Such prediction horizons are called lead-times (a term adopted from the weather and climate forecasting community) and we want to apply diffusion to each of these latent states while accounting for their associated lead-time information.

PhysicsNeMo provides a specialized architecture SongUNetPosLtEmbd that implements lead-time aware models. This is an extension of the SongUNetPosEmbd class, and additionally supports lead-time information. In its forward pass, the model uses the lead_time_label parameter to internally retrieve the associated lead-time embeddings; it then conditions the diffusion process on those with a channel-wise concatenation to the latent-state before the first UNet block.

Here we show an example extending the previous ones with lead-time information. We assume that we have a batch of 3 latent states at times \(T + 2 \Delta t\) (2 time intervals forward), \(T + 0 \Delta t\) (current time), and \(T + \Delta t\) (1 time interval forward). The associated lead-time labels are [2, 0, 1]. In addition, the SongUNetPosLtEmbd model has the ability to predict probabilities for some channels of the latent state, specified by the prob_channels parameter. Here we assume that channels 1 and 3 are probability (i.e. classification) outputs, while other channels are regression outputs.

import torch
from physicsnemo.models.diffusion import SongUNetPosLtEmbd

B, C_x, res = 3, 10, 40
C_cond = 3
C_PE = 8
lead_time_steps = 3  # Maximum supported lead-time is 2 * dt
C_LT = 6  # 6 channels for each lead-time embeddings

# Create a SongUNet with a lead-time embedding grid of shape
# (lead_time_steps, C_lt_emb, res, res)
model = SongUNetPosLtEmbd(
    img_resolution=res,
    in_channels=C_x + C_cond + C_PE + C_LT,  # in_channels must include the number of channels in lead-time grid
    out_channels=C_x,
    label_dim=16,
    augment_dim=0,
    model_channels=64,
    channel_mult=[1, 2, 2],
    num_blocks=4,
    attn_resolutions=[10, 5],
    gridtype="learnable",
    N_grid_channels=C_PE,
    lead_time_channels=C_LT,
    lead_time_steps=lead_time_steps,  # Maximum supported lead-time horizon
    prob_channels=[1, 3],  # Channels 1 and 3 fromn the latent state are probability outputs
)

x = torch.randn(B, C_x, res, res)  # Latent state at times T+2*dt, T+0*dt, and T + 1*dt
cond = torch.randn(B, C_cond, res, res)
x_cond = torch.cat([x, cond], dim=1)
noise_labels = torch.randn(B)
class_labels = torch.randn(B, 16)
lead_time_label = torch.tensor([2, 0, 1])  # Lead-time labels for each sample

# The model internally extracts the lead-time embeddings corresponding to the
# lead-time labels 2, 0, 1 and concatenates them to the input x_cond before the first
# UNet block. In training mode, the model outputs logits for channels 1 and 3.
out = model(x_cond, noise_labels, class_labels, lead_time_label=lead_time_label)
print(out.shape)  # Shape: (B, C_x, res, res), same as the latent state

# If eval mode the model outputs probabilities for channels 1 and 3
model.eval()
out = model(x_cond, noise_labels, class_labels, lead_time_label=lead_time_label)

Note

The SongUNetPosLtEmbd is not an autoregressive model that performs a rollout to produce future predictions. From the point of view of the SongUNetPosLtEmbd, the lead-time information is frozen. The lead-time dependent latent state \(\mathbf{x}\) might however be produced by such an autoregressive/rollout model.

Note

The SongUNetPosLtEmbd model cannot be scaled to very long lead-time horizons (controlled by the lead_time_steps parameter). This is because the lead-time embeddings are represented by a grid of learnable parameters of shape (lead_time_steps, C_LT, res, res). For very long lead-time, the size of this grid of embeddings becomes prohibitively large.

Note

In a given input batch x, the associated lead-times might be not necessarily consecutive or in order. The do not even need to originate from the same forecast trajectory. For example, the lead-time labels might be [0, 1, 2] instead of [2, 0, 1], or even [2, 2, 1].

Application-specific Interfaces#

Application-specific interfaces are not true architectures, but rather wrappers around the model backbones or specialized architectures that provide a more user-friendly interface for specific applications. Note that not all these classes are true diffusion models, but can also be used in conjunction with diffusion models. For instance, the CorrDiff example in CorrDiff example uses the UNet class to implement a regression model.

SongUNet#

class physicsnemo.models.diffusion.song_unet.SongUNet(*args, **kwargs)[source]#

Bases: Module

This architecture is a diffusion backbone for 2D image generation. It is a reimplementation of the DDPM++ and NCSN++ architectures, which are U-Net variants with optional self-attention, embeddings, and encoder-decoder components.

This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder type, embedding type, etc., making it flexible and adaptable to different tasks and configurations.

This architecture supports conditioning on the noise level (called noise labels), as well as on additional vector-valued labels (called class labels) and (optional) vector-valued augmentation labels. The conditioning mechanism relies on addition of the conditioning embeddings in the U-Net blocks of the encoder. To condition on images, the simplest mechanism is to concatenate the image to the input before passing it to the SongUNet.

The model first applies a mapping operation to generate embeddings for all the conditioning inputs (the noise level, the class labels, and the optional augmentation labels).

Then, at each level in the U-Net encoder, a sequence of blocks is applied:

  • A first block downsamples the feature map resolution by a factor of 2 (odd resolutions are floored). This block does not change the number of channels.

  • A sequence of num_blocks U-Net blocks are applied, each with a different number of channels. These blocks do not change the feature map resolution, but they multiply the number of channels by a factor specified in channel_mult. If required, the U-Net blocks also apply self-attention at the specified resolutions.

  • At the end of the level, the feature map is cached to be used in a skip connection in the decoder.

The decoder is a mirror of the encoder, with the same number of levels and the same number of blocks per level. It multiplies the feature map resolution by a factor of 2 at each level.

Parameters:
  • img_resolution (Union[List[int, int], int]) –

    The resolution of the input/output image. Can be a single int \(H\) for square images or a list \([H, W]\) for rectangular images.

    Note: This parameter is only used as a convenience to build the network. In practice, the model can still be used with images of different resolutions. The only exception to this rule is when additive_pos_embed is True, in which case the resolution of the latent state \(\mathbf{x}\) must match img_resolution.

  • in_channels (int) – Number of channels \(C_{in}\) in the input image. May include channels from both the latent state and additional channels when conditioning on images. For an unconditional model, this should be equal to out_channels.

  • out_channels (int) – Number of channels \(C_{out}\) in the output image. Should be equal to the number of channels \(C_{\mathbf{x}}\) in the latent state.

  • label_dim (int, optional, default=0) – Dimension of the vector-valued class_labels conditioning; 0 indicates no conditioning on class labels.

  • augment_dim (int, optional, default=0) – Dimension of the vector-valued augment_labels conditioning; 0 means no conditioning on augmentation labels.

  • model_channels (int, optional, default=128) – Base multiplier for the number of channels accross the entire network.

  • channel_mult (List[int], optional, default=[1, 2, 2, 2]) – Multipliers for the number of channels at every level in the encoder and decoder. The length of channel_mult determines the number of levels in the U-Net. At level i, the number of channel in the feature map is channel_mult[i] * model_channels.

  • channel_mult_emb (int, optional, default=4) – Multiplier for the number of channels in the embedding vector. The embedding vector has model_channels * channel_mult_emb channels.

  • num_blocks (int, optional, default=4) – Number of U-Net blocks at each level.

  • attn_resolutions (List[int], optional, default=[16]) – Resolutions of the levels at which self-attention layers are applied. Note that the feature map resolution must match exactly the value provided in attn_resolutions for the self-attention layers to be applied.

  • dropout (float, optional, default=0.10) – Dropout probability applied to intermediate activations within the U-Net blocks.

  • label_dropout (float, optional, default=0.0) – Dropout probability applied to the class_labels. Typically used for classifier-free guidance.

  • embedding_type (Literal["fourier", "positional", "zero"], optional, default="positional") – Diffusion timestep embedding type: ‘positional’ for DDPM++, ‘fourier’ for NCSN++, ‘zero’ for none.

  • channel_mult_noise (int, optional, default=1) – Multiplier for the number of channels in the noise level embedding. The noise level embedding vector has model_channels * channel_mult_noise channels.

  • encoder_type (Literal["standard", "skip", "residual"], optional, default="standard") – Encoder architecture: ‘standard’ for DDPM++, ‘residual’ for NCSN++, ‘skip’ for skip connections.

  • decoder_type (Literal["standard", "skip"], optional, default="standard") – Decoder architecture: ‘standard’ or ‘skip’ for skip connections.

  • resample_filter (List[int], optional, default=[1, 1]) – Resampling filter coefficients applied in the U-Net blocks convolutions: [1,1] for DDPM++, [1,3,3,1] for NCSN++.

  • checkpoint_level (int, optional, default=0) – Number of levels that should use gradient checkpointing. Only levels at which the feature map resolution is large enough will be checkpointed (0 disables checkpointing, higher values means more layers are checkpointed). Higher values trade memory for computation.

  • additive_pos_embed (bool, optional, default=False) –

    If True, adds a learnable positional embedding after the first convolution layer. Used in StormCast model.

    Note: Those positional embeddings encode spatial position information of the image pixels, unlike the embedding_type parameter which encodes temporal information about the diffusion process. In that sense it is a simpler version of the positional embedding used in SongUNetPosEmbd.

  • use_apex_gn (bool, optional, default=False) – A flag indicating whether we want to use Apex GroupNorm for NHWC layout. Apex needs to be installed for this to work. Need to set this as False on cpu.

  • act (str, optional, default=None) – The activation function to use when fusing activation with GroupNorm. Required when use_apex_gn is True.

  • profile_mode (bool, optional, default=False) – A flag indicating whether to enable all nvtx annotations during profiling.

  • amp_mode (bool, optional, default=False) – A flag indicating whether mixed-precision (AMP) training is enabled.

Forward:
  • x (torch.Tensor) – The input image of shape \((B, C_{in}, H_{in}, W_{in})\). In general x is the channel-wise concatenation of the latent state \(\mathbf{x}\) and additional images used for conditioning. For an unconditional model, x is simply the latent state \(\mathbf{x}\).

    Note: \(H_{in}\) and \(W_{in}\) do not need to match \(H\) and \(W\) defined in img_resolution, except when additive_pos_embed is True. In that case, the resolution of x must match img_resolution.

  • noise_labels (torch.Tensor) – The noise labels of shape \((B,)\). Used for conditioning on the diffusion noise level.

  • class_labels (torch.Tensor) – The class labels of shape \((B, \text{label_dim})\). Used for conditioning on any vector-valued quantity. Can pass None when label_dim is 0.

  • augment_labels (torch.Tensor, optional, default=None) – The augmentation labels of shape \((B, \text{augment_dim})\). Used for conditioning on any additional vector-valued quantity. Can pass None when augment_dim is 0.

Outputs:

torch.Tensor – The denoised latent state of shape \((B, C_{out}, H_{in}, W_{in})\).

Important

  • The terms noise levels (or noise labels) are used to refer to the diffusion time-step, as these are conceptually equivalent.

  • The terms labels and classes originate from the original paper and EDM repository, where this architecture was used for class-conditional image generation. While these terms suggest class-based conditioning, the architecture can actually be conditioned on any vector-valued conditioning.

  • The term positional embedding used in the embedding_type parameter also comes from the original paper and EDM repository. Here, positional refers to the diffusion time-step, similar to how position is used in transformer architectures. Despite the name, these embeddings encode temporal information about the diffusion process rather than spatial position information.

  • Limitations on input image resolution: for a model that has \(N\) levels, the latent state \(\mathbf{x}\) must have resolution that is a multiple of \(2^N\) in each dimension. This is due to a limitation in the decoder that does not support shape mismatch in the residual connections from the encoder to the decoder. For images that do not match this requirement, it is recommended to interpolate your data on a grid of the required resolution beforehand.

Example

>>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2)
>>> noise_labels = torch.randn([1])
>>> class_labels = torch.randint(0, 1, (1, 1))
>>> input_image = torch.ones([1, 2, 16, 16])
>>> output_image = model(input_image, noise_labels, class_labels)
>>> output_image.shape
torch.Size([1, 2, 16, 16])
property amp_mode#

Should be set to True to enable automatic mixed precision.

property profile_mode#

Should be set to True to enable profiling.

DhariwalUNet#

class physicsnemo.models.diffusion.dhariwal_unet.DhariwalUNet(*args, **kwargs)[source]#

Bases: Module

This architecture is a diffusion backbone for 2D image generation. It reimplements the ADM architecture, a U-Net variant, with optional self-attention.

It is highly similar to the U-Net backbone defined in SongUNet, and only differs in a few aspects:

  • The embedding conditioning mechanism relies on adaptive scaling of the group normalization layers within the U-Net blocks.

  • The parameters initialization follows Kaiming uniform initialization.

Parameters:
  • img_resolution (int) –

    The resolution \(H = W\) of the input/output image. Assumes square images.

    Note: This parameter is only used as a convenience to build the network. In practice, the model can still be used with images of different resolutions.

  • in_channels (int) – Number of channels \(C_{in}\) in the input image. May include channels from both the latent state \(\mathbf{x}\) and additional channels when conditioning on images. For an unconditional model, this should be equal to out_channels.

  • out_channels (int) – Number of channels \(C_{out}\) in the output image. Should be equal to the number of channels \(C_{\mathbf{x}}\) in the latent state.

  • label_dim (int, optional, default=0) – Dimension of the vector-valued class_labels conditioning; 0 indicates no conditioning on class labels.

  • augment_dim (int, optional, default=0) – Dimension of the vector-valued augment_labels conditioning; 0 means no conditioning on augmentation labels.

  • model_channels (int, optional, default=128) – Base multiplier for the number of channels accross the entire network.

  • channel_mult (List[int], optional, default=[1,2,2,2]) – Multipliers for the number of channels at every level in the encoder and decoder. The length of channel_mult determines the number of levels in the U-Net. At level i, the number of channel in the feature map is channel_mult[i] * model_channels.

  • channel_mult_emb (int, optional, default=4) – Multiplier for the number of channels in the embedding vector. The embedding vector has model_channels * channel_mult_emb channels.

  • num_blocks (int, optional, default=3) – Number of U-Net blocks at each level.

  • attn_resolutions (List[int], optional, default=[16]) – Resolutions of the levels at which self-attention layers are applied. Note that the feature map resolution must match exactly the value provided in attn_resolutions for the self-attention layers to be applied.

  • dropout (float, optional, default=0.10) – Dropout probability applied to intermediate activations within the U-Net blocks.

  • label_dropout (float, optional, default=0.0) – Dropout probability applied to the class_labels. Typically used for classifier-free guidance.

Forward:
  • x (torch.Tensor) – The input tensor of shape \((B, C_{in}, H_{in}, W_{in})\). In general x is the channel-wise concatenation of the latent state \(\mathbf{x}\) and additional images used for conditioning. For an unconditional model, x is simply the latent state \(\mathbf{x}\).

  • noise_labels (torch.Tensor) – The noise labels of shape \((B,)\). Used for conditioning on the noise level.

  • class_labels (torch.Tensor) – The class labels of shape \((B, \text{label_dim})\). Used for conditioning on any vector-valued quantity. Can pass None when label_dim is 0.

  • augment_labels (torch.Tensor, optional, default=None) – The augmentation labels of shape \((B, \text{augment_dim})\). Used for conditioning on any additional vector-valued quantity. Can pass None when augment_dim is 0.

Outputs:

torch.Tensor – The denoised latent state of shape \((B, C_{out}, H_{in}, W_{in})\).

Examples

>>> model = DhariwalUNet(img_resolution=16, in_channels=2, out_channels=2)
>>> noise_labels = torch.randn([1])
>>> class_labels = torch.randint(0, 1, (1, 1))  # noqa: N806
>>> input_image = torch.ones([1, 2, 16, 16])  # noqa: N806
>>> output_image = model(input_image, noise_labels, class_labels)  # noqa: N806
property amp_mode#

Should be set to True to enable automatic mixed precision.

property profile_mode#

Should be set to True to enable profiling.

SongUNetPosEmbd#

class physicsnemo.models.diffusion.song_unet.SongUNetPosEmbd(*args, **kwargs)[source]#

Bases: SongUNet

This specialized architecture extends SongUNet with positional embeddings that encode global spatial coordinates of the pixels.

This model supports the same type of conditioning as the base SongUNet, and can be in addition conditioned on the positional embeddings. Conditioning on the positional embeddings is performed with a channel-wise concatenation to the input image before the first layer of the U-Net. Multiple types of positional embeddings are supported. Positional embeddings are represented by a 2D grid of shape \((C_{PE}, H, W)\), where \(H\) and \(W\) correspond to the img_resolution parameter.

The following types of positional embeddings are supported:

  • learnable: uses a 2D grid of learnable parameters.

  • linear: uses a 2D rectilinear grid over the domain \([-1, 1] \times [-1, 1]\).

  • sinusoidal: uses sinusoidal functions of the spatial coordinates, with possibly multiple frequency bands.

  • test: uses a 2D grid of integer indices, only used for testing.

When the input image spatial resolution is smaller than the global positional embeddings, it is necessary to select a subset (or patch) of the embedding grid that correspond to the spatial locations of the input image pixels. The model provides two methods for selecting the subset of positional embeddings:

  1. Using a selector function. See positional_embedding_selector() for details.

  2. Using global indices. See positional_embedding_indexing() for details.

If none of these are provided, the entire grid of positional embeddings is used and channel-wise concatenated to the input image.

Most parameters are the same as in the parent class SongUNet. Only the ones that differ are listed below.

Parameters:
  • img_resolution (Union[List[int, int], int]) – The resolution of the input/output image. Can be a single int for square images or a list \([H, W]\) for rectangular images. Used to set the resolution of the positional embedding grid. It must correspond to the spatial resolution of the global domain/image.

  • in_channels (int) –

    Number of channels \(C_{in} + C_{PE}\), where \(C_{in}\) is the number of channels in the image passed to the U-Net and \(C_{PE}\) is the number of channels in the positional embedding grid.

    Important: in comparison to the base SongUNet, this parameter should also include the number of channels in the positional embedding grid \(C_{PE}\).

  • gridtype (Literal["sinusoidal", "learnable", "linear", "test"], optional, default="sinusoidal") – Type of positional embedding to use. Controls how spatial pixels locations are encoded.

  • N_grid_channels (int, optional, default=4) – Number of channels \(C_{PE}\) in the positional embedding grid. For ‘sinusoidal’ must be 4 or multiple of 4. For ‘linear’ and ‘test’ must be 2. For ‘learnable’ can be any value. If 0, positional embedding is disabled (but lead_time_mode may still be used).

  • lead_time_mode (bool, optional, default=False) – Provided for convenience. It is recommended to use the architecture SongUNetPosLtEmbd for a lead-time aware model.

  • lead_time_channels (int, optional, default=None) – Provided for convenience. Refer to SongUNetPosLtEmbd.

  • lead_time_steps (int, optional, default=9) – Provided for convenience. Refer to SongUNetPosLtEmbd.

  • prob_channels (List[int], optional, default=[]) – Provided for convenience. Refer to SongUNetPosLtEmbd.

Forward:
  • x (torch.Tensor) – The input image of shape \((B, C_{in}, H_{in}, W_{in})\), where \(H_{in}\) and \(W_{in}\) are the spatial dimensions of the input image (does not need to be the full image). In general x is the channel-wise concatenation of the latent state \(\mathbf{x}\) and additional images used for conditioning. For an unconditional model, x is simply the latent state \(\mathbf{x}\).

    Note: \(H_{in}\) and \(W_{in}\) do not need to match the img_resolution parameter, except when additive_pos_embed is True. In all other cases, the resolution of x must be smaller than img_resolution.

  • noise_labels (torch.Tensor) – The noise labels of shape \((B,)\). Used for conditioning on the diffusion noise level.

  • class_labels (torch.Tensor) – The class labels of shape \((B, \text{label_dim})\). Used for conditioning on any vector-valued quantity. Can pass None when label_dim is 0.

  • global_index (torch.Tensor, optional, default=None) – The global indices of the positional embeddings to use. If neither global_index nor embedding_selector are provided, the entire positional embedding grid of shape \((C_{PE}, H, W)\) is used. In this case x must have the same spatial resolution as the positional embedding grid. See positional_embedding_indexing() for details.

  • embedding_selector (Callable, optional, default=None) – A function that selects the positional embeddings to use. See positional_embedding_selector() for details.

  • augment_labels (torch.Tensor, optional, default=None) – The augmentation labels of shape \((B, \text{augment_dim})\). Used for conditioning on any additional vector-valued quantity. Can pass None when augment_dim is 0.

Outputs:

torch.Tensor – The output tensor of shape \((B, C_{out}, H_{in}, W_{in})\).

Important

Unlike positional embeddings defined by embedding_type in the parent class SongUNet that encode the diffusion time-step (or noise level), the positional embeddings in this specialized architecture encode global spatial coordinates of the pixels.

Examples

>>> import torch
>>> from physicsnemo.models.diffusion.song_unet import SongUNetPosEmbd
>>> from physicsnemo.utils.patching import GridPatching2D
>>>
>>> # Model initialization - in_channels must include both original input channels (2)
>>> # and the positional embedding channels (N_grid_channels=4 by default)
>>> model = SongUNetPosEmbd(img_resolution=16, in_channels=2+4, out_channels=2)
>>> noise_labels = torch.randn([1])
>>> class_labels = torch.randint(0, 1, (1, 1))
>>> # The input has only the original 2 channels - positional embeddings are
>>> # added automatically inside the forward method
>>> input_image = torch.ones([1, 2, 16, 16])
>>> output_image = model(input_image, noise_labels, class_labels)
>>> output_image.shape
torch.Size([1, 2, 16, 16])
>>>
>>> # Using a global index to select all positional embeddings
>>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16))
>>> global_index = patching.global_index(batch_size=1)
>>> output_image = model(
...     input_image, noise_labels, class_labels,
...     global_index=global_index
... )
>>> output_image.shape
torch.Size([1, 2, 16, 16])
>>>
>>> # Using a custom embedding selector to select all positional embeddings
>>> def patch_embedding_selector(emb):
...     return patching.apply(emb[None].expand(1, -1, -1, -1))
>>> output_image = model(
...     input_image, noise_labels, class_labels,
...     embedding_selector=patch_embedding_selector
... )
>>> output_image.shape
torch.Size([1, 2, 16, 16])
positional_embedding_indexing(
x: Tensor,
global_index: Tensor | None = None,
lead_time_label: Tensor | None = None,
) Tensor[source]#

Select positional embeddings using global indices.

This method uses global indices to select specific subset of the positional embedding grid and/or the lead-time embedding grid (called patches). If no indices are provided, the entire embedding grid is returned. The positional embedding grid is returned if N_grid_channels > 0, while the lead-time embedding grid is returned if lead_time_mode == True. If both positional and lead-time embedding are enabled, both are returned (concatenated). If neither is enabled, this function should not be called; doing so will raise a ValueError.

Parameters:
  • x (torch.Tensor) – Input tensor of shape \((P \times B, C, H_{in}, W_{in})\). Only used to determine batch size \(B\) and device.

  • global_index (Optional[torch.Tensor], default=None) – Tensor of shape \((P, 2, H_{in}, W_{in})\) that correspond to the patches to extract from the positional embedding grid. \(P\) is the number of distinct patches in the input tensor x. The channel dimension should contain \(j\), \(i\) indices that should represent the indices of the pixels to extract from the embedding grid.

  • lead_time_label (Optional[torch.Tensor], default=None) – Tensor of shape \((B,)\) that corresponds to the lead-time label for each batch element. Only used if lead_time_mode is True.

Returns:

Selected embeddings with shape \((P \times B, C_{PE} [+ C_{LT}], H_{in}, W_{in})\). \(C_{PE}\) is the number of embedding channels in the positional embedding grid, and \(C_{LT}\) is the number of embedding channels in the lead-time embedding grid. If lead_time_label is provided, the lead-time embedding channels are included. If global_index is None, \(P = 1\) is assumed, and the positional embedding grid is duplicated \(B\) times and returned with shape \((B, C_{PE} [+ C_{LT}], H, W)\).

Return type:

torch.Tensor

Example

>>> # Create global indices using patching utility:
>>> from physicsnemo.utils.patching import GridPatching2D
>>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8))
>>> global_index = patching.global_index(batch_size=3)
>>> print(global_index.shape)
torch.Size([4, 2, 8, 8])

Notes

  • This method is typically used in patch-based diffusion (or multi-diffusion), where a large input image is split into multiple patches. The batch dimension of the input tensor contains the patches. Patches are processed independently by the model, and the global_index parameter is used to select the grid of positional embeddings corresponding to each patch.

  • See this method from physicsnemo.utils.patching.BasePatching2D for generating the global_index parameter: global_index().

positional_embedding_selector(
x: Tensor,
embedding_selector: Callable[[Tensor], Tensor],
lead_time_label=None,
) Tensor[source]#

Select positional embeddings using a selector function.

Similar to positional_embedding_indexing(), but instead uses a selector function to select the embeddings.

Parameters:
  • x (torch.Tensor) – Input tensor of shape \((P \times B, C, H_{in}, W_{in})\). Only used to determine the dtype.

  • embedding_selector (Callable[[torch.Tensor], torch.Tensor]) – Function that takes as input the entire embedding grid of shape \((C_{PE}, H, W)\) (or \((B, C_{LT}, H, W)\) when lead_time_label is provided) and returns selected embeddings with shape \((P \times B, C_{PE}, H_{in}, W_{in})\) (or \((P \times B, C_{LT}, H_{in}, W_{in})\) when lead_time_label is provided). Each selected embedding should correspond to the portion of the embedding grid that corresponds to the batch element in x. Typically this should be based on physicsnemo.utils.patching.BasePatching2D.apply() method to maintain consistency with patch extraction.

  • lead_time_label (Optional[torch.Tensor], default=None) – Tensor of shape \((B,)\) that corresponds to the lead-time label for each batch element. Only used if lead_time_mode is True.

Returns:

A tensor of shape \((P \times B, C_{PE} [+ C_{LT}], H_{in}, W_{in})\). \(C_{PE}\) is the number of embedding channels in the positional embedding grid, and \(C_{LT}\) is the number of embedding channels in the lead-time embedding grid. If lead_time_label is provided, the lead-time embedding channels are included.

Return type:

torch.Tensor

Notes

  • This method is typically used in patch-based diffusion (or multi-diffusion), where a large input image is split into multiple patches. The batch dimension of the input tensor contains the patches. Patches are processed independently by the model, and the embedding_selector function is used to select the grid of positional embeddings corresponding to each patch.

  • See the method apply() from physicsnemo.utils.patching.BasePatching2D for generating the embedding_selector parameter, as well as the example below.

Example

>>> # Define a selector function with a patching utility:
>>> from physicsnemo.utils.patching import GridPatching2D
>>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8))
>>> B = 4
>>> def embedding_selector(emb):
...     return patching.apply(emb.expand(B, -1, -1, -1))
>>>

SongUNetPosLtEmbd#

class physicsnemo.models.diffusion.song_unet.SongUNetPosLtEmbd(*args, **kwargs)[source]#

Bases: SongUNetPosEmbd

This specialized architecture extends SongUNetPosEmbd with two additional capabilities:

  1. The model can be conditioned on lead-time labels. These labels encode physical time information, such as a forecasting horizon.

  2. Similarly to the parent SongUNetPosEmbd, this model predicts regression targets, but it can also produce classification predictions. More precisely, some of the output channels are probability outputs, that are passed through a softmax activation function. This is useful for multi-task applications, where the objective is a combination of both regression and classification losses.

The mechanism to condition on lead-time labels is implemented by:

  • First generating a grid of learnable lead-time embeddings of shape \((\text{lead_time_steps}, C_{LT}, H, W)\). The spatial resolution of the lead-time embeddings is the same as the input/output image.

  • Then, given an input x, select the lead-time embeddings that corresponds to the lead-times associated with the samples in the input x.

  • Finally, concatenate channels-wise the selected lead-time embeddings and positional embeddings to the input x and pass them to the U-Net network.

Most parameters are similar to the parent SongUNetPosEmbd, at the exception of the ones listed below.

Parameters:
  • in_channels (int) –

    Number of channels \(C_{in} + C_{PE} + C_{LT}\) in the image passed to the U-Net.

    Important: in comparison to the base SongUNet, this parameter should also include the number of channels in the positional embedding grid \(C_{PE}\) and the number of channels in the lead-time embedding grid \(C_{LT}\).

  • lead_time_channels (int, optional, default=None) – Number of channels \(C_{LT}\) in the lead time embedding. These are learned embeddings that encode physical time information.

  • lead_time_steps (int, optional, default=9) – Number of discrete lead time steps to support. Each step gets its own learned embedding vector of shape \((C_{LT}, H, W)\).

  • prob_channels (List[int], optional, default=[]) – Indices of channels that are probability outputs (or classification predictions), In training mode, the model outputs logits for these probability channels, and in eval mode, the model applies a softmax to outputs the probabilities.

Forward:
  • x (torch.Tensor) – The input image of shape \((B, C_{in}, H_{in}, W_{in})\), where \(H_{in}\) and \(W_{in}\) are the spatial dimensions of the input image (does not need to be the full image).

  • noise_labels (torch.Tensor) – The noise labels of shape \((B,)\). Used for conditioning on the diffusion noise level.

  • class_labels (torch.Tensor) – The class labels of shape \((B, \text{label_dim})\). Used for conditioning on any vector-valued quantity. Can pass None when label_dim is 0.

  • global_index (torch.Tensor, optional, default=None) – The global indices of the positional embeddings to use. See positional_embedding_indexing() for details. If neither global_index nor embedding_selector are provided, the entire positional embedding grid is used.

  • embedding_selector (Callable, optional, default=None) – A function that selects the positional embeddings to use. See positional_embedding_selector() for details.

  • augment_labels (torch.Tensor, optional, default=None) – The augmentation labels of shape \((B, \text{augment_dim})\). Used for conditioning on any additional vector-valued quantity.

  • lead_time_label (torch.Tensor, optional, default=None) – The lead-time labels of shape \((B,)\). Used for selecting lead-time embeddings. It should contain the indices of the lead-time embeddings that correspond to the lead-time of each sample in the batch.

Outputs:

torch.Tensor – The output tensor of shape \((B, C_{out}, H_{in}, W_{in})\).

Notes

  • The lead-time embeddings differ from the diffusion time embeddings used in SongUNet class, as they do not encode diffusion time-step but physical forecast time.

Example

>>> import torch
>>> from physicsnemo.models.diffusion.song_unet import SongUNetPosLtEmbd
>>> from physicsnemo.utils.patching import GridPatching2D
>>>
>>> # Model initialization - in_channels must include original input channels (2),
>>> # positional embedding channels (N_grid_channels=4 by default) and
>>> # lead time embedding channels (4)
>>> model = SongUNetPosLtEmbd(
...     img_resolution=16, in_channels=2+4+4, out_channels=2,
...     lead_time_channels=4, lead_time_steps=9
... )
>>> noise_labels = torch.randn([1])
>>> class_labels = torch.randint(0, 1, (1, 1))
>>> # The input has only the original 2 channels - positional embeddings and
>>> # lead time embeddings are added automatically inside the forward method
>>> input_image = torch.ones([1, 2, 16, 16])
>>> lead_time_label = torch.tensor([3])
>>> output_image = model(
...     input_image, noise_labels, class_labels,
...     lead_time_label=lead_time_label
... )
>>> output_image.shape
torch.Size([1, 2, 16, 16])
>>>
>>> # Using global_index to select all the positional and lead time embeddings
>>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16))
>>> global_index = patching.global_index(batch_size=1)
>>> output_image = model(
...     input_image, noise_labels, class_labels,
...     lead_time_label=lead_time_label,
...     global_index=global_index
... )
>>> output_image.shape
torch.Size([1, 2, 16, 16])

UNet#

class physicsnemo.models.diffusion.unet.UNet(*args, **kwargs)[source]#

Bases: Module

This interface provides a U-Net wrapper for CorrDiff deterministic regression model (and other deterministic downsampling models). It supports the following architectures:

It shares the same architeture as a conditional diffusion model. It does so by concatenating a conditioning image to a zero-filled latent state, and by setting the noise level and the class labels to zero.

Parameters:
  • img_resolution (Union[int, Tuple[int, int]]) – The resolution of the input/output image. If a single int is provided, then the image is assumed to be square.

  • img_in_channels (int) – Number of channels in the input image.

  • img_out_channels (int) – Number of channels in the output image.

  • use_fp16 (bool, optional, default=False) – Execute the underlying model at FP16 precision.

  • model_type (Literal['SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd',)

  • 'DhariwalUNet'] – Class name of the underlying architecture. Must be one of the following: ‘SongUNet’, ‘SongUNetPosEmbd’, ‘SongUNetPosLtEmbd’, ‘DhariwalUNet’.

  • default='SongUNetPosEmbd' – Class name of the underlying architecture. Must be one of the following: ‘SongUNet’, ‘SongUNetPosEmbd’, ‘SongUNetPosLtEmbd’, ‘DhariwalUNet’.

  • **model_kwargs (dict) – Keyword arguments passed to the underlying architecture __init__ method.

  • call (Please refer to the documentation of these classes for details on how to)

  • directly. (and use these models)

Forward:
  • x (torch.Tensor) – The input tensor, typically zero-filled, of shape \((B, C_{in}, H_{in}, W_{in})\).

  • img_lr (torch.Tensor) – Conditioning image of shape \((B, C_{lr}, H_{in}, W_{in})\).

  • **model_kwargs (dict) – Additional keyword arguments to pass to the underlying architecture forward method.

Outputs:

torch.Tensor – Output tensor of shape \((B, C_{out}, H_{in}, W_{in})\) (same spatial dimensions as the input).

property amp_mode#

Set to True when using automatic mixed precision.

property profile_mode#

Set to True to enable profiling of the wrapped model.

round_sigma(
sigma: float | List | Tensor,
) Tensor[source]#

Convert a given sigma value(s) to a tensor representation.

Parameters:

sigma (Union[float, List, torch.Tensor]) – The sigma value(s) to convert.

Returns:

The tensor representation of the provided sigma value(s).

Return type:

torch.Tensor

property use_fp16#

Whether the model uses float16 precision.

Returns:

True if the model is in float16 mode, False otherwise.

Return type:

bool

Type:

bool