nemo_automodel.components.moe.state_dict_mixin

View as Markdown

Module Contents

Classes

NameDescription
MoESplitExpertsStateDictMixinMixin class providing MoE state dict conversion utilities.

API

class nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin()

Mixin class providing MoE state dict conversion utilities.

This mixin provides methods for:

  • Expert parallelism calculations (ranges, assignment)
  • Format conversion between HuggingFace and native formats
  • Both GroupedExperts and DeepEP format support
  • DTensor-aware expert loading and conversion

Can be used by any MoE model that needs expert parallelism and format conversion.

_expert_path_segment
str

Path segment for experts (e.g., ‘mlp.experts’ or ‘mixer.experts’). Override in subclass.

_hf_prefix
str

Prefix for HuggingFace format keys. Override in subclass.

_is_gated_moe
bool

Check if the MoE uses gated activation (e.g., SwiGLU) or non-gated (e.g., ReLU²).

view_loaded_native_keys
set[str]

Native keys loaded in-place via strided views during the most recent from_hf.

MoE experts with a plain local split are loaded by DCP writing the checkpoint tensors straight through non-contiguous strided views into the model’s grouped expert storage. Such keys are intentionally absent from the dict from_hf returns (the data is already in the model) but are NOT missing. _from_hf_w_merged_experts records them here so the checkpoint loader can exclude them from false “missing” key-diff warnings. The record is reset at the start of each load by _from_hf_w_merged_experts(reset_view_loaded_keys=True).

nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin._concatenate_expert_weights(
expert_weights_by_layer: dict[str, typing.Any],
n_experts: int
) -> typing.Optional[torch.Tensor]

Concatenate the weights of separate experts into GroupedExpert weights.

Parameters:

expert_weights_by_layer
dict[str, Any]

Nested dict structure containing expert weights

n_experts
int

Total number of experts expected

Returns: Optional[torch.Tensor]

Stacked tensor if all experts are available for a layer, None otherwise

nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin._convert_lora_expert_to_hf(
fqn: str,
tensor: torch.Tensor,
n_experts: int,
inter_dim: int,
expert_segment: str
) -> list[tuple[str, torch.Tensor]]

Convert a grouped MoE expert LoRA tensor to per-expert HF PEFT format.

Handles the four LoRA parameter types produced by GroupedExpertsLoRA / GroupedExpertsDeepEPLoRA and converts them to per-expert lora_A.weight / lora_B.weight keys that HF PEFT understands.

The prefix (e.g. base_model.model.model.) is preserved from the incoming fqn so that both PEFT and FFT save paths work correctly.

nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin._convert_single_merged_expert_to_hf_split_experts(
fqn: str,
tensor: torch.Tensor,
prefix_override: str | None = None,
kwargs = {}
) -> list[tuple[str, torch.Tensor]]

Convert a single merged expert tensor from native format to split HuggingFace format.

When tensor is a model DTensor with a plain (non-DTensor) local split — i.e. ep_shard == 1 — the per-expert outputs are returned as non-contiguous strided views into the local storage of the model’s grouped DTensor instead of newly-allocated contiguous copies. DCP’s target.copy_(source) then writes safetensors data directly through the views into the model’s storage, and _from_hf_w_merged_experts skips the rebuild for the corresponding native key (tracked in _inplace_loaded_native_keys). For loads of large MoE checkpoints this avoids tens of GB of per-expert scratch on top of the already-materialized model.

Save callers must materialize the views before serializing — safetensors.torch.save rejects non-contiguous tensors. See _materialize_to_hf_views_for_save in checkpointing.py.

Parameters:

fqn
str

Fully qualified name of the tensor in native format.

tensor
torch.Tensor

The tensor to convert.

prefix_override
str | NoneDefaults to None

When provided, replaces self._hf_prefix in emitted HF keys. Used to route conversions through namespaces outside the main backbone, e.g. "mtp." for the MTP head.

**kwargs
Defaults to {}

Absorbed for forward-compatibility with base callers that forward arbitrary state-dict kwargs (e.g. exclude_key_regex).

Returns: list[tuple[str, torch.Tensor]]

List of (fqn, tensor) tuples in HuggingFace format, or None if not an expert tensor.

nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin._from_hf_w_merged_experts(
hf_state_dict: dict[str, typing.Any],
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None,
reset_view_loaded_keys: bool = True
) -> dict[str, typing.Any]

Convert HF checkpoint to native format.

For gated activations (SwiGLU, Quick-GEGLU): Creates combined gate_and_up_projs [n_experts, dim, 2*inter_dim] and transposed down_projs tensors.

For non-gated activations (ReLU²): Creates gate_and_up_projs [n_experts, dim, inter_dim] and transposed down_projs tensors.

Parameters:

reset_view_loaded_keys
boolDefaults to True

Clear the in-place (strided-view) loaded-key record at the start of this call. A single from_hf may invoke this method more than once (e.g. backbone then MTP merge); the later call(s) pass False so the view-loaded keys accumulate across one logical load. Resetting here (rather than in the loader) keeps the whole view-key lifecycle inside the adapter and ensures each load starts clean (no leak from a prior load such as an init-time partial load).

nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin._recombine_lora_expert_keys(
state_dict: dict[str, typing.Any]
) -> dict[str, typing.Any]

Recombine per-expert HF LoRA keys back to grouped MoE LoRA format.

This is the reverse of _convert_lora_expert_to_hf. It detects per-expert LoRA keys (e.g. layers.0.mlp.experts.0.gate_proj.lora_A.weight) and recombines them into the grouped tensors expected by GroupedExpertsLoRA / GroupedExpertsDeepEPLoRA (e.g. layers.0.mlp.experts.lora_gate_and_up_A).

nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin._register_inplace_loaded_key(
fqn: str,
prefix_override: str | None
) -> None

Mark fqn as loaded via in-place views so _from_hf_w_merged_experts skips its rebuild.

The tracked key must match the native_key that the from_hf merge loop reconstructs from the HF per-expert keys. For backbone tensors the native_key equals fqn; for MTP tensors (prefix_override="mtp.") the HF keys live under the mtp. namespace and from_hf processes them with that prefix stripped, so the tracked key is also the mtp.-less form. The user of this set (_from_hf_w_merged_experts) receives the matching stripped key when called via the adapter’s per-namespace dispatch.

nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin._split_experts_weights(
weight: torch.Tensor,
n_experts: int
) -> list[torch.Tensor]

Split grouped expert weights into individual expert weights. For grouped expert weights with shape [n_experts, …], split into n_experts tensors each with shape […]. Supports both regular tensors and DTensors.

nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin._to_hf_w_split_experts(
state_dict: dict[str, typing.Any],
kwargs: typing.Any = {}
) -> dict[str, typing.Any]

Convert DeepEP format to HuggingFace format.

Handles gate_and_up_projs / down_projs -> individual expert weights. Forwards **kwargs to _convert_single_merged_expert_to_hf_split_experts for adapter compatibility (e.g. exclude_key_regex).

nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin._validate_expert_availability(
hf_state_dict: dict[str, typing.Any],
n_experts: int,
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None
) -> None

Validate that all required experts are available in the HF state dict before loading. Only validates experts needed for the current rank and layers present in the state dict.

Parameters:

hf_state_dict
dict[str, Any]

HuggingFace format state dict

n_experts
int

Total number of experts

device_mesh
Optional[DeviceMesh]Defaults to None

Optional device mesh for expert parallelism

Raises:

  • RuntimeError: If required expert weights are missing from the checkpoint