core.resharding.utils#
Module Contents#
Classes#
Single logical send/recv operation used in a reshard plan. |
|
Metadata for a parameter (used when param is on different rank). |
|
Descriptor for a sharded dimension for a parameter. |
|
Reshard plan - operations for this rank. |
Functions#
Extract expert index from parameter name for TEGroupedMLP per-expert tensors. |
|
EP-only canonicalization for per-expert parameters. |
|
Set meta.resolved_name so the planner can match the same weights across models. |
|
Build a mapping local_module_prefix -> global_module_prefix for PP layer modules. |
|
Rewrite a parameter name to use global layer indices (PP-aware). |
|
Extract metadata from a parameter for cross-rank communication. |
|
Choose a representative source |
Data#
API#
- class core.resharding.utils.TransferOp#
Single logical send/recv operation used in a reshard plan.
- param_name: str#
None
- peer_rank: int#
None
- is_send: bool#
None
- my_slice: tuple[slice, ...]#
None
- peer_slice: tuple[slice, ...]#
None
- task_id: int | None#
None
- class core.resharding.utils.ParameterMetadata#
Metadata for a parameter (used when param is on different rank).
- name: str#
None
- shape: tuple[int, ...]#
None
- dtype: torch.dtype#
None
- element_size: int#
None
- is_tp: bool#
False
- partition_dim: int#
0
- partition_stride: int#
1
- is_ep: bool#
False
- num_experts: Optional[int]#
None
- owner_rank: int#
None
- tensor_parallel_group_ranks: list[int] | None#
None
- expert_parallel_group_ranks: list[int] | None#
None
- data_parallel_group_ranks: list[int] | None#
None
- pipeline_parallel_group_ranks: list[int] | None#
None
- resolved_name: Optional[str]#
None
- global_expert_index: Optional[int]#
None
- class core.resharding.utils.ShardingDescriptor#
Descriptor for a sharded dimension for a parameter.
- name: str#
None
- dim: int#
None
- src_stride: int#
None
- dst_stride: int#
None
- src_dim_ranks: list[int]#
None
- dst_dim_ranks: list[int]#
None
- class core.resharding.utils.ReshardPlan#
Reshard plan - operations for this rank.
- send_ops: list[core.resharding.utils.TransferOp]#
None
- recv_ops: list[core.resharding.utils.TransferOp]#
None
- __str__()#
- core.resharding.utils._get_rank_in_group(global_rank: int, group_ranks: list[int]) int#
- core.resharding.utils._detect_expert_index_from_param_name(
- param_name: str,
Extract expert index from parameter name for TEGroupedMLP per-expert tensors.
- core.resharding.utils.assign_ep_resolved_name_inplace(
- meta: core.resharding.utils.ParameterMetadata,
- *,
- base_name: str | None = None,
EP-only canonicalization for per-expert parameters.
Under Expert Parallelism (EP), each rank owns a subset of experts with local indices (e.g., rank 1 has “weight0” locally, but it’s actually global expert 4). The raw param name can’t be used to match across source/destination because the same local name refers to different global experts on different ranks. This function remaps local expert indices to global indices in
resolved_nameand setsglobal_expert_index.Effects:
Sets meta.resolved_name (defaults to base_name/meta.name for non-EP).
Sets meta.global_expert_index for per-expert parameters; otherwise leaves it as None.
- core.resharding.utils.assign_resolved_name_inplace(
- meta: core.resharding.utils.ParameterMetadata,
- *,
- layer_module_prefix_map: Mapping[str, str] | None = None,
- base_name: str | None = None,
Set meta.resolved_name so the planner can match the same weights across models.
It rewrites PP layer indices to global layer indices (when layer_module_prefix_map is provided) and rewrites EP per-expert indices (weightK/biasK) to global expert indices.
- core.resharding.utils._build_layer_module_prefix_map(
- module: torch.nn.Module,
Build a mapping local_module_prefix -> global_module_prefix for PP layer modules.
Megatron assigns a global, 1-indexed layer_number to each transformer layer module at construction time (including PP/VPP/layout offsets). We convert that to the 0-indexed naming convention used in parameter names and build a map such as:
“decoder.layers.0” → “decoder.layers.16” (if layer_number == 17)
- core.resharding.utils._resolve_global_layer_number_in_name(
- name: str,
- layer_module_prefix_map: Mapping[str, str],
Rewrite a parameter name to use global layer indices (PP-aware).
Given a parameter name like decoder.layers.0.self_attention…, this function rewrites the decoder.layers.0 prefix to the corresponding global layer index using the owning layer module’s layer_number.
Implementation:
Build a {local_prefix -> global_prefix} map once (outside the per-parameter loop).
Perform a longest-prefix match replacement so we only rewrite the module path portion.
- core.resharding.utils.extract_param_metadata(
- param: torch.nn.Parameter,
- param_name: str,
- owner_rank: int,
- pg_collection,
- num_experts: Optional[int] = None,
- layer_module_prefix_map: Mapping[str, str] | None = None,
Extract metadata from a parameter for cross-rank communication.
- core.resharding.utils.select_src_metadata_balanced(
- src_meta_list: list[core.resharding.utils.ParameterMetadata],
- dst_metadata: core.resharding.utils.ParameterMetadata,
- dst_rank: int,
Choose a representative source
ParameterMetadatafor a destination rank.The selected metadata provides topology information (TP/EP/DP group ranks) that the LCM transfer planner uses to compute actual source ranks and slices. This function doesn’t perform transfers itself - it just picks which source configuration to use as reference for planning.
Two scenarios for EP-sharded parameters:
Non-collocated mode (same EP size, different rank numbering):
Filter by matching EP local rank to pair ranks with same expert position
Example: src ranks [0-63] and dst ranks [64-127] both with EP=8
Dst EP local 0 should use src EP local 0 as reference (same experts)
Resharding mode (different EP sizes):
Skip EP local rank filtering (sizes don’t correspond)
Example: EP=8→EP=16 means dst EP local 8 has no matching src EP local
Expert matching handled by resolved_name; LCM handles TP dimension changes
Finally, balances across data-parallel (DP) groups to distribute load:
Groups src_meta_list by DP group
Selects source DP group via round-robin: dst_rank % num_src_dp_groups
Ensures even distribution of transfer load across source DP replicas
- core.resharding.utils.logger#
‘getLogger(…)’