core.resharding.utils#
Module Contents#
Classes#
Single logical send/recv operation used in a reshard plan. |
|
Metadata for a parameter (used when param is on different rank). |
|
Descriptor for a sharded dimension for a parameter. |
|
Reshard plan - operations for this rank. |
Functions#
Extract expert index from parameter name for TEGroupedMLP per-expert tensors. |
|
Compute a canonical resolved_name for EP per-expert parameters, and set global_expert_index. For non-EP or non-per-expert params, resolved_name defaults to original name. |
|
Extract metadata from a parameter for cross-rank communication. |
|
Choose a representative source |
Data#
API#
- class core.resharding.utils.TransferOp#
Single logical send/recv operation used in a reshard plan.
- param_name: str#
None
- peer_rank: int#
None
- is_send: bool#
None
- my_slice: tuple[slice, ...]#
None
- peer_slice: tuple[slice, ...]#
None
- task_id: int | None#
None
- class core.resharding.utils.ParameterMetadata#
Metadata for a parameter (used when param is on different rank).
- name: str#
None
- shape: tuple[int, ...]#
None
- dtype: torch.dtype#
None
- element_size: int#
None
- is_tp: bool#
False
- partition_dim: int#
0
- partition_stride: int#
1
- is_ep: bool#
False
- num_experts: Optional[int]#
None
- owner_rank: int#
None
- tensor_parallel_group_ranks: list[int] | None#
None
- expert_parallel_group_ranks: list[int] | None#
None
- data_parallel_group_ranks: list[int] | None#
None
- pipeline_parallel_group_ranks: list[int] | None#
None
- resolved_name: Optional[str]#
None
- global_expert_index: Optional[int]#
None
- class core.resharding.utils.ShardingDescriptor#
Descriptor for a sharded dimension for a parameter.
- name: str#
None
- dim: int#
None
- src_stride: int#
None
- dst_stride: int#
None
- src_dim_ranks: list[int]#
None
- dst_dim_ranks: list[int]#
None
- class core.resharding.utils.ReshardPlan#
Reshard plan - operations for this rank.
- send_ops: list[core.resharding.utils.TransferOp]#
None
- recv_ops: list[core.resharding.utils.TransferOp]#
None
- __str__()#
- core.resharding.utils._get_rank_in_group(global_rank: int, group_ranks: list[int]) int#
- core.resharding.utils._detect_expert_index_from_param_name(
- param_name: str,
Extract expert index from parameter name for TEGroupedMLP per-expert tensors.
- core.resharding.utils.assign_resolved_name_inplace( ) None#
Compute a canonical resolved_name for EP per-expert parameters, and set global_expert_index. For non-EP or non-per-expert params, resolved_name defaults to original name.
- core.resharding.utils.extract_param_metadata(
- param: torch.nn.Parameter,
- param_name: str,
- owner_rank: int,
- pg_collection,
- num_experts: Optional[int] = None,
Extract metadata from a parameter for cross-rank communication.
- core.resharding.utils.select_src_metadata_balanced(
- src_meta_list: list[core.resharding.utils.ParameterMetadata],
- dst_metadata: core.resharding.utils.ParameterMetadata,
- dst_rank: int,
Choose a representative source
ParameterMetadatafor a destination rank.Multiple source data-parallel (DP) groups may hold the same logical parameter. To avoid always reading from the same group, we:
bucket
src_meta_listby their DP group (tuple of ranks)if there is only one bucket, just return the first entry
otherwise, map the destination rank’s DP index to one of the source DP groups in a round-robin fashion, and pick the first metadata in it.
- core.resharding.utils.logger#
‘getLogger(…)’