core.resharding.utils#

Module Contents#

Classes#

TransferOp

Single logical send/recv operation used in a reshard plan.

ParameterMetadata

Metadata for a parameter (used when param is on different rank).

ShardingDescriptor

Descriptor for a sharded dimension for a parameter.

ReshardPlan

Reshard plan - operations for this rank.

Functions#

_get_rank_in_group

_detect_expert_index_from_param_name

Extract expert index from parameter name for TEGroupedMLP per-expert tensors.

assign_ep_resolved_name_inplace

EP-only canonicalization for per-expert parameters.

assign_resolved_name_inplace

Set meta.resolved_name so the planner can match the same weights across models.

named_persistent_buffers

Yield (full_name, parent_module, buf_name, tensor) for every persistent buffer in module. Skips _non_persistent_buffers_set.

named_refit_tensors

Yield (name, tensor) pairs for every parameter and persistent buffer.

get_refit_tensor_dict

Return the cached {name: tensor} dict for module, building it if needed.

invalidate_refit_tensor_cache

Drop the cached refit tensor dict so the next call rebuilds it.

_build_layer_module_prefix_map

Build a mapping local_module_prefix -> global_module_prefix for PP layer modules.

_resolve_global_layer_number_in_name

Rewrite a parameter name to use global layer indices (PP-aware).

extract_param_metadata

Extract metadata from a parameter for cross-rank communication.

_filter_by_ep_local_rank

In non-collocated mode with matching EP size, restrict candidates to the source rank holding the same global experts as dst_rank.

_round_robin_dp

Round-robin across source DP groups so transfer load spreads evenly.

select_src_metadata_balanced

Choose a representative source ParameterMetadata for a destination rank.

Data#

API#

class core.resharding.utils.TransferOp#

Single logical send/recv operation used in a reshard plan.

param_name: str#

None

peer_rank: int#

None

is_send: bool#

None

my_slice: tuple[slice, ...]#

None

peer_slice: tuple[slice, ...]#

None

task_id: int | None#

None

class core.resharding.utils.ParameterMetadata#

Metadata for a parameter (used when param is on different rank).

name: str#

None

shape: tuple[int, ...]#

None

dtype: torch.dtype#

None

element_size: int#

None

is_tp: bool#

False

partition_dim: int#

0

partition_stride: int#

1

partition_sizes: list[int] | None#

None

is_ep: bool#

False

num_experts: Optional[int]#

None

owner_rank: int#

None

tensor_parallel_group_ranks: list[int] | None#

None

expert_parallel_group_ranks: list[int] | None#

None

data_parallel_group_ranks: list[int] | None#

None

pipeline_parallel_group_ranks: list[int] | None#

None

resolved_name: Optional[str]#

None

global_expert_index: Optional[int]#

None

class core.resharding.utils.ShardingDescriptor#

Descriptor for a sharded dimension for a parameter.

name: str#

None

dim: int#

None

src_stride: int#

None

dst_stride: int#

None

src_dim_ranks: list[int]#

None

dst_dim_ranks: list[int]#

None

class core.resharding.utils.ReshardPlan#

Reshard plan - operations for this rank.

send_ops: list[core.resharding.utils.TransferOp]#

None

recv_ops: list[core.resharding.utils.TransferOp]#

None

transform: Optional[core.resharding.transforms.ReshardTransform]#

None

buffer_dtypes: Optional[dict[str, torch.dtype]]#

None

__str__()#
core.resharding.utils._get_rank_in_group(global_rank: int, group_ranks: list[int]) int#
core.resharding.utils._EXPERT_PARAM_RE#

‘compile(…)’

core.resharding.utils._detect_expert_index_from_param_name(
param_name: str,
) Optional[int]#

Extract expert index from parameter name for TEGroupedMLP per-expert tensors.

core.resharding.utils.assign_ep_resolved_name_inplace(
meta: core.resharding.utils.ParameterMetadata,
*,
base_name: str | None = None,
) None#

EP-only canonicalization for per-expert parameters.

Under Expert Parallelism (EP), each rank owns a subset of experts with local indices (e.g., rank 1 has “weight0” locally, but it’s actually global expert 4). The raw param name can’t be used to match across source/destination because the same local name refers to different global experts on different ranks. This function remaps local expert indices to global indices in resolved_name and sets global_expert_index.

Effects:

  • Sets meta.resolved_name (defaults to base_name/meta.name for non-EP).

  • Sets meta.global_expert_index for per-expert parameters; otherwise leaves it as None.

core.resharding.utils.assign_resolved_name_inplace(
meta: core.resharding.utils.ParameterMetadata,
*,
layer_module_prefix_map: Mapping[str, str] | None = None,
base_name: str | None = None,
) None#

Set meta.resolved_name so the planner can match the same weights across models.

It rewrites PP layer indices to global layer indices (when layer_module_prefix_map is provided) and rewrites EP per-expert indices (weightK/biasK) to global expert indices.

core.resharding.utils.named_persistent_buffers(module: torch.nn.Module)#

Yield (full_name, parent_module, buf_name, tensor) for every persistent buffer in module. Skips _non_persistent_buffers_set.

Persistent buffers (those saved in state_dict) carry training state that must travel with the weights during refit/resharding — e.g. the MoE router’s expert_bias, which is updated each step by aux-loss-free load balancing. Non-persistent buffers are excluded since they hold ephemeral state (e.g. accumulators reset at the next train step).

core.resharding.utils.named_refit_tensors(module: torch.nn.Module)#

Yield (name, tensor) pairs for every parameter and persistent buffer.

Used by the refit planner and executor to enumerate which tensors should travel during resharding. Persistent buffers are included alongside parameters because they may carry training state (see named_persistent_buffers).

core.resharding.utils._REFIT_TENSOR_CACHE_ATTR#

‘_refit_tensor_cache’

core.resharding.utils.get_refit_tensor_dict(
module: torch.nn.Module,
) dict[str, torch.Tensor]#

Return the cached {name: tensor} dict for module, building it if needed.

Walking named_modules() is hundreds of ms for multi-B-parameter models, and the parameter/persistent-buffer set is stable across refits — so we cache the dict on the module itself. invalidate_refit_tensor_cache must be called whenever a persistent buffer is replaced (e.g. by _harmonize_buffer_dtypes) so the cache picks up the new tensor.

core.resharding.utils.invalidate_refit_tensor_cache(module: torch.nn.Module) None#

Drop the cached refit tensor dict so the next call rebuilds it.

core.resharding.utils._build_layer_module_prefix_map(
module: torch.nn.Module,
) dict[str, str]#

Build a mapping local_module_prefix -> global_module_prefix for PP layer modules.

Megatron assigns a global, 1-indexed layer_number to each transformer layer module at construction time (including PP/VPP/layout offsets). We convert that to the 0-indexed naming convention used in parameter names and build a map such as:

  • “decoder.layers.0” → “decoder.layers.16” (if layer_number == 17)

core.resharding.utils._resolve_global_layer_number_in_name(
name: str,
layer_module_prefix_map: Mapping[str, str],
) str#

Rewrite a parameter name to use global layer indices (PP-aware).

Given a parameter name like decoder.layers.0.self_attention…, this function rewrites the decoder.layers.0 prefix to the corresponding global layer index using the owning layer module’s layer_number.

Implementation:

  • Build a {local_prefix -> global_prefix} map once (outside the per-parameter loop).

  • Perform a longest-prefix match replacement so we only rewrite the module path portion.

core.resharding.utils.extract_param_metadata(
param: torch.nn.Parameter,
param_name: str,
owner_rank: int,
pg_collection,
num_experts: Optional[int] = None,
layer_module_prefix_map: Mapping[str, str] | None = None,
rank_offset: int = 0,
_rank_list_cache: dict | None = None,
) core.resharding.utils.ParameterMetadata#

Extract metadata from a parameter for cross-rank communication.

Parameters:

_rank_list_cache – Optional dict used to deduplicate rank lists so that params sharing the same process group reuse one object. This dramatically shrinks pickle size when metadata is gathered across many ranks (pickle uses backreferences for same-id() objects, avoiding re-serialization of identical group lists).

core.resharding.utils._filter_by_ep_local_rank(
src_meta_list: list[core.resharding.utils.ParameterMetadata],
dst_metadata: core.resharding.utils.ParameterMetadata,
dst_rank: int,
) list[core.resharding.utils.ParameterMetadata]#

In non-collocated mode with matching EP size, restrict candidates to the source rank holding the same global experts as dst_rank.

When EP sizes differ (resharding), expert matching is handled via resolved_name and no filter is applied here.

Why size check matters:

  • Same size (EP=8→EP=8): local ranks 0-7 exist in both src and dst → filter ensures dst EP local 0 uses src EP local 0 (same global experts).

  • Different size (EP=8→EP=16): dst EP local 8 has no corresponding src EP local → skip filter; expert reassignment is handled by resolved_name matching, and the LCM/TP planner handles any TP dimension changes.

core.resharding.utils._round_robin_dp(
src_meta_list: list[core.resharding.utils.ParameterMetadata],
dst_rank: int,
) core.resharding.utils.ParameterMetadata#

Round-robin across source DP groups so transfer load spreads evenly.

Each DP group holds a complete copy of the model, so we can read from any DP group; choosing via dst_rank % num_src_dp_groups ensures even distribution even when destination has different DP configuration. E.g. with 4 src DP groups and 128 dst ranks, each src DP group is selected by 32 dst ranks (128/4=32). Within the chosen DP group we further cycle across available metadata entries to balance load across TP groups in the DP replica.

core.resharding.utils.select_src_metadata_balanced(
src_meta_list: list[core.resharding.utils.ParameterMetadata],
dst_metadata: core.resharding.utils.ParameterMetadata,
dst_rank: int,
) core.resharding.utils.ParameterMetadata#

Choose a representative source ParameterMetadata for a destination rank.

The selected metadata supplies topology (TP/EP/DP group ranks) to the LCM planner. Selection prefers a local copy when dst_rank itself owns a source replica, then round-robins across source DP groups to balance load. A local copy is essentially free (tensor.copy_() on same GPU), while any remote transfer incurs significant overhead even within the same node.