nemo_automodel.components.moe.state_dict_mixin#

Module Contents#

Classes#

MoESplitExpertsStateDictMixin

Mixin class providing MoE state dict conversion utilities.

API#

class nemo_automodel.components.moe.state_dict_mixin.MoESplitExpertsStateDictMixin#

Mixin class providing MoE state dict conversion utilities.

This mixin provides methods for:

  • Expert parallelism calculations (ranges, assignment)

  • Format conversion between HuggingFace and native formats

  • Both GroupedExperts and DeepEP format support

  • DTensor-aware expert loading and conversion

Can be used by any MoE model that needs expert parallelism and format conversion.

_validate_expert_availability(
hf_state_dict: dict[str, Any],
n_experts: int,
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
) None#

Validate that all required experts are available in the HF state dict before loading. Only validates experts needed for the current rank and layers present in the state dict.

Parameters:
  • hf_state_dict – HuggingFace format state dict

  • n_experts – Total number of experts

  • device_mesh – Optional device mesh for expert parallelism

Raises:

RuntimeError – If required expert weights are missing from the checkpoint

_split_experts_weights(
weight: torch.Tensor,
n_experts: int,
) list[torch.Tensor]#

Split grouped expert weights into individual expert weights. For grouped expert weights with shape [n_experts, …], split into n_experts tensors each with shape […]. Supports both regular tensors and DTensors.

_concatenate_expert_weights(
expert_weights_by_layer: dict[str, Any],
n_experts: int,
) Optional[torch.Tensor]#

Concatenate the weights of separate experts into GroupedExpert weights.

Parameters:
  • expert_weights_by_layer – Nested dict structure containing expert weights

  • n_experts – Total number of experts expected

Returns:

Stacked tensor if all experts are available for a layer, None otherwise

_to_hf_w_split_experts(
state_dict: dict[str, Any],
) dict[str, Any]#

Convert DeepEP format to HuggingFace format. Handles: gate_and_up_projs, down_projs -> individual expert weights

_from_hf_w_merged_experts(
hf_state_dict: dict[str, Any],
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
) dict[str, Any]#

Convert HF checkpoint to DeepEP format. Creates combined gate_and_up_projs and transposed down_projs tensors.