core.resharding.utils#
Module Contents#
Classes#
Single logical send/recv operation used in a reshard plan. |
|
Metadata for a parameter (used when param is on different rank). |
|
Descriptor for a sharded dimension for a parameter. |
|
Reshard plan - operations for this rank. |
Functions#
Extract expert index from parameter name for TEGroupedMLP per-expert tensors. |
|
EP-only canonicalization for per-expert parameters. |
|
Set meta.resolved_name so the planner can match the same weights across models. |
|
Yield |
|
Yield |
|
Return the cached |
|
Drop the cached refit tensor dict so the next call rebuilds it. |
|
Build a mapping local_module_prefix -> global_module_prefix for PP layer modules. |
|
Rewrite a parameter name to use global layer indices (PP-aware). |
|
Extract metadata from a parameter for cross-rank communication. |
|
In non-collocated mode with matching EP size, restrict candidates to the
source rank holding the same global experts as |
|
Round-robin across source DP groups so transfer load spreads evenly. |
|
Choose a representative source |
Data#
API#
- class core.resharding.utils.TransferOp#
Single logical send/recv operation used in a reshard plan.
- param_name: str#
None
- peer_rank: int#
None
- is_send: bool#
None
- my_slice: tuple[slice, ...]#
None
- peer_slice: tuple[slice, ...]#
None
- task_id: int | None#
None
- class core.resharding.utils.ParameterMetadata#
Metadata for a parameter (used when param is on different rank).
- name: str#
None
- shape: tuple[int, ...]#
None
- dtype: torch.dtype#
None
- element_size: int#
None
- is_tp: bool#
False
- partition_dim: int#
0
- partition_stride: int#
1
- partition_sizes: list[int] | None#
None
- is_ep: bool#
False
- num_experts: Optional[int]#
None
- owner_rank: int#
None
- tensor_parallel_group_ranks: list[int] | None#
None
- expert_parallel_group_ranks: list[int] | None#
None
- data_parallel_group_ranks: list[int] | None#
None
- pipeline_parallel_group_ranks: list[int] | None#
None
- resolved_name: Optional[str]#
None
- global_expert_index: Optional[int]#
None
- class core.resharding.utils.ShardingDescriptor#
Descriptor for a sharded dimension for a parameter.
- name: str#
None
- dim: int#
None
- src_stride: int#
None
- dst_stride: int#
None
- src_dim_ranks: list[int]#
None
- dst_dim_ranks: list[int]#
None
- class core.resharding.utils.ReshardPlan#
Reshard plan - operations for this rank.
- send_ops: list[core.resharding.utils.TransferOp]#
None
- recv_ops: list[core.resharding.utils.TransferOp]#
None
- transform: Optional[core.resharding.transforms.ReshardTransform]#
None
- buffer_dtypes: Optional[dict[str, torch.dtype]]#
None
- __str__()#
- core.resharding.utils._get_rank_in_group(global_rank: int, group_ranks: list[int]) int#
- core.resharding.utils._EXPERT_PARAM_RE#
‘compile(…)’
- core.resharding.utils._detect_expert_index_from_param_name(
- param_name: str,
Extract expert index from parameter name for TEGroupedMLP per-expert tensors.
- core.resharding.utils.assign_ep_resolved_name_inplace(
- meta: core.resharding.utils.ParameterMetadata,
- *,
- base_name: str | None = None,
EP-only canonicalization for per-expert parameters.
Under Expert Parallelism (EP), each rank owns a subset of experts with local indices (e.g., rank 1 has “weight0” locally, but it’s actually global expert 4). The raw param name can’t be used to match across source/destination because the same local name refers to different global experts on different ranks. This function remaps local expert indices to global indices in
resolved_nameand setsglobal_expert_index.Effects:
Sets meta.resolved_name (defaults to base_name/meta.name for non-EP).
Sets meta.global_expert_index for per-expert parameters; otherwise leaves it as None.
- core.resharding.utils.assign_resolved_name_inplace(
- meta: core.resharding.utils.ParameterMetadata,
- *,
- layer_module_prefix_map: Mapping[str, str] | None = None,
- base_name: str | None = None,
Set meta.resolved_name so the planner can match the same weights across models.
It rewrites PP layer indices to global layer indices (when layer_module_prefix_map is provided) and rewrites EP per-expert indices (weightK/biasK) to global expert indices.
- core.resharding.utils.named_persistent_buffers(module: torch.nn.Module)#
Yield
(full_name, parent_module, buf_name, tensor)for every persistent buffer inmodule. Skips_non_persistent_buffers_set.Persistent buffers (those saved in
state_dict) carry training state that must travel with the weights during refit/resharding — e.g. the MoE router’sexpert_bias, which is updated each step by aux-loss-free load balancing. Non-persistent buffers are excluded since they hold ephemeral state (e.g. accumulators reset at the next train step).
- core.resharding.utils.named_refit_tensors(module: torch.nn.Module)#
Yield
(name, tensor)pairs for every parameter and persistent buffer.Used by the refit planner and executor to enumerate which tensors should travel during resharding. Persistent buffers are included alongside parameters because they may carry training state (see
named_persistent_buffers).
- core.resharding.utils._REFIT_TENSOR_CACHE_ATTR#
‘_refit_tensor_cache’
- core.resharding.utils.get_refit_tensor_dict(
- module: torch.nn.Module,
Return the cached
{name: tensor}dict formodule, building it if needed.Walking
named_modules()is hundreds of ms for multi-B-parameter models, and the parameter/persistent-buffer set is stable across refits — so we cache the dict on the module itself.invalidate_refit_tensor_cachemust be called whenever a persistent buffer is replaced (e.g. by_harmonize_buffer_dtypes) so the cache picks up the new tensor.
- core.resharding.utils.invalidate_refit_tensor_cache(module: torch.nn.Module) None#
Drop the cached refit tensor dict so the next call rebuilds it.
- core.resharding.utils._build_layer_module_prefix_map(
- module: torch.nn.Module,
Build a mapping local_module_prefix -> global_module_prefix for PP layer modules.
Megatron assigns a global, 1-indexed layer_number to each transformer layer module at construction time (including PP/VPP/layout offsets). We convert that to the 0-indexed naming convention used in parameter names and build a map such as:
“decoder.layers.0” → “decoder.layers.16” (if layer_number == 17)
- core.resharding.utils._resolve_global_layer_number_in_name(
- name: str,
- layer_module_prefix_map: Mapping[str, str],
Rewrite a parameter name to use global layer indices (PP-aware).
Given a parameter name like decoder.layers.0.self_attention…, this function rewrites the decoder.layers.0 prefix to the corresponding global layer index using the owning layer module’s layer_number.
Implementation:
Build a {local_prefix -> global_prefix} map once (outside the per-parameter loop).
Perform a longest-prefix match replacement so we only rewrite the module path portion.
- core.resharding.utils.extract_param_metadata(
- param: torch.nn.Parameter,
- param_name: str,
- owner_rank: int,
- pg_collection,
- num_experts: Optional[int] = None,
- layer_module_prefix_map: Mapping[str, str] | None = None,
- rank_offset: int = 0,
- _rank_list_cache: dict | None = None,
Extract metadata from a parameter for cross-rank communication.
- Parameters:
_rank_list_cache – Optional dict used to deduplicate rank lists so that params sharing the same process group reuse one object. This dramatically shrinks pickle size when metadata is gathered across many ranks (pickle uses backreferences for same-
id()objects, avoiding re-serialization of identical group lists).
- core.resharding.utils._filter_by_ep_local_rank(
- src_meta_list: list[core.resharding.utils.ParameterMetadata],
- dst_metadata: core.resharding.utils.ParameterMetadata,
- dst_rank: int,
In non-collocated mode with matching EP size, restrict candidates to the source rank holding the same global experts as
dst_rank.When EP sizes differ (resharding), expert matching is handled via
resolved_nameand no filter is applied here.Why size check matters:
Same size (EP=8→EP=8): local ranks 0-7 exist in both src and dst → filter ensures dst EP local 0 uses src EP local 0 (same global experts).
Different size (EP=8→EP=16): dst EP local 8 has no corresponding src EP local → skip filter; expert reassignment is handled by resolved_name matching, and the LCM/TP planner handles any TP dimension changes.
- core.resharding.utils._round_robin_dp(
- src_meta_list: list[core.resharding.utils.ParameterMetadata],
- dst_rank: int,
Round-robin across source DP groups so transfer load spreads evenly.
Each DP group holds a complete copy of the model, so we can read from any DP group; choosing via
dst_rank % num_src_dp_groupsensures even distribution even when destination has different DP configuration. E.g. with 4 src DP groups and 128 dst ranks, each src DP group is selected by 32 dst ranks (128/4=32). Within the chosen DP group we further cycle across available metadata entries to balance load across TP groups in the DP replica.
- core.resharding.utils.select_src_metadata_balanced(
- src_meta_list: list[core.resharding.utils.ParameterMetadata],
- dst_metadata: core.resharding.utils.ParameterMetadata,
- dst_rank: int,
Choose a representative source
ParameterMetadatafor a destination rank.The selected metadata supplies topology (TP/EP/DP group ranks) to the LCM planner. Selection prefers a local copy when
dst_rankitself owns a source replica, then round-robins across source DP groups to balance load. A local copy is essentially free (tensor.copy_()on same GPU), while any remote transfer incurs significant overhead even within the same node.