nemo_automodel.components.models.diffusion_gemma.layers

View as Markdown

Diffusion-specific layers for diffusion_gemma.

The stateless leaf layers (RMSNorm, the per-layer-type rotary embedding, the dense SwiGLU MLP, the self-conditioning gated MLP, and the RoPE/GQA helpers) are imported directly from the released transformers diffusion_gemma implementation so the model tracks Google’s release. This module keeps only the pieces the reference implementation cannot provide:

  • :class:DiffusionGemmaAttention — a single mask-driven attention used by both the causal (encoder) and bidirectional (decoder) passes of AM’s shared stack. Unlike the reference’s two Cache-coupled attention classes, it returns the freshly computed (K, V) as plain tensors and accepts encoder_kv as plain tensors, so the backbone can thread KV between the two passes without a HF Cache object. scaling = 1.0 (per-head scale folded into q_norm/k_norm); full-attention layers have no v_proj (values reuse the keys), sliding layers do.
  • :class:DiffusionGemmaMoEDecoderLayer — composes the reference’s attention + norms + MLP with NeMo’s Gemma4MoE (GroupedExperts + Gemma4Gate) instead of the reference’s dense-matmul DiffusionGemmaTextExperts, which does not shard under FSDP. The dense MLP and the MoE branch run in parallel and are summed, routing on the unnormalized post-attention residual — same as gemma4_moe.

Module Contents

Classes

NameDescription
DiffusionGemmaAttentionDiffusion attention shared by the causal (encoder) and bidirectional
DiffusionGemmaMoEDecoderLayerSingle shared decoder layer used by both the causal and bidirectional passes.

Functions

NameDescription
_build_moe_configBuild a NeMo :class:MoEConfig from the DiffusionGemma text config.
_make_missing-
eager_attention_forwardEager scaled-dot-product attention with an additive 4-D mask.

Data

_FORK_AVAILABLE

API

class nemo_automodel.components.models.diffusion_gemma.layers.DiffusionGemmaAttention(
config: transformers.models.diffusion_gemma.configuration_diffusion_gemma.DiffusionGemmaTextConfig,
layer_idx: int
)

Bases: Module

Diffusion attention shared by the causal (encoder) and bidirectional (decoder) passes.

is_causal is informational only — the actual causal/bidirectional/ block-diagonal structure is provided by the additive attention_mask the caller passes. When encoder_kv is supplied (the bidirectional canvas pass), the layer concatenates [encoder_K ; canvas_K] on the key axis and returns the freshly computed canvas K/V so the caller can build the encoder KV cache during the causal pass.

attention_dropout
= config.attention_dropout
head_dim
is_sliding
= self.layer_type == 'sliding_attention'
k_norm
k_proj
layer_type
= config.layer_types[layer_idx]
num_attention_heads
= config.num_attention_heads
num_key_value_groups
= config.num_attention_heads // num_key_value_heads
o_proj
q_norm
q_proj
scaling
= 1.0
sliding_window
= config.sliding_window if self.is_sliding else None
v_norm
v_proj
nemo_automodel.components.models.diffusion_gemma.layers.DiffusionGemmaAttention.forward(
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None,
encoder_kv: tuple[torch.Tensor, torch.Tensor] | None = None,
padding_mask: torch.Tensor | None = None
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]
class nemo_automodel.components.models.diffusion_gemma.layers.DiffusionGemmaMoEDecoderLayer(
config: transformers.models.diffusion_gemma.configuration_diffusion_gemma.DiffusionGemmaTextConfig,
layer_idx: int,
moe_config: nemo_automodel.components.moe.layers.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

Single shared decoder layer used by both the causal and bidirectional passes.

Reuses NeMo’s Gemma4MoE (GroupedExperts + Gemma4Gate) for the MoE branch; the dense MLP runs in parallel and the two are summed. layer_scalar is a per-layer output scale (identity unless present in the checkpoint).

attention_type
= config.layer_types[layer_idx]
hidden_size
= config.hidden_size
input_layernorm
mlp
= DiffusionGemmaMLP(config, layer_idx)
moe
= Gemma4MoE(moe_config, backend, config)
post_attention_layernorm
post_feedforward_layernorm
post_feedforward_layernorm_1
post_feedforward_layernorm_2
pre_feedforward_layernorm
pre_feedforward_layernorm_2
self_attn
nemo_automodel.components.models.diffusion_gemma.layers.DiffusionGemmaMoEDecoderLayer.forward(
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None,
encoder_kv: tuple[torch.Tensor, torch.Tensor] | None = None,
padding_mask: torch.Tensor | None = None
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]
nemo_automodel.components.models.diffusion_gemma.layers._build_moe_config(
config: transformers.models.diffusion_gemma.configuration_diffusion_gemma.DiffusionGemmaTextConfig,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None
) -> nemo_automodel.components.moe.layers.MoEConfig

Build a NeMo :class:MoEConfig from the DiffusionGemma text config.

Matches gemma4_moe’s defaults: geglu experts, softmax routing, train_gate=True (the recipe freezes the gate separately), no aux loss.

nemo_automodel.components.models.diffusion_gemma.layers._make_missing(
name: str
)
nemo_automodel.components.models.diffusion_gemma.layers.eager_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0
) -> torch.Tensor

Eager scaled-dot-product attention with an additive 4-D mask.

The mask is expected to be additive (0 keep, -inf mask) and already sliced to the layer’s key axis ([B, 1, Lq, Lkv]). No softcap is applied to attention scores (Gemma4 only softcaps the final logits).

nemo_automodel.components.models.diffusion_gemma.layers._FORK_AVAILABLE = True