nemo_automodel.components.moe.state_dict_utils

View as Markdown

Module Contents

Functions

NameDescription
create_dtensor_from_localCreate a DTensor from a local tensor for expert parallelism.
get_expert_range_for_rank_from_meshGet the range of experts that should be loaded for the current rank.
get_expert_slice_for_rankGet the slice of experts present on the current rank for a DTensor.
get_submeshAccess a submesh by dim names from the given mesh.
is_dtensorCheck if a tensor is a DTensor.
should_load_expert_for_rankCheck if a specific expert should be loaded on the current rank.
split_experts_weights_dtensor_awareSplit expert weights, handling both regular tensors and DTensors.
validate_dtensor_expert_shardingValidate that a DTensor is properly sharded for expert parallelism.

API

nemo_automodel.components.moe.state_dict_utils.create_dtensor_from_local(
local_tensor: torch.Tensor,
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh],
rank: typing.Optional[int] = None
) -> torch.Tensor

Create a DTensor from a local tensor for expert parallelism.

Parameters:

local_tensor
torch.Tensor

Local portion of the tensor on this rank

device_mesh
Optional[DeviceMesh]

Device mesh for DTensor creation

rank
Optional[int]Defaults to None

Current rank (for device placement)

Returns: torch.Tensor

DTensor if device_mesh is provided and DTensor is available, otherwise local_tensor

nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh(
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh],
n_experts: int
) -> tuple[int, int]

Get the range of experts that should be loaded for the current rank.

Parameters:

device_mesh
Optional[DeviceMesh]

Device mesh for expert parallelism

n_experts
int

Total number of experts

Returns: tuple[int, int]

Tuple of (start_expert_id, end_expert_id) for this rank

nemo_automodel.components.moe.state_dict_utils.get_expert_slice_for_rank(
experts_tensor: torch.Tensor,
n_experts: int
) -> tuple[torch.Tensor, int, int]

Get the slice of experts present on the current rank for a DTensor.

For non-DTensors, returns the full tensor with start_expert=0, end_expert=n_experts. For DTensors sharded along the expert dimension (dim=0), returns only the local experts.

Parameters:

experts_tensor
torch.Tensor

Input tensor containing expert weights [n_experts, …]

n_experts
int

Total number of experts across all ranks

Returns: torch.Tensor

tuple of (local_tensor, start_expert_id, end_expert_id)

nemo_automodel.components.moe.state_dict_utils.get_submesh(
device_mesh: torch.distributed.device_mesh.DeviceMesh,
dims: tuple[str, ...]
) -> torch.distributed.device_mesh.DeviceMesh

Access a submesh by dim names from the given mesh.

nemo_automodel.components.moe.state_dict_utils.is_dtensor(
tensor: torch.Tensor
) -> bool

Check if a tensor is a DTensor.

nemo_automodel.components.moe.state_dict_utils.should_load_expert_for_rank(
expert_id: int,
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh],
n_experts: int
) -> bool

Check if a specific expert should be loaded on the current rank.

Parameters:

expert_id
int

The expert ID to check

device_mesh
Optional[DeviceMesh]

Device mesh for expert parallelism

n_experts
int

Total number of experts

Returns: bool

True if this expert should be loaded on the current rank

nemo_automodel.components.moe.state_dict_utils.split_experts_weights_dtensor_aware(
weight: torch.Tensor,
n_experts: int
) -> tuple[list[torch.Tensor], list[int]]

Split expert weights, handling both regular tensors and DTensors.

For DTensors, only splits the experts present on the current rank.

Parameters:

weight
torch.Tensor

Expert weights tensor [n_experts, …] (regular tensor or DTensor)

n_experts
int

Total number of experts across all ranks

Returns: list[torch.Tensor]

tuple of (split_weights, expert_ids)

nemo_automodel.components.moe.state_dict_utils.validate_dtensor_expert_sharding(
tensor: torch.Tensor,
expected_experts: int,
tensor_name: str = 'tensor'
) -> bool

Validate that a DTensor is properly sharded for expert parallelism.

Parameters:

tensor
torch.Tensor

Tensor to validate

expected_experts
int

Expected total number of experts

tensor_name
strDefaults to 'tensor'

Name for error messages

Returns: bool

True if valid, raises ValueError if invalid