nemo_automodel.components.moe.state_dict_mixin#

Module Contents#

Classes#

MoESplitExpertsStateDictMixin

Mixin class providing MoE state dict conversion utilities.

API#

class nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin#

Mixin class providing MoE state dict conversion utilities.

This mixin provides methods for:

  • Expert parallelism calculations (ranges, assignment)

  • Format conversion between HuggingFace and native formats

  • Both GroupedExperts and DeepEP format support

  • DTensor-aware expert loading and conversion

Can be used by any MoE model that needs expert parallelism and format conversion.

property _is_gated_moe: bool#

Check if the MoE uses gated activation (e.g., SwiGLU) or non-gated (e.g., ReLU²).

property _hf_prefix: str#

Prefix for HuggingFace format keys. Override in subclass.

property _expert_path_segment: str#

Path segment for experts (e.g., ‘mlp.experts’ or ‘mixer.experts’). Override in subclass.

_validate_expert_availability(
hf_state_dict: dict[str, Any],
n_experts: int,
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
) None#

Validate that all required experts are available in the HF state dict before loading. Only validates experts needed for the current rank and layers present in the state dict.

Parameters:
  • hf_state_dict – HuggingFace format state dict

  • n_experts – Total number of experts

  • device_mesh – Optional device mesh for expert parallelism

Raises:

RuntimeError – If required expert weights are missing from the checkpoint

_split_experts_weights(
weight: torch.Tensor,
n_experts: int,
) list[torch.Tensor]#

Split grouped expert weights into individual expert weights. For grouped expert weights with shape [n_experts, …], split into n_experts tensors each with shape […]. Supports both regular tensors and DTensors.

_concatenate_expert_weights(
expert_weights_by_layer: dict[str, Any],
n_experts: int,
) Optional[torch.Tensor]#

Concatenate the weights of separate experts into GroupedExpert weights.

Parameters:
  • expert_weights_by_layer – Nested dict structure containing expert weights

  • n_experts – Total number of experts expected

Returns:

Stacked tensor if all experts are available for a layer, None otherwise

_convert_lora_expert_to_hf(
fqn: str,
tensor: torch.Tensor,
n_experts: int,
inter_dim: int,
expert_segment: str,
) list[tuple[str, torch.Tensor]]#

Convert a grouped MoE expert LoRA tensor to per-expert HF PEFT format.

Handles the four LoRA parameter types produced by GroupedExpertsLoRA / GroupedExpertsDeepEPLoRA and converts them to per-expert lora_A.weight / lora_B.weight keys that HF PEFT understands.

The prefix (e.g. base_model.model.model.) is preserved from the incoming fqn so that both PEFT and FFT save paths work correctly.

_recombine_lora_expert_keys(
state_dict: dict[str, Any],
) dict[str, Any]#

Recombine per-expert HF LoRA keys back to grouped MoE LoRA format.

This is the reverse of _convert_lora_expert_to_hf. It detects per-expert LoRA keys (e.g. layers.0.mlp.experts.0.gate_proj.lora_A.weight) and recombines them into the grouped tensors expected by GroupedExpertsLoRA / GroupedExpertsDeepEPLoRA (e.g. layers.0.mlp.experts.lora_gate_and_up_A).

_to_hf_w_split_experts(
state_dict: dict[str, Any],
) dict[str, Any]#

Convert DeepEP format to HuggingFace format. Handles: gate_and_up_projs, down_projs -> individual expert weights

_from_hf_w_merged_experts(
hf_state_dict: dict[str, Any],
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
) dict[str, Any]#

Convert HF checkpoint to native format.

For gated activations (SwiGLU, Quick-GEGLU): Creates combined gate_and_up_projs [n_experts, dim, 2*inter_dim] and transposed down_projs tensors.

For non-gated activations (ReLU²): Creates gate_and_up_projs [n_experts, dim, inter_dim] and transposed down_projs tensors.

_convert_single_merged_expert_to_hf_split_experts(
fqn: str,
tensor: torch.Tensor,
**kwargs,
) list[tuple[str, torch.Tensor]]#

Convert a single merged expert tensor from native format to split HuggingFace format.

Parameters:
  • fqn – Fully qualified name of the tensor in native format

  • tensor – The tensor to convert

Returns:

List of (fqn, tensor) tuples in HuggingFace format, or None if not an expert tensor