Diffusion UNets#

This page documents the UNet family of backbone architectures for diffusion models. These are built-in architectures specialized for diffusion on structured 2D domains (images, spatial fields). For other domains or architectures, see the DiT backbone, or use any model from the PhysicsNeMo model zoo or external libraries as described in Model Backbones.

All models on this page are based on the Module class.

Important

These UNet backbones do not implement the DiffusionModel protocol directly. Their forward signatures differ (e.g., forward(x, noise_labels, class_labels, augment_labels)). To use them with preconditioners, losses, and samplers, wrap them with a thin adapter. See the adapter examples below.

SongUNet — The Primary Backbone#

The SongUNet is the primary UNet backbone. It is a highly configurable multi-resolution architecture that supports both conditional and unconditional modeling.

Its latent state \(\mathbf{x}\) is a tensor of shape \((B, C, H, W)\), where \(B\) is the batch size, \(C\) is the number of channels, and \(H\) and \(W\) are the height and width. The model is always conditional on the noise level, and can additionally be conditioned on vector-valued class labels and/or images.

The model is organized into levels, whose number is determined by len(channel_mult), and each level operates at half the resolution of the previous level. Each level is composed of a sequence of UNet blocks, that optionally contain self-attention layers, as controlled by the attn_resolutions parameter.

Here we create a SongUNet with three levels, that applies self-attention at levels one and two. The model is unconditional (that is, it is not conditioned on any class labels or images, but is still conditional on the noise level, as is standard practice for diffusion models).

import torch
from physicsnemo.models.diffusion_unets import SongUNet

B, C_x, res = 3, 6, 40   # Batch size, channels, and resolution of the latent state

model = SongUNet(
    img_resolution=res,
    in_channels=C_x,
    out_channels=C_x,  # No conditioning on image: number of output channels is the same as the input channels
    label_dim=0,  # No conditioning on vector-valued class labels
    augment_dim=0,
    model_channels=64,
    channel_mult=[1, 2, 3],  # 3-levels UNet with 64, 128, and 192 channels at each level, respectively
    num_blocks=4,  # 4 UNet blocks at each level
    attn_resolutions=[20, 10],  # Attention is applied at level 1 (resolution 20x20) and level 2 (resolution 10x10)
)

x = torch.randn(B, C_x, res, res)  # Latent state
noise_labels = torch.randn(B)  # Noise level for each sample

# The feature map resolution is 40 at level 0, 20 at level 1, and 10 at level 2
out = model(x, noise_labels, None)
print(out.shape)  # Shape: (B, C_x, res, res), same as the latent state

# The same model can be used on images of different resolution
# Note: the attention is still applied at levels 1 and 2
x_32 = torch.randn(B, C_x, 32, 32)  # Lower resolution latent state
out_32 = model(x_32, noise_labels, None)  # None means no conditioning on class labels
print(out_32.shape)  # Shape: (B, C_x, 32, 32), same as the latent state

The unconditional SongUNet can be extended to be conditional on class labels and/or images. Conditioning on images is performed by channel-wise concatenation of the image to the latent state \(\mathbf{x}\) before passing it to the model. The model does not perform conditioning on images internally, and this operation is left to the user. For conditioning on class labels (or any vector-valued quantity whose dimension is label_dim), the model internally generates embeddings for the class labels and adds them to intermediate activations within the UNet blocks. Here we extend the previous example to be conditional on a 16-dimensional vector-valued class label and a 3-channel image.

import torch
from physicsnemo.models.diffusion_unets import SongUNet

B, C_x, res = 3, 10, 40
C_cond = 3

model = SongUNet(
    img_resolution=res,
    in_channels=C_x + C_cond,  # Conditioning on an image with C_cond channels
    out_channels=C_x,  # Output channels: only those of the latent state
    label_dim=16,  # Conditioning on 16-dimensional vector-valued class labels
    augment_dim=0,
    model_channels=64,
    channel_mult=[1, 2, 2],
    num_blocks=4,
    attn_resolutions=[20, 10],
)

x = torch.randn(B, C_x, res, res)  # Latent state
cond = torch.randn(B, C_cond, res, res)  # Conditioning image
x_cond = torch.cat([x, cond], dim=1)  # Channel-wise concatenation of the conditioning image before passing to the model
noise_labels = torch.randn(B)
class_labels = torch.randn(B, 16)  # Conditioning on vector-valued class labels

out = model(x_cond, noise_labels, class_labels)
print(out.shape)  # Shape: (B, C_x, res, res), same as the latent state

Using UNet Backbones with the Diffusion Framework#

Because these UNet backbones have their own forward signature, they need to be adapted to the DiffusionModel protocol before they can be used with preconditioners, losses, and samplers.

The simplest approach is a thin adapter class:

import torch
from physicsnemo.core import Module
from physicsnemo.models.diffusion_unets import SongUNet
from physicsnemo.diffusion.preconditioners import EDMPreconditioner

class SongUNetAdapter(Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.net = SongUNet(**kwargs)

    def forward(self, x, t, condition=None):
        return self.net(x, noise_labels=t, class_labels=condition)

backbone = SongUNetAdapter(
    img_resolution=64, in_channels=3, out_channels=3,
    model_channels=64, channel_mult=[1, 2, 2], num_blocks=2,
)

# The adapter satisfies DiffusionModel, so it can be used with preconditioners
precond = EDMPreconditioner(backbone, sigma_data=0.5)

DhariwalUNet#

The DhariwalUNet is an alternative UNet backbone that can be used interchangeably with SongUNet. It follows the same adapter pattern described above.

Lead-Time Aware Models#

In many diffusion applications, the latent state is time-dependent, and the diffusion process should account for the time-dependence of the latent state. For instance, a forecast model could provide latent states \(\mathbf{x}(T)\) (current time), \(\mathbf{x}(T + \Delta t)\) (one time step forward), …, up to \(\mathbf{x}(T + K \Delta t)\) (K time steps forward). Such prediction horizons are called lead-times (a term adopted from the weather and climate forecasting community) and we want to apply diffusion to each of these latent states while accounting for their associated lead-time information.

PhysicsNeMo provides a specialized architecture SongUNetPosLtEmbd that implements lead-time aware models. It extends SongUNet with learnable positional embeddings and lead-time embeddings. In its forward pass, the model uses the lead_time_label parameter to retrieve the associated lead-time embeddings and conditions the diffusion process on those with a channel-wise concatenation to the latent state before the first UNet block.

This is the recommended architecture for lead-time aware diffusion problems such as weather forecasting.

Here we show an example with lead-time information. We assume that we have a batch of three latent states at times \(T + 2 \Delta t\) (two time intervals forward), \(T + 0 \Delta t\) (current time), and \(T + \Delta t\) (one time interval forward). The associated lead-time labels are [2, 0, 1]. In addition, the SongUNetPosLtEmbd model has the ability to predict probabilities for some channels of the latent state, specified by the prob_channels parameter. Here we assume that channels one and three are probability (that is, classification) outputs, while other channels are regression outputs.

import torch
from physicsnemo.models.diffusion_unets import SongUNetPosLtEmbd

B, C_x, res = 3, 10, 40
C_cond = 3
C_PE = 8
lead_time_steps = 3  # Maximum supported lead-time is 2 * dt
C_LT = 6  # 6 channels for each lead-time embeddings

# Create a SongUNet with a lead-time embedding grid of shape
# (lead_time_steps, C_lt_emb, res, res)
model = SongUNetPosLtEmbd(
    img_resolution=res,
    in_channels=C_x + C_cond + C_PE + C_LT,  # in_channels must include the number of channels in lead-time grid
    out_channels=C_x,
    label_dim=16,
    augment_dim=0,
    model_channels=64,
    channel_mult=[1, 2, 2],
    num_blocks=4,
    attn_resolutions=[10, 5],
    gridtype="learnable",
    N_grid_channels=C_PE,
    lead_time_channels=C_LT,
    lead_time_steps=lead_time_steps,  # Maximum supported lead-time horizon
    prob_channels=[1, 3],  # Channels 1 and 3 fromn the latent state are probability outputs
)

x = torch.randn(B, C_x, res, res)  # Latent state at times T+2*dt, T+0*dt, and T + 1*dt
cond = torch.randn(B, C_cond, res, res)
x_cond = torch.cat([x, cond], dim=1)
noise_labels = torch.randn(B)
class_labels = torch.randn(B, 16)
lead_time_label = torch.tensor([2, 0, 1])  # Lead-time labels for each sample

# The model internally extracts the lead-time embeddings corresponding to the
# lead-time labels 2, 0, 1 and concatenates them to the input x_cond before the first
# UNet block. In training mode, the model outputs logits for channels 1 and 3.
out = model(x_cond, noise_labels, class_labels, lead_time_label=lead_time_label)
print(out.shape)  # Shape: (B, C_x, res, res), same as the latent state

# If eval mode the model outputs probabilities for channels 1 and 3
model.eval()
out = model(x_cond, noise_labels, class_labels, lead_time_label=lead_time_label)

Note

  • The SongUNetPosLtEmbd is not an autoregressive model that performs a rollout to produce future predictions. From the point of view of the SongUNetPosLtEmbd, the lead-time information is frozen. The lead-time dependent latent state \(\mathbf{x}\) might however be produced by such an autoregressive or rollout model.

  • The SongUNetPosLtEmbd model cannot be scaled to very long lead-time horizons (controlled by the lead_time_steps parameter). This is because the lead-time embeddings are represented by a grid of learnable parameters of shape (lead_time_steps, C_LT, res, res). For very long lead-time, the size of this grid of embeddings becomes prohibitively large.

  • In a given input batch x, the associated lead-times might be not necessarily consecutive or in order. They do not even need to originate from the same forecast trajectory. For example, the lead-time labels might be [0, 1, 2] instead of [2, 0, 1], or even [2, 2, 1].

Positional Embeddings (SongUNetPosEmbd)#

The SongUNetPosEmbd extends SongUNet with learnable positional embeddings. It was originally designed for multi-diffusion (patch-based diffusion), where each patch needs to be informed of its position in the global domain.

Note

SongUNetPosEmbd bakes multi-diffusion logic directly into the architecture. For new projects, the recommended approach is to use the multi-diffusion APIs, which decouple patching from the backbone and can be combined with any architecture (UNet, DiT, or custom). SongUNetPosEmbd remains available for backward compatibility and for use cases where integrated positional embeddings are specifically desired.

API Reference#

SongUNet#

class physicsnemo.models.diffusion_unets.SongUNet(*args, **kwargs)[source]#

Bases: Module

This architecture is a diffusion backbone for 2D image generation. It is a reimplementation of the DDPM++ and NCSN++ architectures, which are U-Net variants with optional self-attention, embeddings, and encoder-decoder components.

This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder type, embedding type, etc., making it flexible and adaptable to different tasks and configurations.

This architecture supports conditioning on the noise level (called noise labels), as well as on additional vector-valued labels (called class labels) and (optional) vector-valued augmentation labels. The conditioning mechanism relies on addition of the conditioning embeddings in the U-Net blocks of the encoder. To condition on images, the simplest mechanism is to concatenate the image to the input before passing it to the SongUNet.

The model first applies a mapping operation to generate embeddings for all the conditioning inputs (the noise level, the class labels, and the optional augmentation labels).

Then, at each level in the U-Net encoder, a sequence of blocks is applied:

  • A first block downsamples the feature map resolution by a factor of 2 (odd resolutions are floored). This block does not change the number of channels.

  • A sequence of num_blocks U-Net blocks are applied, each with a different number of channels. These blocks do not change the feature map resolution, but they multiply the number of channels by a factor specified in channel_mult. If required, the U-Net blocks also apply self-attention at the specified resolutions.

  • At the end of the level, the feature map is cached to be used in a skip connection in the decoder.

The decoder is a mirror of the encoder, with the same number of levels and the same number of blocks per level. It multiplies the feature map resolution by a factor of 2 at each level.

Parameters:
  • img_resolution (Union[List[int, int], int]) –

    The resolution of the input/output image. Can be a single int \(H\) for square images or a list \([H, W]\) for rectangular images.

    Note: This parameter is only used as a convenience to build the network. In practice, the model can still be used with images of different resolutions. The only exception to this rule is when additive_pos_embed is True, in which case the resolution of the latent state \(\mathbf{x}\) must match img_resolution.

  • in_channels (int) – Number of channels \(C_{in}\) in the input image. May include channels from both the latent state and additional channels when conditioning on images. For an unconditional model, this should be equal to out_channels.

  • out_channels (int) – Number of channels \(C_{out}\) in the output image. Should be equal to the number of channels \(C_{\mathbf{x}}\) in the latent state.

  • label_dim (int, optional, default=0) – Dimension of the vector-valued class_labels conditioning; 0 indicates no conditioning on class labels.

  • augment_dim (int, optional, default=0) – Dimension of the vector-valued augment_labels conditioning; 0 means no conditioning on augmentation labels.

  • model_channels (int, optional, default=128) – Base multiplier for the number of channels accross the entire network.

  • channel_mult (List[int], optional, default=[1, 2, 2, 2]) – Multipliers for the number of channels at every level in the encoder and decoder. The length of channel_mult determines the number of levels in the U-Net. At level i, the number of channel in the feature map is channel_mult[i] * model_channels.

  • channel_mult_emb (int, optional, default=4) – Multiplier for the number of channels in the embedding vector. The embedding vector has model_channels * channel_mult_emb channels.

  • num_blocks (int, optional, default=4) – Number of U-Net blocks at each level.

  • attn_resolutions (List[int], optional, default=[16]) – Resolutions of the levels at which self-attention layers are applied. Note that the feature map resolution must match exactly the value provided in attn_resolutions for the self-attention layers to be applied.

  • dropout (float, optional, default=0.10) – Dropout probability applied to intermediate activations within the U-Net blocks.

  • label_dropout (float, optional, default=0.0) – Dropout probability applied to the class_labels. Typically used for classifier-free guidance.

  • embedding_type (Literal["fourier", "positional", "zero"], optional, default="positional") – Diffusion timestep embedding type: ‘positional’ for DDPM++, ‘fourier’ for NCSN++, ‘zero’ for none.

  • channel_mult_noise (int, optional, default=1) – Multiplier for the number of channels in the noise level embedding. The noise level embedding vector has model_channels * channel_mult_noise channels.

  • encoder_type (Literal["standard", "skip", "residual"], optional, default="standard") – Encoder architecture: ‘standard’ for DDPM++, ‘residual’ for NCSN++, ‘skip’ for skip connections.

  • decoder_type (Literal["standard", "skip"], optional, default="standard") – Decoder architecture: ‘standard’ or ‘skip’ for skip connections.

  • resample_filter (List[int], optional, default=[1, 1]) – Resampling filter coefficients applied in the U-Net blocks convolutions: [1,1] for DDPM++, [1,3,3,1] for NCSN++.

  • checkpoint_level (int, optional, default=0) – Number of levels that should use gradient checkpointing. Only levels at which the feature map resolution is large enough will be checkpointed (0 disables checkpointing, higher values means more layers are checkpointed). Higher values trade memory for computation.

  • additive_pos_embed (bool, optional, default=False) –

    If True, adds a learnable positional embedding after the first convolution layer. Used in StormCast model.

    Note: Those positional embeddings encode spatial position information of the image pixels, unlike the embedding_type parameter which encodes temporal information about the diffusion process. In that sense it is a simpler version of the positional embedding used in SongUNetPosEmbd.

  • bottleneck_attention (bool, optional, default=True) – If True, applies self-attention at the bottleneck (innermost decoder block). Set to False to disable bottleneck attention for faster inference.

  • use_apex_gn (bool, optional, default=False) – A flag indicating whether we want to use Apex GroupNorm for NHWC layout. Apex needs to be installed for this to work. Need to set this as False on cpu.

  • act (str, optional, default=None) – The activation function to use when fusing activation with GroupNorm. Required when use_apex_gn is True.

  • profile_mode (bool, optional, default=False) – A flag indicating whether to enable all nvtx annotations during profiling.

  • amp_mode (bool, optional, default=False) – A flag indicating whether mixed-precision (AMP) training is enabled.

Forward:
  • x (torch.Tensor) – The input image of shape \((B, C_{in}, H_{in}, W_{in})\). In general x is the channel-wise concatenation of the latent state \(\mathbf{x}\) and additional images used for conditioning. For an unconditional model, x is simply the latent state \(\mathbf{x}\).

    Note: \(H_{in}\) and \(W_{in}\) do not need to match \(H\) and \(W\) defined in img_resolution, except when additive_pos_embed is True. In that case, the resolution of x must match img_resolution.

  • noise_labels (torch.Tensor) – The noise labels of shape \((B,)\). Used for conditioning on the diffusion noise level.

  • class_labels (torch.Tensor) – The class labels of shape \((B, \text{label\_dim})\). Used for conditioning on any vector-valued quantity. Can pass None when label_dim is 0.

  • augment_labels (torch.Tensor, optional, default=None) – The augmentation labels of shape \((B, \text{augment\_dim})\). Used for conditioning on any additional vector-valued quantity. Can pass None when augment_dim is 0.

Outputs:

torch.Tensor – The denoised latent state of shape \((B, C_{out}, H_{in}, W_{in})\).

Important

  • The terms noise levels (or noise labels) are used to refer to the diffusion time-step, as these are conceptually equivalent.

  • The terms labels and classes originate from the original paper and EDM repository, where this architecture was used for class-conditional image generation. While these terms suggest class-based conditioning, the architecture can actually be conditioned on any vector-valued conditioning.

  • The term positional embedding used in the embedding_type parameter also comes from the original paper and EDM repository. Here, positional refers to the diffusion time-step, similar to how position is used in transformer architectures. Despite the name, these embeddings encode temporal information about the diffusion process rather than spatial position information.

  • Limitations on input image resolution: for a model that has \(N\) levels, the latent state \(\mathbf{x}\) must have resolution that is a multiple of \(2^{N-1}\) in each dimension. This is due to a limitation in the decoder that does not support shape mismatch in the residual connections from the encoder to the decoder. For images that do not match this requirement, it is recommended to interpolate your data on a grid of the required resolution beforehand.

Example

>>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2)
>>> noise_labels = torch.randn([1])
>>> class_labels = torch.randint(0, 1, (1, 1))
>>> input_image = torch.ones([1, 2, 16, 16])
>>> output_image = model(input_image, noise_labels, class_labels)
>>> output_image.shape
torch.Size([1, 2, 16, 16])
property amp_mode#

Should be set to True to enable automatic mixed precision.

property profile_mode#

Should be set to True to enable profiling.

DhariwalUNet#

class physicsnemo.models.diffusion_unets.DhariwalUNet(*args, **kwargs)[source]#

Bases: Module

This architecture is a diffusion backbone for 2D image generation. It reimplements the ADM architecture, a U-Net variant, with optional self-attention.

It is highly similar to the U-Net backbone defined in SongUNet, and only differs in a few aspects:

  • The embedding conditioning mechanism relies on adaptive scaling of the group normalization layers within the U-Net blocks.

  • The parameters initialization follows Kaiming uniform initialization.

Parameters:
  • img_resolution (int) –

    The resolution \(H = W\) of the input/output image. Assumes square images.

    Note: This parameter is only used as a convenience to build the network. In practice, the model can still be used with images of different resolutions.

  • in_channels (int) – Number of channels \(C_{in}\) in the input image. May include channels from both the latent state \(\mathbf{x}\) and additional channels when conditioning on images. For an unconditional model, this should be equal to out_channels.

  • out_channels (int) – Number of channels \(C_{out}\) in the output image. Should be equal to the number of channels \(C_{\mathbf{x}}\) in the latent state.

  • label_dim (int, optional, default=0) – Dimension of the vector-valued class_labels conditioning; 0 indicates no conditioning on class labels.

  • augment_dim (int, optional, default=0) – Dimension of the vector-valued augment_labels conditioning; 0 means no conditioning on augmentation labels.

  • model_channels (int, optional, default=128) – Base multiplier for the number of channels accross the entire network.

  • channel_mult (List[int], optional, default=[1,2,2,2]) – Multipliers for the number of channels at every level in the encoder and decoder. The length of channel_mult determines the number of levels in the U-Net. At level i, the number of channel in the feature map is channel_mult[i] * model_channels.

  • channel_mult_emb (int, optional, default=4) – Multiplier for the number of channels in the embedding vector. The embedding vector has model_channels * channel_mult_emb channels.

  • num_blocks (int, optional, default=3) – Number of U-Net blocks at each level.

  • attn_resolutions (List[int], optional, default=[16]) – Resolutions of the levels at which self-attention layers are applied. Note that the feature map resolution must match exactly the value provided in attn_resolutions for the self-attention layers to be applied.

  • dropout (float, optional, default=0.10) – Dropout probability applied to intermediate activations within the U-Net blocks.

  • label_dropout (float, optional, default=0.0) – Dropout probability applied to the class_labels. Typically used for classifier-free guidance.

Forward:
  • x (torch.Tensor) – The input tensor of shape \((B, C_{in}, H_{in}, W_{in})\). In general x is the channel-wise concatenation of the latent state \(\mathbf{x}\) and additional images used for conditioning. For an unconditional model, x is simply the latent state \(\mathbf{x}\).

  • noise_labels (torch.Tensor) – The noise labels of shape \((B,)\). Used for conditioning on the noise level.

  • class_labels (torch.Tensor) – The class labels of shape \((B, \text{label\_dim})\). Used for conditioning on any vector-valued quantity. Can pass None when label_dim is 0.

  • augment_labels (torch.Tensor, optional, default=None) – The augmentation labels of shape \((B, \text{augment\_dim})\). Used for conditioning on any additional vector-valued quantity. Can pass None when augment_dim is 0.

Outputs:

torch.Tensor – The denoised latent state of shape \((B, C_{out}, H_{in}, W_{in})\).

Examples

>>> model = DhariwalUNet(img_resolution=16, in_channels=2, out_channels=2)
>>> noise_labels = torch.randn([1])
>>> class_labels = torch.randint(0, 1, (1, 1))  # noqa: N806
>>> input_image = torch.ones([1, 2, 16, 16])  # noqa: N806
>>> output_image = model(input_image, noise_labels, class_labels)  # noqa: N806
property amp_mode#

Should be set to True to enable automatic mixed precision.

property profile_mode#

Should be set to True to enable profiling.

SongUNetPosEmbd#

class physicsnemo.models.diffusion_unets.SongUNetPosEmbd(*args, **kwargs)[source]#

Bases: SongUNet

This specialized architecture extends SongUNet with positional embeddings that encode global spatial coordinates of the pixels.

This model supports the same type of conditioning as the base SongUNet, and can be in addition conditioned on the positional embeddings. Conditioning on the positional embeddings is performed with a channel-wise concatenation to the input image before the first layer of the U-Net. Multiple types of positional embeddings are supported. Positional embeddings are represented by a 2D grid of shape \((C_{PE}, H, W)\), where \(H\) and \(W\) correspond to the img_resolution parameter.

The following types of positional embeddings are supported:

  • learnable: uses a 2D grid of learnable parameters.

  • linear: uses a 2D rectilinear grid over the domain \([-1, 1] \times [-1, 1]\).

  • sinusoidal: uses sinusoidal functions of the spatial coordinates, with possibly multiple frequency bands.

  • test: uses a 2D grid of integer indices, only used for testing.

When the input image spatial resolution is smaller than the global positional embeddings, it is necessary to select a subset (or patch) of the embedding grid that correspond to the spatial locations of the input image pixels. The model provides two methods for selecting the subset of positional embeddings:

  1. Using a selector function. See positional_embedding_selector() for details.

  2. Using global indices. See positional_embedding_indexing() for details.

If none of these are provided, the entire grid of positional embeddings is used and channel-wise concatenated to the input image.

Most parameters are the same as in the parent class SongUNet. Only the ones that differ are listed below.

Parameters:
  • img_resolution (Union[List[int, int], int]) – The resolution of the input/output image. Can be a single int for square images or a list \([H, W]\) for rectangular images. Used to set the resolution of the positional embedding grid. It must correspond to the spatial resolution of the global domain/image.

  • in_channels (int) –

    Number of channels \(C_{in} + C_{PE}\), where \(C_{in}\) is the number of channels in the image passed to the U-Net and \(C_{PE}\) is the number of channels in the positional embedding grid.

    Important: in comparison to the base SongUNet, this parameter should also include the number of channels in the positional embedding grid \(C_{PE}\).

  • gridtype (Literal["sinusoidal", "learnable", "linear", "test"], optional, default="sinusoidal") – Type of positional embedding to use. Controls how spatial pixels locations are encoded.

  • N_grid_channels (int, optional, default=4) – Number of channels \(C_{PE}\) in the positional embedding grid. For ‘sinusoidal’ must be 4 or multiple of 4. For ‘linear’ and ‘test’ must be 2. For ‘learnable’ can be any value. If 0, positional embedding is disabled (but lead_time_mode may still be used).

  • lead_time_mode (bool, optional, default=False) – Provided for convenience. It is recommended to use the architecture SongUNetPosLtEmbd for a lead-time aware model.

  • lead_time_channels (int, optional, default=None) – Provided for convenience. Refer to SongUNetPosLtEmbd.

  • lead_time_steps (int, optional, default=9) – Provided for convenience. Refer to SongUNetPosLtEmbd.

  • prob_channels (List[int], optional, default=[]) – Provided for convenience. Refer to SongUNetPosLtEmbd.

Forward:
  • x (torch.Tensor) – The input image of shape \((B, C_{in}, H_{in}, W_{in})\), where \(H_{in}\) and \(W_{in}\) are the spatial dimensions of the input image (does not need to be the full image). In general x is the channel-wise concatenation of the latent state \(\mathbf{x}\) and additional images used for conditioning. For an unconditional model, x is simply the latent state \(\mathbf{x}\).

    Note: \(H_{in}\) and \(W_{in}\) do not need to match the img_resolution parameter, except when additive_pos_embed is True. In all other cases, the resolution of x must be smaller than img_resolution.

  • noise_labels (torch.Tensor) – The noise labels of shape \((B,)\). Used for conditioning on the diffusion noise level.

  • class_labels (torch.Tensor) – The class labels of shape \((B, \text{label\_dim})\). Used for conditioning on any vector-valued quantity. Can pass None when label_dim is 0.

  • global_index (torch.Tensor, optional, default=None) – The global indices of the positional embeddings to use. If neither global_index nor embedding_selector are provided, the entire positional embedding grid of shape \((C_{PE}, H, W)\) is used. In this case x must have the same spatial resolution as the positional embedding grid. See positional_embedding_indexing() for details.

  • embedding_selector (Callable, optional, default=None) – A function that selects the positional embeddings to use. See positional_embedding_selector() for details.

  • augment_labels (torch.Tensor, optional, default=None) – The augmentation labels of shape \((B, \text{augment\_dim})\). Used for conditioning on any additional vector-valued quantity. Can pass None when augment_dim is 0.

Outputs:

torch.Tensor – The output tensor of shape \((B, C_{out}, H_{in}, W_{in})\).

Important

Unlike positional embeddings defined by embedding_type in the parent class SongUNet that encode the diffusion time-step (or noise level), the positional embeddings in this specialized architecture encode global spatial coordinates of the pixels.

Examples

>>> import torch
>>> from physicsnemo.models.diffusion_unets import SongUNetPosEmbd
>>> from physicsnemo.diffusion.multi_diffusion import GridPatching2D
>>>
>>> # Model initialization - in_channels must include both original input channels (2)
>>> # and the positional embedding channels (N_grid_channels=4 by default)
>>> model = SongUNetPosEmbd(img_resolution=16, in_channels=2+4, out_channels=2)
>>> noise_labels = torch.randn([1])
>>> class_labels = torch.randint(0, 1, (1, 1))
>>> # The input has only the original 2 channels - positional embeddings are
>>> # added automatically inside the forward method
>>> input_image = torch.ones([1, 2, 16, 16])
>>> output_image = model(input_image, noise_labels, class_labels)
>>> output_image.shape
torch.Size([1, 2, 16, 16])
>>>
>>> # Using a global index to select all positional embeddings
>>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16))
>>> global_index = patching.global_index(batch_size=1)
>>> output_image = model(
...     input_image, noise_labels, class_labels,
...     global_index=global_index
... )
>>> output_image.shape
torch.Size([1, 2, 16, 16])
>>>
>>> # Using a custom embedding selector to select all positional embeddings
>>> def patch_embedding_selector(emb):
...     return patching.apply(emb[None].expand(1, -1, -1, -1))
>>> output_image = model(
...     input_image, noise_labels, class_labels,
...     embedding_selector=patch_embedding_selector
... )
>>> output_image.shape
torch.Size([1, 2, 16, 16])
positional_embedding_indexing(
x: Float[Tensor, 'PB C H_in W_in'],
global_index: Float[Tensor, 'P 2 H_in W_in'] | None = None,
lead_time_label: Float[Tensor, 'B'] | None = None,
) Float[Tensor, 'PB C_emb H_in W_in'][source]#

Select positional embeddings using global indices.

This method uses global indices to select specific subset of the positional embedding grid and/or the lead-time embedding grid (called patches). If no indices are provided, the entire embedding grid is returned. The positional embedding grid is returned if N_grid_channels > 0, while the lead-time embedding grid is returned if lead_time_mode == True. If both positional and lead-time embedding are enabled, both are returned (concatenated). If neither is enabled, this function should not be called; doing so will raise a ValueError.

Parameters:
  • x (torch.Tensor) – Input tensor of shape \((P \times B, C, H_{in}, W_{in})\). Only used to determine batch size \(B\) and device.

  • global_index (Optional[torch.Tensor], default=None) – Tensor of shape \((P, 2, H_{in}, W_{in})\) that correspond to the patches to extract from the positional embedding grid. \(P\) is the number of distinct patches in the input tensor x. The channel dimension should contain \(j\), \(i\) indices that should represent the indices of the pixels to extract from the embedding grid.

  • lead_time_label (Optional[torch.Tensor], default=None) – Tensor of shape \((B,)\) that corresponds to the lead-time label for each batch element. Only used if lead_time_mode is True.

Returns:

Selected embeddings with shape \((P \times B, C_{PE} [+ C_{LT}], H_{in}, W_{in})\). \(C_{PE}\) is the number of embedding channels in the positional embedding grid, and \(C_{LT}\) is the number of embedding channels in the lead-time embedding grid. If lead_time_label is provided, the lead-time embedding channels are included. If global_index is None, \(P = 1\) is assumed, and the positional embedding grid is duplicated \(B\) times and returned with shape \((B, C_{PE} [+ C_{LT}], H, W)\).

Return type:

torch.Tensor

Example

>>> # Create global indices using patching utility:
>>> from physicsnemo.diffusion.multi_diffusion import GridPatching2D
>>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8))
>>> global_index = patching.global_index(batch_size=3)
>>> print(global_index.shape)
torch.Size([4, 2, 8, 8])

Notes

  • This method is typically used in patch-based diffusion (or multi-diffusion), where a large input image is split into multiple patches. The batch dimension of the input tensor contains the patches. Patches are processed independently by the model, and the global_index parameter is used to select the grid of positional embeddings corresponding to each patch.

  • See this method from physicsnemo.diffusion.multi_diffusion.BasePatching2D for generating the global_index parameter: global_index().

positional_embedding_selector(
x: Float[Tensor, 'PB C H_in W_in'],
embedding_selector: Callable[[Tensor], Tensor],
lead_time_label: Float[Tensor, 'B'] | None = None,
) Float[Tensor, 'PB C_emb H_in W_in'][source]#

Select positional embeddings using a selector function.

Similar to positional_embedding_indexing(), but instead uses a selector function to select the embeddings.

Parameters:
  • x (torch.Tensor) – Input tensor of shape \((P \times B, C, H_{in}, W_{in})\). Only used to determine the dtype.

  • embedding_selector (Callable[[torch.Tensor], torch.Tensor]) – Function that takes as input the entire embedding grid of shape \((C_{PE}, H, W)\) (or \((B, C_{LT}, H, W)\) when lead_time_label is provided) and returns selected embeddings with shape \((P \times B, C_{PE}, H_{in}, W_{in})\) (or \((P \times B, C_{LT}, H_{in}, W_{in})\) when lead_time_label is provided). Each selected embedding should correspond to the portion of the embedding grid that corresponds to the batch element in x. Typically this should be based on physicsnemo.diffusion.multi_diffusion.BasePatching2D.apply() method to maintain consistency with patch extraction.

  • lead_time_label (Optional[torch.Tensor], default=None) – Tensor of shape \((B,)\) that corresponds to the lead-time label for each batch element. Only used if lead_time_mode is True.

Returns:

A tensor of shape \((P \times B, C_{PE} [+ C_{LT}], H_{in}, W_{in})\). \(C_{PE}\) is the number of embedding channels in the positional embedding grid, and \(C_{LT}\) is the number of embedding channels in the lead-time embedding grid. If lead_time_label is provided, the lead-time embedding channels are included.

Return type:

torch.Tensor

Notes

  • This method is typically used in patch-based diffusion (or multi-diffusion), where a large input image is split into multiple patches. The batch dimension of the input tensor contains the patches. Patches are processed independently by the model, and the embedding_selector function is used to select the grid of positional embeddings corresponding to each patch.

  • See the method apply() from physicsnemo.diffusion.multi_diffusion.BasePatching2D for generating the embedding_selector parameter, as well as the example below.

Example

>>> # Define a selector function with a patching utility:
>>> from physicsnemo.diffusion.multi_diffusion import GridPatching2D
>>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8))
>>> B = 4
>>> def embedding_selector(emb):
...     return patching.apply(emb.expand(B, -1, -1, -1))
>>>

SongUNetPosLtEmbd#

class physicsnemo.models.diffusion_unets.SongUNetPosLtEmbd(*args, **kwargs)[source]#

Bases: SongUNetPosEmbd

This specialized architecture extends SongUNetPosEmbd with two additional capabilities:

  1. The model can be conditioned on lead-time labels. These labels encode physical time information, such as a forecasting horizon.

  2. Similarly to the parent SongUNetPosEmbd, this model predicts regression targets, but it can also produce classification predictions. More precisely, some of the output channels are probability outputs, that are passed through a softmax activation function. This is useful for multi-task applications, where the objective is a combination of both regression and classification losses.

The mechanism to condition on lead-time labels is implemented by:

  • First generating a grid of learnable lead-time embeddings of shape \((\text{lead\_time\_steps}, C_{LT}, H, W)\). The spatial resolution of the lead-time embeddings is the same as the input/output image.

  • Then, given an input x, select the lead-time embeddings that corresponds to the lead-times associated with the samples in the input x.

  • Finally, concatenate channels-wise the selected lead-time embeddings and positional embeddings to the input x and pass them to the U-Net network.

Most parameters are similar to the parent SongUNetPosEmbd, at the exception of the ones listed below.

Parameters:
  • in_channels (int) –

    Number of channels \(C_{in} + C_{PE} + C_{LT}\) in the image passed to the U-Net.

    Important: in comparison to the base SongUNet, this parameter should also include the number of channels in the positional embedding grid \(C_{PE}\) and the number of channels in the lead-time embedding grid \(C_{LT}\).

  • lead_time_channels (int, optional, default=None) – Number of channels \(C_{LT}\) in the lead time embedding. These are learned embeddings that encode physical time information.

  • lead_time_steps (int, optional, default=9) – Number of discrete lead time steps to support. Each step gets its own learned embedding vector of shape \((C_{LT}, H, W)\).

  • prob_channels (List[int], optional, default=[]) – Indices of channels that are probability outputs (or classification predictions), In training mode, the model outputs logits for these probability channels, and in eval mode, the model applies a softmax to outputs the probabilities.

Forward:
  • x (torch.Tensor) – The input image of shape \((B, C_{in}, H_{in}, W_{in})\), where \(H_{in}\) and \(W_{in}\) are the spatial dimensions of the input image (does not need to be the full image).

  • noise_labels (torch.Tensor) – The noise labels of shape \((B,)\). Used for conditioning on the diffusion noise level.

  • class_labels (torch.Tensor) – The class labels of shape \((B, \text{label\_dim})\). Used for conditioning on any vector-valued quantity. Can pass None when label_dim is 0.

  • global_index (torch.Tensor, optional, default=None) – The global indices of the positional embeddings to use. See positional_embedding_indexing() for details. If neither global_index nor embedding_selector are provided, the entire positional embedding grid is used.

  • embedding_selector (Callable, optional, default=None) – A function that selects the positional embeddings to use. See positional_embedding_selector() for details.

  • augment_labels (torch.Tensor, optional, default=None) – The augmentation labels of shape \((B, \text{augment\_dim})\). Used for conditioning on any additional vector-valued quantity.

  • lead_time_label (torch.Tensor, optional, default=None) – The lead-time labels of shape \((B,)\). Used for selecting lead-time embeddings. It should contain the indices of the lead-time embeddings that correspond to the lead-time of each sample in the batch.

Outputs:

torch.Tensor – The output tensor of shape \((B, C_{out}, H_{in}, W_{in})\).

Notes

  • The lead-time embeddings differ from the diffusion time embeddings used in SongUNet class, as they do not encode diffusion time-step but physical forecast time.

Example

>>> import torch
>>> from physicsnemo.models.diffusion_unets import SongUNetPosLtEmbd
>>> from physicsnemo.diffusion.multi_diffusion import GridPatching2D
>>>
>>> # Model initialization - in_channels must include original input channels (2),
>>> # positional embedding channels (N_grid_channels=4 by default) and
>>> # lead time embedding channels (4)
>>> model = SongUNetPosLtEmbd(
...     img_resolution=16, in_channels=2+4+4, out_channels=2,
...     lead_time_channels=4, lead_time_steps=9
... )
>>> noise_labels = torch.randn([1])
>>> class_labels = torch.randint(0, 1, (1, 1))
>>> # The input has only the original 2 channels - positional embeddings and
>>> # lead time embeddings are added automatically inside the forward method
>>> input_image = torch.ones([1, 2, 16, 16])
>>> lead_time_label = torch.tensor([3])
>>> output_image = model(
...     input_image, noise_labels, class_labels,
...     lead_time_label=lead_time_label
... )
>>> output_image.shape
torch.Size([1, 2, 16, 16])
>>>
>>> # Using global_index to select all the positional and lead time embeddings
>>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16))
>>> global_index = patching.global_index(batch_size=1)
>>> output_image = model(
...     input_image, noise_labels, class_labels,
...     lead_time_label=lead_time_label,
...     global_index=global_index
... )
>>> output_image.shape
torch.Size([1, 2, 16, 16])