nemo_automodel.components.models.diffusion_gemma.state_dict_adapter

View as Markdown

State-dict adapter for diffusion_gemma.

The HF checkpoint stores all transformer weights under model.decoder.* (the “decoder” of the encoder-decoder framing). The encoder’s text weights and the lm_head are tied (and therefore absent from the checkpoint) and are reconstructed at load time; only model.encoder.language_model.layers.{L}.layer_scalar buffers are emitted (duplicates of the decoder layer_scalar) and are dropped here. The native model uses a single shared stack at model.* (no encoder/decoder split), so the mapping is essentially model.decoder.X -> model.X plus the Gemma4 MoE expert/router transforms shared with gemma4_moe:

  • decoder.layers.{L}.experts.gate_up_proj [E, 2inter, hidden] -> model.layers.{L}.moe.experts.gate_and_up_projs [E, hidden, 2inter]
  • decoder.layers.{L}.experts.down_proj [E, hidden, inter] -> model.layers.{L}.moe.experts.down_projs [E, inter, hidden] (with router.per_expert_scale folded in)
  • decoder.layers.{L}.router.{proj.weight,scale} -> model.layers.{L}.moe.gate.{proj.weight,scale}
  • decoder.embed_tokens.weight -> model.embed_tokens.weight (also tied to lm_head.weight)
  • decoder.norm.weight -> model.norm.weight
  • decoder.self_conditioning.* -> model.self_conditioning.*

Full-attention layers ({5, 11, 17, 23, 29}) have no v_proj in the checkpoint; those keys are simply absent and the model has no v_proj parameter there, so no special handling is needed (the pass-through preserves whatever is present).

Module Contents

Classes

NameDescription
DiffusionGemmaStateDictAdapterConverts between HF diffusion_gemma checkpoints and the NeMo layout.

Data

_DECODER_PREFIX

_NATIVE_PREFIX

API

class nemo_automodel.components.models.diffusion_gemma.state_dict_adapter.DiffusionGemmaStateDictAdapter(
config: typing.Any,
moe_config: nemo_automodel.components.moe.layers.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype = torch.float32
)

Bases: StateDictAdapter

Converts between HF diffusion_gemma checkpoints and the NeMo layout.

nemo_automodel.components.models.diffusion_gemma.state_dict_adapter.DiffusionGemmaStateDictAdapter._fold_per_expert_scale(
down: torch.Tensor,
per_expert_scale: torch.Tensor
) -> torch.Tensor
staticmethod

Fold per_expert_scale[E] into down[E, inter, hidden] (scale by row).

Under pure FSDP the standard DCP load path leaves down as a Shard(0) DTensor (each rank holds [E/world, ...]) while per_expert_scale is loaded replicated as a plain [E] tensor. A plain down * scale[:, None, None] then broadcasts a global-[E] factor against a local-[E/world] shard and fails. Wrap the scale as a DTensor replicated on down’s mesh so the multiply runs DTensor-vs-DTensor and the scale is sliced per shard.

nemo_automodel.components.models.diffusion_gemma.state_dict_adapter.DiffusionGemmaStateDictAdapter._gather_expert_tensor(
tensor: torch.Tensor,
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh],
n_experts: int
) -> torch.Tensor

Map a stacked expert tensor to its HF (global [n_experts, ...]) form.

device_mesh is the MoE/EP mesh and is None under pure FSDP (ep_size=1). In that case the tensor is either a plain [E, ...] tensor or an FSDP Shard(0) DTensor whose global first dimension is already n_experts; both are returned unchanged so DCP sees the checkpoint’s global expert shape and reads/writes each rank’s shard itself. Calling to_local() here would expose the per-rank [E/world] shard as the apparent global shape and break the DCP size match (saved [E, ...] vs current [E/world, ...]). With an EP mesh the tensor is gathered across EP ranks into the full [n_experts, ...].

nemo_automodel.components.models.diffusion_gemma.state_dict_adapter.DiffusionGemmaStateDictAdapter.convert_single_tensor_to_hf(
fqn: str,
tensor: typing.Any,
kwargs = {}
) -> list[tuple[str, typing.Any]]

Convert a single native tensor back to HF format (weight-streaming refit).

Mirrors :meth:to_hf per-tensor: router rename, expert transpose + per_expert_scale emission, model.X -> model.decoder.X re-prefix, and dropping the tied lm_head.weight.

nemo_automodel.components.models.diffusion_gemma.state_dict_adapter.DiffusionGemmaStateDictAdapter.from_hf(
hf_state_dict: dict[str, typing.Any],
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None,
kwargs = {}
) -> dict[str, typing.Any]
nemo_automodel.components.models.diffusion_gemma.state_dict_adapter.DiffusionGemmaStateDictAdapter.to_hf(
state_dict: dict[str, typing.Any],
exclude_key_regex: typing.Optional[str] = None,
quantization: bool = False,
kwargs = {}
) -> dict[str, typing.Any]
nemo_automodel.components.models.diffusion_gemma.state_dict_adapter._DECODER_PREFIX = 'model.decoder.'
nemo_automodel.components.models.diffusion_gemma.state_dict_adapter._NATIVE_PREFIX = 'model.'