Source code for physicsnemo.models.dit.dit

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional, Tuple, Union

import torch
import torch.nn as nn
from jaxtyping import Float

from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module
from physicsnemo.nn import (
    ConditioningEmbedder,
    ConditioningEmbedderType,
    DetokenizerModuleBase,
    DiTBlock,
    TokenizerModuleBase,
    get_conditioning_embedder,
    get_detokenizer,
    get_tokenizer,
)


@dataclass
class MetaData(ModelMetaData):
    # Optimization
    jit: bool = False
    cuda_graphs: bool = False
    amp_cpu: bool = False
    amp_gpu: bool = True
    torch_fx: bool = False
    # Data type
    bf16: bool = True
    # Inference
    onnx: bool = False
    # Physics informed
    func_torch: bool = False
    auto_grad: bool = False


[docs] class DiT(Module): r""" The Diffusion Transformer (DiT) model. Parameters ---------- input_size : Union[int, Tuple[int]] Spatial dimensions of the input. If an integer is provided, the input is assumed to be on a square 2D domain. If a tuple is provided, the input is assumed to be on a multi-dimensional domain. in_channels : int The number of input channels. patch_size : Union[int, Tuple[int]], optional, default=(8, 8) The size of each image patch. If an integer is provided, a square 2D patch is assumed. If a tuple is provided, a multi-dimensional patch is assumed. tokenizer : Union[Literal["patch_embed_2d", "hpx_patch_embed"], Module], optional, default="patch_embed_2d" The tokenizer to use. Either a string in ``{"patch_embed_2d", "hpx_patch_embed"}`` or an instantiated PhysicsNeMo :class:`~physicsnemo.core.Module` implementing :class:`~physicsnemo.nn.TokenizerModuleBase`, with forward accepting input of shape :math:`(B, C, *\text{spatial\_dims})` and returning :math:`(B, L, D)`. detokenizer : Union[Literal["proj_reshape_2d", "hpx_patch_detokenizer"], Module], optional, default="proj_reshape_2d" The detokenizer to use. Either a string in ``{"proj_reshape_2d", "hpx_patch_detokenizer"}`` or an instantiated PhysicsNeMo :class:`~physicsnemo.core.Module` implementing :class:`~physicsnemo.nn.DetokenizerModuleBase`, with forward accepting :math:`(B, L, D)` and :math:`(B, D)` and returning :math:`(B, C, *\text{spatial\_dims})`. out_channels : Union[None, int], optional, default=None The number of output channels. If ``None``, set to ``in_channels``. hidden_size : int, optional, default=384 The dimensionality of the transformer embeddings. depth : int, optional, default=12 The number of transformer blocks. num_heads : int, optional, default=8 The number of attention heads. mlp_ratio : float, optional, default=4.0 The ratio of the MLP hidden dimension to the embedding dimension. attention_backend : Literal["timm", "transformer_engine", "natten2d"], optional, default="timm" The attention backend to use. See :class:`~physicsnemo.nn.DiTBlock` for a description of each built-in backend. layernorm_backend : Literal["apex", "torch"], optional, default="torch" If ``"apex"``, uses FusedLayerNorm from apex. If ``"torch"``, uses :class:`torch.nn.LayerNorm`. Also passed to :class:`~physicsnemo.nn.Natten2DSelfAttention` when ``qk_norm=True``. condition_dim : int, optional, default=None Dimensionality of conditioning. If ``None``, the model is unconditional. dit_initialization : bool, optional, default=True If ``True``, applies DiT-specific initialization. conditioning_embedder : Literal["dit", "edm", "zero"] or ConditioningEmbedder, optional, default="dit" The conditioning embedder type or an instantiated :class:`~physicsnemo.nn.ConditioningEmbedder`. conditioning_embedder_kwargs : Dict[str, Any], optional, default={} Additional keyword arguments for the conditioning embedder. tokenizer_kwargs : Dict[str, Any], optional, default={} Additional keyword arguments for the tokenizer module. detokenizer_kwargs : Dict[str, Any], optional, default={} Additional keyword arguments for the detokenizer module. block_kwargs : Dict[str, Any], optional, default={} Additional keyword arguments for the DiTBlock modules. attn_kwargs : Dict[str, Any], optional, default={} Additional keyword arguments for the attention module constructor (e.g. ``na2d_kwargs`` when using ``attention_backend="natten2d"``). drop_path_rates : list[float], optional, default=None DropPath (stochastic depth) rates, one per block. Must have length equal to ``depth``. If ``None``, no drop path is applied. force_tokenization_fp32 : bool, optional, default=False If ``True``, forces tokenization and de-tokenization to run in fp32. Forward ------- x : torch.Tensor Spatial inputs of shape :math:`(N, C, *\text{spatial\_dims})`. ``spatial_dims`` is determined by ``input_size``. t : torch.Tensor Diffusion timesteps of shape :math:`(N,)`. condition : Optional[torch.Tensor] Conditions of shape :math:`(N, d)`. p_dropout : Optional[Union[float, torch.Tensor]], optional Dropout probability for the intermediate dropout (pre-attention) in each DiTBlock. If ``None``, no dropout. If a scalar, same for all samples; if a tensor, shape :math:`(B,)` for per-sample dropout. attn_kwargs : Dict[str, Any], optional Additional keyword arguments passed to the attention module's forward method. tokenizer_kwargs : Dict[str, Any], optional Additional keyword arguments passed to the tokenizer's forward method. Outputs ------- torch.Tensor Output tensor of shape :math:`(N, \text{out\_channels}, *\text{spatial\_dims})`. Notes ----- Reference: Peebles, W., & Xie, S. (2023). Scalable diffusion models with transformers. In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp. 4195-4205). Examples -------- >>> model = DiT( ... input_size=(32, 64), ... patch_size=4, ... in_channels=3, ... out_channels=3, ... condition_dim=8, ... ) >>> x = torch.randn(2, 3, 32, 64) >>> t = torch.randint(0, 1000, (2,)) >>> condition = torch.randn(2, 8) >>> output = model(x, t, condition) >>> output.shape torch.Size([2, 3, 32, 64]) """ __model_checkpoint_version__ = "0.2.0" __supported_model_checkpoint_version__ = { "0.1.0": "Automatically converting legacy DiT checkpoint timestep / conditioning embedder arguments.", } @classmethod def _backward_compat_arg_mapper( cls, version: str, args: Dict[str, Any] ) -> Dict[str, Any]: r""" Map arguments from legacy checkpoints to the current format. Parameters ---------- version : str Version of the checkpoint being loaded. args : Dict[str, Any] Arguments dictionary from the checkpoint. Returns ------- Dict[str, Any] Updated arguments dictionary compatible with the current version. """ args = super()._backward_compat_arg_mapper(version, args) if version != "0.1.0": return args if "timestep_embed_kwargs" in args: args["conditioning_embedder_kwargs"] = args.pop("timestep_embed_kwargs") return args def __init__( self, input_size: Union[int, Tuple[int]], in_channels: int, patch_size: Union[int, Tuple[int]] = (8, 8), tokenizer: Union[ Literal["patch_embed_2d", "hpx_patch_embed"], Module ] = "patch_embed_2d", detokenizer: Union[ Literal["proj_reshape_2d", "hpx_patch_detokenizer"], Module ] = "proj_reshape_2d", out_channels: Optional[int] = None, hidden_size: int = 384, depth: int = 12, num_heads: int = 8, mlp_ratio: float = 4.0, attention_backend: Literal["timm", "transformer_engine", "natten2d"] = "timm", layernorm_backend: Literal["apex", "torch"] = "torch", condition_dim: Optional[int] = None, conditioning_embedder: Literal["dit", "edm", "zero"] | ConditioningEmbedder = "dit", dit_initialization: Optional[int] = True, conditioning_embedder_kwargs: Dict[str, Any] = {}, tokenizer_kwargs: Dict[str, Any] = {}, detokenizer_kwargs: Dict[str, Any] = {}, block_kwargs: Dict[str, Any] = {}, attn_kwargs: Dict[str, Any] = {}, drop_path_rates: list[float] | None = None, force_tokenization_fp32: bool = False, ): super().__init__(meta=MetaData()) self.input_size = ( input_size if isinstance(input_size, (tuple, list)) else (input_size, input_size) ) self.in_channels = in_channels if out_channels: self.out_channels = out_channels else: self.out_channels = in_channels self.patch_size = ( patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size) ) self.num_heads = num_heads self.condition_dim = condition_dim if attention_backend == "natten2d": latent_hw = ( self.input_size[0] // self.patch_size[0], self.input_size[1] // self.patch_size[1], ) self.attn_kwargs_forward = {"latent_hw": latent_hw} else: self.attn_kwargs_forward = {} # Input validation if attention_backend not in ["timm", "transformer_engine", "natten2d"]: raise ValueError( "attention_backend must be one of 'timm', 'transformer_engine', 'natten2d'" ) if layernorm_backend not in ["apex", "torch"]: raise ValueError("layernorm_backend must be one of 'apex', 'torch'") if isinstance(tokenizer, str) and tokenizer not in [ "patch_embed_2d", "hpx_patch_embed", ]: raise ValueError("tokenizer must be 'patch_embed_2d' or 'hpx_patch_embed'") if isinstance(detokenizer, str) and detokenizer not in [ "proj_reshape_2d", "hpx_patch_detokenizer", ]: raise ValueError( "detokenizer must be 'proj_reshape_2d' or 'hpx_patch_detokenizer'" ) # Tokenizer module: accept string or pre-instantiated PhysicsNeMo Module if isinstance(tokenizer, str): self.tokenizer = get_tokenizer( input_size=self.input_size, patch_size=self.patch_size, in_channels=in_channels, hidden_size=hidden_size, tokenizer=tokenizer, **tokenizer_kwargs, ) else: if not isinstance(tokenizer, TokenizerModuleBase): raise TypeError( "tokenizer must be a string or a physicsnemo.core.Module instance subclassing physicsnemo.nn.TokenizerModuleBase" ) self.tokenizer = tokenizer # Conditioning embedder: accept enum or pre-instantiated Module if isinstance(conditioning_embedder, str): self.conditioning_embedder = get_conditioning_embedder( ConditioningEmbedderType[conditioning_embedder.upper()], hidden_size=hidden_size, condition_dim=condition_dim or 0, amp_mode=self.meta.amp_gpu, **conditioning_embedder_kwargs, ) else: if not isinstance(conditioning_embedder, ConditioningEmbedder): raise TypeError( "conditioning_embedder must be a ConditioningEmbedderType or a Module implementing the ConditioningEmbedder protocol" ) self.conditioning_embedder = conditioning_embedder # Detokenizer module: accept string or pre-instantiated PhysicsNeMo Module if isinstance(detokenizer, str): self.detokenizer = get_detokenizer( input_size=self.input_size, patch_size=self.patch_size, out_channels=self.out_channels, hidden_size=hidden_size, layernorm_backend=layernorm_backend, detokenizer=detokenizer, **detokenizer_kwargs, ) else: if not isinstance(detokenizer, DetokenizerModuleBase): raise TypeError( "detokenizer must be a string or a physicsnemo.core.Module instance subclassing physicsnemo.nn.DetokenizerModuleBase" ) self.detokenizer = detokenizer # Validate drop_path_rates if drop_path_rates is None: drop_path_rates = [0.0] * depth else: if len(drop_path_rates) != depth: raise ValueError( f"drop_path_rates length ({len(drop_path_rates)}) must match DiT depth ({depth})" ) self.blocks = nn.ModuleList( [ DiTBlock( hidden_size, num_heads, attention_backend=attention_backend, layernorm_backend=layernorm_backend, mlp_ratio=mlp_ratio, drop_path=drop_path_rates[i], condition_embed_dim=self.conditioning_embedder.output_dim, **block_kwargs, **attn_kwargs, ) for i in range(depth) ] ) if dit_initialization: self.initialize_weights() self.force_tokenization_fp32 = force_tokenization_fp32 self.register_load_state_dict_pre_hook(self._migrate_legacy_checkpoint) @staticmethod def _migrate_legacy_checkpoint( module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): r"""Remap legacy state_dict keys where timestep embedder was at root. Previous versions stored the timestep embedder at root (e.g. ``t_embedder.mlp.0.weight``). The current model nests it under ``conditioning_embedder`` (e.g. ``conditioning_embedder.t_embedder.mlp.0.weight``). This pre-hook rewrites those keys in-place so loading succeeds. It also drops the positional embedding ``freqs`` key, which is not part of the state_dict anymore due to the usage of ``persistent=False``. Parameters ---------- module : torch.nn.Module The module being loaded (unused; required by ``register_load_state_dict_pre_hook``). state_dict : dict State dict being loaded; modified in-place. prefix : str Prefix for the module (unused). local_metadata : dict, optional Local metadata (unused). strict : bool Whether strict loading is requested (unused). missing_keys : list of str List of missing keys (unused). unexpected_keys : list of str List of unexpected keys (unused). error_msgs : list of str Error messages (unused). Returns ------- None Modifies ``state_dict`` in-place; no return value. """ legacy_prefix = "t_embedder." new_prefix = "conditioning_embedder.t_embedder." # Iterate over a snapshot of keys to avoid mutating dict while iterating for old_key in list(state_dict.keys()): if not old_key.startswith(legacy_prefix): continue new_key = new_prefix + old_key[len(legacy_prefix) :] if old_key == legacy_prefix + "freqs": del state_dict[old_key] elif new_key not in state_dict: state_dict[new_key] = state_dict.pop(old_key)
[docs] def initialize_weights(self): r"""Apply DiT-specific weight initialization. Applies Xavier uniform to linear layers, then delegates to tokenizer, detokenizer, and each block's ``initialize_weights``. Parameters ---------- None Uses ``self`` (module state). Returns ------- None Modifies module parameters in-place. """ # Apply a basic Xavier uniform initialization to all linear layers. def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Delegate custom weight initialization to the tokenizer, detokenizer, and blocks self.tokenizer.initialize_weights() self.detokenizer.initialize_weights() for block in self.blocks: block.initialize_weights()
def forward( self, x: Float[torch.Tensor, "batch in_channels *spatial_dims"], t: Float[torch.Tensor, " batch"], condition: Optional[Float[torch.Tensor, "batch condition_dim"]] = None, p_dropout: Optional[float | Float[torch.Tensor, " batch"]] = None, attn_kwargs: Dict[str, Any] = {}, tokenizer_kwargs: Dict[str, Any] = {}, ) -> Float[torch.Tensor, "batch out_channels *spatial_dims"]: # Tokenize: (B, C, H, W) -> (B, L, D) if self.force_tokenization_fp32: dtype = x.dtype x = x.to(torch.float32) with torch.autocast(device_type="cuda", enabled=False): x = self.tokenizer(x, **tokenizer_kwargs) x = x.to(dtype) else: x = self.tokenizer(x, **tokenizer_kwargs) # Compute conditioning embedding c = self.conditioning_embedder(t, condition=condition) # (B, D) for block in self.blocks: x = block( x, c, p_dropout=p_dropout, attn_kwargs={**self.attn_kwargs_forward, **attn_kwargs}, ) # (B, L, D) # De-tokenize: (B, L, D) -> (B, C, H, W) if self.force_tokenization_fp32: dtype = x.dtype x = x.to(torch.float32) with torch.autocast(device_type="cuda", enabled=False): x = self.detokenizer(x, c) x = x.to(dtype) else: x = self.detokenizer(x, c) return x