nemo_automodel.components.models.diffusion_gemma.model

View as Markdown

NeMo Automodel support for diffusion_gemma (block diffusion).

Architecture (design v2 item 1) — ONE shared parameter stack run twice:

  • Run the decoder layers once causally over the clean full sequence to build a per-layer read-only KV cache (the “encoder” KV). The text encoder is causal because use_bidirectional_attention == "vision" (not "all"); a single causal pass over the clean full sequence reproduces the per-position KV that block-by-block inference builds.
  • Run the same layers once bidirectionally over the noised canvas (the response region), each layer concatenating [encoder_KV ; canvas_KV] on the key axis and using the block-causal training mask from attention_mask.build_block_diffusion_training_mask.

A single shared stack (rather than tied-but-separate encoder/decoder modules) keeps the model visible to AM’s MoE FSDP grad-sync (MoEFSDPSyncMixin / _iter_fsdp_modules assume a single model.layers stack with block.moe.experts) and avoids FSDP2 double-sharding tied storage. The lm_head is tied to model.embed_tokens.

Self-conditioning (decoder-only, Analog-Bits two-pass) is encapsulated in the training forward so the recipe still calls model(**batch) once.

Module Contents

Classes

NameDescription
DiffusionGemmaBackboneSingle shared Gemma MoE transformer stack run causally then bidirectionally.
DiffusionGemmaForBlockDiffusionBlock-diffusion Gemma MoE model for SFT.
DiffusionGemmaOutputTraining forward output.

Functions

NameDescription
_make_causal_additive_maskBuild an additive causal (optionally sliding-window) mask for the encoder.
_make_missing-

Data

ModelClass

_TRANSFORMERS_AVAILABLE

API

class nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone(
config: transformers.models.diffusion_gemma.configuration_diffusion_gemma.DiffusionGemmaTextConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None
)

Bases: Module

Single shared Gemma MoE transformer stack run causally then bidirectionally.

Exposes layers (a ModuleDict keyed by string layer index), embed_tokens, norm, self_conditioning and rotary_emb. The layers / embed_tokens names are what MoEFSDPSyncMixin and the FSDP2 sharding path key on.

embed_tokens
layer_types
= config.layer_types
layers
moe_config
= _build_moe_config(config, moe_config)
norm
padding_idx
= getattr(config, 'pad_token_id', None)
rotary_emb
= DiffusionGemmaTextRotaryEmbedding(config)
self_conditioning
= DiffusionGemmaSelfConditioning(config)
vocab_size
= config.vocab_size
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone._position_embeddings(
hidden_states: torch.Tensor,
position_ids: torch.Tensor
) -> dict
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone.decode(
canvas_ids: torch.Tensor,
encoder_kv: list[tuple[torch.Tensor, torch.Tensor]],
decoder_position_ids: torch.Tensor,
decoder_masks: dict,
decoder_padding_mask: torch.Tensor | None = None,
self_conditioning_logits: torch.Tensor | None = None,
self_conditioning_mask: torch.Tensor | None = None
) -> torch.Tensor

Bidirectional pass over the noised canvas with cross-attention to the encoder KV cache. Returns the final (normed) hidden states.

self_conditioning_mask ([B] bool, training only) gates the self-cond branch PER EXAMPLE: examples with False get a zeroed soft-embedding (identical to the no-self-cond path), so a single always-on pass-1 can serve Google’s per-example conditioned / zero-conditioned mix.

nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone.encode(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
padding_mask: torch.Tensor | None,
return_hidden: bool = False
)

Causal pass over the clean full sequence -> per-layer (K, V) cache.

When return_hidden is True, also returns the final normed hidden states [B, S, H] (so the caller can produce the encoder’s autoregressive logits for the co-trained AR loss). Default False keeps the KV-only contract used by inference and the parity/leakage tests.

nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone.forward(
mode: str,
input_ids: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
return_hidden: bool = False,
canvas_ids: torch.Tensor | None = None,
encoder_kv: list[tuple[torch.Tensor, torch.Tensor]] | None = None,
decoder_position_ids: torch.Tensor | None = None,
decoder_masks: dict | None = None,
decoder_padding_mask: torch.Tensor | None = None,
self_conditioning_logits: torch.Tensor | None = None,
self_conditioning_mask: torch.Tensor | None = None
) -> list[tuple[torch.Tensor, torch.Tensor]] | torch.Tensor

Dispatch encode/decode through nn.Module.__call__ for FSDP hooks.

FSDP2 hooks are installed on module calls, not on arbitrary helper methods. The block-diffusion top-level forward must therefore enter the backbone via self.model(...) so root-owned parameters such as self_conditioning and the final norm are gathered before use.

nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone.get_input_embeddings() -> torch.nn.Module
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone.set_input_embeddings(
value: torch.nn.Module
) -> None
class nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion(
config: transformers.models.diffusion_gemma.configuration_diffusion_gemma.DiffusionGemmaConfig,
moe_config: 'MoEConfig | None' = None,
backend: 'BackendConfig | None' = None,
canvas_length: int | None = None,
self_conditioning: bool | None = None,
freeze_router: bool | None = None,
kwargs: typing.Any = {}
)

Bases: HFCheckpointingMixin, MoEFSDPSyncMixin, PreTrainedModel

Block-diffusion Gemma MoE model for SFT.

Inherits the AM checkpointing + MoE-FSDP machinery. The MoE backbone is reused from gemma4_moe; the diffusion training forward and the two-pass self-conditioning are new. See module docstring for the single-shared-stack design.

forward is the SFT training forward. A generation/inference loop (encode the prompt once, then iteratively denoise canvas blocks reusing the KV cache, with the self-conditioning recycling loop) is deferred; the model.encode / model.decode building blocks are the reusable pieces for it, and forward already accepts an explicit self_conditioning_logits for the per-step inference contract.

_keep_in_fp32_modules
= ['rotary_emb']
_no_split_modules
= ['DiffusionGemmaMoEDecoderLayer']
_tied_weights_keys
= ['lm_head.weight']
backend
= backend or BackendConfig()
base_model_prefix
= 'model'
canvas_length
= int(getattr(config, 'canvas_length', 256))
final_logit_softcapping
= text_config.final_logit_softcapping
freeze_router
lm_head
model
moe_config
= self.model.moe_config
self_conditioning
state_dict_adapter
vocab_size
= text_config.vocab_size
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion._softcap_logits(
hidden_states: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.forward(
input_ids: torch.Tensor | None = None,
canvas_ids: torch.Tensor | None = None,
self_conditioning_logits: torch.Tensor | None = None,
encoder_position_ids: torch.Tensor | None = None,
encoder_padding_mask: torch.Tensor | None = None,
decoder_position_ids: torch.Tensor | None = None,
decoder_attention_mask: dict | None = None,
decoder_padding_mask: torch.Tensor | None = None,
do_self_conditioning: torch.Tensor | bool | None = None,
kwargs: typing.Any = {}
) -> 'DiffusionGemmaOutput'

Training forward — single shared stack run twice + two-pass self-cond.

Parameters:

input_ids
torch.Tensor | NoneDefaults to None

Clean full sequence (prompt + response), [B, S]. Run causally to build the read-only encoder KV cache.

canvas_ids
torch.Tensor | NoneDefaults to None

Noised response/canvas tokens, [B, canvas_len]. Run bidirectionally with the block-causal mask.

self_conditioning_logits
torch.Tensor | NoneDefaults to None

If given (inference / external loop), used directly and the two-pass logic is skipped. During training the two-pass scheme generates the self-cond signal internally.

encoder_position_ids
torch.Tensor | NoneDefaults to None

Position ids for the encoder pass ([B, S]). Defaults to arange(S).

encoder_padding_mask
torch.Tensor | NoneDefaults to None

True at padded encoder positions ([B, S]).

decoder_position_ids
torch.Tensor | NoneDefaults to None

Position ids for the canvas ([B, canvas_len]). Must be the canvas tokens’ absolute positions so their query RoPE aligns with the encoder key RoPE of the clean copies. In the v1 full-sequence-canvas layout (canvas_len == S) this is arange(S) (the default); a response-window canvas would use prefix_length + arange(canvas_len) per example.

decoder_attention_mask
dict | NoneDefaults to None

Dict {"full_attention", "sliding_attention"} of additive block-causal masks (from build_block_diffusion_training_mask). Required for training; built by the recipe’s _forward_backward_step override.

decoder_padding_mask
torch.Tensor | NoneDefaults to None

True at padded canvas positions ([B, canvas_len]). Used to keep padded rows out of MoE routing.

do_self_conditioning
torch.Tensor | bool | NoneDefaults to None

Per-example self-conditioning coins, a [B] bool tensor (a scalar bool is broadcast). During training pass-1 always runs (constant FSDP collectives every step -> no rank desync, and correct for local_batch_size > 1); this mask gates, per example, whether pass-2 consumes the self-cond signal (False -> zeroed soft-embed, i.e. no self-cond). The recipe supplies it via _decide_self_conditioning. Required during training (None would drop Google’s per-example mix -> ValueError); ignored outside training (eval / single pass).

Returns: 'DiffusionGemmaOutput'

DiffusionGemmaOutput with canvas-only logits

nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.freeze_router_params() -> None

Freeze the MoE router/gate (design v2 item 9).

Sets train_gate=False and requires_grad=False on the gate’s proj.weight and scale for every layer. Routing indices are already non-differentiable; per_expert_scale is folded into the (trainable) expert down_proj by the state-dict adapter, so the experts stay trainable. MoEFSDPSyncMixin keys on set_requires_gradient_sync, never requires_grad, so freezing the gate does not break grad-sync.

nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.from_config(
config: transformers.models.diffusion_gemma.configuration_diffusion_gemma.DiffusionGemmaConfig,
moe_config: 'MoEConfig | None' = None,
backend: 'BackendConfig | None' = None,
kwargs: typing.Any = {}
) -> 'DiffusionGemmaForBlockDiffusion'
classmethod
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.get_capabilities(
config: 'DiffusionGemmaConfig'
) -> 'ModelCapabilities'
classmethod

Parallelism support for the DiffusionGemma block-diffusion MoE.

Single variant: FSDP2 + Expert Parallelism are supported (validated at EP=8). TP is unsupported for the custom MoE; CP/PP are not supported for this encoder-decoder block-diffusion path.

nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.get_input_embeddings() -> torch.nn.Module
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.get_output_embeddings() -> torch.nn.Module
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.initialize_weights(
buffer_device: torch.device | None = None,
dtype: torch.dtype | None = None
) -> None

Initialize grouped-expert parameters (other params init via HF post_init).

dtype defaults to the model’s configured torch_dtype rather than a hardcoded bfloat16. The meta/FSDP init path (checkpoint.checkpointing.initialize_model_weights) calls this with no dtype; a blanket self.to(torch.bfloat16) would materialize the whole model in bf16 before the checkpoint loads, silently defeating the fp32 master weights that model.torch_dtype: float32 configs request (leaving AdamW on bf16 params). Honor the requested dtype instead.

nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.set_input_embeddings(
value: torch.nn.Module
) -> None
class nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaOutput(
logits: 'torch.Tensor',
encoder_logits: 'torch.Tensor | None' = None
)
Dataclass

Training forward output.

logits are the canvas-only (response) denoising logits [B, canvas_len, V]. encoder_logits are the causal encoder’s next-token logits over the clean full sequence [B, S, V] for the co-trained AR loss — None outside training (and when the AR loss is unused).

encoder_logits
'torch.Tensor | None' = None
logits
'torch.Tensor'
nemo_automodel.components.models.diffusion_gemma.model._make_causal_additive_mask(
seq_len: int,
padding_mask: torch.Tensor | None,
sliding_window: int | None,
batch_size: int,
device: torch.device,
dtype: torch.dtype
) -> torch.Tensor

Build an additive causal (optionally sliding-window) mask for the encoder.

Shape [B, 1, seq_len, seq_len]; 0 keep, finfo.min masked. padding_mask is [B, seq_len] with True at padding positions.

nemo_automodel.components.models.diffusion_gemma.model._make_missing(
name: str
)
nemo_automodel.components.models.diffusion_gemma.model.ModelClass = DiffusionGemmaForBlockDiffusion
nemo_automodel.components.models.diffusion_gemma.model._TRANSFORMERS_AVAILABLE = True