nemo_automodel.components.moe.state_dict_utils
nemo_automodel.components.moe.state_dict_utils
Module Contents
Functions
API
Create a DTensor from a local tensor for expert parallelism.
Parameters:
Local portion of the tensor on this rank
Device mesh for DTensor creation
Current rank (for device placement)
Returns: torch.Tensor
DTensor if device_mesh is provided and DTensor is available, otherwise local_tensor
Get the range of experts that should be loaded for the current rank.
Parameters:
Device mesh for expert parallelism
Total number of experts
Returns: tuple[int, int]
Tuple of (start_expert_id, end_expert_id) for this rank
Get the slice of experts present on the current rank for a DTensor.
For non-DTensors, returns the full tensor with start_expert=0, end_expert=n_experts. For DTensors sharded along the expert dimension (dim=0), returns only the local experts.
Parameters:
Input tensor containing expert weights [n_experts, …]
Total number of experts across all ranks
Returns: torch.Tensor
tuple of (local_tensor, start_expert_id, end_expert_id)
Access a submesh by dim names from the given mesh.
Check if a tensor is a DTensor.
Check if a specific expert should be loaded on the current rank.
Parameters:
The expert ID to check
Device mesh for expert parallelism
Total number of experts
Returns: bool
True if this expert should be loaded on the current rank
Split expert weights, handling both regular tensors and DTensors.
For DTensors, only splits the experts present on the current rank.
Parameters:
Expert weights tensor [n_experts, …] (regular tensor or DTensor)
Total number of experts across all ranks
Returns: list[torch.Tensor]
tuple of (split_weights, expert_ids)
Validate that a DTensor is properly sharded for expert parallelism.
Parameters:
Tensor to validate
Expected total number of experts
Name for error messages
Returns: bool
True if valid, raises ValueError if invalid