Source code for physicsnemo.models.diffusion_unets.song_unet

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import math
from dataclasses import dataclass
from typing import Callable, List, Literal, Optional, Set, Union

import numpy as np
import nvtx
import torch
from torch.nn.functional import silu
from torch.utils.checkpoint import checkpoint

from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module
from physicsnemo.nn import (
    Conv2d,
    FourierEmbedding,
    Linear,
    PositionalEmbedding,
    UNetBlock,
    get_group_norm,
)

from ._utils import _recursive_property

# ------------------------------------------------------------------------------
# Backbone architectures
# ------------------------------------------------------------------------------


@dataclass
class MetaData(ModelMetaData):
    # Optimization
    jit: bool = False
    cuda_graphs: bool = False
    amp_cpu: bool = False
    amp_gpu: bool = True
    torch_fx: bool = False
    # Data type
    bf16: bool = True
    # Inference
    onnx: bool = False
    # Physics informed
    func_torch: bool = False
    auto_grad: bool = False


[docs] class SongUNet(Module): r""" This architecture is a diffusion backbone for 2D image generation. It is a reimplementation of the `DDPM++ <https://proceedings.mlr.press/v139/nichol21a.html>`_ and `NCSN++ <https://arxiv.org/abs/2011.13456>`_ architectures, which are U-Net variants with optional self-attention, embeddings, and encoder-decoder components. This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder type, embedding type, etc., making it flexible and adaptable to different tasks and configurations. This architecture supports conditioning on the noise level (called *noise labels*), as well as on additional vector-valued labels (called *class labels*) and (optional) vector-valued augmentation labels. The conditioning mechanism relies on addition of the conditioning embeddings in the U-Net blocks of the encoder. To condition on images, the simplest mechanism is to concatenate the image to the input before passing it to the SongUNet. The model first applies a mapping operation to generate embeddings for all the conditioning inputs (the noise level, the class labels, and the optional augmentation labels). Then, at each level in the U-Net encoder, a sequence of blocks is applied: • A first block downsamples the feature map resolution by a factor of 2 (odd resolutions are floored). This block does not change the number of channels. • A sequence of ``num_blocks`` U-Net blocks are applied, each with a different number of channels. These blocks do not change the feature map resolution, but they multiply the number of channels by a factor specified in ``channel_mult``. If required, the U-Net blocks also apply self-attention at the specified resolutions. • At the end of the level, the feature map is cached to be used in a skip connection in the decoder. The decoder is a mirror of the encoder, with the same number of levels and the same number of blocks per level. It multiplies the feature map resolution by a factor of 2 at each level. Parameters ----------- img_resolution : Union[List[int, int], int] The resolution of the input/output image. Can be a single int :math:`H` for square images or a list :math:`[H, W]` for rectangular images. *Note:* This parameter is only used as a convenience to build the network. In practice, the model can still be used with images of different resolutions. The only exception to this rule is when ``additive_pos_embed`` is True, in which case the resolution of the latent state :math:`\mathbf{x}` must match ``img_resolution``. in_channels : int Number of channels :math:`C_{in}` in the input image. May include channels from both the latent state and additional channels when conditioning on images. For an unconditional model, this should be equal to ``out_channels``. out_channels : int Number of channels :math:`C_{out}` in the output image. Should be equal to the number of channels :math:`C_{\mathbf{x}}` in the latent state. label_dim : int, optional, default=0 Dimension of the vector-valued ``class_labels`` conditioning; 0 indicates no conditioning on class labels. augment_dim : int, optional, default=0 Dimension of the vector-valued `augment_labels` conditioning; 0 means no conditioning on augmentation labels. model_channels : int, optional, default=128 Base multiplier for the number of channels accross the entire network. channel_mult : List[int], optional, default=[1, 2, 2, 2] Multipliers for the number of channels at every level in the encoder and decoder. The length of ``channel_mult`` determines the number of levels in the U-Net. At level ``i``, the number of channel in the feature map is ``channel_mult[i] * model_channels``. channel_mult_emb : int, optional, default=4 Multiplier for the number of channels in the embedding vector. The embedding vector has ``model_channels * channel_mult_emb`` channels. num_blocks : int, optional, default=4 Number of U-Net blocks at each level. attn_resolutions : List[int], optional, default=[16] Resolutions of the levels at which self-attention layers are applied. Note that the feature map resolution must match exactly the value provided in `attn_resolutions` for the self-attention layers to be applied. dropout : float, optional, default=0.10 Dropout probability applied to intermediate activations within the U-Net blocks. label_dropout : float, optional, default=0.0 Dropout probability applied to the `class_labels`. Typically used for classifier-free guidance. embedding_type : Literal["fourier", "positional", "zero"], optional, default="positional" Diffusion timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++, 'zero' for none. channel_mult_noise : int, optional, default=1 Multiplier for the number of channels in the noise level embedding. The noise level embedding vector has ``model_channels * channel_mult_noise`` channels. encoder_type : Literal["standard", "skip", "residual"], optional, default="standard" Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++, 'skip' for skip connections. decoder_type : Literal["standard", "skip"], optional, default="standard" Decoder architecture: 'standard' or 'skip' for skip connections. resample_filter : List[int], optional, default=[1, 1] Resampling filter coefficients applied in the U-Net blocks convolutions: [1,1] for DDPM++, [1,3,3,1] for NCSN++. checkpoint_level : int, optional, default=0 Number of levels that should use gradient checkpointing. Only levels at which the feature map resolution is large enough will be checkpointed (0 disables checkpointing, higher values means more layers are checkpointed). Higher values trade memory for computation. additive_pos_embed : bool, optional, default=False If ``True``, adds a learnable positional embedding after the first convolution layer. Used in StormCast model. *Note:* Those positional embeddings encode spatial position information of the image pixels, unlike the ``embedding_type`` parameter which encodes temporal information about the diffusion process. In that sense it is a simpler version of the positional embedding used in :class:`~physicsnemo.models.diffusion_unets.SongUNetPosEmbd`. bottleneck_attention : bool, optional, default=True If ``True``, applies self-attention at the bottleneck (innermost decoder block). Set to ``False`` to disable bottleneck attention for faster inference. use_apex_gn : bool, optional, default=False A flag indicating whether we want to use Apex GroupNorm for NHWC layout. Apex needs to be installed for this to work. Need to set this as False on cpu. act : str, optional, default=None The activation function to use when fusing activation with GroupNorm. Required when ``use_apex_gn`` is ``True``. profile_mode : bool, optional, default=False A flag indicating whether to enable all nvtx annotations during profiling. amp_mode : bool, optional, default=False A flag indicating whether mixed-precision (AMP) training is enabled. Forward ------- x : torch.Tensor The input image of shape :math:`(B, C_{in}, H_{in}, W_{in})`. In general ``x`` is the channel-wise concatenation of the latent state :math:`\mathbf{x}` and additional images used for conditioning. For an unconditional model, ``x`` is simply the latent state :math:`\mathbf{x}`. *Note:* :math:`H_{in}` and :math:`W_{in}` do not need to match :math:`H` and :math:`W` defined in ``img_resolution``, except when ``additive_pos_embed`` is ``True``. In that case, the resolution of ``x`` must match ``img_resolution``. noise_labels : torch.Tensor The noise labels of shape :math:`(B,)`. Used for conditioning on the diffusion noise level. class_labels : torch.Tensor The class labels of shape :math:`(B, \text{label_dim})`. Used for conditioning on any vector-valued quantity. Can pass ``None`` when ``label_dim`` is 0. augment_labels : torch.Tensor, optional, default=None The augmentation labels of shape :math:`(B, \text{augment_dim})`. Used for conditioning on any additional vector-valued quantity. Can pass ``None`` when ``augment_dim`` is 0. Outputs ------- torch.Tensor The denoised latent state of shape :math:`(B, C_{out}, H_{in}, W_{in})`. .. important:: • The terms *noise levels* (or *noise labels*) are used to refer to the diffusion time-step, as these are conceptually equivalent. • The terms *labels* and *classes* originate from the original paper and EDM repository, where this architecture was used for class-conditional image generation. While these terms suggest class-based conditioning, the architecture can actually be conditioned on any vector-valued conditioning. • The term *positional embedding* used in the `embedding_type` parameter also comes from the original paper and EDM repository. Here, *positional* refers to the diffusion time-step, similar to how position is used in transformer architectures. Despite the name, these embeddings encode temporal information about the diffusion process rather than spatial position information. • Limitations on input image resolution: for a model that has :math:`N` levels, the latent state :math:`\mathbf{x}` must have resolution that is a multiple of :math:`2^{N-1}` in each dimension. This is due to a limitation in the decoder that does not support shape mismatch in the residual connections from the encoder to the decoder. For images that do not match this requirement, it is recommended to interpolate your data on a grid of the required resolution beforehand. Example -------- >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels) >>> output_image.shape torch.Size([1, 2, 16, 16]) """ # Arguments of the __init__ method that can be overridden with the # ``Module.from_checkpoint`` method. _overridable_args: Set[str] = {"use_apex_gn", "act"} def __init__( self, img_resolution: Union[List[int], int], in_channels: int, out_channels: int, label_dim: int = 0, augment_dim: int = 0, model_channels: int = 128, channel_mult: List[int] = [1, 2, 2, 2], channel_mult_emb: int = 4, num_blocks: int = 4, attn_resolutions: List[int] = [16], dropout: float = 0.10, label_dropout: float = 0.0, embedding_type: Literal["fourier", "positional", "zero"] = "positional", channel_mult_noise: int = 1, encoder_type: Literal["standard", "skip", "residual"] = "standard", decoder_type: Literal["standard", "skip"] = "standard", resample_filter: List[int] = [1, 1], checkpoint_level: int = 0, additive_pos_embed: bool = False, bottleneck_attention: bool = True, use_apex_gn: bool = False, act: str = "silu", profile_mode: bool = False, amp_mode: bool = False, ): valid_embedding_types = ["fourier", "positional", "zero"] if embedding_type not in valid_embedding_types: raise ValueError( f"Invalid embedding_type: {embedding_type}. Must be one of {valid_embedding_types}." ) valid_encoder_types = ["standard", "skip", "residual"] if encoder_type not in valid_encoder_types: raise ValueError( f"Invalid encoder_type: {encoder_type}. Must be one of {valid_encoder_types}." ) valid_decoder_types = ["standard", "skip"] if decoder_type not in valid_decoder_types: raise ValueError( f"Invalid decoder_type: {decoder_type}. Must be one of {valid_decoder_types}." ) super().__init__(meta=MetaData()) self.label_dropout = label_dropout self.embedding_type = embedding_type emb_channels = model_channels * channel_mult_emb self.emb_channels = emb_channels noise_channels = model_channels * channel_mult_noise init = dict(init_mode="xavier_uniform") init_zero = dict(init_mode="xavier_uniform", init_weight=1e-5) init_attn = dict(init_mode="xavier_uniform", init_weight=np.sqrt(0.2)) block_kwargs = dict( emb_channels=emb_channels, num_heads=1, dropout=dropout, skip_scale=0.7071067811865476, # 1 / sqrt(2) eps=1e-6, resample_filter=resample_filter, resample_proj=True, adaptive_scale=False, init=init, init_zero=init_zero, init_attn=init_attn, use_apex_gn=use_apex_gn, act=act, fused_conv_bias=True, profile_mode=profile_mode, amp_mode=amp_mode, ) self.use_apex_gn = use_apex_gn # for compatibility with older versions that took only 1 dimension self.img_resolution = img_resolution if isinstance(img_resolution, int): self.img_shape_y = self.img_shape_x = img_resolution else: self.img_shape_y = img_resolution[0] self.img_shape_x = img_resolution[1] self._num_levels = len(channel_mult) self._input_shape_mult = 2 ** (self._num_levels - 1) # set the threshold for checkpointing based on image resolution self.checkpoint_threshold = ( math.floor(math.sqrt(self.img_shape_x * self.img_shape_y)) >> checkpoint_level ) + 1 # Optional additive learned positition embed after the first conv self.additive_pos_embed = additive_pos_embed if self.additive_pos_embed: self.spatial_emb = torch.nn.Parameter( torch.randn(1, model_channels, self.img_shape_y, self.img_shape_x) ) torch.nn.init.trunc_normal_(self.spatial_emb, std=0.02) # Mapping. if self.embedding_type != "zero": self.map_noise = ( PositionalEmbedding( num_channels=noise_channels, endpoint=True, amp_mode=amp_mode ) if embedding_type == "positional" else FourierEmbedding(num_channels=noise_channels, amp_mode=amp_mode) ) self.map_label = ( Linear( in_features=label_dim, out_features=noise_channels, amp_mode=amp_mode, **init, ) if label_dim else None ) self.map_augment = ( Linear( in_features=augment_dim, out_features=noise_channels, bias=False, amp_mode=amp_mode, **init, ) if augment_dim else None ) self.map_layer0 = Linear( in_features=noise_channels, out_features=emb_channels, amp_mode=amp_mode, **init, ) self.map_layer1 = Linear( in_features=emb_channels, out_features=emb_channels, amp_mode=amp_mode, **init, ) # Encoder. self.enc = torch.nn.ModuleDict() cout = in_channels caux = in_channels for level, mult in enumerate(channel_mult): res = self.img_shape_y >> level if level == 0: cin = cout cout = model_channels self.enc[f"{res}x{res}_conv"] = Conv2d( in_channels=cin, out_channels=cout, kernel=3, fused_conv_bias=True, amp_mode=amp_mode, **init, ) else: self.enc[f"{res}x{res}_down"] = UNetBlock( in_channels=cout, out_channels=cout, down=True, **block_kwargs ) if encoder_type == "skip": self.enc[f"{res}x{res}_aux_down"] = Conv2d( in_channels=caux, out_channels=caux, kernel=0, down=True, resample_filter=resample_filter, amp_mode=amp_mode, ) self.enc[f"{res}x{res}_aux_skip"] = Conv2d( in_channels=caux, out_channels=cout, kernel=1, fused_conv_bias=True, amp_mode=amp_mode, **init, ) if encoder_type == "residual": self.enc[f"{res}x{res}_aux_residual"] = Conv2d( in_channels=caux, out_channels=cout, kernel=3, down=True, resample_filter=resample_filter, fused_resample=True, fused_conv_bias=True, amp_mode=amp_mode, **init, ) caux = cout for idx in range(num_blocks): cin = cout cout = model_channels * mult attn = res in attn_resolutions self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( in_channels=cin, out_channels=cout, attention=attn, **block_kwargs ) skips = [ block.out_channels for name, block in self.enc.items() if "aux" not in name ] # Decoder. self.dec = torch.nn.ModuleDict() for level, mult in reversed(list(enumerate(channel_mult))): res = self.img_shape_y >> level if level == len(channel_mult) - 1: self.dec[f"{res}x{res}_in0"] = UNetBlock( in_channels=cout, out_channels=cout, attention=bottleneck_attention, **block_kwargs, ) self.dec[f"{res}x{res}_in1"] = UNetBlock( in_channels=cout, out_channels=cout, **block_kwargs ) else: self.dec[f"{res}x{res}_up"] = UNetBlock( in_channels=cout, out_channels=cout, up=True, **block_kwargs ) for idx in range(num_blocks + 1): cin = cout + skips.pop() cout = model_channels * mult attn = idx == num_blocks and res in attn_resolutions self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( in_channels=cin, out_channels=cout, attention=attn, **block_kwargs ) if decoder_type == "skip" or level == 0: if decoder_type == "skip" and level < len(channel_mult) - 1: self.dec[f"{res}x{res}_aux_up"] = Conv2d( in_channels=out_channels, out_channels=out_channels, kernel=0, up=True, resample_filter=resample_filter, amp_mode=amp_mode, ) self.dec[f"{res}x{res}_aux_norm"] = get_group_norm( num_channels=cout, eps=1e-6, use_apex_gn=use_apex_gn, amp_mode=amp_mode, ) self.dec[f"{res}x{res}_aux_conv"] = Conv2d( in_channels=cout, out_channels=out_channels, kernel=3, fused_conv_bias=True, amp_mode=amp_mode, **init_zero, ) # Set properties recursively on submodules self.profile_mode = profile_mode self.amp_mode = amp_mode # Properties that are recursively set on submodules profile_mode = _recursive_property( "profile_mode", bool, "Should be set to ``True`` to enable profiling." ) amp_mode = _recursive_property( "amp_mode", bool, "Should be set to ``True`` to enable automatic mixed precision.", ) def forward(self, x, noise_labels, class_labels, augment_labels=None): with ( nvtx.annotate(message="SongUNet", color="blue") if self.profile_mode else contextlib.nullcontext() ): # Validate input shapes batch_size = x.shape[0] if x.ndim != 4: raise ValueError( f"Expected 'x' to be a 4D tensor, " f"got {x.ndim}D tensor with shape {tuple(x.shape)}" ) # Check spatial dimensions are powers of 2 or multiples of 2^{N-1} for d in x.shape[-2:]: # Check if d is a power of 2 is_power_of_2 = (d & (d - 1)) == 0 and d > 0 # If not power of 2, must be multiple of self._input_shape_mult if not ( (is_power_of_2 and d < self._input_shape_mult) or (d % self._input_shape_mult == 0) ): raise ValueError( f"Input spatial dimensions ({x.shape[-2:]}) must be " f"either powers of 2 or multiples of 2**(N-1) where " f"N (={self._num_levels}) is the number of levels " f"in the U-Net." ) # TODO: noise_labels of shape (1,) means that all inputs share the # same noise level. This should be removed in the future, though. if noise_labels.ndim != 1 or noise_labels.shape[0] not in (batch_size, 1): raise ValueError( f"Expected 'noise_labels' shape ({batch_size},) or (1,), " f"got {tuple(noise_labels.shape)}" ) if class_labels is not None and ( class_labels.ndim != 2 or class_labels.shape[0] != batch_size ): raise ValueError( f"Expected 'class_labels' shape ({batch_size}, C), " f"got {tuple(class_labels.shape)}" ) if augment_labels is not None and ( augment_labels.ndim != 2 or augment_labels.shape[0] != batch_size ): raise ValueError( f"Expected 'augment_labels' shape ({batch_size}, C), " f"got {tuple(augment_labels.shape)}" ) if ( self.use_apex_gn and (not x.is_contiguous(memory_format=torch.channels_last)) and x.dim() == 4 ): x = x.to(memory_format=torch.channels_last) if self.embedding_type != "zero": # Mapping. emb = self.map_noise(noise_labels) emb_shape = emb.shape emb = emb.reshape(emb.shape[0], 2, -1) # swap sin/cos emb = torch.concat([emb[:, 1:], emb[:, :1]], dim=1).reshape(*emb_shape) if self.map_label is not None: tmp = class_labels if self.training and self.label_dropout: tmp = tmp * ( torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout ).to(tmp.dtype) emb = emb + self.map_label( tmp * np.sqrt(self.map_label.in_features) ) if self.map_augment is not None and augment_labels is not None: emb = emb + self.map_augment(augment_labels) emb = silu(self.map_layer0(emb)) emb = silu(self.map_layer1(emb)) else: emb = torch.zeros( (noise_labels.shape[0], self.emb_channels), device=x.device, dtype=x.dtype, ) # Encoder. skips = [] aux = x for name, block in self.enc.items(): with ( nvtx.annotate(f"SongUNet encoder: {name}", color="blue") if self.profile_mode else contextlib.nullcontext() ): if "aux_down" in name: aux = block(aux) elif "aux_skip" in name: x = skips[-1] = x + block(aux) elif "aux_residual" in name: x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) elif "_conv" in name: x = block(x) if self.additive_pos_embed: x = x + self.spatial_emb.to(dtype=x.dtype) skips.append(x) else: # For UNetBlocks check if we should use gradient checkpointing if isinstance(block, UNetBlock): if ( math.floor(math.sqrt(x.shape[-2] * x.shape[-1])) > self.checkpoint_threshold ): # self.checkpoint = checkpoint? # else: self.checkpoint = lambda(block,x,emb:block(x,emb)) x = checkpoint(block, x, emb, use_reentrant=False) else: # AssertionError: Only support NHWC layout. x = block(x, emb) else: x = block(x) skips.append(x) # Decoder. aux = None tmp = None for name, block in self.dec.items(): with ( nvtx.annotate(f"SongUNet decoder: {name}", color="blue") if self.profile_mode else contextlib.nullcontext() ): if "aux_up" in name: aux = block(aux) elif "aux_norm" in name: tmp = block(x) elif "aux_conv" in name: tmp = block(silu(tmp)) aux = tmp if aux is None else tmp + aux else: if x.shape[1] != block.in_channels: x = torch.cat([x, skips.pop()], dim=1) # check for checkpointing on decoder blocks and up sampling blocks if ( math.floor(math.sqrt(x.shape[-2] * x.shape[-1])) > self.checkpoint_threshold and "_block" in name ) or ( math.floor(math.sqrt(x.shape[-2] * x.shape[-1])) > (self.checkpoint_threshold / 2) and "_up" in name ): x = checkpoint(block, x, emb, use_reentrant=False) else: x = block(x, emb) return aux
# ------------------------------------------------------------------------------ # Specialized architectures # ------------------------------------------------------------------------------
[docs] class SongUNetPosEmbd(SongUNet): r"""This specialized architecture extends :class:`~physicsnemo.models.diffusion_unets.SongUNet` with positional embeddings that encode global spatial coordinates of the pixels. This model supports the same type of conditioning as the base SongUNet, and can be in addition conditioned on the positional embeddings. Conditioning on the positional embeddings is performed with a channel-wise concatenation to the input image before the first layer of the U-Net. Multiple types of positional embeddings are supported. Positional embeddings are represented by a 2D grid of shape :math:`(C_{PE}, H, W)`, where :math:`H` and :math:`W` correspond to the ``img_resolution`` parameter. The following types of positional embeddings are supported: • learnable: uses a 2D grid of learnable parameters. • linear: uses a 2D rectilinear grid over the domain :math:`[-1, 1] \times [-1, 1]`. • sinusoidal: uses sinusoidal functions of the spatial coordinates, with possibly multiple frequency bands. • test: uses a 2D grid of integer indices, only used for testing. When the input image spatial resolution is smaller than the global positional embeddings, it is necessary to select a subset (or *patch*) of the embedding grid that correspond to the spatial locations of the input image pixels. The model provides two methods for selecting the subset of positional embeddings: 1. Using a selector function. See :meth:`positional_embedding_selector` for details. 2. Using global indices. See :meth:`positional_embedding_indexing` for details. If none of these are provided, the entire grid of positional embeddings is used and channel-wise concatenated to the input image. Most parameters are the same as in the parent class :class:`~physicsnemo.models.diffusion_unets.SongUNet`. Only the ones that differ are listed below. Parameters ---------- img_resolution : Union[List[int, int], int] The resolution of the input/output image. Can be a single int for square images or a list :math:`[H, W]` for rectangular images. Used to set the resolution of the positional embedding grid. It must correspond to the spatial resolution of the *global* domain/image. in_channels : int Number of channels :math:`C_{in} + C_{PE}`, where :math:`C_{in}` is the number of channels in the image passed to the U-Net and :math:`C_{PE}` is the number of channels in the positional embedding grid. **Important:** in comparison to the base :class:`~physicsnemo.models.diffusion_unets.SongUNet`, this parameter should also include the number of channels in the positional embedding grid :math:`C_{PE}`. gridtype : Literal["sinusoidal", "learnable", "linear", "test"], optional, default="sinusoidal" Type of positional embedding to use. Controls how spatial pixels locations are encoded. N_grid_channels : int, optional, default=4 Number of channels :math:`C_{PE}` in the positional embedding grid. For 'sinusoidal' must be 4 or multiple of 4. For 'linear' and 'test' must be 2. For 'learnable' can be any value. If 0, positional embedding is disabled (but ``lead_time_mode`` may still be used). lead_time_mode : bool, optional, default=False Provided for convenience. It is recommended to use the architecture :class:`~physicsnemo.models.diffusion_unets.SongUNetPosLtEmbd` for a lead-time aware model. lead_time_channels : int, optional, default=None Provided for convenience. Refer to :class:`~physicsnemo.models.diffusion_unets.SongUNetPosLtEmbd`. lead_time_steps : int, optional, default=9 Provided for convenience. Refer to :class:`~physicsnemo.models.diffusion_unets.SongUNetPosLtEmbd`. prob_channels : List[int], optional, default=[] Provided for convenience. Refer to :class:`~physicsnemo.models.diffusion_unets.SongUNetPosLtEmbd`. Forward ------- x : torch.Tensor The input image of shape :math:`(B, C_{in}, H_{in}, W_{in})`, where :math:`H_{in}` and :math:`W_{in}` are the spatial dimensions of the input image (does not need to be the full image). In general ``x`` is the channel-wise concatenation of the latent state :math:`\mathbf{x}` and additional images used for conditioning. For an unconditional model, ``x`` is simply the latent state :math:`\mathbf{x}`. *Note:* :math:`H_{in}` and :math:`W_{in}` do not need to match the ``img_resolution`` parameter, except when ``additive_pos_embed`` is ``True``. In all other cases, the resolution of ``x`` must be smaller than ``img_resolution``. noise_labels : torch.Tensor The noise labels of shape :math:`(B,)`. Used for conditioning on the diffusion noise level. class_labels : torch.Tensor The class labels of shape :math:`(B, \text{label_dim})`. Used for conditioning on any vector-valued quantity. Can pass ``None`` when ``label_dim`` is 0. global_index : torch.Tensor, optional, default=None The global indices of the positional embeddings to use. If neither ``global_index`` nor ``embedding_selector`` are provided, the entire positional embedding grid of shape :math:`(C_{PE}, H, W)` is used. In this case ``x`` must have the same spatial resolution as the positional embedding grid. See :meth:`positional_embedding_indexing` for details. embedding_selector : Callable, optional, default=None A function that selects the positional embeddings to use. See :meth:`positional_embedding_selector` for details. augment_labels : torch.Tensor, optional, default=None The augmentation labels of shape :math:`(B, \text{augment_dim})`. Used for conditioning on any additional vector-valued quantity. Can pass ``None`` when ``augment_dim`` is 0. Outputs ------- torch.Tensor The output tensor of shape :math:`(B, C_{out}, H_{in}, W_{in})`. .. important:: Unlike positional embeddings defined by ``embedding_type`` in the parent class :class:`~physicsnemo.models.diffusion_unets.SongUNet` that encode the diffusion time-step (or noise level), the positional embeddings in this specialized architecture encode global spatial coordinates of the pixels. Examples -------- >>> import torch >>> from physicsnemo.models.diffusion_unets import SongUNetPosEmbd >>> from physicsnemo.diffusion.multi_diffusion import GridPatching2D >>> >>> # Model initialization - in_channels must include both original input channels (2) >>> # and the positional embedding channels (N_grid_channels=4 by default) >>> model = SongUNetPosEmbd(img_resolution=16, in_channels=2+4, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> # The input has only the original 2 channels - positional embeddings are >>> # added automatically inside the forward method >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels) >>> output_image.shape torch.Size([1, 2, 16, 16]) >>> >>> # Using a global index to select all positional embeddings >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) >>> global_index = patching.global_index(batch_size=1) >>> output_image = model( ... input_image, noise_labels, class_labels, ... global_index=global_index ... ) >>> output_image.shape torch.Size([1, 2, 16, 16]) >>> >>> # Using a custom embedding selector to select all positional embeddings >>> def patch_embedding_selector(emb): ... return patching.apply(emb[None].expand(1, -1, -1, -1)) >>> output_image = model( ... input_image, noise_labels, class_labels, ... embedding_selector=patch_embedding_selector ... ) >>> output_image.shape torch.Size([1, 2, 16, 16]) """ def __init__( self, img_resolution: Union[List[int], int], in_channels: int, out_channels: int, label_dim: int = 0, augment_dim: int = 0, model_channels: int = 128, channel_mult: List[int] = [1, 2, 2, 2, 2], channel_mult_emb: int = 4, num_blocks: int = 4, attn_resolutions: List[int] = [28], dropout: float = 0.13, label_dropout: float = 0.0, embedding_type: str = "positional", channel_mult_noise: int = 1, encoder_type: str = "standard", decoder_type: str = "standard", resample_filter: List[int] = [1, 1], gridtype: Literal["sinusoidal", "learnable", "linear", "test"] = "sinusoidal", N_grid_channels: int = 4, checkpoint_level: int = 0, additive_pos_embed: bool = False, bottleneck_attention: bool = True, use_apex_gn: bool = False, act: str = "silu", profile_mode: bool = False, amp_mode: bool = False, lead_time_mode: bool = False, lead_time_channels: int | None = None, lead_time_steps: int = 9, prob_channels: List[int] = [], ): # Force users to use the correct class for models with lead-time embeddings if not getattr(self, "_is_song_unet_pos_lt_embd", False) and ( lead_time_mode or lead_time_channels ): raise ValueError( "For a model with lead-time embeddings, the recommended class is " "`SongUNetPosLtEmbd` instead of `SongUNetPosEmbd`." ) super().__init__( img_resolution=img_resolution, in_channels=in_channels, out_channels=out_channels, label_dim=label_dim, augment_dim=augment_dim, model_channels=model_channels, channel_mult=channel_mult, channel_mult_emb=channel_mult_emb, num_blocks=num_blocks, attn_resolutions=attn_resolutions, dropout=dropout, label_dropout=label_dropout, embedding_type=embedding_type, channel_mult_noise=channel_mult_noise, encoder_type=encoder_type, decoder_type=decoder_type, resample_filter=resample_filter, checkpoint_level=checkpoint_level, additive_pos_embed=additive_pos_embed, bottleneck_attention=bottleneck_attention, use_apex_gn=use_apex_gn, act=act, profile_mode=profile_mode, amp_mode=amp_mode, ) self.gridtype = gridtype self.N_grid_channels = N_grid_channels if (self.gridtype == "learnable") or (self.N_grid_channels == 0): self.pos_embd = self._get_positional_embedding() else: self.register_buffer( "pos_embd", self._get_positional_embedding().float(), persistent=False ) self.lead_time_mode = lead_time_mode if self.lead_time_mode: if (lead_time_channels is None) or (lead_time_channels <= 0): raise ValueError( "`lead_time_channels` must be >= 1 if `lead_time_mode` is enabled." ) self.lead_time_channels = lead_time_channels self.lead_time_steps = lead_time_steps self.lt_embd = self._get_lead_time_embedding() self.prob_channels = prob_channels if self.prob_channels: self.scalar = torch.nn.Parameter( torch.ones((1, len(self.prob_channels), 1, 1)) ) else: if lead_time_channels: raise ValueError( "When `lead_time_mode` is disabled, `lead_time_channels` may not be set." ) self.lt_embd = None def forward( self, x, noise_labels, class_labels, global_index: Optional[torch.Tensor] = None, embedding_selector: Optional[Callable] = None, augment_labels=None, lead_time_label=None, ): with ( nvtx.annotate(message="SongUNetPosEmbd", color="blue") if self.profile_mode else contextlib.nullcontext() ): if embedding_selector is not None and global_index is not None: raise ValueError( "Cannot provide both embedding_selector and global_index." ) # Append positional embedding to input conditioning if (self.pos_embd is not None) or (self.lt_embd is not None): # Select positional embeddings with a selector function if embedding_selector is not None: selected_pos_embd = self.positional_embedding_selector( x, embedding_selector, lead_time_label=lead_time_label ) # Select positional embeddings using global indices (selects all # embeddings if global_index is None) else: selected_pos_embd = self.positional_embedding_indexing( x, global_index=global_index, lead_time_label=lead_time_label ) x = torch.cat((x, selected_pos_embd.to(x.dtype)), dim=1) out = super().forward(x, noise_labels, class_labels, augment_labels) if self.lead_time_mode and self.prob_channels: # if training mode, let crossEntropyLoss do softmax. The model outputs logits. # if eval mode, the model outputs probability scalar = self.scalar if out.dtype != scalar.dtype: scalar = scalar.to(out.dtype) if self.training: out[:, self.prob_channels] = out[:, self.prob_channels] * scalar else: out[:, self.prob_channels] = ( (out[:, self.prob_channels] * scalar) .softmax(dim=1) .to(out.dtype) ) return out
[docs] def positional_embedding_indexing( self, x: torch.Tensor, global_index: Optional[torch.Tensor] = None, lead_time_label: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""Select positional embeddings using global indices. This method uses global indices to select specific subset of the positional embedding grid and/or the lead-time embedding grid (called *patches*). If no indices are provided, the entire embedding grid is returned. The positional embedding grid is returned if ``N_grid_channels > 0``, while the lead-time embedding grid is returned if ``lead_time_mode == True``. If both positional and lead-time embedding are enabled, both are returned (concatenated). If neither is enabled, this function should not be called; doing so will raise a ValueError. Parameters ---------- x : torch.Tensor Input tensor of shape :math:`(P \times B, C, H_{in}, W_{in})`. Only used to determine batch size :math:`B` and device. global_index : Optional[torch.Tensor], default=None Tensor of shape :math:`(P, 2, H_{in}, W_{in})` that correspond to the patches to extract from the positional embedding grid. :math:`P` is the number of distinct patches in the input tensor ``x``. The channel dimension should contain :math:`j`, :math:`i` indices that should represent the indices of the pixels to extract from the embedding grid. lead_time_label : Optional[torch.Tensor], default=None Tensor of shape :math:`(B,)` that corresponds to the lead-time label for each batch element. Only used if ``lead_time_mode`` is True. Returns ------- torch.Tensor Selected embeddings with shape :math:`(P \times B, C_{PE} [+ C_{LT}], H_{in}, W_{in})`. :math:`C_{PE}` is the number of embedding channels in the positional embedding grid, and :math:`C_{LT}` is the number of embedding channels in the lead-time embedding grid. If ``lead_time_label`` is provided, the lead-time embedding channels are included. If ``global_index`` is `None`, :math:`P = 1` is assumed, and the positional embedding grid is duplicated :math:`B` times and returned with shape :math:`(B, C_{PE} [+ C_{LT}], H, W)`. Example ------- >>> # Create global indices using patching utility: >>> from physicsnemo.diffusion.multi_diffusion import GridPatching2D >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) >>> global_index = patching.global_index(batch_size=3) >>> print(global_index.shape) torch.Size([4, 2, 8, 8]) Notes ----- - This method is typically used in patch-based diffusion (or multi-diffusion), where a large input image is split into multiple patches. The batch dimension of the input tensor contains the patches. Patches are processed independently by the model, and the ``global_index`` parameter is used to select the grid of positional embeddings corresponding to each patch. - See this method from :class:`physicsnemo.diffusion.multi_diffusion.BasePatching2D` for generating the ``global_index`` parameter: :meth:`~physicsnemo.diffusion.multi_diffusion.BasePatching2D.global_index`. """ # dtype casting of embeddings pos_embd = self.pos_embd if (pos_embd is not None) and (x.dtype != pos_embd.dtype): pos_embd = pos_embd.to(x.dtype) lt_embd = self.lt_embd if (lt_embd is not None) and (x.dtype != lt_embd.dtype): lt_embd = lt_embd.to(x.dtype) # If no global indices are provided, select all embeddings and expand # to match the batch size of the input if global_index is None: selected_embd = [] # Select positional embedding if pos_embd is not None: selected_embd.append(pos_embd[None].expand((x.shape[0], -1, -1, -1))) # Select lead-time embedding if lt_embd is not None: if lead_time_label is None: raise ValueError( "`lead_time_label` must be provided when `lt_embd` is not None." ) selected_embd.append( torch.reshape( lt_embd[lead_time_label.int()], ( x.shape[0], self.lead_time_channels, self.img_shape_y, self.img_shape_x, ), ) ) # If global indices are provided, select the embeddings corresponding # to the patches else: P = global_index.shape[0] B = x.shape[0] // P H = global_index.shape[2] W = global_index.shape[3] global_index = torch.reshape( torch.permute(global_index, (1, 0, 2, 3)), (2, -1) ) # (P, 2, X, Y) to (2, P*X*Y) selected_embd = [] # Select positional embedding if pos_embd is not None: selected_pos_embd = pos_embd[ :, global_index[0], global_index[1] ] # (C_pe, P*X*Y) selected_pos_embd = torch.permute( torch.reshape(selected_pos_embd, (pos_embd.shape[0], P, H, W)), (1, 0, 2, 3), ) # (P, C_pe, X, Y) selected_pos_embd = selected_pos_embd.repeat( B, 1, 1, 1 ) # (B*P, C_pe, X, Y) selected_embd.append(selected_pos_embd) # Select lead-time embedding if lt_embd is not None: if lead_time_label is None: raise ValueError( "`lead_time_label` must be provided when `lt_embd` is not None." ) selected_lt_embd = lt_embd[ lead_time_label.int() ] # (B, self.lead_time_channels, self.img_shape_y, self.img_shape_x), selected_lt_embd = selected_lt_embd[ :, :, global_index[0], global_index[1] ] # (B, C_lt, P*X*Y) selected_lt_embd = torch.reshape( torch.permute( torch.reshape( selected_lt_embd, (B, self.lead_time_channels, P, H, W), ), (0, 2, 1, 3, 4), ).contiguous(), (B * P, self.lead_time_channels, H, W), ) # (B*P, C_pe, X, Y) selected_embd.append(selected_lt_embd) # Concatenate all selected embeddings if len(selected_embd) > 0: selected_embd = torch.cat(selected_embd, dim=1) else: raise ValueError( "`positional_embedding_indexing` should not be called when neither " "lead-time nor positional embeddings are used." ) return selected_embd
[docs] def positional_embedding_selector( self, x: torch.Tensor, embedding_selector: Callable[[torch.Tensor], torch.Tensor], lead_time_label=None, ) -> torch.Tensor: r"""Select positional embeddings using a selector function. Similar to :meth:`positional_embedding_indexing`, but instead uses a selector function to select the embeddings. Parameters ---------- x : torch.Tensor Input tensor of shape :math:`(P \times B, C, H_{in}, W_{in})`. Only used to determine the dtype. embedding_selector : Callable[[torch.Tensor], torch.Tensor] Function that takes as input the entire embedding grid of shape :math:`(C_{PE}, H, W)` (or :math:`(B, C_{LT}, H, W)` when ``lead_time_label`` is provided) and returns selected embeddings with shape :math:`(P \times B, C_{PE}, H_{in}, W_{in})` (or :math:`(P \times B, C_{LT}, H_{in}, W_{in})` when ``lead_time_label`` is provided). Each selected embedding should correspond to the portion of the embedding grid that corresponds to the batch element in ``x``. Typically this should be based on :meth:`physicsnemo.diffusion.multi_diffusion.BasePatching2D.apply` method to maintain consistency with patch extraction. lead_time_label : Optional[torch.Tensor], default=None Tensor of shape :math:`(B,)` that corresponds to the lead-time label for each batch element. Only used if ``lead_time_mode`` is ``True``. Returns ------- torch.Tensor A tensor of shape :math:`(P \times B, C_{PE} [+ C_{LT}], H_{in}, W_{in})`. :math:`C_{PE}` is the number of embedding channels in the positional embedding grid, and :math:`C_{LT}` is the number of embedding channels in the lead-time embedding grid. If ``lead_time_label`` is provided, the lead-time embedding channels are included. Notes ----- - This method is typically used in patch-based diffusion (or multi-diffusion), where a large input image is split into multiple patches. The batch dimension of the input tensor contains the patches. Patches are processed independently by the model, and the ``embedding_selector`` function is used to select the grid of positional embeddings corresponding to each patch. - See the method :meth:`~physicsnemo.diffusion.multi_diffusion.BasePatching2D.apply` from :class:`physicsnemo.diffusion.multi_diffusion.BasePatching2D` for generating the ``embedding_selector`` parameter, as well as the example below. Example ------- >>> # Define a selector function with a patching utility: >>> from physicsnemo.diffusion.multi_diffusion import GridPatching2D >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) >>> B = 4 >>> def embedding_selector(emb): ... return patching.apply(emb.expand(B, -1, -1, -1)) >>> """ # dtype casting of embeddings pos_embd = self.pos_embd if (pos_embd is not None) and (x.dtype != pos_embd.dtype): pos_embd = pos_embd.to(x.dtype) # (C_PE, H, W) lt_embd = self.lt_embd if (lt_embd is not None) and (x.dtype != lt_embd.dtype): lt_embd = lt_embd.to(x.dtype) # (lead_time_steps, C_LT, H, W) embeddings: list[torch.Tensor] = [] # Select positional embedding if pos_embd is not None: selected_pos_embd = embedding_selector(pos_embd) # (P * B, C_PE, H_p, W_p) embeddings.append(selected_pos_embd) # Select lead-time embedding if lt_embd is not None: if lead_time_label is None: raise ValueError( "`lead_time_label` must be provided when `lt_embd` is not None." ) selected_lt_embd: torch.Tensor = lt_embd[ lead_time_label.int() ] # (B, C_LT, H, W) selected_lt_embd = embedding_selector( selected_lt_embd ) # (P * B, C_LT, H_p, W_p) embeddings.append(selected_lt_embd) if len(embeddings) > 0: embeddings: torch.Tensor = torch.cat(embeddings, dim=1) else: raise ValueError( "`positional_embedding_selector` should not be called when neither " "lead-time nor positional embeddings are used." ) return embeddings
def _get_positional_embedding(self): if self.N_grid_channels == 0: return None elif self.gridtype == "learnable": grid = torch.nn.Parameter( torch.randn(self.N_grid_channels, self.img_shape_y, self.img_shape_x) ) # (N_grid_channels, img_shape_y, img_shape_x) elif self.gridtype == "linear": if self.N_grid_channels != 2: raise ValueError("N_grid_channels must be set to 2 for gridtype linear") y = np.meshgrid(np.linspace(-1, 1, self.img_shape_y)) x = np.meshgrid(np.linspace(-1, 1, self.img_shape_x)) grid_y, grid_x = np.meshgrid(x, y) grid = torch.from_numpy( np.stack((grid_y, grid_x), axis=0) ) # (2, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "sinusoidal" and self.N_grid_channels == 4: # print('sinusuidal grid added ......') x1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_x))) x2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_x))) y1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_y))) y2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_y))) grid_x1, grid_y1 = np.meshgrid(x1, y1) grid_x2, grid_y2 = np.meshgrid(x2, y2) grid = torch.from_numpy( np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0) ) grid.requires_grad = False elif self.gridtype == "sinusoidal" and self.N_grid_channels != 4: if self.N_grid_channels % 4 != 0: raise ValueError("N_grid_channels must be a factor of 4") num_freq = self.N_grid_channels // 4 freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) grid_list = [] grid_x, grid_y = np.meshgrid( np.linspace(0, 2 * np.pi, self.img_shape_x), np.linspace(0, 2 * np.pi, self.img_shape_y), ) for freq in freq_bands: for p_fn in [np.sin, np.cos]: grid_list.append(p_fn(grid_x * freq)) grid_list.append(p_fn(grid_y * freq)) grid = torch.from_numpy( np.stack(grid_list, axis=0) ) # (N_grid_channels, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "test" and self.N_grid_channels == 2: idx_x = torch.arange(self.img_shape_x) idx_y = torch.arange(self.img_shape_y) mesh_y, mesh_x = torch.meshgrid(idx_y, idx_x) grid = torch.stack((mesh_y, mesh_x), dim=0) # (2, img_shape_y, img_shape_x) else: raise ValueError("Gridtype not supported.") return grid def _get_lead_time_embedding(self): if (self.lead_time_steps is None) or (self.lead_time_channels is None): return None grid = torch.nn.Parameter( torch.randn( self.lead_time_steps, self.lead_time_channels, self.img_shape_y, self.img_shape_x, ) ) # (lead_time_steps, lead_time_channels, img_shape_y, img_shape_x) return grid
# TODO: the entire logic of the lead-time logic should be moved there. We # should use subclass of the SongUNetPosEmbd class and specialize it for # lead-time aware embeddings.
[docs] class SongUNetPosLtEmbd(SongUNetPosEmbd): r""" This specialized architecture extends :class:`~physicsnemo.models.diffusion_unets.SongUNetPosEmbd` with two additional capabilities: 1. The model can be conditioned on lead-time labels. These labels encode *physical* time information, such as a forecasting horizon. 2. Similarly to the parent ``SongUNetPosEmbd``, this model predicts regression targets, but it can also produce classification predictions. More precisely, some of the output channels are probability outputs, that are passed through a softmax activation function. This is useful for multi-task applications, where the objective is a combination of both regression and classification losses. The mechanism to condition on lead-time labels is implemented by: • First generating a grid of learnable lead-time embeddings of shape :math:`(\text{lead_time_steps}, C_{LT}, H, W)`. The spatial resolution of the lead-time embeddings is the same as the input/output image. • Then, given an input ``x``, select the lead-time embeddings that corresponds to the lead-times associated with the samples in the input ``x``. • Finally, concatenate channels-wise the selected lead-time embeddings and positional embeddings to the input ``x`` and pass them to the U-Net network. Most parameters are similar to the parent :class:`~physicsnemo.models.diffusion_unets.SongUNetPosEmbd`, at the exception of the ones listed below. Parameters ----------- in_channels : int Number of channels :math:`C_{in} + C_{PE} + C_{LT}` in the image passed to the U-Net. *Important:* in comparison to the base :class:`~physicsnemo.models.diffusion_unets.SongUNet`, this parameter should also include the number of channels in the positional embedding grid :math:`C_{PE}` and the number of channels in the lead-time embedding grid :math:`C_{LT}`. lead_time_channels : int, optional, default=None Number of channels :math:`C_{LT}` in the lead time embedding. These are learned embeddings that encode *physical* time information. lead_time_steps : int, optional, default=9 Number of discrete lead time steps to support. Each step gets its own learned embedding vector of shape :math:`(C_{LT}, H, W)`. prob_channels : List[int], optional, default=[] Indices of channels that are probability outputs (or *classification* predictions), In training mode, the model outputs logits for these probability channels, and in eval mode, the model applies a softmax to outputs the probabilities. Forward ------- x : torch.Tensor The input image of shape :math:`(B, C_{in}, H_{in}, W_{in})`, where :math:`H_{in}` and :math:`W_{in}` are the spatial dimensions of the input image (does not need to be the full image). noise_labels : torch.Tensor The noise labels of shape :math:`(B,)`. Used for conditioning on the diffusion noise level. class_labels : torch.Tensor The class labels of shape :math:`(B, \text{label_dim})`. Used for conditioning on any vector-valued quantity. Can pass ``None`` when ``label_dim`` is 0. global_index : torch.Tensor, optional, default=None The global indices of the positional embeddings to use. See :meth:`positional_embedding_indexing` for details. If neither ``global_index`` nor ``embedding_selector`` are provided, the entire positional embedding grid is used. embedding_selector : Callable, optional, default=None A function that selects the positional embeddings to use. See :meth:`positional_embedding_selector` for details. augment_labels : torch.Tensor, optional, default=None The augmentation labels of shape :math:`(B, \text{augment_dim})`. Used for conditioning on any additional vector-valued quantity. lead_time_label : torch.Tensor, optional, default=None The lead-time labels of shape :math:`(B,)`. Used for selecting lead-time embeddings. It should contain the indices of the lead-time embeddings that correspond to the lead-time of each sample in the batch. Outputs ------- torch.Tensor The output tensor of shape :math:`(B, C_{out}, H_{in}, W_{in})`. Notes ----- - The lead-time embeddings differ from the diffusion time embeddings used in :class:`~physicsnemo.models.diffusion_unets.SongUNet` class, as they do not encode diffusion time-step but *physical forecast time*. Example -------- >>> import torch >>> from physicsnemo.models.diffusion_unets import SongUNetPosLtEmbd >>> from physicsnemo.diffusion.multi_diffusion import GridPatching2D >>> >>> # Model initialization - in_channels must include original input channels (2), >>> # positional embedding channels (N_grid_channels=4 by default) and >>> # lead time embedding channels (4) >>> model = SongUNetPosLtEmbd( ... img_resolution=16, in_channels=2+4+4, out_channels=2, ... lead_time_channels=4, lead_time_steps=9 ... ) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> # The input has only the original 2 channels - positional embeddings and >>> # lead time embeddings are added automatically inside the forward method >>> input_image = torch.ones([1, 2, 16, 16]) >>> lead_time_label = torch.tensor([3]) >>> output_image = model( ... input_image, noise_labels, class_labels, ... lead_time_label=lead_time_label ... ) >>> output_image.shape torch.Size([1, 2, 16, 16]) >>> >>> # Using global_index to select all the positional and lead time embeddings >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) >>> global_index = patching.global_index(batch_size=1) >>> output_image = model( ... input_image, noise_labels, class_labels, ... lead_time_label=lead_time_label, ... global_index=global_index ... ) >>> output_image.shape torch.Size([1, 2, 16, 16]) """ def __init__( self, img_resolution: Union[List[int], int], in_channels: int, out_channels: int, label_dim: int = 0, augment_dim: int = 0, model_channels: int = 128, channel_mult: List[int] = [1, 2, 2, 2, 2], channel_mult_emb: int = 4, num_blocks: int = 4, attn_resolutions: List[int] = [28], dropout: float = 0.13, label_dropout: float = 0.0, embedding_type: str = "positional", channel_mult_noise: int = 1, encoder_type: str = "standard", decoder_type: str = "standard", resample_filter: List[int] = [1, 1], gridtype: str = "sinusoidal", N_grid_channels: int = 4, lead_time_channels: int | None = None, lead_time_steps: int = 9, prob_channels: List[int] = [], checkpoint_level: int = 0, additive_pos_embed: bool = False, bottleneck_attention: bool = True, use_apex_gn: bool = False, act: str = "silu", profile_mode: bool = False, amp_mode: bool = False, ): self._is_song_unet_pos_lt_embd = True super().__init__( img_resolution=img_resolution, in_channels=in_channels, out_channels=out_channels, label_dim=label_dim, augment_dim=augment_dim, model_channels=model_channels, channel_mult=channel_mult, channel_mult_emb=channel_mult_emb, num_blocks=num_blocks, attn_resolutions=attn_resolutions, dropout=dropout, label_dropout=label_dropout, embedding_type=embedding_type, channel_mult_noise=channel_mult_noise, encoder_type=encoder_type, decoder_type=decoder_type, resample_filter=resample_filter, gridtype=gridtype, N_grid_channels=N_grid_channels, checkpoint_level=checkpoint_level, additive_pos_embed=additive_pos_embed, bottleneck_attention=bottleneck_attention, use_apex_gn=use_apex_gn, act=act, profile_mode=profile_mode, amp_mode=amp_mode, lead_time_mode=True, # Note: lead_time_mode=True is enforced here lead_time_channels=lead_time_channels, lead_time_steps=lead_time_steps, prob_channels=prob_channels, ) def forward( self, x, noise_labels, class_labels, lead_time_label=None, global_index: Optional[torch.Tensor] = None, embedding_selector: Optional[Callable] = None, augment_labels=None, ): return super().forward( x=x, noise_labels=noise_labels, class_labels=class_labels, global_index=global_index, embedding_selector=embedding_selector, augment_labels=augment_labels, lead_time_label=lead_time_label, )
# Nothing else is re-implemented, because everything is already in the # parent SongUNetPosEmbd