> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.models.diffusion_gemma.state_dict_adapter

State-dict adapter for `diffusion_gemma`.

The HF checkpoint stores all transformer weights under `model.decoder.*` (the
"decoder" of the encoder-decoder framing). The encoder's text weights and the
`lm_head` are tied (and therefore absent from the checkpoint) and are
reconstructed at load time; only `model.encoder.language_model.layers.&#123;L&#125;.layer_scalar`
buffers are emitted (duplicates of the decoder `layer_scalar`) and are dropped
here. The native model uses a **single shared stack** at `model.*` (no
encoder/decoder split), so the mapping is essentially `model.decoder.X -&gt; model.X`
plus the Gemma4 MoE expert/router transforms shared with `gemma4_moe`:

* `decoder.layers.&#123;L&#125;.experts.gate_up_proj`  \[E, 2*inter, hidden]
  -> `model.layers.&#123;L&#125;.moe.experts.gate_and_up_projs`  \[E, hidden, 2*inter]
* `decoder.layers.&#123;L&#125;.experts.down_proj`     \[E, hidden, inter]
  -> `model.layers.&#123;L&#125;.moe.experts.down_projs`  \[E, inter, hidden]
  (with `router.per_expert_scale` folded in)
* `decoder.layers.&#123;L&#125;.router.&#123;proj.weight,scale&#125;` -> `model.layers.&#123;L&#125;.moe.gate.&#123;proj.weight,scale&#125;`
* `decoder.embed_tokens.weight` -> `model.embed_tokens.weight` (also tied to `lm_head.weight`)
* `decoder.norm.weight` -> `model.norm.weight`
* `decoder.self_conditioning.*` -> `model.self_conditioning.*`

Full-attention layers (\{5, 11, 17, 23, 29}) have no `v_proj` in the checkpoint;
those keys are simply absent and the model has no `v_proj` parameter there, so
no special handling is needed (the pass-through preserves whatever is present).

## Module Contents

### Classes

| Name                                                                                                                                    | Description                                                            |
| --------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------- |
| [`DiffusionGemmaStateDictAdapter`](#nemo_automodel-components-models-diffusion_gemma-state_dict_adapter-DiffusionGemmaStateDictAdapter) | Converts between HF `diffusion_gemma` checkpoints and the NeMo layout. |

### Data

[`_DECODER_PREFIX`](#nemo_automodel-components-models-diffusion_gemma-state_dict_adapter-_DECODER_PREFIX)

[`_NATIVE_PREFIX`](#nemo_automodel-components-models-diffusion_gemma-state_dict_adapter-_NATIVE_PREFIX)

### API

```python
class nemo_automodel.components.models.diffusion_gemma.state_dict_adapter.DiffusionGemmaStateDictAdapter(
    config: typing.Any,
    moe_config: nemo_automodel.components.moe.layers.MoEConfig,
    backend: nemo_automodel.components.models.common.BackendConfig,
    dtype: torch.dtype = torch.float32
)
```

**Bases:** [StateDictAdapter](/nemo-automodel/nemo_automodel/components/checkpoint/state_dict_adapter#nemo_automodel-components-checkpoint-state_dict_adapter-StateDictAdapter)

Converts between HF `diffusion_gemma` checkpoints and the NeMo layout.

```python
nemo_automodel.components.models.diffusion_gemma.state_dict_adapter.DiffusionGemmaStateDictAdapter._fold_per_expert_scale(
    down: torch.Tensor,
    per_expert_scale: torch.Tensor
) -> torch.Tensor
```

staticmethod

Fold `per_expert_scale[E]` into `down[E, inter, hidden]` (scale by row).

Under pure FSDP the standard DCP load path leaves `down` as a `Shard(0)`
DTensor (each rank holds `[E/world, ...]`) while `per_expert_scale` is
loaded replicated as a plain `[E]` tensor. A plain `down * scale[:, None,
None]` then broadcasts a global-`[E]` factor against a local-`[E/world]`
shard and fails. Wrap the scale as a DTensor replicated on `down`'s mesh so
the multiply runs DTensor-vs-DTensor and the scale is sliced per shard.

```python
nemo_automodel.components.models.diffusion_gemma.state_dict_adapter.DiffusionGemmaStateDictAdapter._gather_expert_tensor(
    tensor: torch.Tensor,
    device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh],
    n_experts: int
) -> torch.Tensor
```

Map a stacked expert tensor to its HF (global `[n_experts, ...]`) form.

`device_mesh` is the MoE/EP mesh and is `None` under pure FSDP
(`ep_size=1`). In that case the tensor is either a plain `[E, ...]`
tensor or an FSDP `Shard(0)` DTensor whose **global** first dimension
is already `n_experts`; both are returned unchanged so DCP sees the
checkpoint's global expert shape and reads/writes each rank's shard
itself. Calling `to_local()` here would expose the per-rank `[E/world]`
shard as the apparent global shape and break the DCP size match
(`saved [E, ...] vs current [E/world, ...]`). With an EP mesh the
tensor is gathered across EP ranks into the full `[n_experts, ...]`.

```python
nemo_automodel.components.models.diffusion_gemma.state_dict_adapter.DiffusionGemmaStateDictAdapter.convert_single_tensor_to_hf(
    fqn: str,
    tensor: typing.Any,
    kwargs = {}
) -> list[tuple[str, typing.Any]]
```

Convert a single native tensor back to HF format (weight-streaming refit).

Mirrors :meth:`to_hf` per-tensor: router rename, expert transpose +
`per_expert_scale` emission, `model.X -&gt; model.decoder.X` re-prefix,
and dropping the tied `lm_head.weight`.

```python
nemo_automodel.components.models.diffusion_gemma.state_dict_adapter.DiffusionGemmaStateDictAdapter.from_hf(
    hf_state_dict: dict[str, typing.Any],
    device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None,
    kwargs = {}
) -> dict[str, typing.Any]
```

```python
nemo_automodel.components.models.diffusion_gemma.state_dict_adapter.DiffusionGemmaStateDictAdapter.to_hf(
    state_dict: dict[str, typing.Any],
    exclude_key_regex: typing.Optional[str] = None,
    quantization: bool = False,
    kwargs = {}
) -> dict[str, typing.Any]
```

```python
nemo_automodel.components.models.diffusion_gemma.state_dict_adapter._DECODER_PREFIX = 'model.decoder.'
```

```python
nemo_automodel.components.models.diffusion_gemma.state_dict_adapter._NATIVE_PREFIX = 'model.'
```