core.resharding.planner#

Module Contents#

Functions#

_build_descriptors_for_param

Construct sharding descriptors (currently TP) for this parameter based on actual layout. Guard TP descriptor with size conservation so we don’t mis-classify replicated tensors.

_plan_multi_dim_lcm

TP-only planner using LCM tiling to support strides on source/destination.

_finalize_dp_transfers

Return receiver-side transfer for a parameter that is not TP-sharded.

_determine_source_ranks_for_dst_param

Route to dimension-specific planner based on parameter sharding type.

build_centralized_reshard_plan

Centralized planning: Rank 0 builds complete plan for all ranks, then scatters.

Data#

API#

core.resharding.planner.logger#

‘getLogger(…)’

core.resharding.planner._build_descriptors_for_param(
src_metadata: core.resharding.utils.ParameterMetadata,
dst_metadata: core.resharding.utils.ParameterMetadata,
) list[core.resharding.utils.ShardingDescriptor]#

Construct sharding descriptors (currently TP) for this parameter based on actual layout. Guard TP descriptor with size conservation so we don’t mis-classify replicated tensors.

core.resharding.planner._plan_multi_dim_lcm(
param_name: str,
src_metadata: core.resharding.utils.ParameterMetadata,
dst_metadata: core.resharding.utils.ParameterMetadata,
descriptors: list[core.resharding.utils.ShardingDescriptor],
my_global_rank: int,
) list[tuple[int, tuple[slice, ...], tuple[slice, ...]]]#

TP-only planner using LCM tiling to support strides on source/destination.

  • Requires exactly one TP descriptor

  • Supports arbitrary integer strides (contiguous micro-tiles)

core.resharding.planner._finalize_dp_transfers(
param_name: str,
src_metadata: core.resharding.utils.ParameterMetadata,
dst_metadata: core.resharding.utils.ParameterMetadata,
my_global_rank: int,
) list[tuple[int, tuple[slice, ...], tuple[slice, ...]]]#

Return receiver-side transfer for a parameter that is not TP-sharded.

This is reached when we cannot build a TP sharding descriptor for the parameter (i.e., it is effectively replicated with respect to sharding). We use this when the destination and source mode have no TP or the parameter is replicted on all ranks such as layernorm. If the source and destination DP groups match, we return a local full-tensor copy; otherwise we pick a source rank from the source DP group in a deterministic round-robin manner based on the receiver’s index in its destination DP group.

core.resharding.planner._determine_source_ranks_for_dst_param(
param_name: str,
src_metadata: core.resharding.utils.ParameterMetadata,
dst_metadata: core.resharding.utils.ParameterMetadata,
my_global_rank: int,
) list[tuple[int, tuple[slice, ...], tuple[slice, ...]]]#

Route to dimension-specific planner based on parameter sharding type.

core.resharding.planner.build_centralized_reshard_plan(
src_module: torch.nn.Module,
dst_module: torch.nn.Module,
num_experts: int = None,
) core.resharding.utils.ReshardPlan#

Centralized planning: Rank 0 builds complete plan for all ranks, then scatters.