core.resharding.utils#

Module Contents#

Classes#

TransferOp

Single logical send/recv operation used in a reshard plan.

ParameterMetadata

Metadata for a parameter (used when param is on different rank).

ShardingDescriptor

Descriptor for a sharded dimension for a parameter.

ReshardPlan

Reshard plan - operations for this rank.

Functions#

_get_rank_in_group

_detect_expert_index_from_param_name

Extract expert index from parameter name for TEGroupedMLP per-expert tensors.

assign_ep_resolved_name_inplace

EP-only canonicalization for per-expert parameters.

assign_resolved_name_inplace

Set meta.resolved_name so the planner can match the same weights across models.

_build_layer_module_prefix_map

Build a mapping local_module_prefix -> global_module_prefix for PP layer modules.

_resolve_global_layer_number_in_name

Rewrite a parameter name to use global layer indices (PP-aware).

extract_param_metadata

Extract metadata from a parameter for cross-rank communication.

select_src_metadata_balanced

Choose a representative source ParameterMetadata for a destination rank.

Data#

API#

class core.resharding.utils.TransferOp#

Single logical send/recv operation used in a reshard plan.

param_name: str#

None

peer_rank: int#

None

is_send: bool#

None

my_slice: tuple[slice, ...]#

None

peer_slice: tuple[slice, ...]#

None

task_id: int | None#

None

class core.resharding.utils.ParameterMetadata#

Metadata for a parameter (used when param is on different rank).

name: str#

None

shape: tuple[int, ...]#

None

dtype: torch.dtype#

None

element_size: int#

None

is_tp: bool#

False

partition_dim: int#

0

partition_stride: int#

1

is_ep: bool#

False

num_experts: Optional[int]#

None

owner_rank: int#

None

tensor_parallel_group_ranks: list[int] | None#

None

expert_parallel_group_ranks: list[int] | None#

None

data_parallel_group_ranks: list[int] | None#

None

pipeline_parallel_group_ranks: list[int] | None#

None

resolved_name: Optional[str]#

None

global_expert_index: Optional[int]#

None

class core.resharding.utils.ShardingDescriptor#

Descriptor for a sharded dimension for a parameter.

name: str#

None

dim: int#

None

src_stride: int#

None

dst_stride: int#

None

src_dim_ranks: list[int]#

None

dst_dim_ranks: list[int]#

None

class core.resharding.utils.ReshardPlan#

Reshard plan - operations for this rank.

send_ops: list[core.resharding.utils.TransferOp]#

None

recv_ops: list[core.resharding.utils.TransferOp]#

None

__str__()#
core.resharding.utils._get_rank_in_group(global_rank: int, group_ranks: list[int]) int#
core.resharding.utils._detect_expert_index_from_param_name(
param_name: str,
) Optional[int]#

Extract expert index from parameter name for TEGroupedMLP per-expert tensors.

core.resharding.utils.assign_ep_resolved_name_inplace(
meta: core.resharding.utils.ParameterMetadata,
*,
base_name: str | None = None,
) None#

EP-only canonicalization for per-expert parameters.

Under Expert Parallelism (EP), each rank owns a subset of experts with local indices (e.g., rank 1 has “weight0” locally, but it’s actually global expert 4). The raw param name can’t be used to match across source/destination because the same local name refers to different global experts on different ranks. This function remaps local expert indices to global indices in resolved_name and sets global_expert_index.

Effects:

  • Sets meta.resolved_name (defaults to base_name/meta.name for non-EP).

  • Sets meta.global_expert_index for per-expert parameters; otherwise leaves it as None.

core.resharding.utils.assign_resolved_name_inplace(
meta: core.resharding.utils.ParameterMetadata,
*,
layer_module_prefix_map: Mapping[str, str] | None = None,
base_name: str | None = None,
) None#

Set meta.resolved_name so the planner can match the same weights across models.

It rewrites PP layer indices to global layer indices (when layer_module_prefix_map is provided) and rewrites EP per-expert indices (weightK/biasK) to global expert indices.

core.resharding.utils._build_layer_module_prefix_map(
module: torch.nn.Module,
) dict[str, str]#

Build a mapping local_module_prefix -> global_module_prefix for PP layer modules.

Megatron assigns a global, 1-indexed layer_number to each transformer layer module at construction time (including PP/VPP/layout offsets). We convert that to the 0-indexed naming convention used in parameter names and build a map such as:

  • “decoder.layers.0” → “decoder.layers.16” (if layer_number == 17)

core.resharding.utils._resolve_global_layer_number_in_name(
name: str,
layer_module_prefix_map: Mapping[str, str],
) str#

Rewrite a parameter name to use global layer indices (PP-aware).

Given a parameter name like decoder.layers.0.self_attention…, this function rewrites the decoder.layers.0 prefix to the corresponding global layer index using the owning layer module’s layer_number.

Implementation:

  • Build a {local_prefix -> global_prefix} map once (outside the per-parameter loop).

  • Perform a longest-prefix match replacement so we only rewrite the module path portion.

core.resharding.utils.extract_param_metadata(
param: torch.nn.Parameter,
param_name: str,
owner_rank: int,
pg_collection,
num_experts: Optional[int] = None,
layer_module_prefix_map: Mapping[str, str] | None = None,
) core.resharding.utils.ParameterMetadata#

Extract metadata from a parameter for cross-rank communication.

core.resharding.utils.select_src_metadata_balanced(
src_meta_list: list[core.resharding.utils.ParameterMetadata],
dst_metadata: core.resharding.utils.ParameterMetadata,
dst_rank: int,
) core.resharding.utils.ParameterMetadata#

Choose a representative source ParameterMetadata for a destination rank.

The selected metadata provides topology information (TP/EP/DP group ranks) that the LCM transfer planner uses to compute actual source ranks and slices. This function doesn’t perform transfers itself - it just picks which source configuration to use as reference for planning.

Two scenarios for EP-sharded parameters:

  1. Non-collocated mode (same EP size, different rank numbering):

    • Filter by matching EP local rank to pair ranks with same expert position

    • Example: src ranks [0-63] and dst ranks [64-127] both with EP=8

    • Dst EP local 0 should use src EP local 0 as reference (same experts)

  2. Resharding mode (different EP sizes):

    • Skip EP local rank filtering (sizes don’t correspond)

    • Example: EP=8→EP=16 means dst EP local 8 has no matching src EP local

    • Expert matching handled by resolved_name; LCM handles TP dimension changes

Finally, balances across data-parallel (DP) groups to distribute load:

  • Groups src_meta_list by DP group

  • Selects source DP group via round-robin: dst_rank % num_src_dp_groups

  • Ensures even distribution of transfer load across source DP replicas

core.resharding.utils.logger#

‘getLogger(…)’