nemo_automodel.components.moe.state_dict_mixin
#
Module Contents#
Classes#
Mixin class providing MoE state dict conversion utilities. |
API#
- class nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin#
Mixin class providing MoE state dict conversion utilities.
This mixin provides methods for:
Expert parallelism calculations (ranges, assignment)
Format conversion between HuggingFace and native formats
Both GroupedExperts and DeepEP format support
DTensor-aware expert loading and conversion
Can be used by any MoE model that needs expert parallelism and format conversion.
- _validate_expert_availability(
- hf_state_dict: dict[str, Any],
- n_experts: int,
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
Validate that all required experts are available in the HF state dict before loading. Only validates experts needed for the current rank and layers present in the state dict.
- Parameters:
hf_state_dict – HuggingFace format state dict
n_experts – Total number of experts
device_mesh – Optional device mesh for expert parallelism
- Raises:
RuntimeError – If required expert weights are missing from the checkpoint
- _split_experts_weights(
- weight: torch.Tensor,
- n_experts: int,
Split grouped expert weights into individual expert weights. For grouped expert weights with shape [n_experts, …], split into n_experts tensors each with shape […]. Supports both regular tensors and DTensors.
- _concatenate_expert_weights(
- expert_weights_by_layer: dict[str, Any],
- n_experts: int,
Concatenate the weights of separate experts into GroupedExpert weights.
- Parameters:
expert_weights_by_layer – Nested dict structure containing expert weights
n_experts – Total number of experts expected
- Returns:
Stacked tensor if all experts are available for a layer, None otherwise
- _to_hf_w_split_experts(
- state_dict: dict[str, Any],
Convert DeepEP format to HuggingFace format. Handles: gate_and_up_projs, down_projs -> individual expert weights
- _from_hf_w_merged_experts(
- hf_state_dict: dict[str, Any],
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
Convert HF checkpoint to DeepEP format. Creates combined gate_and_up_projs and transposed down_projs tensors.