Diffusion Models#
PhysicsNeMo diffusion library provides three categories of models, that serve
different purposes. All models are based on the
Module class.
- Model backbones:
Those are highly configurable architectures that can be used as a building block for more complex models.
- Specialized architectures:
Those are models that usually inherit from the model backbones, with some specific additional functionalities.
- Application-specific interfaces:
These Modules are not truly architectures, but rather wrappers around the model backbones or specialized architectures. Their intent is to provide a more user-friendly interface for specific applications.
In addition of these model architectures, PhysicsNeMo provides diffusion preconditioners, which are essentially wrappers around model architectures, that rescale the inputs and outputs of diffusion models to improve their performance.
Architecture Backbones#
Diffusion model backbones are highly configurable architectures that can be used
as a building block for more complex models. Backbones support
both conditional and unconditional modeling. Currently, there are two provided
backbones: the SongUNet, as implemented in the
SongUNet class and the DhariwalUNet,
as implemented in the DhariwalUNet
class. These models were introduced in the papers Score-based generative modeling through stochastic
differential equations, Song et al. and
Diffusion models beat gans on image synthesis, Dhariwal et al..
The PhysicsNeMo implementation of these models follows closely that used in the paper
Elucidating the Design Space of Diffusion-Based Generative Models, Karras et al.. The original implementation of these
models can be found in the EDM repository.
Model backbones can be used as is, such as in in the StormCast example, but they can also be used as a base class for more complex models.
One of the most common diffusion backbones for image generation is the
SongUNet
class. Its latent state \(\mathbf{x}\) is a tensor of shape \((B, C, H, W)\),
where \(B\) is the batch size, \(C\) is the number of channels,
and \(H\) and \(W\) are the height and width of the feature map. The
model is conditional on the noise level, and can additionally be conditioned on
vector-valued class labels and/or images. The model is organized into levels,
whose number is determined by len(channel_mult), and each level operates at half the resolution of the
previous level (odd resolutions are rounded down). Each level is composed of a sequence of UNet blocks, that optionally contain
self-attention layers, as controlled by the attn_resolutions parameter. The feature map resolution
is halved at the first block of each level and then remains constant within the level.
Here we start by creating a SongUNet model with 3 levels, that applies self-attention
at levels 1 and 2. The model is unconditional, i.e. it is not conditioned on any
class labels or images (but is still conditional on the noise level, as it is
standard practice for diffusion models).
import torch
from physicsnemo.models.diffusion import SongUNet
B, C_x, res = 3, 6, 40 # Batch size, channels, and resolution of the latent state
model = SongUNet(
img_resolution=res,
in_channels=C_x,
out_channels=C_x, # No conditioning on image: number of output channels is the same as the input channels
label_dim=0, # No conditioning on vector-valued class labels
augment_dim=0,
model_channels=64,
channel_mult=[1, 2, 3], # 3-levels UNet with 64, 128, and 192 channels at each level, respectively
num_blocks=4, # 4 UNet blocks at each level
attn_resolutions=[20, 10], # Attention is applied at level 1 (resolution 20x20) and level 2 (resolution 10x10)
)
x = torch.randn(B, C_x, res, res) # Latent state
noise_labels = torch.randn(B) # Noise level for each sample
# The feature map resolution is 40 at level 0, 20 at level 1, and 10 at level 2
out = model(x, noise_labels, None)
print(out.shape) # Shape: (B, C_x, res, res), same as the latent state
# The same model can be used on images of different resolution
# Note: the attention is still applied at levels 1 and 2
x_32 = torch.randn(B, C_x, 32, 32) # Lower resolution latent state
out_32 = model(x_32, noise_labels, None) # None means no conditioning on class labels
print(out_32.shape) # Shape: (B, C_x, 32, 32), same as the latent state
The unconditional SongUNet can be extended to be conditional on class labels and/or
images. Conditioning on images is performed by channel-wise concatenation of the image
to the latent state \(\mathbf{x}\) before passing it to the model. The model does not perform
conditioning on images internally, and this operation is left to the user. For
conditioning on class labels (or any vector-valued quantity whose dimension is label_dim),
the model internally generates embeddings for the class labels
and adds them to intermediate activations within the UNet blocks. Here we
extend the previous example to be conditional on a 16-dimensional vector-valued
class label and a 3-channel image.
import torch
from physicsnemo.models.diffusion import SongUNet
B, C_x, res = 3, 10, 40
C_cond = 3
model = SongUNet(
img_resolution=res,
in_channels=C_x + C_cond, # Conditioning on an image with C_cond channels
out_channels=C_x, # Output channels: only those of the latent state
label_dim=16, # Conditioning on 16-dimensional vector-valued class labels
augment_dim=0,
model_channels=64,
channel_mult=[1, 2, 2],
num_blocks=4,
attn_resolutions=[20, 10],
)
x = torch.randn(B, C_x, res, res) # Latent state
cond = torch.randn(B, C_cond, res, res) # Conditioning image
x_cond = torch.cat([x, cond], dim=1) # Channel-wise concatenation of the conditioning image before passing to the model
noise_labels = torch.randn(B)
class_labels = torch.randn(B, 16) # Conditioning on vector-valued class labels
out = model(x_cond, noise_labels, class_labels)
print(out.shape) # Shape: (B, C_x, res, res), same as the latent state
Specialized Architectures#
Note that even though backbones can be used as is, some of the examples in
PhysicsNeMo examples use specialized architectures. These specialized architectures
typically inherit from the backbones and implement additional functionalities for specific
applications. For example the CorrDiff example
uses the specialized architectures SongUNetPosEmbd
and SongUNetPosLtEmbd to implement
the diffusion model.
Positional embeddings#
Multi-diffusion (also called patch-based diffusion) is a technique to scale
diffusion models to large domains. The idea is to split the full domain into
patches, and run a diffusion model on each patch in parallel. The generated
patches are then fused back to form the final image. This technique is
particularly useful for domains that are too large to fit into the memory of
a single GPU. The CorrDiff example
uses patch-based diffusion for weather downscaling on large domains. A key
ingredient in the implementation of patch-based diffusion is the use of a
global spatial grid, that is used to inform each patch with their respective
position in the full domain. The SongUNetPosEmbd
class implements this functionality by providing multiple methods to encode
global spatial coordinates of the pixels into a global positional embedding grid.
In addition of multi-diffusion, spatial positional embeddings have also been
observed to improve the quality of the generated images, even for diffusion models
that operate on the full domain.
The following example shows how to use the specialized architecture
SongUNetPosEmbd to implement a
multi-diffusion model. First, we create a SongUNetPosEmbd model similar to
the one in the conditional SongUnet example
with a global positional embedding grid of shape (C_pos_emb, res, res). We
show that the model can be used with the entire latent state (full domain).
import torch
from physicsnemo.models.diffusion import SongUNetPosEmbd
B, C_x, res = 3, 10, 40
C_cond = 3
C_PE = 8 # Number of channels in the positional embedding grid
# Create a SongUNet with a global positional embedding grid of shape (C_PE, res, res)
model = SongUNetPosEmbd(
img_resolution=res, # Define the resolution of the global positional embedding grid
in_channels=C_x + C_cond + C_PE, # in_channels must include the number of channels in the positional embedding grid
out_channels=C_x,
label_dim=16,
augment_dim=0,
model_channels=64,
channel_mult=[1, 2, 2],
num_blocks=4,
attn_resolutions=[20, 10],
gridtype="learnable", # Use a learnable grid of positional embeddings
N_grid_channels=C_PE # Number of channels in the positional embedding grid
)
# Can pass the entire latent state to the model
x_global = torch.randn(B, C_x, res, res) # Entire latent state
cond = torch.randn(B, C_cond, res, res) # Conditioning image
x_cond = torch.cat([x_global, cond], dim=1) # Latent state with conditioning image
noise_labels = torch.randn(B)
class_labels = torch.randn(B, 16)
# The model internally concatenates the global positional embedding grid to the
# input x_cond before the first UNet block.
# Note: global_index=None means use the entire positional embedding grid
out = model(x_cond, noise_labels, class_labels, global_index=None)
print(out.shape) # Shape: (B, C_x, res, res), same as the latent state
Now we show that the model can be used on local patches of the latent state
(multi-diffusion approach). We manually extract 3 patches from the latent
state. Patches are treated as individual samples, so they are concatenated along
the batch dimension. We also create a global grid of indices grid that
contains the indices of the pixels in the full domain, and we exctract the same
3 patches from the global grid and pass them to the global_index
parameter. The model internally uses global_index to extract the corresponding
patches from the positional embedding grid and concatenate them to the input
x_cond_patches before the first UNet block. Note that conditional
multi-diffusion still requires each patch to be conditioned on the entire
conditioning image cond, which is why we interpolate the conditioning image
to the patch resolution and concatenate it to each individual patch.
In practice it is not necessary to manually extract the patches from the latent
state and the global grid, as PhysicsNeMo provides utilities to help with the
patching operations, in patching. For an example of how
to use these utilities, see the CorrDiff example.
# Can pass local patches to the model
# Create batch of 3 patches from `x_global` with resolution 16x16
pres = 16 # Patch resolution
p1 = x_global[0:1, :, :pres, :pres] # Patch 1
p2 = x_global[3:4, :, pres:2*pres, pres:2*pres] # Patch 2
p3 = x_global[1:2, :, -pres:, pres:2*pres] # Patch 3
patches = torch.cat([p1, p2, p3], dim=0) # Batch of 3 patches
# Note: the conditioning image needs interpolation (or other operations) to
# match the patch resolution
cond1 = torch.nn.functional.interpolate(cond[0:1], size=(pres, pres), mode="bilinear")
cond2 = torch.nn.functional.interpolate(cond[3:4], size=(pres, pres), mode="bilinear")
cond3 = torch.nn.functional.interpolate(cond[1:2], size=(pres, pres), mode="bilinear")
cond_patches = torch.cat([cond1, cond2, cond3], dim=0)
# Concatenate the patches and the conditioning image
x_cond_patches = torch.cat([patches, cond_patches], dim=1)
# Create corresponding global indices for the patches
Ny, Nx = torch.arange(res).int(), torch.arange(res).int()
grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)
idx_patch1 = grid[:, :pres, :pres] # Global indices for patch 1
idx_patch2 = grid[:, pres:2*pres, pres:2*pres] # Global indices for patch 2
idx_patch3 = grid[:, -pres:, pres:2*pres] # Global indices for patch 3
global_index = torch.stack([idx_patch1, idx_patch2, idx_patch3], dim=0)
# The model internally extracts the corresponding patches from the global
# positional embedding grid and concatenates them to the input x_cond_patches
# before the first UNet block.
out = model(x_cond_patches, noise_labels, class_labels, global_index=global_index)
print(out.shape) # Shape: (3, C_x, pres, pres), same as the patches extracted from the latent state
Lead-time aware models#
In many diffusion applications, the latent state is time-dependent, and the diffusion process should account for the time-dependence of the latent state. For instance, a forecast model could provide latent states \(\mathbf{x}(T)\) (current time), \(\mathbf{x}(T + \Delta t)\) (one time step forward), …, up to \(\mathbf{x}(T + K \Delta t)\) (K time steps forward). Such prediction horizons are called lead-times (a term adopted from the weather and climate forecasting community) and we want to apply diffusion to each of these latent states while accounting for their associated lead-time information.
PhysicsNeMo provides a specialized architecture
SongUNetPosLtEmbd that implements
lead-time aware models. This is an extension of the
SongUNetPosEmbd class, and
additionally supports lead-time information. In its forward pass, the model
uses the lead_time_label parameter to internally retrieve the associated
lead-time embeddings; it then conditions the diffusion process on those with a
channel-wise concatenation to the latent-state before the first UNet block.
Here we show an example extending the previous ones with lead-time information.
We assume that we have a batch of 3 latent states at times \(T + 2 \Delta t\)
(2 time intervals forward), \(T + 0 \Delta t\) (current time),
and \(T + \Delta t\) (1 time interval forward). The associated lead-time labels are
[2, 0, 1]. In addition, the SongUNetPosLtEmbd model has the ability to
predict probabilities for some channels of the latent state, specified by the
prob_channels parameter. Here we assume that channels 1 and 3 are
probability (i.e. classification) outputs, while other channels are regression
outputs.
import torch
from physicsnemo.models.diffusion import SongUNetPosLtEmbd
B, C_x, res = 3, 10, 40
C_cond = 3
C_PE = 8
lead_time_steps = 3 # Maximum supported lead-time is 2 * dt
C_LT = 6 # 6 channels for each lead-time embeddings
# Create a SongUNet with a lead-time embedding grid of shape
# (lead_time_steps, C_lt_emb, res, res)
model = SongUNetPosLtEmbd(
img_resolution=res,
in_channels=C_x + C_cond + C_PE + C_LT, # in_channels must include the number of channels in lead-time grid
out_channels=C_x,
label_dim=16,
augment_dim=0,
model_channels=64,
channel_mult=[1, 2, 2],
num_blocks=4,
attn_resolutions=[10, 5],
gridtype="learnable",
N_grid_channels=C_PE,
lead_time_channels=C_LT,
lead_time_steps=lead_time_steps, # Maximum supported lead-time horizon
prob_channels=[1, 3], # Channels 1 and 3 fromn the latent state are probability outputs
)
x = torch.randn(B, C_x, res, res) # Latent state at times T+2*dt, T+0*dt, and T + 1*dt
cond = torch.randn(B, C_cond, res, res)
x_cond = torch.cat([x, cond], dim=1)
noise_labels = torch.randn(B)
class_labels = torch.randn(B, 16)
lead_time_label = torch.tensor([2, 0, 1]) # Lead-time labels for each sample
# The model internally extracts the lead-time embeddings corresponding to the
# lead-time labels 2, 0, 1 and concatenates them to the input x_cond before the first
# UNet block. In training mode, the model outputs logits for channels 1 and 3.
out = model(x_cond, noise_labels, class_labels, lead_time_label=lead_time_label)
print(out.shape) # Shape: (B, C_x, res, res), same as the latent state
# If eval mode the model outputs probabilities for channels 1 and 3
model.eval()
out = model(x_cond, noise_labels, class_labels, lead_time_label=lead_time_label)
Note
The SongUNetPosLtEmbd is not an autoregressive model that performs a rollout
to produce future predictions. From the point of view of the SongUNetPosLtEmbd,
the lead-time information is frozen. The lead-time dependent latent state \(\mathbf{x}\)
might however be produced by such an autoregressive/rollout model.
Note
The SongUNetPosLtEmbd model cannot be scaled to very long lead-time
horizons (controlled by the lead_time_steps parameter). This is because
the lead-time embeddings are represented by a grid of learnable parameters of
shape (lead_time_steps, C_LT, res, res). For very long lead-time, the
size of this grid of embeddings becomes prohibitively large.
Note
In a given input batch x, the associated lead-times might be not necessarily
consecutive or in order. The do not even need to originate from the same forecast
trajectory. For example, the lead-time labels might be [0, 1, 2] instead of [2, 0, 1],
or even [2, 2, 1].
Application-specific Interfaces#
Application-specific interfaces are not true architectures, but rather wrappers
around the model backbones or specialized architectures that provide a more
user-friendly interface for specific applications. Note that not all these
classes are true diffusion models, but can also be used in conjunction with
diffusion models. For instance, the CorrDiff example in
CorrDiff example uses the UNet
class to implement a regression model.
SongUNet#
- class physicsnemo.models.diffusion.song_unet.SongUNet(*args, **kwargs)[source]#
Bases:
ModuleThis architecture is a diffusion backbone for 2D image generation. It is a reimplementation of the DDPM++ and NCSN++ architectures, which are U-Net variants with optional self-attention, embeddings, and encoder-decoder components.
This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder type, embedding type, etc., making it flexible and adaptable to different tasks and configurations.
This architecture supports conditioning on the noise level (called noise labels), as well as on additional vector-valued labels (called class labels) and (optional) vector-valued augmentation labels. The conditioning mechanism relies on addition of the conditioning embeddings in the U-Net blocks of the encoder. To condition on images, the simplest mechanism is to concatenate the image to the input before passing it to the SongUNet.
The model first applies a mapping operation to generate embeddings for all the conditioning inputs (the noise level, the class labels, and the optional augmentation labels).
Then, at each level in the U-Net encoder, a sequence of blocks is applied:
A first block downsamples the feature map resolution by a factor of 2 (odd resolutions are floored). This block does not change the number of channels.
A sequence of
num_blocksU-Net blocks are applied, each with a different number of channels. These blocks do not change the feature map resolution, but they multiply the number of channels by a factor specified inchannel_mult. If required, the U-Net blocks also apply self-attention at the specified resolutions.At the end of the level, the feature map is cached to be used in a skip connection in the decoder.
The decoder is a mirror of the encoder, with the same number of levels and the same number of blocks per level. It multiplies the feature map resolution by a factor of 2 at each level.
- Parameters:
img_resolution (Union[List[int, int], int]) –
The resolution of the input/output image. Can be a single int \(H\) for square images or a list \([H, W]\) for rectangular images.
Note: This parameter is only used as a convenience to build the network. In practice, the model can still be used with images of different resolutions. The only exception to this rule is when
additive_pos_embedis True, in which case the resolution of the latent state \(\mathbf{x}\) must matchimg_resolution.in_channels (int) – Number of channels \(C_{in}\) in the input image. May include channels from both the latent state and additional channels when conditioning on images. For an unconditional model, this should be equal to
out_channels.out_channels (int) – Number of channels \(C_{out}\) in the output image. Should be equal to the number of channels \(C_{\mathbf{x}}\) in the latent state.
label_dim (int, optional, default=0) – Dimension of the vector-valued
class_labelsconditioning; 0 indicates no conditioning on class labels.augment_dim (int, optional, default=0) – Dimension of the vector-valued augment_labels conditioning; 0 means no conditioning on augmentation labels.
model_channels (int, optional, default=128) – Base multiplier for the number of channels accross the entire network.
channel_mult (List[int], optional, default=[1, 2, 2, 2]) – Multipliers for the number of channels at every level in the encoder and decoder. The length of
channel_multdetermines the number of levels in the U-Net. At leveli, the number of channel in the feature map ischannel_mult[i] * model_channels.channel_mult_emb (int, optional, default=4) – Multiplier for the number of channels in the embedding vector. The embedding vector has
model_channels * channel_mult_embchannels.num_blocks (int, optional, default=4) – Number of U-Net blocks at each level.
attn_resolutions (List[int], optional, default=[16]) – Resolutions of the levels at which self-attention layers are applied. Note that the feature map resolution must match exactly the value provided in attn_resolutions for the self-attention layers to be applied.
dropout (float, optional, default=0.10) – Dropout probability applied to intermediate activations within the U-Net blocks.
label_dropout (float, optional, default=0.0) – Dropout probability applied to the class_labels. Typically used for classifier-free guidance.
embedding_type (Literal["fourier", "positional", "zero"], optional, default="positional") – Diffusion timestep embedding type: ‘positional’ for DDPM++, ‘fourier’ for NCSN++, ‘zero’ for none.
channel_mult_noise (int, optional, default=1) – Multiplier for the number of channels in the noise level embedding. The noise level embedding vector has
model_channels * channel_mult_noisechannels.encoder_type (Literal["standard", "skip", "residual"], optional, default="standard") – Encoder architecture: ‘standard’ for DDPM++, ‘residual’ for NCSN++, ‘skip’ for skip connections.
decoder_type (Literal["standard", "skip"], optional, default="standard") – Decoder architecture: ‘standard’ or ‘skip’ for skip connections.
resample_filter (List[int], optional, default=[1, 1]) – Resampling filter coefficients applied in the U-Net blocks convolutions: [1,1] for DDPM++, [1,3,3,1] for NCSN++.
checkpoint_level (int, optional, default=0) – Number of levels that should use gradient checkpointing. Only levels at which the feature map resolution is large enough will be checkpointed (0 disables checkpointing, higher values means more layers are checkpointed). Higher values trade memory for computation.
additive_pos_embed (bool, optional, default=False) –
If
True, adds a learnable positional embedding after the first convolution layer. Used in StormCast model.Note: Those positional embeddings encode spatial position information of the image pixels, unlike the
embedding_typeparameter which encodes temporal information about the diffusion process. In that sense it is a simpler version of the positional embedding used inSongUNetPosEmbd.use_apex_gn (bool, optional, default=False) – A flag indicating whether we want to use Apex GroupNorm for NHWC layout. Apex needs to be installed for this to work. Need to set this as False on cpu.
act (str, optional, default=None) – The activation function to use when fusing activation with GroupNorm. Required when
use_apex_gnisTrue.profile_mode (bool, optional, default=False) – A flag indicating whether to enable all nvtx annotations during profiling.
amp_mode (bool, optional, default=False) – A flag indicating whether mixed-precision (AMP) training is enabled.
- Forward:
x (torch.Tensor) – The input image of shape \((B, C_{in}, H_{in}, W_{in})\). In general
xis the channel-wise concatenation of the latent state \(\mathbf{x}\) and additional images used for conditioning. For an unconditional model,xis simply the latent state \(\mathbf{x}\).Note: \(H_{in}\) and \(W_{in}\) do not need to match \(H\) and \(W\) defined in
img_resolution, except whenadditive_pos_embedisTrue. In that case, the resolution ofxmust matchimg_resolution.noise_labels (torch.Tensor) – The noise labels of shape \((B,)\). Used for conditioning on the diffusion noise level.
class_labels (torch.Tensor) – The class labels of shape \((B, \text{label_dim})\). Used for conditioning on any vector-valued quantity. Can pass
Nonewhenlabel_dimis 0.augment_labels (torch.Tensor, optional, default=None) – The augmentation labels of shape \((B, \text{augment_dim})\). Used for conditioning on any additional vector-valued quantity. Can pass
Nonewhenaugment_dimis 0.
- Outputs:
torch.Tensor – The denoised latent state of shape \((B, C_{out}, H_{in}, W_{in})\).
Important
The terms noise levels (or noise labels) are used to refer to the diffusion time-step, as these are conceptually equivalent.
The terms labels and classes originate from the original paper and EDM repository, where this architecture was used for class-conditional image generation. While these terms suggest class-based conditioning, the architecture can actually be conditioned on any vector-valued conditioning.
The term positional embedding used in the embedding_type parameter also comes from the original paper and EDM repository. Here, positional refers to the diffusion time-step, similar to how position is used in transformer architectures. Despite the name, these embeddings encode temporal information about the diffusion process rather than spatial position information.
Limitations on input image resolution: for a model that has \(N\) levels, the latent state \(\mathbf{x}\) must have resolution that is a multiple of \(2^N\) in each dimension. This is due to a limitation in the decoder that does not support shape mismatch in the residual connections from the encoder to the decoder. For images that do not match this requirement, it is recommended to interpolate your data on a grid of the required resolution beforehand.
Example
>>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels) >>> output_image.shape torch.Size([1, 2, 16, 16])
- property amp_mode#
Should be set to
Trueto enable automatic mixed precision.
- property profile_mode#
Should be set to
Trueto enable profiling.
DhariwalUNet#
- class physicsnemo.models.diffusion.dhariwal_unet.DhariwalUNet(*args, **kwargs)[source]#
Bases:
ModuleThis architecture is a diffusion backbone for 2D image generation. It reimplements the ADM architecture, a U-Net variant, with optional self-attention.
It is highly similar to the U-Net backbone defined in
SongUNet, and only differs in a few aspects:The embedding conditioning mechanism relies on adaptive scaling of the group normalization layers within the U-Net blocks.
The parameters initialization follows Kaiming uniform initialization.
- Parameters:
img_resolution (int) –
The resolution \(H = W\) of the input/output image. Assumes square images.
Note: This parameter is only used as a convenience to build the network. In practice, the model can still be used with images of different resolutions.
in_channels (int) – Number of channels \(C_{in}\) in the input image. May include channels from both the latent state \(\mathbf{x}\) and additional channels when conditioning on images. For an unconditional model, this should be equal to
out_channels.out_channels (int) – Number of channels \(C_{out}\) in the output image. Should be equal to the number of channels \(C_{\mathbf{x}}\) in the latent state.
label_dim (int, optional, default=0) – Dimension of the vector-valued
class_labelsconditioning; 0 indicates no conditioning on class labels.augment_dim (int, optional, default=0) – Dimension of the vector-valued
augment_labelsconditioning; 0 means no conditioning on augmentation labels.model_channels (int, optional, default=128) – Base multiplier for the number of channels accross the entire network.
channel_mult (List[int], optional, default=[1,2,2,2]) – Multipliers for the number of channels at every level in the encoder and decoder. The length of
channel_multdetermines the number of levels in the U-Net. At leveli, the number of channel in the feature map ischannel_mult[i] * model_channels.channel_mult_emb (int, optional, default=4) – Multiplier for the number of channels in the embedding vector. The embedding vector has
model_channels * channel_mult_embchannels.num_blocks (int, optional, default=3) – Number of U-Net blocks at each level.
attn_resolutions (List[int], optional, default=[16]) – Resolutions of the levels at which self-attention layers are applied. Note that the feature map resolution must match exactly the value provided in
attn_resolutionsfor the self-attention layers to be applied.dropout (float, optional, default=0.10) – Dropout probability applied to intermediate activations within the U-Net blocks.
label_dropout (float, optional, default=0.0) – Dropout probability applied to the
class_labels. Typically used for classifier-free guidance.
- Forward:
x (torch.Tensor) – The input tensor of shape \((B, C_{in}, H_{in}, W_{in})\). In general
xis the channel-wise concatenation of the latent state \(\mathbf{x}\) and additional images used for conditioning. For an unconditional model,xis simply the latent state \(\mathbf{x}\).noise_labels (torch.Tensor) – The noise labels of shape \((B,)\). Used for conditioning on the noise level.
class_labels (torch.Tensor) – The class labels of shape \((B, \text{label_dim})\). Used for conditioning on any vector-valued quantity. Can pass
Nonewhenlabel_dimis 0.augment_labels (torch.Tensor, optional, default=None) – The augmentation labels of shape \((B, \text{augment_dim})\). Used for conditioning on any additional vector-valued quantity. Can pass
Nonewhenaugment_dimis 0.
- Outputs:
torch.Tensor – The denoised latent state of shape \((B, C_{out}, H_{in}, W_{in})\).
Examples
>>> model = DhariwalUNet(img_resolution=16, in_channels=2, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) # noqa: N806 >>> input_image = torch.ones([1, 2, 16, 16]) # noqa: N806 >>> output_image = model(input_image, noise_labels, class_labels) # noqa: N806
- property amp_mode#
Should be set to
Trueto enable automatic mixed precision.
- property profile_mode#
Should be set to
Trueto enable profiling.
SongUNetPosEmbd#
- class physicsnemo.models.diffusion.song_unet.SongUNetPosEmbd(*args, **kwargs)[source]#
Bases:
SongUNetThis specialized architecture extends
SongUNetwith positional embeddings that encode global spatial coordinates of the pixels.This model supports the same type of conditioning as the base SongUNet, and can be in addition conditioned on the positional embeddings. Conditioning on the positional embeddings is performed with a channel-wise concatenation to the input image before the first layer of the U-Net. Multiple types of positional embeddings are supported. Positional embeddings are represented by a 2D grid of shape \((C_{PE}, H, W)\), where \(H\) and \(W\) correspond to the
img_resolutionparameter.The following types of positional embeddings are supported:
learnable: uses a 2D grid of learnable parameters.
linear: uses a 2D rectilinear grid over the domain \([-1, 1] \times [-1, 1]\).
sinusoidal: uses sinusoidal functions of the spatial coordinates, with possibly multiple frequency bands.
test: uses a 2D grid of integer indices, only used for testing.
When the input image spatial resolution is smaller than the global positional embeddings, it is necessary to select a subset (or patch) of the embedding grid that correspond to the spatial locations of the input image pixels. The model provides two methods for selecting the subset of positional embeddings:
Using a selector function. See
positional_embedding_selector()for details.Using global indices. See
positional_embedding_indexing()for details.
If none of these are provided, the entire grid of positional embeddings is used and channel-wise concatenated to the input image.
Most parameters are the same as in the parent class
SongUNet. Only the ones that differ are listed below.- Parameters:
img_resolution (Union[List[int, int], int]) – The resolution of the input/output image. Can be a single int for square images or a list \([H, W]\) for rectangular images. Used to set the resolution of the positional embedding grid. It must correspond to the spatial resolution of the global domain/image.
in_channels (int) –
Number of channels \(C_{in} + C_{PE}\), where \(C_{in}\) is the number of channels in the image passed to the U-Net and \(C_{PE}\) is the number of channels in the positional embedding grid.
Important: in comparison to the base
SongUNet, this parameter should also include the number of channels in the positional embedding grid \(C_{PE}\).gridtype (Literal["sinusoidal", "learnable", "linear", "test"], optional, default="sinusoidal") – Type of positional embedding to use. Controls how spatial pixels locations are encoded.
N_grid_channels (int, optional, default=4) – Number of channels \(C_{PE}\) in the positional embedding grid. For ‘sinusoidal’ must be 4 or multiple of 4. For ‘linear’ and ‘test’ must be 2. For ‘learnable’ can be any value. If 0, positional embedding is disabled (but
lead_time_modemay still be used).lead_time_mode (bool, optional, default=False) – Provided for convenience. It is recommended to use the architecture
SongUNetPosLtEmbdfor a lead-time aware model.lead_time_channels (int, optional, default=None) – Provided for convenience. Refer to
SongUNetPosLtEmbd.lead_time_steps (int, optional, default=9) – Provided for convenience. Refer to
SongUNetPosLtEmbd.prob_channels (List[int], optional, default=[]) – Provided for convenience. Refer to
SongUNetPosLtEmbd.
- Forward:
x (torch.Tensor) – The input image of shape \((B, C_{in}, H_{in}, W_{in})\), where \(H_{in}\) and \(W_{in}\) are the spatial dimensions of the input image (does not need to be the full image). In general
xis the channel-wise concatenation of the latent state \(\mathbf{x}\) and additional images used for conditioning. For an unconditional model,xis simply the latent state \(\mathbf{x}\).Note: \(H_{in}\) and \(W_{in}\) do not need to match the
img_resolutionparameter, except whenadditive_pos_embedisTrue. In all other cases, the resolution ofxmust be smaller thanimg_resolution.noise_labels (torch.Tensor) – The noise labels of shape \((B,)\). Used for conditioning on the diffusion noise level.
class_labels (torch.Tensor) – The class labels of shape \((B, \text{label_dim})\). Used for conditioning on any vector-valued quantity. Can pass
Nonewhenlabel_dimis 0.global_index (torch.Tensor, optional, default=None) – The global indices of the positional embeddings to use. If neither
global_indexnorembedding_selectorare provided, the entire positional embedding grid of shape \((C_{PE}, H, W)\) is used. In this casexmust have the same spatial resolution as the positional embedding grid. Seepositional_embedding_indexing()for details.embedding_selector (Callable, optional, default=None) – A function that selects the positional embeddings to use. See
positional_embedding_selector()for details.augment_labels (torch.Tensor, optional, default=None) – The augmentation labels of shape \((B, \text{augment_dim})\). Used for conditioning on any additional vector-valued quantity. Can pass
Nonewhenaugment_dimis 0.
- Outputs:
torch.Tensor – The output tensor of shape \((B, C_{out}, H_{in}, W_{in})\).
Important
Unlike positional embeddings defined by
embedding_typein the parent classSongUNetthat encode the diffusion time-step (or noise level), the positional embeddings in this specialized architecture encode global spatial coordinates of the pixels.Examples
>>> import torch >>> from physicsnemo.models.diffusion.song_unet import SongUNetPosEmbd >>> from physicsnemo.utils.patching import GridPatching2D >>> >>> # Model initialization - in_channels must include both original input channels (2) >>> # and the positional embedding channels (N_grid_channels=4 by default) >>> model = SongUNetPosEmbd(img_resolution=16, in_channels=2+4, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> # The input has only the original 2 channels - positional embeddings are >>> # added automatically inside the forward method >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels) >>> output_image.shape torch.Size([1, 2, 16, 16]) >>> >>> # Using a global index to select all positional embeddings >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) >>> global_index = patching.global_index(batch_size=1) >>> output_image = model( ... input_image, noise_labels, class_labels, ... global_index=global_index ... ) >>> output_image.shape torch.Size([1, 2, 16, 16]) >>> >>> # Using a custom embedding selector to select all positional embeddings >>> def patch_embedding_selector(emb): ... return patching.apply(emb[None].expand(1, -1, -1, -1)) >>> output_image = model( ... input_image, noise_labels, class_labels, ... embedding_selector=patch_embedding_selector ... ) >>> output_image.shape torch.Size([1, 2, 16, 16])
- positional_embedding_indexing(
- x: Tensor,
- global_index: Tensor | None = None,
- lead_time_label: Tensor | None = None,
Select positional embeddings using global indices.
This method uses global indices to select specific subset of the positional embedding grid and/or the lead-time embedding grid (called patches). If no indices are provided, the entire embedding grid is returned. The positional embedding grid is returned if
N_grid_channels > 0, while the lead-time embedding grid is returned iflead_time_mode == True. If both positional and lead-time embedding are enabled, both are returned (concatenated). If neither is enabled, this function should not be called; doing so will raise a ValueError.- Parameters:
x (torch.Tensor) – Input tensor of shape \((P \times B, C, H_{in}, W_{in})\). Only used to determine batch size \(B\) and device.
global_index (Optional[torch.Tensor], default=None) – Tensor of shape \((P, 2, H_{in}, W_{in})\) that correspond to the patches to extract from the positional embedding grid. \(P\) is the number of distinct patches in the input tensor
x. The channel dimension should contain \(j\), \(i\) indices that should represent the indices of the pixels to extract from the embedding grid.lead_time_label (Optional[torch.Tensor], default=None) – Tensor of shape \((B,)\) that corresponds to the lead-time label for each batch element. Only used if
lead_time_modeis True.
- Returns:
Selected embeddings with shape \((P \times B, C_{PE} [+ C_{LT}], H_{in}, W_{in})\). \(C_{PE}\) is the number of embedding channels in the positional embedding grid, and \(C_{LT}\) is the number of embedding channels in the lead-time embedding grid. If
lead_time_labelis provided, the lead-time embedding channels are included. Ifglobal_indexis None, \(P = 1\) is assumed, and the positional embedding grid is duplicated \(B\) times and returned with shape \((B, C_{PE} [+ C_{LT}], H, W)\).- Return type:
torch.Tensor
Example
>>> # Create global indices using patching utility: >>> from physicsnemo.utils.patching import GridPatching2D >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) >>> global_index = patching.global_index(batch_size=3) >>> print(global_index.shape) torch.Size([4, 2, 8, 8])
Notes
This method is typically used in patch-based diffusion (or multi-diffusion), where a large input image is split into multiple patches. The batch dimension of the input tensor contains the patches. Patches are processed independently by the model, and the
global_indexparameter is used to select the grid of positional embeddings corresponding to each patch.See this method from
physicsnemo.utils.patching.BasePatching2Dfor generating theglobal_indexparameter:global_index().
- positional_embedding_selector(
- x: Tensor,
- embedding_selector: Callable[[Tensor], Tensor],
- lead_time_label=None,
Select positional embeddings using a selector function.
Similar to
positional_embedding_indexing(), but instead uses a selector function to select the embeddings.- Parameters:
x (torch.Tensor) – Input tensor of shape \((P \times B, C, H_{in}, W_{in})\). Only used to determine the dtype.
embedding_selector (Callable[[torch.Tensor], torch.Tensor]) – Function that takes as input the entire embedding grid of shape \((C_{PE}, H, W)\) (or \((B, C_{LT}, H, W)\) when
lead_time_labelis provided) and returns selected embeddings with shape \((P \times B, C_{PE}, H_{in}, W_{in})\) (or \((P \times B, C_{LT}, H_{in}, W_{in})\) whenlead_time_labelis provided). Each selected embedding should correspond to the portion of the embedding grid that corresponds to the batch element inx. Typically this should be based onphysicsnemo.utils.patching.BasePatching2D.apply()method to maintain consistency with patch extraction.lead_time_label (Optional[torch.Tensor], default=None) – Tensor of shape \((B,)\) that corresponds to the lead-time label for each batch element. Only used if
lead_time_modeisTrue.
- Returns:
A tensor of shape \((P \times B, C_{PE} [+ C_{LT}], H_{in}, W_{in})\). \(C_{PE}\) is the number of embedding channels in the positional embedding grid, and \(C_{LT}\) is the number of embedding channels in the lead-time embedding grid. If
lead_time_labelis provided, the lead-time embedding channels are included.- Return type:
torch.Tensor
Notes
This method is typically used in patch-based diffusion (or multi-diffusion), where a large input image is split into multiple patches. The batch dimension of the input tensor contains the patches. Patches are processed independently by the model, and the
embedding_selectorfunction is used to select the grid of positional embeddings corresponding to each patch.See the method
apply()fromphysicsnemo.utils.patching.BasePatching2Dfor generating theembedding_selectorparameter, as well as the example below.
Example
>>> # Define a selector function with a patching utility: >>> from physicsnemo.utils.patching import GridPatching2D >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) >>> B = 4 >>> def embedding_selector(emb): ... return patching.apply(emb.expand(B, -1, -1, -1)) >>>
SongUNetPosLtEmbd#
- class physicsnemo.models.diffusion.song_unet.SongUNetPosLtEmbd(*args, **kwargs)[source]#
Bases:
SongUNetPosEmbdThis specialized architecture extends
SongUNetPosEmbdwith two additional capabilities:The model can be conditioned on lead-time labels. These labels encode physical time information, such as a forecasting horizon.
Similarly to the parent
SongUNetPosEmbd, this model predicts regression targets, but it can also produce classification predictions. More precisely, some of the output channels are probability outputs, that are passed through a softmax activation function. This is useful for multi-task applications, where the objective is a combination of both regression and classification losses.
The mechanism to condition on lead-time labels is implemented by:
First generating a grid of learnable lead-time embeddings of shape \((\text{lead_time_steps}, C_{LT}, H, W)\). The spatial resolution of the lead-time embeddings is the same as the input/output image.
Then, given an input
x, select the lead-time embeddings that corresponds to the lead-times associated with the samples in the inputx.Finally, concatenate channels-wise the selected lead-time embeddings and positional embeddings to the input
xand pass them to the U-Net network.
Most parameters are similar to the parent
SongUNetPosEmbd, at the exception of the ones listed below.- Parameters:
in_channels (int) –
Number of channels \(C_{in} + C_{PE} + C_{LT}\) in the image passed to the U-Net.
Important: in comparison to the base
SongUNet, this parameter should also include the number of channels in the positional embedding grid \(C_{PE}\) and the number of channels in the lead-time embedding grid \(C_{LT}\).lead_time_channels (int, optional, default=None) – Number of channels \(C_{LT}\) in the lead time embedding. These are learned embeddings that encode physical time information.
lead_time_steps (int, optional, default=9) – Number of discrete lead time steps to support. Each step gets its own learned embedding vector of shape \((C_{LT}, H, W)\).
prob_channels (List[int], optional, default=[]) – Indices of channels that are probability outputs (or classification predictions), In training mode, the model outputs logits for these probability channels, and in eval mode, the model applies a softmax to outputs the probabilities.
- Forward:
x (torch.Tensor) – The input image of shape \((B, C_{in}, H_{in}, W_{in})\), where \(H_{in}\) and \(W_{in}\) are the spatial dimensions of the input image (does not need to be the full image).
noise_labels (torch.Tensor) – The noise labels of shape \((B,)\). Used for conditioning on the diffusion noise level.
class_labels (torch.Tensor) – The class labels of shape \((B, \text{label_dim})\). Used for conditioning on any vector-valued quantity. Can pass
Nonewhenlabel_dimis 0.global_index (torch.Tensor, optional, default=None) – The global indices of the positional embeddings to use. See
positional_embedding_indexing()for details. If neitherglobal_indexnorembedding_selectorare provided, the entire positional embedding grid is used.embedding_selector (Callable, optional, default=None) – A function that selects the positional embeddings to use. See
positional_embedding_selector()for details.augment_labels (torch.Tensor, optional, default=None) – The augmentation labels of shape \((B, \text{augment_dim})\). Used for conditioning on any additional vector-valued quantity.
lead_time_label (torch.Tensor, optional, default=None) – The lead-time labels of shape \((B,)\). Used for selecting lead-time embeddings. It should contain the indices of the lead-time embeddings that correspond to the lead-time of each sample in the batch.
- Outputs:
torch.Tensor – The output tensor of shape \((B, C_{out}, H_{in}, W_{in})\).
Notes
The lead-time embeddings differ from the diffusion time embeddings used in
SongUNetclass, as they do not encode diffusion time-step but physical forecast time.
Example
>>> import torch >>> from physicsnemo.models.diffusion.song_unet import SongUNetPosLtEmbd >>> from physicsnemo.utils.patching import GridPatching2D >>> >>> # Model initialization - in_channels must include original input channels (2), >>> # positional embedding channels (N_grid_channels=4 by default) and >>> # lead time embedding channels (4) >>> model = SongUNetPosLtEmbd( ... img_resolution=16, in_channels=2+4+4, out_channels=2, ... lead_time_channels=4, lead_time_steps=9 ... ) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> # The input has only the original 2 channels - positional embeddings and >>> # lead time embeddings are added automatically inside the forward method >>> input_image = torch.ones([1, 2, 16, 16]) >>> lead_time_label = torch.tensor([3]) >>> output_image = model( ... input_image, noise_labels, class_labels, ... lead_time_label=lead_time_label ... ) >>> output_image.shape torch.Size([1, 2, 16, 16]) >>> >>> # Using global_index to select all the positional and lead time embeddings >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) >>> global_index = patching.global_index(batch_size=1) >>> output_image = model( ... input_image, noise_labels, class_labels, ... lead_time_label=lead_time_label, ... global_index=global_index ... ) >>> output_image.shape torch.Size([1, 2, 16, 16])
UNet#
- class physicsnemo.models.diffusion.unet.UNet(*args, **kwargs)[source]#
Bases:
ModuleThis interface provides a U-Net wrapper for CorrDiff deterministic regression model (and other deterministic downsampling models). It supports the following architectures:
It shares the same architeture as a conditional diffusion model. It does so by concatenating a conditioning image to a zero-filled latent state, and by setting the noise level and the class labels to zero.
- Parameters:
img_resolution (Union[int, Tuple[int, int]]) – The resolution of the input/output image. If a single int is provided, then the image is assumed to be square.
img_in_channels (int) – Number of channels in the input image.
img_out_channels (int) – Number of channels in the output image.
use_fp16 (bool, optional, default=False) – Execute the underlying model at FP16 precision.
model_type (Literal['SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd',)
'DhariwalUNet'] – Class name of the underlying architecture. Must be one of the following: ‘SongUNet’, ‘SongUNetPosEmbd’, ‘SongUNetPosLtEmbd’, ‘DhariwalUNet’.
default='SongUNetPosEmbd' – Class name of the underlying architecture. Must be one of the following: ‘SongUNet’, ‘SongUNetPosEmbd’, ‘SongUNetPosLtEmbd’, ‘DhariwalUNet’.
**model_kwargs (dict) – Keyword arguments passed to the underlying architecture __init__ method.
call (Please refer to the documentation of these classes for details on how to)
directly. (and use these models)
- Forward:
x (torch.Tensor) – The input tensor, typically zero-filled, of shape \((B, C_{in}, H_{in}, W_{in})\).
img_lr (torch.Tensor) – Conditioning image of shape \((B, C_{lr}, H_{in}, W_{in})\).
**model_kwargs (dict) – Additional keyword arguments to pass to the underlying architecture forward method.
- Outputs:
torch.Tensor – Output tensor of shape \((B, C_{out}, H_{in}, W_{in})\) (same spatial dimensions as the input).
- property amp_mode#
Set to
Truewhen using automatic mixed precision.
- property profile_mode#
Set to
Trueto enable profiling of the wrapped model.
- round_sigma(
- sigma: float | List | Tensor,
Convert a given sigma value(s) to a tensor representation.
- Parameters:
sigma (Union[float, List, torch.Tensor]) – The sigma value(s) to convert.
- Returns:
The tensor representation of the provided sigma value(s).
- Return type:
torch.Tensor
- property use_fp16#
Whether the model uses float16 precision.
- Returns:
True if the model is in float16 mode, False otherwise.
- Return type:
bool
- Type:
bool