core.resharding.utils#
Module Contents#
Classes#
Single logical send/recv operation used in a reshard plan. |
|
Metadata for a parameter (used when param is on different rank). |
|
Descriptor for a sharded dimension for a parameter. |
|
Reshard plan - operations for this rank. |
Functions#
Extract expert index from parameter name for TEGroupedMLP per-expert tensors. |
|
EP-only canonicalization for per-expert parameters. |
|
Set meta.resolved_name so the planner can match the same weights across models. |
|
Build a mapping local_module_prefix -> global_module_prefix for PP layer modules. |
|
Rewrite a parameter name to use global layer indices (PP-aware). |
|
Extract metadata from a parameter for cross-rank communication. |
|
Choose a representative source |
Data#
API#
- class core.resharding.utils.TransferOp#
Single logical send/recv operation used in a reshard plan.
- param_name: str#
None
- peer_rank: int#
None
- is_send: bool#
None
- my_slice: tuple[slice, ...]#
None
- peer_slice: tuple[slice, ...]#
None
- task_id: int | None#
None
- class core.resharding.utils.ParameterMetadata#
Metadata for a parameter (used when param is on different rank).
- name: str#
None
- shape: tuple[int, ...]#
None
- dtype: torch.dtype#
None
- element_size: int#
None
- is_tp: bool#
False
- partition_dim: int#
0
- partition_stride: int#
1
- partition_sizes: list[int] | None#
None
- is_ep: bool#
False
- num_experts: Optional[int]#
None
- owner_rank: int#
None
- tensor_parallel_group_ranks: list[int] | None#
None
- expert_parallel_group_ranks: list[int] | None#
None
- data_parallel_group_ranks: list[int] | None#
None
- pipeline_parallel_group_ranks: list[int] | None#
None
- resolved_name: Optional[str]#
None
- global_expert_index: Optional[int]#
None
- class core.resharding.utils.ShardingDescriptor#
Descriptor for a sharded dimension for a parameter.
- name: str#
None
- dim: int#
None
- src_stride: int#
None
- dst_stride: int#
None
- src_dim_ranks: list[int]#
None
- dst_dim_ranks: list[int]#
None
- class core.resharding.utils.ReshardPlan#
Reshard plan - operations for this rank.
- send_ops: list[core.resharding.utils.TransferOp]#
None
- recv_ops: list[core.resharding.utils.TransferOp]#
None
- __str__()#
- core.resharding.utils._get_rank_in_group(global_rank: int, group_ranks: list[int]) int#
- core.resharding.utils._detect_expert_index_from_param_name(
- param_name: str,
Extract expert index from parameter name for TEGroupedMLP per-expert tensors.
- core.resharding.utils.assign_ep_resolved_name_inplace(
- meta: core.resharding.utils.ParameterMetadata,
- *,
- base_name: str | None = None,
EP-only canonicalization for per-expert parameters.
Under Expert Parallelism (EP), each rank owns a subset of experts with local indices (e.g., rank 1 has “weight0” locally, but it’s actually global expert 4). The raw param name can’t be used to match across source/destination because the same local name refers to different global experts on different ranks. This function remaps local expert indices to global indices in
resolved_nameand setsglobal_expert_index.Effects:
Sets meta.resolved_name (defaults to base_name/meta.name for non-EP).
Sets meta.global_expert_index for per-expert parameters; otherwise leaves it as None.
- core.resharding.utils.assign_resolved_name_inplace(
- meta: core.resharding.utils.ParameterMetadata,
- *,
- layer_module_prefix_map: Mapping[str, str] | None = None,
- base_name: str | None = None,
Set meta.resolved_name so the planner can match the same weights across models.
It rewrites PP layer indices to global layer indices (when layer_module_prefix_map is provided) and rewrites EP per-expert indices (weightK/biasK) to global expert indices.
- core.resharding.utils._build_layer_module_prefix_map(
- module: torch.nn.Module,
Build a mapping local_module_prefix -> global_module_prefix for PP layer modules.
Megatron assigns a global, 1-indexed layer_number to each transformer layer module at construction time (including PP/VPP/layout offsets). We convert that to the 0-indexed naming convention used in parameter names and build a map such as:
“decoder.layers.0” → “decoder.layers.16” (if layer_number == 17)
- core.resharding.utils._resolve_global_layer_number_in_name(
- name: str,
- layer_module_prefix_map: Mapping[str, str],
Rewrite a parameter name to use global layer indices (PP-aware).
Given a parameter name like decoder.layers.0.self_attention…, this function rewrites the decoder.layers.0 prefix to the corresponding global layer index using the owning layer module’s layer_number.
Implementation:
Build a {local_prefix -> global_prefix} map once (outside the per-parameter loop).
Perform a longest-prefix match replacement so we only rewrite the module path portion.
- core.resharding.utils.extract_param_metadata(
- param: torch.nn.Parameter,
- param_name: str,
- owner_rank: int,
- pg_collection,
- num_experts: Optional[int] = None,
- layer_module_prefix_map: Mapping[str, str] | None = None,
- rank_offset: int = 0,
- _rank_list_cache: dict | None = None,
Extract metadata from a parameter for cross-rank communication.
- Parameters:
_rank_list_cache – Optional dict used to deduplicate rank lists so that params sharing the same process group reuse one object. This dramatically shrinks pickle size when metadata is gathered across many ranks (pickle uses backreferences for same-
id()objects, avoiding re-serialization of identical group lists).
- core.resharding.utils.select_src_metadata_balanced(
- src_meta_list: list[core.resharding.utils.ParameterMetadata],
- dst_metadata: core.resharding.utils.ParameterMetadata,
- dst_rank: int,
Choose a representative source
ParameterMetadatafor a destination rank.The selected metadata provides topology information (TP/EP/DP group ranks) that the LCM transfer planner uses to compute actual source ranks and slices. This function doesn’t perform transfers itself - it just picks which source configuration to use as reference for planning.
Two scenarios for EP-sharded parameters:
Non-collocated mode (same EP size, different rank numbering):
Filter by matching EP local rank to pair ranks with same expert position
Example: src ranks [0-63] and dst ranks [64-127] both with EP=8
Dst EP local 0 should use src EP local 0 as reference (same experts)
Resharding mode (different EP sizes):
Skip EP local rank filtering (sizes don’t correspond)
Example: EP=8→EP=16 means dst EP local 8 has no matching src EP local
Expert matching handled by resolved_name; LCM handles TP dimension changes
Finally, balances across data-parallel (DP) groups to distribute load:
Groups src_meta_list by DP group
Selects source DP group via round-robin: dst_rank % num_src_dp_groups
Ensures even distribution of transfer load across source DP replicas
- core.resharding.utils.logger#
‘getLogger(…)’