> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.models.diffusion_gemma.model

NeMo Automodel support for `diffusion_gemma` (block diffusion).

Architecture (design v2 item 1) — ONE shared parameter stack run twice:

* Run the decoder layers once **causally** over the clean full sequence to
  build a per-layer read-only KV cache (the "encoder" KV). The text encoder
  is causal because `use_bidirectional_attention == "vision"` (not
  `"all"`); a single causal pass over the clean full sequence reproduces
  the per-position KV that block-by-block inference builds.
* Run the same layers once **bidirectionally** over the noised canvas (the
  response region), each layer concatenating `[encoder_KV ; canvas_KV]` on
  the key axis and using the block-causal training mask from
  `attention_mask.build_block_diffusion_training_mask`.

A single shared stack (rather than tied-but-separate encoder/decoder modules)
keeps the model visible to AM's MoE FSDP grad-sync (`MoEFSDPSyncMixin` /
`_iter_fsdp_modules` assume a single `model.layers` stack with
`block.moe.experts`) and avoids FSDP2 double-sharding tied storage. The
`lm_head` is tied to `model.embed_tokens`.

Self-conditioning (decoder-only, Analog-Bits two-pass) is encapsulated in the
training forward so the recipe still calls `model(**batch)` once.

## Module Contents

### Classes

| Name                                                                                                                         | Description                                                                  |
| ---------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------- |
| [`DiffusionGemmaBackbone`](#nemo_automodel-components-models-diffusion_gemma-model-DiffusionGemmaBackbone)                   | Single shared Gemma MoE transformer stack run causally then bidirectionally. |
| [`DiffusionGemmaForBlockDiffusion`](#nemo_automodel-components-models-diffusion_gemma-model-DiffusionGemmaForBlockDiffusion) | Block-diffusion Gemma MoE model for SFT.                                     |
| [`DiffusionGemmaOutput`](#nemo_automodel-components-models-diffusion_gemma-model-DiffusionGemmaOutput)                       | Training forward output.                                                     |

### Functions

| Name                                                                                                               | Description                                                                |
| ------------------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------- |
| [`_make_causal_additive_mask`](#nemo_automodel-components-models-diffusion_gemma-model-_make_causal_additive_mask) | Build an additive causal (optionally sliding-window) mask for the encoder. |
| [`_make_missing`](#nemo_automodel-components-models-diffusion_gemma-model-_make_missing)                           | -                                                                          |

### Data

[`ModelClass`](#nemo_automodel-components-models-diffusion_gemma-model-ModelClass)

[`_TRANSFORMERS_AVAILABLE`](#nemo_automodel-components-models-diffusion_gemma-model-_TRANSFORMERS_AVAILABLE)

### API

```python
class nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone(
    config: transformers.models.diffusion_gemma.configuration_diffusion_gemma.DiffusionGemmaTextConfig,
    backend: nemo_automodel.components.models.common.BackendConfig,
    moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None
)
```

**Bases:** `Module`

Single shared Gemma MoE transformer stack run causally then bidirectionally.

Exposes `layers` (a `ModuleDict` keyed by string layer index),
`embed_tokens`, `norm`, `self_conditioning` and `rotary_emb`. The
`layers` / `embed_tokens` names are what `MoEFSDPSyncMixin` and the
FSDP2 sharding path key on.

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone._position_embeddings(
    hidden_states: torch.Tensor,
    position_ids: torch.Tensor
) -> dict
```

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone.decode(
    canvas_ids: torch.Tensor,
    encoder_kv: list[tuple[torch.Tensor, torch.Tensor]],
    decoder_position_ids: torch.Tensor,
    decoder_masks: dict,
    decoder_padding_mask: torch.Tensor | None = None,
    self_conditioning_logits: torch.Tensor | None = None,
    self_conditioning_mask: torch.Tensor | None = None
) -> torch.Tensor
```

Bidirectional pass over the noised canvas with cross-attention to the
encoder KV cache. Returns the final (normed) hidden states.

`self_conditioning_mask` (`[B]` bool, training only) gates the self-cond
branch PER EXAMPLE: examples with `False` get a zeroed soft-embedding
(identical to the no-self-cond path), so a single always-on pass-1 can serve
Google's per-example conditioned / zero-conditioned mix.

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone.encode(
    input_ids: torch.Tensor,
    position_ids: torch.Tensor,
    padding_mask: torch.Tensor | None,
    return_hidden: bool = False
)
```

Causal pass over the clean full sequence -> per-layer (K, V) cache.

When `return_hidden` is True, also returns the final **normed** hidden
states `[B, S, H]` (so the caller can produce the encoder's
autoregressive logits for the co-trained AR loss). Default False keeps
the KV-only contract used by inference and the parity/leakage tests.

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone.forward(
    mode: str,
    input_ids: torch.Tensor | None = None,
    position_ids: torch.Tensor | None = None,
    padding_mask: torch.Tensor | None = None,
    return_hidden: bool = False,
    canvas_ids: torch.Tensor | None = None,
    encoder_kv: list[tuple[torch.Tensor, torch.Tensor]] | None = None,
    decoder_position_ids: torch.Tensor | None = None,
    decoder_masks: dict | None = None,
    decoder_padding_mask: torch.Tensor | None = None,
    self_conditioning_logits: torch.Tensor | None = None,
    self_conditioning_mask: torch.Tensor | None = None
) -> list[tuple[torch.Tensor, torch.Tensor]] | torch.Tensor
```

Dispatch encode/decode through `nn.Module.__call__` for FSDP hooks.

FSDP2 hooks are installed on module calls, not on arbitrary helper
methods. The block-diffusion top-level forward must therefore enter the
backbone via `self.model(...)` so root-owned parameters such as
`self_conditioning` and the final `norm` are gathered before use.

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone.get_input_embeddings() -> torch.nn.Module
```

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaBackbone.set_input_embeddings(
    value: torch.nn.Module
) -> None
```

```python
class nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion(
    config: transformers.models.diffusion_gemma.configuration_diffusion_gemma.DiffusionGemmaConfig,
    moe_config: 'MoEConfig | None' = None,
    backend: 'BackendConfig | None' = None,
    canvas_length: int | None = None,
    self_conditioning: bool | None = None,
    freeze_router: bool | None = None,
    kwargs: typing.Any = {}
)
```

**Bases:** [HFCheckpointingMixin](/nemo-automodel/nemo_automodel/components/models/common/hf_checkpointing_mixin#nemo_automodel-components-models-common-hf_checkpointing_mixin-HFCheckpointingMixin), [MoEFSDPSyncMixin](/nemo-automodel/nemo_automodel/components/moe/fsdp_mixin#nemo_automodel-components-moe-fsdp_mixin-MoEFSDPSyncMixin), `PreTrainedModel`

Block-diffusion Gemma MoE model for SFT.

Inherits the AM checkpointing + MoE-FSDP machinery. The MoE backbone is
reused from `gemma4_moe`; the diffusion training forward and the two-pass
self-conditioning are new. See module docstring for the single-shared-stack
design.

`forward` is the SFT **training** forward. A generation/inference loop
(encode the prompt once, then iteratively denoise canvas blocks reusing the
KV cache, with the self-conditioning recycling loop) is deferred; the
`model.encode` / `model.decode` building blocks are the reusable pieces
for it, and `forward` already accepts an explicit `self_conditioning_logits`
for the per-step inference contract.

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion._softcap_logits(
    hidden_states: torch.Tensor
) -> torch.Tensor
```

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.forward(
    input_ids: torch.Tensor | None = None,
    canvas_ids: torch.Tensor | None = None,
    self_conditioning_logits: torch.Tensor | None = None,
    encoder_position_ids: torch.Tensor | None = None,
    encoder_padding_mask: torch.Tensor | None = None,
    decoder_position_ids: torch.Tensor | None = None,
    decoder_attention_mask: dict | None = None,
    decoder_padding_mask: torch.Tensor | None = None,
    do_self_conditioning: torch.Tensor | bool | None = None,
    kwargs: typing.Any = {}
) -> 'DiffusionGemmaOutput'
```

Training forward — single shared stack run twice + two-pass self-cond.

**Parameters:**

Clean full sequence (prompt + response), `[B, S]`. Run
causally to build the read-only encoder KV cache.

Noised response/canvas tokens, `[B, canvas_len]`. Run
bidirectionally with the block-causal mask.

If given (inference / external loop), used
directly and the two-pass logic is skipped. During training the
two-pass scheme generates the self-cond signal internally.

Position ids for the encoder pass (`[B, S]`).
Defaults to `arange(S)`.

`True` at padded encoder positions (`[B, S]`).

Position ids for the canvas (`[B, canvas_len]`).
Must be the canvas tokens' **absolute** positions so their query
RoPE aligns with the encoder key RoPE of the clean copies. In the
v1 full-sequence-canvas layout (`canvas_len == S`) this is
`arange(S)` (the default); a response-window canvas would use
`prefix_length + arange(canvas_len)` per example.

Dict `&#123;"full_attention", "sliding_attention"&#125;`
of additive block-causal masks (from
`build_block_diffusion_training_mask`). Required for training;
built by the recipe's `_forward_backward_step` override.

`True` at padded canvas positions
(`[B, canvas_len]`). Used to keep padded rows out of MoE routing.

Per-example self-conditioning coins, a `[B]`
bool tensor (a scalar bool is broadcast). During training pass-1
**always** runs (constant FSDP collectives every step -> no rank
desync, and correct for `local_batch_size &gt; 1`); this mask gates,
per example, whether pass-2 consumes the self-cond signal (`False`
-> zeroed soft-embed, i.e. no self-cond). The recipe supplies it via
`_decide_self_conditioning`. Required during training (`None`
would drop Google's per-example mix -> `ValueError`); ignored
outside training (eval / single pass).

**Returns:** `'DiffusionGemmaOutput'`

`DiffusionGemmaOutput` with canvas-only `logits`

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.freeze_router_params() -> None
```

Freeze the MoE router/gate (design v2 item 9).

Sets `train_gate=False` and `requires_grad=False` on the gate's
`proj.weight` and `scale` for every layer. Routing indices are
already non-differentiable; `per_expert_scale` is folded into the
(trainable) expert `down_proj` by the state-dict adapter, so the
experts stay trainable. `MoEFSDPSyncMixin` keys on
`set_requires_gradient_sync`, never `requires_grad`, so freezing
the gate does not break grad-sync.

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.from_config(
    config: transformers.models.diffusion_gemma.configuration_diffusion_gemma.DiffusionGemmaConfig,
    moe_config: 'MoEConfig | None' = None,
    backend: 'BackendConfig | None' = None,
    kwargs: typing.Any = {}
) -> 'DiffusionGemmaForBlockDiffusion'
```

classmethod

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.get_capabilities(
    config: 'DiffusionGemmaConfig'
) -> 'ModelCapabilities'
```

classmethod

Parallelism support for the DiffusionGemma block-diffusion MoE.

Single variant: FSDP2 + Expert Parallelism are supported (validated at
EP=8). TP is unsupported for the custom MoE; CP/PP are not supported for
this encoder-decoder block-diffusion path.

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.get_input_embeddings() -> torch.nn.Module
```

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.get_output_embeddings() -> torch.nn.Module
```

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.initialize_weights(
    buffer_device: torch.device | None = None,
    dtype: torch.dtype | None = None
) -> None
```

Initialize grouped-expert parameters (other params init via HF post\_init).

`dtype` defaults to the model's **configured** `torch_dtype` rather than a
hardcoded `bfloat16`. The meta/FSDP init path
(`checkpoint.checkpointing.initialize_model_weights`) calls this with no
dtype; a blanket `self.to(torch.bfloat16)` would materialize the whole model
in bf16 before the checkpoint loads, silently defeating the fp32 master weights
that `model.torch_dtype: float32` configs request (leaving AdamW on bf16
params). Honor the requested dtype instead.

```python
nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaForBlockDiffusion.set_input_embeddings(
    value: torch.nn.Module
) -> None
```

```python
class nemo_automodel.components.models.diffusion_gemma.model.DiffusionGemmaOutput(
    logits: 'torch.Tensor',
    encoder_logits: 'torch.Tensor | None' = None
)
```

Dataclass

Training forward output.

`logits` are the canvas-only (response) denoising logits `[B, canvas_len, V]`.
`encoder_logits` are the causal encoder's next-token logits over the clean
full sequence `[B, S, V]` for the co-trained AR loss — `None` outside
training (and when the AR loss is unused).

```python
nemo_automodel.components.models.diffusion_gemma.model._make_causal_additive_mask(
    seq_len: int,
    padding_mask: torch.Tensor | None,
    sliding_window: int | None,
    batch_size: int,
    device: torch.device,
    dtype: torch.dtype
) -> torch.Tensor
```

Build an additive causal (optionally sliding-window) mask for the encoder.

Shape `[B, 1, seq_len, seq_len]`; `0` keep, `finfo.min` masked.
`padding_mask` is `[B, seq_len]` with `True` at padding positions.

```python
nemo_automodel.components.models.diffusion_gemma.model._make_missing(
    name: str
)
```

```python
nemo_automodel.components.models.diffusion_gemma.model.ModelClass = DiffusionGemmaForBlockDiffusion
```

```python
nemo_automodel.components.models.diffusion_gemma.model._TRANSFORMERS_AVAILABLE = True
```