nemo_automodel.components.models.diffusion_gemma.layers
nemo_automodel.components.models.diffusion_gemma.layers
Diffusion-specific layers for diffusion_gemma.
The stateless leaf layers (RMSNorm, the per-layer-type rotary embedding, the
dense SwiGLU MLP, the self-conditioning gated MLP, and the RoPE/GQA helpers) are
imported directly from the released transformers diffusion_gemma
implementation so the model tracks Google’s release. This module keeps only
the pieces the reference implementation cannot provide:
- :class:
DiffusionGemmaAttention— a single mask-driven attention used by both the causal (encoder) and bidirectional (decoder) passes of AM’s shared stack. Unlike the reference’s twoCache-coupled attention classes, it returns the freshly computed(K, V)as plain tensors and acceptsencoder_kvas plain tensors, so the backbone can thread KV between the two passes without a HFCacheobject.scaling = 1.0(per-head scale folded intoq_norm/k_norm); full-attention layers have nov_proj(values reuse the keys), sliding layers do. - :class:
DiffusionGemmaMoEDecoderLayer— composes the reference’s attention + norms + MLP with NeMo’sGemma4MoE(GroupedExperts+Gemma4Gate) instead of the reference’s dense-matmulDiffusionGemmaTextExperts, which does not shard under FSDP. The dense MLP and the MoE branch run in parallel and are summed, routing on the unnormalized post-attention residual — same asgemma4_moe.
Module Contents
Classes
Functions
Data
API
Bases: Module
Diffusion attention shared by the causal (encoder) and bidirectional (decoder) passes.
is_causal is informational only — the actual causal/bidirectional/
block-diagonal structure is provided by the additive attention_mask the
caller passes. When encoder_kv is supplied (the bidirectional canvas
pass), the layer concatenates [encoder_K ; canvas_K] on the key axis and
returns the freshly computed canvas K/V so the caller can build the encoder
KV cache during the causal pass.
Bases: Module
Single shared decoder layer used by both the causal and bidirectional passes.
Reuses NeMo’s Gemma4MoE (GroupedExperts + Gemma4Gate) for the MoE
branch; the dense MLP runs in parallel and the two are summed. layer_scalar
is a per-layer output scale (identity unless present in the checkpoint).
Build a NeMo :class:MoEConfig from the DiffusionGemma text config.
Matches gemma4_moe’s defaults: geglu experts, softmax routing,
train_gate=True (the recipe freezes the gate separately), no aux loss.
Eager scaled-dot-product attention with an additive 4-D mask.
The mask is expected to be additive (0 keep, -inf mask) and already
sliced to the layer’s key axis ([B, 1, Lq, Lkv]). No softcap is applied
to attention scores (Gemma4 only softcaps the final logits).