nemo_automodel.components.models.gemma4_moe.state_dict_adapter

View as Markdown

State-dict adapter for Gemma4 MoE.

HF Gemma4 MoE (eevee-4 26B-A4B) stores expert weights as 3-D tensors:

layers.{L}.moe.gate_up_proj # [n_experts, 2*expert_inter_size, hidden_size] layers.{L}.moe.down_proj # [n_experts, hidden_size, expert_inter_size] layers.{L}.moe.per_expert_scale # [n_experts]

NeMo uses transposed layout with concatenated gate+up:

layers.{L}.moe.experts.gate_and_up_projs # [n_experts, hidden_size, 2*expert_inter_size] layers.{L}.moe.experts.down_projs # [n_experts, expert_inter_size, hidden_size]

Additionally, the Gemma4 router is mapped to the NeMo Gemma4Gate:

HF: .router.proj.weight / .router.scale NeMo: .moe.gate.proj.weight / .moe.gate.scale

The per_expert_scale is absorbed into down_projs during from_hf. When saving back to HF, per_expert_scale is emitted as ones (scale already baked into the weights).

Module Contents

Classes

NameDescription
Gemma4MoEStateDictAdapterConverts between HF Gemma4 MoE checkpoints and the NeMo format.

API

class nemo_automodel.components.models.gemma4_moe.state_dict_adapter.Gemma4MoEStateDictAdapter(
config: typing.Any,
moe_config: nemo_automodel.components.moe.layers.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
dtype: torch.dtype = torch.float32
)

Bases: StateDictAdapter

Converts between HF Gemma4 MoE checkpoints and the NeMo format.

nemo_automodel.components.models.gemma4_moe.state_dict_adapter.Gemma4MoEStateDictAdapter._gather_expert_tensor(
tensor: torch.Tensor,
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh],
n_experts: int
) -> torch.Tensor

Gather EP-sharded expert tensor across ranks into a full tensor.

nemo_automodel.components.models.gemma4_moe.state_dict_adapter.Gemma4MoEStateDictAdapter.convert_single_tensor_to_hf(
fqn: str,
tensor: typing.Any,
kwargs = {}
) -> list[tuple[str, typing.Any]]

Convert a single native tensor back to HF format.

Handles per-tensor conversion for weight streaming (IPC refit) required in RL training:

  • Router keys: moe.gate.{proj.weight,scale} -> router.{proj.weight,scale}
  • Expert gate_and_up_projs: transpose [E, hidden, 2inter] -> [E, 2inter, hidden] and rename to experts.gate_up_proj
  • Expert down_projs: transpose [E, inter, hidden] -> [E, hidden, inter], rename to experts.down_proj, and emit router.per_expert_scale as ones
nemo_automodel.components.models.gemma4_moe.state_dict_adapter.Gemma4MoEStateDictAdapter.from_hf(
hf_state_dict: dict[str, typing.Any],
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None,
kwargs = {}
) -> dict[str, typing.Any]
nemo_automodel.components.models.gemma4_moe.state_dict_adapter.Gemma4MoEStateDictAdapter.to_hf(
state_dict: dict[str, typing.Any],
exclude_key_regex: typing.Optional[str] = None,
quantization: bool = False,
kwargs = {}
) -> dict[str, typing.Any]