NVIDIA PhysicsNeMo Core (Latest Release)

deeplearning/physicsnemo/physicsnemo-core/_modules/physicsnemo/models/diffusion/song_unet.html

Source code for physicsnemo.models.diffusion.song_unet

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Model architectures used in the paper "Elucidating the Design Space of 
Diffusion-Based Generative Models".
"""

import contextlib
from dataclasses import dataclass
from typing import Callable, List, Optional, Union

import numpy as np
import nvtx
import torch
from torch.nn.functional import silu
from torch.utils.checkpoint import checkpoint

from physicsnemo.models.diffusion import (
    Conv2d,
    FourierEmbedding,
    GroupNorm,
    Linear,
    PositionalEmbedding,
    UNetBlock,
)
from physicsnemo.models.meta import ModelMetaData
from physicsnemo.models.module import Module


[docs]@dataclass class MetaData(ModelMetaData): name: str = "SongUNet" # Optimization jit: bool = False cuda_graphs: bool = False amp_cpu: bool = False amp_gpu: bool = True torch_fx: bool = False # Data type bf16: bool = True # Inference onnx: bool = False # Physics informed func_torch: bool = False auto_grad: bool = False
[docs]class SongUNet(Module): """ Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with optional self-attention, embeddings, and encoder-decoder components. This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder type, embedding type, etc., making it flexible and adaptable to different tasks and configurations. Parameters ----------- img_resolution : Union[List[int], int] The resolution of the input/output image. Can be a single int for square images or a list [height, width] for rectangular images. in_channels : int Number of channels in the input image. out_channels : int Number of channels in the output image. label_dim : int, optional Number of class labels; 0 indicates an unconditional model. By default 0. augment_dim : int, optional Dimensionality of augmentation labels; 0 means no augmentation. By default 0. model_channels : int, optional Base multiplier for the number of channels across the network. By default 128. channel_mult : List[int], optional Per-resolution multipliers for the number of channels. By default [1,2,2,2]. channel_mult_emb : int, optional Multiplier for the dimensionality of the embedding vector. By default 4. num_blocks : int, optional Number of residual blocks per resolution. By default 4. attn_resolutions : List[int], optional Resolutions at which self-attention layers are applied. By default [16]. dropout : float, optional Dropout probability applied to intermediate activations. By default 0.10. label_dropout : float, optional Dropout probability of class labels for classifier-free guidance. By default 0.0. embedding_type : str, optional Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++, 'zero' for none. By default 'positional'. channel_mult_noise : int, optional Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. encoder_type : str, optional Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++, 'skip' for skip connections. By default 'standard'. decoder_type : str, optional Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. resample_filter : List[int], optional Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. checkpoint_level : int, optional Number of layers that should use gradient checkpointing (0 disables checkpointing). Higher values trade memory for computation. By default 0. additive_pos_embed : bool, optional If True, adds a learned positional embedding after the first convolution layer. Used in StormCast model. By default False. use_apex_gn : bool, optional A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. Need to set this as False on cpu. Defaults to False. act : str, optional The activation function to use when fusing activation with GroupNorm. Defaults to None. profile_mode: A boolean flag indicating whether to enable all nvtx annotations during profiling. amp_mode : bool, optional A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. Reference ---------- Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and Poole, B., 2020. Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456. Note ----- Equivalent to the original implementation by Song et al., available at https://github.com/yang-song/score_sde_pytorch Example -------- >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels) >>> output_image.shape torch.Size([1, 2, 16, 16]) """ def __init__( self, img_resolution: Union[List[int], int], in_channels: int, out_channels: int, label_dim: int = 0, augment_dim: int = 0, model_channels: int = 128, channel_mult: List[int] = [1, 2, 2, 2], channel_mult_emb: int = 4, num_blocks: int = 4, attn_resolutions: List[int] = [16], dropout: float = 0.10, label_dropout: float = 0.0, embedding_type: str = "positional", channel_mult_noise: int = 1, encoder_type: str = "standard", decoder_type: str = "standard", resample_filter: List[int] = [1, 1], checkpoint_level: int = 0, additive_pos_embed: bool = False, use_apex_gn: bool = False, act: str = "silu", profile_mode: bool = False, amp_mode: bool = False, ): valid_embedding_types = ["fourier", "positional", "zero"] if embedding_type not in valid_embedding_types: raise ValueError( f"Invalid embedding_type: {embedding_type}. Must be one of {valid_embedding_types}." ) valid_encoder_types = ["standard", "skip", "residual"] if encoder_type not in valid_encoder_types: raise ValueError( f"Invalid encoder_type: {encoder_type}. Must be one of {valid_encoder_types}." ) valid_decoder_types = ["standard", "skip"] if decoder_type not in valid_decoder_types: raise ValueError( f"Invalid decoder_type: {decoder_type}. Must be one of {valid_decoder_types}." ) super().__init__(meta=MetaData()) self.label_dropout = label_dropout self.embedding_type = embedding_type emb_channels = model_channels * channel_mult_emb self.emb_channels = emb_channels noise_channels = model_channels * channel_mult_noise init = dict(init_mode="xavier_uniform") init_zero = dict(init_mode="xavier_uniform", init_weight=1e-5) init_attn = dict(init_mode="xavier_uniform", init_weight=np.sqrt(0.2)) block_kwargs = dict( emb_channels=emb_channels, num_heads=1, dropout=dropout, skip_scale=np.sqrt(0.5), eps=1e-6, resample_filter=resample_filter, resample_proj=True, adaptive_scale=False, init=init, init_zero=init_zero, init_attn=init_attn, use_apex_gn=use_apex_gn, act=act, fused_conv_bias=True, profile_mode=profile_mode, amp_mode=amp_mode, ) self.profile_mode = profile_mode self.amp_mode = amp_mode # for compatibility with older versions that took only 1 dimension self.img_resolution = img_resolution if isinstance(img_resolution, int): self.img_shape_y = self.img_shape_x = img_resolution else: self.img_shape_y = img_resolution[0] self.img_shape_x = img_resolution[1] # set the threshold for checkpointing based on image resolution self.checkpoint_threshold = (self.img_shape_y >> checkpoint_level) + 1 # Optional additive learned positition embed after the first conv self.additive_pos_embed = additive_pos_embed if self.additive_pos_embed: self.spatial_emb = torch.nn.Parameter( torch.randn(1, model_channels, self.img_shape_y, self.img_shape_x) ) torch.nn.init.trunc_normal_(self.spatial_emb, std=0.02) # Mapping. if self.embedding_type != "zero": self.map_noise = ( PositionalEmbedding( num_channels=noise_channels, endpoint=True, amp_mode=amp_mode ) if embedding_type == "positional" else FourierEmbedding(num_channels=noise_channels, amp_mode=amp_mode) ) self.map_label = ( Linear( in_features=label_dim, out_features=noise_channels, amp_mode=amp_mode, **init, ) if label_dim else None ) self.map_augment = ( Linear( in_features=augment_dim, out_features=noise_channels, bias=False, amp_mode=amp_mode, **init, ) if augment_dim else None ) self.map_layer0 = Linear( in_features=noise_channels, out_features=emb_channels, amp_mode=amp_mode, **init, ) self.map_layer1 = Linear( in_features=emb_channels, out_features=emb_channels, amp_mode=amp_mode, **init, ) # Encoder. self.enc = torch.nn.ModuleDict() cout = in_channels caux = in_channels for level, mult in enumerate(channel_mult): res = self.img_shape_y >> level if level == 0: cin = cout cout = model_channels self.enc[f"{res}x{res}_conv"] = Conv2d( in_channels=cin, out_channels=cout, kernel=3, fused_conv_bias=True, amp_mode=amp_mode, **init, ) else: self.enc[f"{res}x{res}_down"] = UNetBlock( in_channels=cout, out_channels=cout, down=True, **block_kwargs ) if encoder_type == "skip": self.enc[f"{res}x{res}_aux_down"] = Conv2d( in_channels=caux, out_channels=caux, kernel=0, down=True, resample_filter=resample_filter, amp_mode=amp_mode, ) self.enc[f"{res}x{res}_aux_skip"] = Conv2d( in_channels=caux, out_channels=cout, kernel=1, fused_conv_bias=True, amp_mode=amp_mode, **init, ) if encoder_type == "residual": self.enc[f"{res}x{res}_aux_residual"] = Conv2d( in_channels=caux, out_channels=cout, kernel=3, down=True, resample_filter=resample_filter, fused_resample=True, fused_conv_bias=True, amp_mode=amp_mode, **init, ) caux = cout for idx in range(num_blocks): cin = cout cout = model_channels * mult attn = res in attn_resolutions self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( in_channels=cin, out_channels=cout, attention=attn, **block_kwargs ) skips = [ block.out_channels for name, block in self.enc.items() if "aux" not in name ] # Decoder. self.dec = torch.nn.ModuleDict() for level, mult in reversed(list(enumerate(channel_mult))): res = self.img_shape_y >> level if level == len(channel_mult) - 1: self.dec[f"{res}x{res}_in0"] = UNetBlock( in_channels=cout, out_channels=cout, attention=True, **block_kwargs ) self.dec[f"{res}x{res}_in1"] = UNetBlock( in_channels=cout, out_channels=cout, **block_kwargs ) else: self.dec[f"{res}x{res}_up"] = UNetBlock( in_channels=cout, out_channels=cout, up=True, **block_kwargs ) for idx in range(num_blocks + 1): cin = cout + skips.pop() cout = model_channels * mult attn = idx == num_blocks and res in attn_resolutions self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( in_channels=cin, out_channels=cout, attention=attn, **block_kwargs ) if decoder_type == "skip" or level == 0: if decoder_type == "skip" and level < len(channel_mult) - 1: self.dec[f"{res}x{res}_aux_up"] = Conv2d( in_channels=out_channels, out_channels=out_channels, kernel=0, up=True, resample_filter=resample_filter, amp_mode=amp_mode, ) self.dec[f"{res}x{res}_aux_norm"] = GroupNorm( num_channels=cout, eps=1e-6, use_apex_gn=use_apex_gn, amp_mode=amp_mode, ) self.dec[f"{res}x{res}_aux_conv"] = Conv2d( in_channels=cout, out_channels=out_channels, kernel=3, fused_conv_bias=True, amp_mode=amp_mode, **init_zero, )
[docs] def forward(self, x, noise_labels, class_labels, augment_labels=None): with nvtx.annotate( message="SongUNet", color="blue" ) if self.profile_mode else contextlib.nullcontext(): if self.embedding_type != "zero": # Mapping. emb = self.map_noise(noise_labels) emb = ( emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) ) # swap sin/cos if self.map_label is not None: tmp = class_labels if self.training and self.label_dropout: tmp = tmp * ( torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout ).to(tmp.dtype) emb = emb + self.map_label( tmp * np.sqrt(self.map_label.in_features) ) if self.map_augment is not None and augment_labels is not None: emb = emb + self.map_augment(augment_labels) emb = silu(self.map_layer0(emb)) emb = silu(self.map_layer1(emb)) else: emb = torch.zeros( (noise_labels.shape[0], self.emb_channels), device=x.device ) # Encoder. skips = [] aux = x for name, block in self.enc.items(): with nvtx.annotate( f"SongUNet encoder: {name}", color="blue" ) if self.profile_mode else contextlib.nullcontext(): if "aux_down" in name: aux = block(aux) elif "aux_skip" in name: x = skips[-1] = x + block(aux) elif "aux_residual" in name: x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) elif "_conv" in name: x = block(x) if self.additive_pos_embed: x = x + self.spatial_emb.to(dtype=x.dtype) skips.append(x) else: # For UNetBlocks check if we should use gradient checkpointing if isinstance(block, UNetBlock): if x.shape[-1] > self.checkpoint_threshold: # self.checkpoint = checkpoint? # else: self.checkpoint = lambda(block,x,emb:block(x,emb)) x = checkpoint(block, x, emb) else: # AssertionError: Only support NHWC layout. x = block(x, emb) else: x = block(x) skips.append(x) # Decoder. aux = None tmp = None for name, block in self.dec.items(): with nvtx.annotate( f"SongUNet decoder: {name}", color="blue" ) if self.profile_mode else contextlib.nullcontext(): if "aux_up" in name: aux = block(aux) elif "aux_norm" in name: tmp = block(x) elif "aux_conv" in name: tmp = block(silu(tmp)) aux = tmp if aux is None else tmp + aux else: if x.shape[1] != block.in_channels: x = torch.cat([x, skips.pop()], dim=1) # check for checkpointing on decoder blocks and up sampling blocks if ( x.shape[-1] > self.checkpoint_threshold and "_block" in name ) or ( x.shape[-1] > (self.checkpoint_threshold / 2) and "_up" in name ): x = checkpoint(block, x, emb) else: x = block(x, emb) return aux
[docs]class SongUNetPosEmbd(SongUNet): """Extends SongUNet with positional embeddings. This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder type, embedding type, etc., making it flexible and adaptable to different tasks and configurations. This model adds positional embeddings to the base SongUNet architecture. The embeddings can be selected using either a selector function or global indices, with the selector approach being more computationally efficient. The model provides two methods for selecting positional embeddings: 1. Using a selector function (preferred method). See :meth:`positional_embedding_selector` for details. 2. Using global indices. See :meth:`positional_embedding_indexing` for details. Parameters ---------- img_resolution : Union[List[int], int] The resolution of the input/output image. Can be a single int for square images or a list [height, width] for rectangular images. in_channels : int Number of channels in the input image. out_channels : int Number of channels in the output image. label_dim : int, optional Number of class labels; 0 indicates an unconditional model. By default 0. augment_dim : int, optional Dimensionality of augmentation labels; 0 means no augmentation. By default 0. model_channels : int, optional Base multiplier for the number of channels across the network. By default 128. channel_mult : List[int], optional Per-resolution multipliers for the number of channels. By default [1,2,2,2,2]. channel_mult_emb : int, optional Multiplier for the dimensionality of the embedding vector. By default 4. num_blocks : int, optional Number of residual blocks per resolution. By default 4. attn_resolutions : List[int], optional Resolutions at which self-attention layers are applied. By default [28]. dropout : float, optional Dropout probability applied to intermediate activations. By default 0.13. label_dropout : float, optional Dropout probability of class labels for classifier-free guidance. By default 0.0. embedding_type : str, optional Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. By default 'positional'. channel_mult_noise : int, optional Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. encoder_type : str, optional Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++, 'skip' for skip connections. By default 'standard'. decoder_type : str, optional Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. resample_filter : List[int], optional Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. gridtype : str, optional Type of positional grid to use: 'sinusoidal', 'learnable', 'linear', or 'test'. Controls how positional information is encoded. By default 'sinusoidal'. N_grid_channels : int, optional Number of channels in the positional embedding grid. For 'sinusoidal' must be 4 or multiple of 4. For 'linear' must be 2. By default 4. checkpoint_level : int, optional Number of layers that should use gradient checkpointing (0 disables checkpointing). Higher values trade memory for computation. By default 0. additive_pos_embed : bool, optional If True, adds a learned positional embedding after the first convolution layer. Used in StormCast model. By default False. use_apex_gn : bool, optional A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. Need to set this as False on cpu. Defaults to False. act : str, optional The activation function to use when fusing activation with GroupNorm. Defaults to None. profile_mode: A boolean flag indicating whether to enable all nvtx annotations during profiling. amp_mode : bool, optional A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. lead_time_mode : bool, optional A boolean flag indicating whether we are running SongUNet with lead time embedding. Defaults to False. lead_time_channels : int, optional Number of channels in the lead time embedding. These are learned embeddings that encode temporal forecast information. By default None. lead_time_steps : int, optional Number of discrete lead time steps to support. Each step gets its own learned embedding vector. By default 9. prob_channels : List[int], optional Indices of probability output channels that should use softmax activation. Used for classification outputs. By default empty list. Note ----- Equivalent to the original implementation by Song et al., available at https://github.com/yang-song/score_sde_pytorch Example -------- >>> import torch >>> from physicsnemo.models.diffusion.song_unet import SongUNetPosEmbd >>> from physicsnemo.utils.patching import GridPatching2D >>> >>> # Model initialization - in_channels must include both original input channels (2) >>> # and the positional embedding channels (N_grid_channels=4 by default) >>> model = SongUNetPosEmbd(img_resolution=16, in_channels=2+4, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> # The input has only the original 2 channels - positional embeddings are >>> # added automatically inside the forward method >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels) >>> output_image.shape torch.Size([1, 2, 16, 16]) >>> >>> # Using a global index to select all positional embeddings >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) >>> global_index = patching.global_index(batch_size=1) >>> output_image = model( ... input_image, noise_labels, class_labels, ... global_index=global_index ... ) >>> output_image.shape torch.Size([1, 2, 16, 16]) >>> >>> # Using a custom embedding selector to select all positional embeddings >>> def patch_embedding_selector(emb): ... return patching.apply(emb[None].expand(1, -1, -1, -1)) >>> output_image = model( ... input_image, noise_labels, class_labels, ... embedding_selector=patch_embedding_selector ... ) >>> output_image.shape torch.Size([1, 2, 16, 16]) """ def __init__( self, img_resolution: Union[List[int], int], in_channels: int, out_channels: int, label_dim: int = 0, augment_dim: int = 0, model_channels: int = 128, channel_mult: List[int] = [1, 2, 2, 2, 2], channel_mult_emb: int = 4, num_blocks: int = 4, attn_resolutions: List[int] = [28], dropout: float = 0.13, label_dropout: float = 0.0, embedding_type: str = "positional", channel_mult_noise: int = 1, encoder_type: str = "standard", decoder_type: str = "standard", resample_filter: List[int] = [1, 1], gridtype: str = "sinusoidal", N_grid_channels: int = 4, checkpoint_level: int = 0, additive_pos_embed: bool = False, use_apex_gn: bool = False, act: str = "silu", profile_mode: bool = False, amp_mode: bool = False, lead_time_mode: bool = False, lead_time_channels: int = None, lead_time_steps: int = 9, prob_channels: List[int] = [], ): super().__init__( img_resolution, in_channels, out_channels, label_dim, augment_dim, model_channels, channel_mult, channel_mult_emb, num_blocks, attn_resolutions, dropout, label_dropout, embedding_type, channel_mult_noise, encoder_type, decoder_type, resample_filter, checkpoint_level, additive_pos_embed, use_apex_gn, act, profile_mode, amp_mode, ) self.gridtype = gridtype self.N_grid_channels = N_grid_channels if self.gridtype == "learnable": self.pos_embd = self._get_positional_embedding() else: self.register_buffer("pos_embd", self._get_positional_embedding().float()) self.lead_time_mode = lead_time_mode if self.lead_time_mode: self.lead_time_channels = lead_time_channels self.lead_time_steps = lead_time_steps self.lt_embd = self._get_lead_time_embedding() self.prob_channels = prob_channels if self.prob_channels: self.scalar = torch.nn.Parameter( torch.ones((1, len(self.prob_channels), 1, 1)) )
[docs] def forward( self, x, noise_labels, class_labels, global_index: Optional[torch.Tensor] = None, embedding_selector: Optional[Callable] = None, augment_labels=None, lead_time_label=None, ): with nvtx.annotate( message="SongUNetPosEmbd", color="blue" ) if self.profile_mode else contextlib.nullcontext(): if embedding_selector is not None and global_index is not None: raise ValueError( "Cannot provide both embedding_selector and global_index. " "embedding_selector is the preferred approach for better efficiency." ) if x.dtype != self.pos_embd.dtype: self.pos_embd = self.pos_embd.to(x.dtype) # Append positional embedding to input conditioning if self.pos_embd is not None: # Select positional embeddings with a selector function if embedding_selector is not None: selected_pos_embd = self.positional_embedding_selector( x, embedding_selector ) # Select positional embeddings using global indices (selects all # embeddings if global_index is None) else: selected_pos_embd = self.positional_embedding_indexing( x, global_index=global_index, lead_time_label=lead_time_label ) x = torch.cat((x, selected_pos_embd), dim=1) out = super().forward(x, noise_labels, class_labels, augment_labels) if self.lead_time_mode: # if training mode, let crossEntropyLoss do softmax. The model outputs logits. # if eval mode, the model outputs probability all_channels = list(range(out.shape[1])) # [0, 1, 2, ..., 10] scalar_channels = [ item for item in all_channels if item not in self.prob_channels ] if self.prob_channels and (not self.training): out_final = torch.cat( ( out[:, scalar_channels], (out[:, self.prob_channels] * self.scalar).softmax(dim=1), ), dim=1, ) elif self.prob_channels and self.training: out_final = torch.cat( ( out[:, scalar_channels], (out[:, self.prob_channels] * self.scalar), ), dim=1, ) else: out_final = out return out_final return out
[docs] def positional_embedding_indexing( self, x: torch.Tensor, global_index: Optional[torch.Tensor] = None, lead_time_label=None, ) -> torch.Tensor: """Select positional embeddings using global indices. This method either uses global indices to select specific embeddings or expands the embeddings for the full input when no indices are provided. Typically used in patch-based training, where the batch dimension contains multiple patches extracted from a larger image. Arguments --------- x : torch.Tensor Input tensor of shape (B, C, H, W), used to determine batch size and device. global_index : Optional[torch.Tensor] Optional tensor of indices for selecting embeddings. These should correspond to the spatial indices of the batch elements in the input tensor x. When provided, should have shape (P, 2, H, W) where the second dimension contains y,x coordinates (indices of the positional embedding grid). Returns ------- torch.Tensor Selected positional embeddings with shape: - If global_index provided: (B, N_pe, H, W) - If global_index is None: (B, N_pe, H_pe, W_pe) where N_pe is the number of positional embedding channels, and H_pe and W_pe are the height and width of the positional embedding grid. Example ------- >>> # Create global indices using patching utility: >>> from physicsnemo.utils.patching import GridPatching2D >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) >>> global_index = patching.global_index(batch_size=3) >>> print(global_index.shape) torch.Size([4, 2, 8, 8]) See Also -------- :meth:`physicsnemo.utils.patching.RandomPatching2D.global_index` For generating random patch indices. :meth:`physicsnemo.utils.patching.GridPatching2D.global_index` For generating deterministic grid-based patch indices. See these methods for possible ways to generate the global_index parameter. """ # If no global indices are provided, select all embeddings and expand # to match the batch size of the input if x.dtype != self.pos_embd.dtype: self.pos_embd = self.pos_embd.to(x.dtype) if global_index is None: if self.lead_time_mode: selected_pos_embd = [] if self.pos_embd is not None: selected_pos_embd.append( self.pos_embd[None].expand((x.shape[0], -1, -1, -1)) ) if self.lt_embd is not None: selected_pos_embd.append( torch.reshape( self.lt_embd[lead_time_label.int()], ( x.shape[0], self.lead_time_channels, self.img_shape_y, self.img_shape_x, ), ) ) if len(selected_pos_embd) > 0: selected_pos_embd = torch.cat(selected_pos_embd, dim=1) else: selected_pos_embd = self.pos_embd[None].expand( (x.shape[0], -1, -1, -1) ) # (B, N_pe, H, W) else: P = global_index.shape[0] B = x.shape[0] // P H = global_index.shape[2] W = global_index.shape[3] global_index = torch.reshape( torch.permute(global_index, (1, 0, 2, 3)), (2, -1) ) # (P, 2, X, Y) to (2, P*X*Y) selected_pos_embd = self.pos_embd[ :, global_index[0], global_index[1] ] # (N_pe, P*X*Y) selected_pos_embd = torch.permute( torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)), (1, 0, 2, 3), ) # (P, N_pe, X, Y) selected_pos_embd = selected_pos_embd.repeat( B, 1, 1, 1 ) # (B*P, N_pe, X, Y) # Append positional and lead time embeddings to input conditioning if self.lead_time_mode: embeds = [] if self.pos_embd is not None: embeds.append(selected_pos_embd) # reuse code below if self.lt_embd is not None: lt_embds = self.lt_embd[ lead_time_label.int() ] # (B, self.lead_time_channels, self.img_shape_y, self.img_shape_x), selected_lt_pos_embd = lt_embds[ :, :, global_index[0], global_index[1] ] # (B, N_lt, P*X*Y) selected_lt_pos_embd = torch.reshape( torch.permute( torch.reshape( selected_lt_pos_embd, (B, self.lead_time_channels, P, H, W), ), (0, 2, 1, 3, 4), ).contiguous(), (B * P, self.lead_time_channels, H, W), ) # (B*P, N_pe, X, Y) embeds.append(selected_lt_pos_embd) if len(embeds) > 0: selected_pos_embd = torch.cat(embeds, dim=1) return selected_pos_embd
[docs] def positional_embedding_selector( self, x: torch.Tensor, embedding_selector: Callable[[torch.Tensor], torch.Tensor], ) -> torch.Tensor: """Select positional embeddings using a selector function. Similar to positional_embedding_indexing, but uses a selector function to select the embeddings. This method provides a more efficient way to select embeddings for batches of data. Typically used with patch-based processing, where the batch dimension contains multiple patches extracted from a larger image. Arguments --------- x : torch.Tensor Input tensor of shape (B, C, H, W) only used to determine dtype and device. embedding_selector : Callable Function that takes as input an embedding tensor of shape (N_pe, H_pe, W_pe) and returns selected embeddings with shape (batch_size, N_pe, H, W). Each selected embedding should correspond to the positional information of each batch element in x. For patch-based processing, typically this should be based on :meth:`physicsnemo.utils.patching.BasePatching2D.apply` method to maintain consistency with patch extraction. embeds : Optional[torch.Tensor] Optional tensor for combined positional and lead time embeddings tensor Returns ------- torch.Tensor Selected positional embeddings with shape (B, N_pe, H, W) where N_pe is the number of positional embedding channels. Example ------- >>> # Define a selector function with a patching utility: >>> from physicsnemo.utils.patching import GridPatching2D >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) >>> batch_size = 4 >>> def embedding_selector(emb): ... return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) >>> See Also -------- :meth:`physicsnemo.utils.patching.BasePatching2D.apply` For the base patching method typically used in embedding_selector. """ if x.dtype != self.pos_embd.dtype: self.pos_embd = self.pos_embd.to(x.dtype) return embedding_selector(self.pos_embd) # (B, N_pe, H, W)

def _get_positional_embedding(self): if self.N_grid_channels == 0: return None elif self.gridtype == "learnable": grid = torch.nn.Parameter( torch.randn(self.N_grid_channels, self.img_shape_y, self.img_shape_x) ) # (N_grid_channels, img_shape_y, img_shape_x) elif self.gridtype == "linear": if self.N_grid_channels != 2: raise ValueError("N_grid_channels must be set to 2 for gridtype linear") x = np.meshgrid(np.linspace(-1, 1, self.img_shape_y)) y = np.meshgrid(np.linspace(-1, 1, self.img_shape_x)) grid_x, grid_y = np.meshgrid(y, x) grid = torch.from_numpy( np.stack((grid_x, grid_y), axis=0) ) # (2, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "sinusoidal" and self.N_grid_channels == 4: # print('sinusuidal grid added ......') x1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_y))) x2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_y))) y1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_x))) y2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_x))) grid_x1, grid_y1 = np.meshgrid(y1, x1) grid_x2, grid_y2 = np.meshgrid(y2, x2) grid = torch.squeeze( torch.from_numpy( np.expand_dims( np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0), axis=0 ) ) ) # (4, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "sinusoidal" and self.N_grid_channels != 4: if self.N_grid_channels % 4 != 0: raise ValueError("N_grid_channels must be a factor of 4") num_freq = self.N_grid_channels // 4 freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) grid_list = [] grid_x, grid_y = np.meshgrid( np.linspace(0, 2 * np.pi, self.img_shape_x), np.linspace(0, 2 * np.pi, self.img_shape_y), ) for freq in freq_bands: for p_fn in [np.sin, np.cos]: grid_list.append(p_fn(grid_x * freq)) grid_list.append(p_fn(grid_y * freq)) grid = torch.from_numpy( np.stack(grid_list, axis=0) ) # (N_grid_channels, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "test" and self.N_grid_channels == 2: idx_x = torch.arange(self.img_shape_y) idx_y = torch.arange(self.img_shape_x) mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) grid = torch.stack((mesh_x, mesh_y), dim=0) # (2, img_shape_y, img_shape_x) else: raise ValueError("Gridtype not supported.") return grid def _get_lead_time_embedding(self): if (self.lead_time_steps is None) or (self.lead_time_channels is None): return None grid = torch.nn.Parameter( torch.randn( self.lead_time_steps, self.lead_time_channels, self.img_shape_y, self.img_shape_x, ) ) # (lead_time_steps, lead_time_channels, img_shape_y, img_shape_x) return grid

[docs]class SongUNetPosLtEmbd(SongUNetPosEmbd): """ This model is adapted from SongUNetPosEmbd, with the incorporation of lead-time aware embeddings. The lead-time embedding is activated by setting the `lead_time_channels` and `lead_time_steps` parameters. Like SongUNetPosEmbd, this model provides two methods for selecting positional embeddings: 1. Using a selector function (preferred method). See :meth:`positional_embedding_selector` for details. 2. Using global indices. See :meth:`positional_embedding_indexing` for details. Parameters ----------- img_resolution : Union[List[int], int] The resolution of the input/output image. Can be a single int for square images or a list [height, width] for rectangular images. in_channels : int Number of channels in the input image. out_channels : int Number of channels in the output image. label_dim : int, optional Number of class labels; 0 indicates an unconditional model. By default 0. augment_dim : int, optional Dimensionality of augmentation labels; 0 means no augmentation. By default 0. model_channels : int, optional Base multiplier for the number of channels across the network. By default 128. channel_mult : List[int], optional Per-resolution multipliers for the number of channels. By default [1,2,2,2,2]. channel_mult_emb : int, optional Multiplier for the dimensionality of the embedding vector. By default 4. num_blocks : int, optional Number of residual blocks per resolution. By default 4. attn_resolutions : List[int], optional Resolutions at which self-attention layers are applied. By default [28]. dropout : float, optional Dropout probability applied to intermediate activations. By default 0.13. label_dropout : float, optional Dropout probability of class labels for classifier-free guidance. By default 0.0. embedding_type : str, optional Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. By default 'positional'. channel_mult_noise : int, optional Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. encoder_type : str, optional Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++, 'skip' for skip connections. By default 'standard'. decoder_type : str, optional Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. resample_filter : List[int], optional Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. gridtype : str, optional Type of positional grid to use: 'sinusoidal', 'learnable', 'linear', or 'test'. Controls how positional information is encoded. By default 'sinusoidal'. N_grid_channels : int, optional Number of channels in the positional embedding grid. For 'sinusoidal' must be 4 or multiple of 4. For 'linear' must be 2. By default 4. lead_time_channels : int, optional Number of channels in the lead time embedding. These are learned embeddings that encode temporal forecast information. By default None. lead_time_steps : int, optional Number of discrete lead time steps to support. Each step gets its own learned embedding vector. By default 9. prob_channels : List[int], optional Indices of probability output channels that should use softmax activation. Used for classification outputs. By default empty list. checkpoint_level : int, optional Number of layers that should use gradient checkpointing (0 disables checkpointing). Higher values trade memory for computation. By default 0. additive_pos_embed : bool, optional If True, adds a learned positional embedding after the first convolution layer. Used in StormCast model. By default False. use_apex_gn : bool, optional A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. Need to set this as False on cpu. Defaults to False. act : str, optional The activation function to use when fusing activation with GroupNorm. Defaults to None. profile_mode: A boolean flag indicating whether to enable all nvtx annotations during profiling. amp_mode : bool, optional A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. Note ----- Equivalent to the original implementation by Song et al., available at https://github.com/yang-song/score_sde_pytorch Example -------- >>> import torch >>> from physicsnemo.models.diffusion.song_unet import SongUNetPosLtEmbd >>> from physicsnemo.utils.patching import GridPatching2D >>> >>> # Model initialization - in_channels must include original input channels (2), >>> # positional embedding channels (N_grid_channels=4 by default) and >>> # lead time embedding channels (4) >>> model = SongUNetPosLtEmbd( ... img_resolution=16, in_channels=2+4+4, out_channels=2, ... lead_time_channels=4, lead_time_steps=9 ... ) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) >>> # The input has only the original 2 channels - positional embeddings and >>> # lead time embeddings are added automatically inside the forward method >>> input_image = torch.ones([1, 2, 16, 16]) >>> lead_time_label = torch.tensor([3]) >>> output_image = model( ... input_image, noise_labels, class_labels, ... lead_time_label=lead_time_label ... ) >>> output_image.shape torch.Size([1, 2, 16, 16]) >>> >>> # Using global_index to select all the positional and lead time embeddings >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) >>> global_index = patching.global_index(batch_size=1) >>> output_image = model( ... input_image, noise_labels, class_labels, ... lead_time_label=lead_time_label, ... global_index=global_index ... ) >>> output_image.shape torch.Size([1, 2, 16, 16]) # NOTE: commented out doctest for embedding_selector due to compatibility issue # >>> # >>> # Using custom embedding selector to select all the positional and lead time embeddings # >>> def patch_embedding_selector(emb): # ... return patching.apply(emb[None].expand(1, -1, -1, -1)) # >>> output_image = model( # ... input_image, noise_labels, class_labels, # ... lead_time_label=lead_time_label, # ... embedding_selector=patch_embedding_selector # ... ) # >>> output_image.shape # torch.Size([1, 2, 16, 16]) """ def __init__( self, img_resolution: Union[List[int], int], in_channels: int, out_channels: int, label_dim: int = 0, augment_dim: int = 0, model_channels: int = 128, channel_mult: List[int] = [1, 2, 2, 2, 2], channel_mult_emb: int = 4, num_blocks: int = 4, attn_resolutions: List[int] = [28], dropout: float = 0.13, label_dropout: float = 0.0, embedding_type: str = "positional", channel_mult_noise: int = 1, encoder_type: str = "standard", decoder_type: str = "standard", resample_filter: List[int] = [1, 1], gridtype: str = "sinusoidal", N_grid_channels: int = 4, lead_time_channels: int = None, lead_time_steps: int = 9, prob_channels: List[int] = [], checkpoint_level: int = 0, additive_pos_embed: bool = False, use_apex_gn: bool = False, act: str = "silu", profile_mode: bool = False, amp_mode: bool = False, ): super().__init__( img_resolution, in_channels, out_channels, label_dim, augment_dim, model_channels, channel_mult, channel_mult_emb, num_blocks, attn_resolutions, dropout, label_dropout, embedding_type, channel_mult_noise, encoder_type, decoder_type, resample_filter, gridtype, N_grid_channels, checkpoint_level, additive_pos_embed, use_apex_gn, act, profile_mode, amp_mode, True, # Note: lead_time_mode=True is enforced here lead_time_channels, lead_time_steps, prob_channels, )
[docs] def forward( self, x, noise_labels, class_labels, lead_time_label=None, global_index: Optional[torch.Tensor] = None, embedding_selector: Optional[Callable] = None, augment_labels=None, ): return super().forward( x=x, noise_labels=noise_labels, class_labels=class_labels, global_index=global_index, embedding_selector=embedding_selector, augment_labels=augment_labels, lead_time_label=lead_time_label, )

# Nothing else is re-implemented, because everything is already in the parent SongUNetPosEmb

© Copyright 2023, NVIDIA PhysicsNeMo Team. Last updated on Jun 11, 2025.