nemo_automodel.components.models.diffusion_gemma.model
nemo_automodel.components.models.diffusion_gemma.model
NeMo Automodel support for diffusion_gemma (block diffusion).
Architecture (design v2 item 1) — ONE shared parameter stack run twice:
- Run the decoder layers once causally over the clean full sequence to
build a per-layer read-only KV cache (the “encoder” KV). The text encoder
is causal because
use_bidirectional_attention == "vision"(not"all"); a single causal pass over the clean full sequence reproduces the per-position KV that block-by-block inference builds. - Run the same layers once bidirectionally over the noised canvas (the
response region), each layer concatenating
[encoder_KV ; canvas_KV]on the key axis and using the block-causal training mask fromattention_mask.build_block_diffusion_training_mask.
A single shared stack (rather than tied-but-separate encoder/decoder modules)
keeps the model visible to AM’s MoE FSDP grad-sync (MoEFSDPSyncMixin /
_iter_fsdp_modules assume a single model.layers stack with
block.moe.experts) and avoids FSDP2 double-sharding tied storage. The
lm_head is tied to model.embed_tokens.
Self-conditioning (decoder-only, Analog-Bits two-pass) is encapsulated in the
training forward so the recipe still calls model(**batch) once.
Module Contents
Classes
Functions
Data
API
Bases: Module
Single shared Gemma MoE transformer stack run causally then bidirectionally.
Exposes layers (a ModuleDict keyed by string layer index),
embed_tokens, norm, self_conditioning and rotary_emb. The
layers / embed_tokens names are what MoEFSDPSyncMixin and the
FSDP2 sharding path key on.
Bidirectional pass over the noised canvas with cross-attention to the encoder KV cache. Returns the final (normed) hidden states.
self_conditioning_mask ([B] bool, training only) gates the self-cond
branch PER EXAMPLE: examples with False get a zeroed soft-embedding
(identical to the no-self-cond path), so a single always-on pass-1 can serve
Google’s per-example conditioned / zero-conditioned mix.
Causal pass over the clean full sequence -> per-layer (K, V) cache.
When return_hidden is True, also returns the final normed hidden
states [B, S, H] (so the caller can produce the encoder’s
autoregressive logits for the co-trained AR loss). Default False keeps
the KV-only contract used by inference and the parity/leakage tests.
Dispatch encode/decode through nn.Module.__call__ for FSDP hooks.
FSDP2 hooks are installed on module calls, not on arbitrary helper
methods. The block-diffusion top-level forward must therefore enter the
backbone via self.model(...) so root-owned parameters such as
self_conditioning and the final norm are gathered before use.
Bases: HFCheckpointingMixin, MoEFSDPSyncMixin, PreTrainedModel
Block-diffusion Gemma MoE model for SFT.
Inherits the AM checkpointing + MoE-FSDP machinery. The MoE backbone is
reused from gemma4_moe; the diffusion training forward and the two-pass
self-conditioning are new. See module docstring for the single-shared-stack
design.
forward is the SFT training forward. A generation/inference loop
(encode the prompt once, then iteratively denoise canvas blocks reusing the
KV cache, with the self-conditioning recycling loop) is deferred; the
model.encode / model.decode building blocks are the reusable pieces
for it, and forward already accepts an explicit self_conditioning_logits
for the per-step inference contract.
Training forward — single shared stack run twice + two-pass self-cond.
Parameters:
Clean full sequence (prompt + response), [B, S]. Run
causally to build the read-only encoder KV cache.
Noised response/canvas tokens, [B, canvas_len]. Run
bidirectionally with the block-causal mask.
If given (inference / external loop), used directly and the two-pass logic is skipped. During training the two-pass scheme generates the self-cond signal internally.
Position ids for the encoder pass ([B, S]).
Defaults to arange(S).
True at padded encoder positions ([B, S]).
Position ids for the canvas ([B, canvas_len]).
Must be the canvas tokens’ absolute positions so their query
RoPE aligns with the encoder key RoPE of the clean copies. In the
v1 full-sequence-canvas layout (canvas_len == S) this is
arange(S) (the default); a response-window canvas would use
prefix_length + arange(canvas_len) per example.
Dict {"full_attention", "sliding_attention"}
of additive block-causal masks (from
build_block_diffusion_training_mask). Required for training;
built by the recipe’s _forward_backward_step override.
True at padded canvas positions
([B, canvas_len]). Used to keep padded rows out of MoE routing.
Per-example self-conditioning coins, a [B]
bool tensor (a scalar bool is broadcast). During training pass-1
always runs (constant FSDP collectives every step -> no rank
desync, and correct for local_batch_size > 1); this mask gates,
per example, whether pass-2 consumes the self-cond signal (False
-> zeroed soft-embed, i.e. no self-cond). The recipe supplies it via
_decide_self_conditioning. Required during training (None
would drop Google’s per-example mix -> ValueError); ignored
outside training (eval / single pass).
Returns: 'DiffusionGemmaOutput'
DiffusionGemmaOutput with canvas-only logits
Freeze the MoE router/gate (design v2 item 9).
Sets train_gate=False and requires_grad=False on the gate’s
proj.weight and scale for every layer. Routing indices are
already non-differentiable; per_expert_scale is folded into the
(trainable) expert down_proj by the state-dict adapter, so the
experts stay trainable. MoEFSDPSyncMixin keys on
set_requires_gradient_sync, never requires_grad, so freezing
the gate does not break grad-sync.
Parallelism support for the DiffusionGemma block-diffusion MoE.
Single variant: FSDP2 + Expert Parallelism are supported (validated at EP=8). TP is unsupported for the custom MoE; CP/PP are not supported for this encoder-decoder block-diffusion path.
Initialize grouped-expert parameters (other params init via HF post_init).
dtype defaults to the model’s configured torch_dtype rather than a
hardcoded bfloat16. The meta/FSDP init path
(checkpoint.checkpointing.initialize_model_weights) calls this with no
dtype; a blanket self.to(torch.bfloat16) would materialize the whole model
in bf16 before the checkpoint loads, silently defeating the fp32 master weights
that model.torch_dtype: float32 configs request (leaving AdamW on bf16
params). Honor the requested dtype instead.
Training forward output.
logits are the canvas-only (response) denoising logits [B, canvas_len, V].
encoder_logits are the causal encoder’s next-token logits over the clean
full sequence [B, S, V] for the co-trained AR loss — None outside
training (and when the AR loss is unused).
Build an additive causal (optionally sliding-window) mask for the encoder.
Shape [B, 1, seq_len, seq_len]; 0 keep, finfo.min masked.
padding_mask is [B, seq_len] with True at padding positions.