nemo_automodel.components.models.diffusion_gemma.state_dict_adapter
nemo_automodel.components.models.diffusion_gemma.state_dict_adapter
State-dict adapter for diffusion_gemma.
The HF checkpoint stores all transformer weights under model.decoder.* (the
“decoder” of the encoder-decoder framing). The encoder’s text weights and the
lm_head are tied (and therefore absent from the checkpoint) and are
reconstructed at load time; only model.encoder.language_model.layers.{L}.layer_scalar
buffers are emitted (duplicates of the decoder layer_scalar) and are dropped
here. The native model uses a single shared stack at model.* (no
encoder/decoder split), so the mapping is essentially model.decoder.X -> model.X
plus the Gemma4 MoE expert/router transforms shared with gemma4_moe:
decoder.layers.{L}.experts.gate_up_proj[E, 2inter, hidden] ->model.layers.{L}.moe.experts.gate_and_up_projs[E, hidden, 2inter]decoder.layers.{L}.experts.down_proj[E, hidden, inter] ->model.layers.{L}.moe.experts.down_projs[E, inter, hidden] (withrouter.per_expert_scalefolded in)decoder.layers.{L}.router.{proj.weight,scale}->model.layers.{L}.moe.gate.{proj.weight,scale}decoder.embed_tokens.weight->model.embed_tokens.weight(also tied tolm_head.weight)decoder.norm.weight->model.norm.weightdecoder.self_conditioning.*->model.self_conditioning.*
Full-attention layers ({5, 11, 17, 23, 29}) have no v_proj in the checkpoint;
those keys are simply absent and the model has no v_proj parameter there, so
no special handling is needed (the pass-through preserves whatever is present).
Module Contents
Classes
Data
API
Bases: StateDictAdapter
Converts between HF diffusion_gemma checkpoints and the NeMo layout.
Fold per_expert_scale[E] into down[E, inter, hidden] (scale by row).
Under pure FSDP the standard DCP load path leaves down as a Shard(0)
DTensor (each rank holds [E/world, ...]) while per_expert_scale is
loaded replicated as a plain [E] tensor. A plain down * scale[:, None, None] then broadcasts a global-[E] factor against a local-[E/world]
shard and fails. Wrap the scale as a DTensor replicated on down’s mesh so
the multiply runs DTensor-vs-DTensor and the scale is sliced per shard.
Map a stacked expert tensor to its HF (global [n_experts, ...]) form.
device_mesh is the MoE/EP mesh and is None under pure FSDP
(ep_size=1). In that case the tensor is either a plain [E, ...]
tensor or an FSDP Shard(0) DTensor whose global first dimension
is already n_experts; both are returned unchanged so DCP sees the
checkpoint’s global expert shape and reads/writes each rank’s shard
itself. Calling to_local() here would expose the per-rank [E/world]
shard as the apparent global shape and break the DCP size match
(saved [E, ...] vs current [E/world, ...]). With an EP mesh the
tensor is gathered across EP ranks into the full [n_experts, ...].
Convert a single native tensor back to HF format (weight-streaming refit).
Mirrors :meth:to_hf per-tensor: router rename, expert transpose +
per_expert_scale emission, model.X -> model.decoder.X re-prefix,
and dropping the tied lm_head.weight.