nemo_automodel.components.moe.state_dict_utils#

Module Contents#

Functions#

is_dtensor

Check if a tensor is a DTensor.

get_submesh

get_expert_slice_for_rank

Get the slice of experts present on the current rank for a DTensor.

split_experts_weights_dtensor_aware

Split expert weights, handling both regular tensors and DTensors.

validate_dtensor_expert_sharding

Validate that a DTensor is properly sharded for expert parallelism.

create_dtensor_from_local

Create a DTensor from a local tensor for expert parallelism.

get_expert_range_for_rank_from_mesh

Get the range of experts that should be loaded for the current rank.

should_load_expert_for_rank

Check if a specific expert should be loaded on the current rank.

API#

nemo_automodel.components.moe.state_dict_utils.is_dtensor(tensor: torch.Tensor) bool#

Check if a tensor is a DTensor.

nemo_automodel.components.moe.state_dict_utils.get_submesh(
device_mesh: torch.distributed.device_mesh.DeviceMesh,
dims: tuple[str, ...],
) torch.distributed.device_mesh.DeviceMesh#
nemo_automodel.components.moe.state_dict_utils.get_expert_slice_for_rank(
experts_tensor: torch.Tensor,
n_experts: int,
) tuple[torch.Tensor, int, int]#

Get the slice of experts present on the current rank for a DTensor.

For non-DTensors, returns the full tensor with start_expert=0, end_expert=n_experts. For DTensors sharded along the expert dimension (dim=0), returns only the local experts.

Parameters:
  • experts_tensor – Input tensor containing expert weights [n_experts, …]

  • n_experts – Total number of experts across all ranks

Returns:

tuple of (local_tensor, start_expert_id, end_expert_id)

  • local_tensor: The local portion of the tensor

  • start_expert_id: Global ID of the first expert on this rank

  • end_expert_id: Global ID after the last expert on this rank (exclusive)

nemo_automodel.components.moe.state_dict_utils.split_experts_weights_dtensor_aware(
weight: torch.Tensor,
n_experts: int,
) tuple[list[torch.Tensor], list[int]]#

Split expert weights, handling both regular tensors and DTensors.

For DTensors, only splits the experts present on the current rank.

Parameters:
  • weight – Expert weights tensor [n_experts, …] (regular tensor or DTensor)

  • n_experts – Total number of experts across all ranks

Returns:

tuple of (split_weights, expert_ids)

  • split_weights: List of individual expert weight tensors

  • expert_ids: List of global expert IDs corresponding to split_weights

nemo_automodel.components.moe.state_dict_utils.validate_dtensor_expert_sharding(
tensor: torch.Tensor,
expected_experts: int,
tensor_name: str = 'tensor',
) bool#

Validate that a DTensor is properly sharded for expert parallelism.

Parameters:
  • tensor – Tensor to validate

  • expected_experts – Expected total number of experts

  • tensor_name – Name for error messages

Returns:

True if valid, raises ValueError if invalid

nemo_automodel.components.moe.state_dict_utils.create_dtensor_from_local(
local_tensor: torch.Tensor,
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
rank: Optional[int] = None,
) torch.Tensor#

Create a DTensor from a local tensor for expert parallelism.

Parameters:
  • local_tensor – Local portion of the tensor on this rank

  • device_mesh – Device mesh for DTensor creation

  • rank – Current rank (for device placement)

Returns:

DTensor if device_mesh is provided and DTensor is available, otherwise local_tensor

nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh(
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
n_experts: int,
) tuple[int, int]#

Get the range of experts that should be loaded for the current rank.

Parameters:
  • device_mesh – Device mesh for expert parallelism

  • n_experts – Total number of experts

Returns:

Tuple of (start_expert_id, end_expert_id) for this rank

nemo_automodel.components.moe.state_dict_utils.should_load_expert_for_rank(
expert_id: int,
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
n_experts: int,
) bool#

Check if a specific expert should be loaded on the current rank.

Parameters:
  • expert_id – The expert ID to check

  • device_mesh – Device mesh for expert parallelism

  • n_experts – Total number of experts

Returns:

True if this expert should be loaded on the current rank