nemo_automodel.components.moe.state_dict_mixin
nemo_automodel.components.moe.state_dict_mixin
Module Contents
Classes
API
Mixin class providing MoE state dict conversion utilities.
This mixin provides methods for:
- Expert parallelism calculations (ranges, assignment)
- Format conversion between HuggingFace and native formats
- Both GroupedExperts and DeepEP format support
- DTensor-aware expert loading and conversion
Can be used by any MoE model that needs expert parallelism and format conversion.
Path segment for experts (e.g., ‘mlp.experts’ or ‘mixer.experts’). Override in subclass.
Prefix for HuggingFace format keys. Override in subclass.
Check if the MoE uses gated activation (e.g., SwiGLU) or non-gated (e.g., ReLU²).
Native keys loaded in-place via strided views during the most recent from_hf.
MoE experts with a plain local split are loaded by DCP writing the checkpoint tensors
straight through non-contiguous strided views into the model’s grouped expert storage.
Such keys are intentionally absent from the dict from_hf returns (the data is already
in the model) but are NOT missing. _from_hf_w_merged_experts records them here so the
checkpoint loader can exclude them from false “missing” key-diff warnings. The record is
reset at the start of each load by _from_hf_w_merged_experts(reset_view_loaded_keys=True).
Concatenate the weights of separate experts into GroupedExpert weights.
Parameters:
Nested dict structure containing expert weights
Total number of experts expected
Returns: Optional[torch.Tensor]
Stacked tensor if all experts are available for a layer, None otherwise
Convert a grouped MoE expert LoRA tensor to per-expert HF PEFT format.
Handles the four LoRA parameter types produced by GroupedExpertsLoRA /
GroupedExpertsDeepEPLoRA and converts them to per-expert lora_A.weight
/ lora_B.weight keys that HF PEFT understands.
The prefix (e.g. base_model.model.model.) is preserved from the
incoming fqn so that both PEFT and FFT save paths work correctly.
Convert a single merged expert tensor from native format to split HuggingFace format.
When tensor is a model DTensor with a plain (non-DTensor) local
split — i.e. ep_shard == 1 — the per-expert outputs are returned
as non-contiguous strided views into the local storage of the
model’s grouped DTensor instead of newly-allocated contiguous copies.
DCP’s target.copy_(source) then writes safetensors data directly
through the views into the model’s storage, and
_from_hf_w_merged_experts skips the rebuild for the corresponding
native key (tracked in _inplace_loaded_native_keys). For loads of
large MoE checkpoints this avoids tens of GB of per-expert
scratch on top of the already-materialized model.
Save callers must materialize the views before serializing —
safetensors.torch.save rejects non-contiguous tensors. See
_materialize_to_hf_views_for_save in checkpointing.py.
Parameters:
Fully qualified name of the tensor in native format.
The tensor to convert.
When provided, replaces self._hf_prefix in
emitted HF keys. Used to route conversions through namespaces
outside the main backbone, e.g. "mtp." for the MTP head.
Absorbed for forward-compatibility with base callers
that forward arbitrary state-dict kwargs (e.g. exclude_key_regex).
Returns: list[tuple[str, torch.Tensor]]
List of (fqn, tensor) tuples in HuggingFace format, or None if not an expert tensor.
Convert HF checkpoint to native format.
For gated activations (SwiGLU, Quick-GEGLU): Creates combined gate_and_up_projs [n_experts, dim, 2*inter_dim] and transposed down_projs tensors.
For non-gated activations (ReLU²): Creates gate_and_up_projs [n_experts, dim, inter_dim] and transposed down_projs tensors.
Parameters:
Clear the in-place (strided-view) loaded-key record at the
start of this call. A single from_hf may invoke this method more than once
(e.g. backbone then MTP merge); the later call(s) pass False so the view-loaded
keys accumulate across one logical load. Resetting here (rather than in the loader)
keeps the whole view-key lifecycle inside the adapter and ensures each load starts
clean (no leak from a prior load such as an init-time partial load).
Recombine per-expert HF LoRA keys back to grouped MoE LoRA format.
This is the reverse of _convert_lora_expert_to_hf. It detects
per-expert LoRA keys (e.g.
layers.0.mlp.experts.0.gate_proj.lora_A.weight) and recombines
them into the grouped tensors expected by GroupedExpertsLoRA /
GroupedExpertsDeepEPLoRA (e.g. layers.0.mlp.experts.lora_gate_and_up_A).
Mark fqn as loaded via in-place views so _from_hf_w_merged_experts skips its rebuild.
The tracked key must match the native_key that the from_hf merge loop
reconstructs from the HF per-expert keys. For backbone tensors the
native_key equals fqn; for MTP tensors (prefix_override="mtp.")
the HF keys live under the mtp. namespace and from_hf processes
them with that prefix stripped, so the tracked key is also the
mtp.-less form. The user of this set (_from_hf_w_merged_experts)
receives the matching stripped key when called via the adapter’s
per-namespace dispatch.
Split grouped expert weights into individual expert weights. For grouped expert weights with shape [n_experts, …], split into n_experts tensors each with shape […]. Supports both regular tensors and DTensors.
Convert DeepEP format to HuggingFace format.
Handles gate_and_up_projs / down_projs -> individual expert
weights. Forwards **kwargs to
_convert_single_merged_expert_to_hf_split_experts for adapter
compatibility (e.g. exclude_key_regex).
Validate that all required experts are available in the HF state dict before loading. Only validates experts needed for the current rank and layers present in the state dict.
Parameters:
HuggingFace format state dict
Total number of experts
Optional device mesh for expert parallelism
Raises:
RuntimeError: If required expert weights are missing from the checkpoint