core.resharding.utils#

Module Contents#

Classes#

TransferOp

Single logical send/recv operation used in a reshard plan.

ParameterMetadata

Metadata for a parameter (used when param is on different rank).

ShardingDescriptor

Descriptor for a sharded dimension for a parameter.

ReshardPlan

Reshard plan - operations for this rank.

Functions#

_get_rank_in_group

_detect_expert_index_from_param_name

Extract expert index from parameter name for TEGroupedMLP per-expert tensors.

assign_ep_resolved_name_inplace

EP-only canonicalization for per-expert parameters.

assign_resolved_name_inplace

Set meta.resolved_name so the planner can match the same weights across models.

_build_layer_module_prefix_map

Build a mapping local_module_prefix -> global_module_prefix for PP layer modules.

_resolve_global_layer_number_in_name

Rewrite a parameter name to use global layer indices (PP-aware).

extract_param_metadata

Extract metadata from a parameter for cross-rank communication.

select_src_metadata_balanced

Choose a representative source ParameterMetadata for a destination rank.

Data#

API#

class core.resharding.utils.TransferOp#

Single logical send/recv operation used in a reshard plan.

param_name: str#

None

peer_rank: int#

None

is_send: bool#

None

my_slice: tuple[slice, ...]#

None

peer_slice: tuple[slice, ...]#

None

task_id: int | None#

None

class core.resharding.utils.ParameterMetadata#

Metadata for a parameter (used when param is on different rank).

name: str#

None

shape: tuple[int, ...]#

None

dtype: torch.dtype#

None

element_size: int#

None

is_tp: bool#

False

partition_dim: int#

0

partition_stride: int#

1

partition_sizes: list[int] | None#

None

is_ep: bool#

False

num_experts: Optional[int]#

None

owner_rank: int#

None

tensor_parallel_group_ranks: list[int] | None#

None

expert_parallel_group_ranks: list[int] | None#

None

data_parallel_group_ranks: list[int] | None#

None

pipeline_parallel_group_ranks: list[int] | None#

None

resolved_name: Optional[str]#

None

global_expert_index: Optional[int]#

None

class core.resharding.utils.ShardingDescriptor#

Descriptor for a sharded dimension for a parameter.

name: str#

None

dim: int#

None

src_stride: int#

None

dst_stride: int#

None

src_dim_ranks: list[int]#

None

dst_dim_ranks: list[int]#

None

class core.resharding.utils.ReshardPlan#

Reshard plan - operations for this rank.

send_ops: list[core.resharding.utils.TransferOp]#

None

recv_ops: list[core.resharding.utils.TransferOp]#

None

__str__()#
core.resharding.utils._get_rank_in_group(global_rank: int, group_ranks: list[int]) int#
core.resharding.utils._detect_expert_index_from_param_name(
param_name: str,
) Optional[int]#

Extract expert index from parameter name for TEGroupedMLP per-expert tensors.

core.resharding.utils.assign_ep_resolved_name_inplace(
meta: core.resharding.utils.ParameterMetadata,
*,
base_name: str | None = None,
) None#

EP-only canonicalization for per-expert parameters.

Under Expert Parallelism (EP), each rank owns a subset of experts with local indices (e.g., rank 1 has “weight0” locally, but it’s actually global expert 4). The raw param name can’t be used to match across source/destination because the same local name refers to different global experts on different ranks. This function remaps local expert indices to global indices in resolved_name and sets global_expert_index.

Effects:

  • Sets meta.resolved_name (defaults to base_name/meta.name for non-EP).

  • Sets meta.global_expert_index for per-expert parameters; otherwise leaves it as None.

core.resharding.utils.assign_resolved_name_inplace(
meta: core.resharding.utils.ParameterMetadata,
*,
layer_module_prefix_map: Mapping[str, str] | None = None,
base_name: str | None = None,
) None#

Set meta.resolved_name so the planner can match the same weights across models.

It rewrites PP layer indices to global layer indices (when layer_module_prefix_map is provided) and rewrites EP per-expert indices (weightK/biasK) to global expert indices.

core.resharding.utils._build_layer_module_prefix_map(
module: torch.nn.Module,
) dict[str, str]#

Build a mapping local_module_prefix -> global_module_prefix for PP layer modules.

Megatron assigns a global, 1-indexed layer_number to each transformer layer module at construction time (including PP/VPP/layout offsets). We convert that to the 0-indexed naming convention used in parameter names and build a map such as:

  • “decoder.layers.0” → “decoder.layers.16” (if layer_number == 17)

core.resharding.utils._resolve_global_layer_number_in_name(
name: str,
layer_module_prefix_map: Mapping[str, str],
) str#

Rewrite a parameter name to use global layer indices (PP-aware).

Given a parameter name like decoder.layers.0.self_attention…, this function rewrites the decoder.layers.0 prefix to the corresponding global layer index using the owning layer module’s layer_number.

Implementation:

  • Build a {local_prefix -> global_prefix} map once (outside the per-parameter loop).

  • Perform a longest-prefix match replacement so we only rewrite the module path portion.

core.resharding.utils.extract_param_metadata(
param: torch.nn.Parameter,
param_name: str,
owner_rank: int,
pg_collection,
num_experts: Optional[int] = None,
layer_module_prefix_map: Mapping[str, str] | None = None,
rank_offset: int = 0,
_rank_list_cache: dict | None = None,
) core.resharding.utils.ParameterMetadata#

Extract metadata from a parameter for cross-rank communication.

Parameters:

_rank_list_cache – Optional dict used to deduplicate rank lists so that params sharing the same process group reuse one object. This dramatically shrinks pickle size when metadata is gathered across many ranks (pickle uses backreferences for same-id() objects, avoiding re-serialization of identical group lists).

core.resharding.utils.select_src_metadata_balanced(
src_meta_list: list[core.resharding.utils.ParameterMetadata],
dst_metadata: core.resharding.utils.ParameterMetadata,
dst_rank: int,
) core.resharding.utils.ParameterMetadata#

Choose a representative source ParameterMetadata for a destination rank.

The selected metadata provides topology information (TP/EP/DP group ranks) that the LCM transfer planner uses to compute actual source ranks and slices. This function doesn’t perform transfers itself - it just picks which source configuration to use as reference for planning.

Two scenarios for EP-sharded parameters:

  1. Non-collocated mode (same EP size, different rank numbering):

    • Filter by matching EP local rank to pair ranks with same expert position

    • Example: src ranks [0-63] and dst ranks [64-127] both with EP=8

    • Dst EP local 0 should use src EP local 0 as reference (same experts)

  2. Resharding mode (different EP sizes):

    • Skip EP local rank filtering (sizes don’t correspond)

    • Example: EP=8→EP=16 means dst EP local 8 has no matching src EP local

    • Expert matching handled by resolved_name; LCM handles TP dimension changes

Finally, balances across data-parallel (DP) groups to distribute load:

  • Groups src_meta_list by DP group

  • Selects source DP group via round-robin: dst_rank % num_src_dp_groups

  • Ensures even distribution of transfer load across source DP replicas

core.resharding.utils.logger#

‘getLogger(…)’