nemo_automodel.components.moe.state_dict_utils
#
Module Contents#
Functions#
Check if a tensor is a DTensor. |
|
Get the slice of experts present on the current rank for a DTensor. |
|
Split expert weights, handling both regular tensors and DTensors. |
|
Validate that a DTensor is properly sharded for expert parallelism. |
|
Create a DTensor from a local tensor for expert parallelism. |
|
Get the range of experts that should be loaded for the current rank. |
|
Check if a specific expert should be loaded on the current rank. |
API#
- nemo_automodel.components.moe.state_dict_utils.is_dtensor(tensor: torch.Tensor) bool #
Check if a tensor is a DTensor.
- nemo_automodel.components.moe.state_dict_utils.get_submesh(
- device_mesh: torch.distributed.device_mesh.DeviceMesh,
- dims: tuple[str, ...],
- nemo_automodel.components.moe.state_dict_utils.get_expert_slice_for_rank(
- experts_tensor: torch.Tensor,
- n_experts: int,
Get the slice of experts present on the current rank for a DTensor.
For non-DTensors, returns the full tensor with start_expert=0, end_expert=n_experts. For DTensors sharded along the expert dimension (dim=0), returns only the local experts.
- Parameters:
experts_tensor – Input tensor containing expert weights [n_experts, …]
n_experts – Total number of experts across all ranks
- Returns:
tuple of (local_tensor, start_expert_id, end_expert_id)
local_tensor: The local portion of the tensor
start_expert_id: Global ID of the first expert on this rank
end_expert_id: Global ID after the last expert on this rank (exclusive)
- nemo_automodel.components.moe.state_dict_utils.split_experts_weights_dtensor_aware(
- weight: torch.Tensor,
- n_experts: int,
Split expert weights, handling both regular tensors and DTensors.
For DTensors, only splits the experts present on the current rank.
- Parameters:
weight – Expert weights tensor [n_experts, …] (regular tensor or DTensor)
n_experts – Total number of experts across all ranks
- Returns:
tuple of (split_weights, expert_ids)
split_weights: List of individual expert weight tensors
expert_ids: List of global expert IDs corresponding to split_weights
- nemo_automodel.components.moe.state_dict_utils.validate_dtensor_expert_sharding(
- tensor: torch.Tensor,
- expected_experts: int,
- tensor_name: str = 'tensor',
Validate that a DTensor is properly sharded for expert parallelism.
- Parameters:
tensor – Tensor to validate
expected_experts – Expected total number of experts
tensor_name – Name for error messages
- Returns:
True if valid, raises ValueError if invalid
- nemo_automodel.components.moe.state_dict_utils.create_dtensor_from_local(
- local_tensor: torch.Tensor,
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
- rank: Optional[int] = None,
Create a DTensor from a local tensor for expert parallelism.
- Parameters:
local_tensor – Local portion of the tensor on this rank
device_mesh – Device mesh for DTensor creation
rank – Current rank (for device placement)
- Returns:
DTensor if device_mesh is provided and DTensor is available, otherwise local_tensor
- nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh(
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
- n_experts: int,
Get the range of experts that should be loaded for the current rank.
- Parameters:
device_mesh – Device mesh for expert parallelism
n_experts – Total number of experts
- Returns:
Tuple of (start_expert_id, end_expert_id) for this rank
- nemo_automodel.components.moe.state_dict_utils.should_load_expert_for_rank(
- expert_id: int,
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
- n_experts: int,
Check if a specific expert should be loaded on the current rank.
- Parameters:
expert_id – The expert ID to check
device_mesh – Device mesh for expert parallelism
n_experts – Total number of experts
- Returns:
True if this expert should be loaded on the current rank