Diffusion Transformer (DiT)#

The Diffusion Transformer (DiT) is a Vision Transformer backbone for diffusion models. It operates on image patches via a patchify embedding, processes tokens with a sequence of transformer blocks conditioned through adaptive layer normalization (adaLN-Zero), and reconstructs the output via an unpatchify step.

DiT was introduced in Scalable Diffusion Models with Transformers, Peebles & Xie.

DiT#

class physicsnemo.models.dit.DiT(*args, **kwargs)[source]#

Bases: Module

The Diffusion Transformer (DiT) model.

Parameters:
  • input_size (Union[int, Tuple[int]]) – Spatial dimensions of the input. If an integer is provided, the input is assumed to be on a square 2D domain. If a tuple is provided, the input is assumed to be on a multi-dimensional domain.

  • in_channels (int) – The number of input channels.

  • patch_size (Union[int, Tuple[int]], optional, default=(8, 8)) – The size of each image patch. If an integer is provided, a square 2D patch is assumed. If a tuple is provided, a multi-dimensional patch is assumed.

  • tokenizer (Union[Literal["patch_embed_2d", "hpx_patch_embed"], Module], optional, default="patch_embed_2d") – The tokenizer to use. Either a string in {"patch_embed_2d", "hpx_patch_embed"} or an instantiated PhysicsNeMo Module implementing TokenizerModuleBase, with forward accepting input of shape \((B, C, *\text{spatial\_dims})\) and returning \((B, L, D)\).

  • detokenizer (Union[Literal["proj_reshape_2d", "hpx_patch_detokenizer"], Module], optional, default="proj_reshape_2d") – The detokenizer to use. Either a string in {"proj_reshape_2d", "hpx_patch_detokenizer"} or an instantiated PhysicsNeMo Module implementing DetokenizerModuleBase, with forward accepting \((B, L, D)\) and \((B, D)\) and returning \((B, C, *\text{spatial\_dims})\).

  • out_channels (Union[None, int], optional, default=None) – The number of output channels. If None, set to in_channels.

  • hidden_size (int, optional, default=384) – The dimensionality of the transformer embeddings.

  • depth (int, optional, default=12) – The number of transformer blocks.

  • num_heads (int, optional, default=8) – The number of attention heads.

  • mlp_ratio (float, optional, default=4.0) – The ratio of the MLP hidden dimension to the embedding dimension.

  • attention_backend (Literal["timm", "transformer_engine", "natten2d"], optional, default="timm") – The attention backend to use. See DiTBlock for a description of each built-in backend.

  • layernorm_backend (Literal["apex", "torch"], optional, default="torch") – If "apex", uses FusedLayerNorm from apex. If "torch", uses torch.nn.LayerNorm. Also passed to Natten2DSelfAttention when qk_norm=True.

  • condition_dim (int, optional, default=None) – Dimensionality of conditioning. If None, the model is unconditional.

  • dit_initialization (bool, optional, default=True) – If True, applies DiT-specific initialization.

  • conditioning_embedder (Literal["dit", "edm", "zero"] or ConditioningEmbedder, optional, default="dit") – The conditioning embedder type or an instantiated ConditioningEmbedder.

  • conditioning_embedder_kwargs (Dict[str, Any], optional, default={}) – Additional keyword arguments for the conditioning embedder.

  • tokenizer_kwargs (Dict[str, Any], optional, default={}) – Additional keyword arguments for the tokenizer module.

  • detokenizer_kwargs (Dict[str, Any], optional, default={}) – Additional keyword arguments for the detokenizer module.

  • block_kwargs (Dict[str, Any], optional, default={}) – Additional keyword arguments for the DiTBlock modules.

  • attn_kwargs (Dict[str, Any], optional, default={}) – Additional keyword arguments for the attention module constructor (e.g. na2d_kwargs when using attention_backend="natten2d").

  • drop_path_rates (list[float], optional, default=None) – DropPath (stochastic depth) rates, one per block. Must have length equal to depth. If None, no drop path is applied.

  • force_tokenization_fp32 (bool, optional, default=False) – If True, forces tokenization and de-tokenization to run in fp32.

Forward:
  • x (torch.Tensor) – Spatial inputs of shape \((N, C, *\text{spatial\_dims})\). spatial_dims is determined by input_size.

  • t (torch.Tensor) – Diffusion timesteps of shape \((N,)\).

  • condition (Optional[torch.Tensor]) – Conditions of shape \((N, d)\).

  • p_dropout (Optional[Union[float, torch.Tensor]], optional) – Dropout probability for the intermediate dropout (pre-attention) in each DiTBlock. If None, no dropout. If a scalar, same for all samples; if a tensor, shape \((B,)\) for per-sample dropout.

  • attn_kwargs (Dict[str, Any], optional) – Additional keyword arguments passed to the attention module’s forward method.

  • tokenizer_kwargs (Dict[str, Any], optional) – Additional keyword arguments passed to the tokenizer’s forward method.

Outputs:

torch.Tensor – Output tensor of shape \((N, \text{out\_channels}, *\text{spatial\_dims})\).

Notes

Reference: Peebles, W., & Xie, S. (2023). Scalable diffusion models with transformers. In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp. 4195-4205).

Examples

>>> model = DiT(
...     input_size=(32, 64),
...     patch_size=4,
...     in_channels=3,
...     out_channels=3,
...     condition_dim=8,
... )
>>> x = torch.randn(2, 3, 32, 64)
>>> t = torch.randint(0, 1000, (2,))
>>> condition = torch.randn(2, 8)
>>> output = model(x, t, condition)
>>> output.shape
torch.Size([2, 3, 32, 64])
initialize_weights()[source]#

Apply DiT-specific weight initialization.

Applies Xavier uniform to linear layers, then delegates to tokenizer, detokenizer, and each block’s initialize_weights.

Parameters:

None – Uses self (module state).

Returns:

Modifies module parameters in-place.

Return type:

None