core.resharding.utils#

Module Contents#

Classes#

TransferOp

Single logical send/recv operation used in a reshard plan.

ParameterMetadata

Metadata for a parameter (used when param is on different rank).

ShardingDescriptor

Descriptor for a sharded dimension for a parameter.

ReshardPlan

Reshard plan - operations for this rank.

Functions#

_get_rank_in_group

_detect_expert_index_from_param_name

Extract expert index from parameter name for TEGroupedMLP per-expert tensors.

assign_resolved_name_inplace

Compute a canonical resolved_name for EP per-expert parameters, and set global_expert_index. For non-EP or non-per-expert params, resolved_name defaults to original name.

extract_param_metadata

Extract metadata from a parameter for cross-rank communication.

select_src_metadata_balanced

Choose a representative source ParameterMetadata for a destination rank.

Data#

API#

class core.resharding.utils.TransferOp#

Single logical send/recv operation used in a reshard plan.

param_name: str#

None

peer_rank: int#

None

is_send: bool#

None

my_slice: tuple[slice, ...]#

None

peer_slice: tuple[slice, ...]#

None

task_id: int | None#

None

class core.resharding.utils.ParameterMetadata#

Metadata for a parameter (used when param is on different rank).

name: str#

None

shape: tuple[int, ...]#

None

dtype: torch.dtype#

None

element_size: int#

None

is_tp: bool#

False

partition_dim: int#

0

partition_stride: int#

1

is_ep: bool#

False

num_experts: Optional[int]#

None

owner_rank: int#

None

tensor_parallel_group_ranks: list[int] | None#

None

expert_parallel_group_ranks: list[int] | None#

None

data_parallel_group_ranks: list[int] | None#

None

pipeline_parallel_group_ranks: list[int] | None#

None

resolved_name: Optional[str]#

None

global_expert_index: Optional[int]#

None

class core.resharding.utils.ShardingDescriptor#

Descriptor for a sharded dimension for a parameter.

name: str#

None

dim: int#

None

src_stride: int#

None

dst_stride: int#

None

src_dim_ranks: list[int]#

None

dst_dim_ranks: list[int]#

None

class core.resharding.utils.ReshardPlan#

Reshard plan - operations for this rank.

send_ops: list[core.resharding.utils.TransferOp]#

None

recv_ops: list[core.resharding.utils.TransferOp]#

None

__str__()#
core.resharding.utils._get_rank_in_group(global_rank: int, group_ranks: list[int]) int#
core.resharding.utils._detect_expert_index_from_param_name(
param_name: str,
) Optional[int]#

Extract expert index from parameter name for TEGroupedMLP per-expert tensors.

core.resharding.utils.assign_resolved_name_inplace(
meta: core.resharding.utils.ParameterMetadata,
) None#

Compute a canonical resolved_name for EP per-expert parameters, and set global_expert_index. For non-EP or non-per-expert params, resolved_name defaults to original name.

core.resharding.utils.extract_param_metadata(
param: torch.nn.Parameter,
param_name: str,
owner_rank: int,
pg_collection,
num_experts: Optional[int] = None,
) core.resharding.utils.ParameterMetadata#

Extract metadata from a parameter for cross-rank communication.

core.resharding.utils.select_src_metadata_balanced(
src_meta_list: list[core.resharding.utils.ParameterMetadata],
dst_metadata: core.resharding.utils.ParameterMetadata,
dst_rank: int,
) core.resharding.utils.ParameterMetadata#

Choose a representative source ParameterMetadata for a destination rank.

Multiple source data-parallel (DP) groups may hold the same logical parameter. To avoid always reading from the same group, we:

  • bucket src_meta_list by their DP group (tuple of ranks)

  • if there is only one bucket, just return the first entry

  • otherwise, map the destination rank’s DP index to one of the source DP groups in a round-robin fashion, and pick the first metadata in it.

core.resharding.utils.logger#

‘getLogger(…)’