core.resharding.planner#

Module Contents#

Functions#

_sort_ops_by_dst_offset

Sort transfer ops by destination offset on the sharded dimension.

_build_descriptors_for_param

Construct sharding descriptors (currently TP) for this parameter based on actual layout. Guard TP descriptor with size conservation so we don’t mis-classify replicated tensors.

_emit_lcm_block_ops

Emit (src_rank, src_slice, dst_slice) ops for one LCM-tiled block.

_tp_block_layout

Compute the per-block layout for a TP transfer.

_plan_tp

Plan TP transfers via LCM tiling, supporting both plain and block-interleaved TP.

_finalize_dp_transfers

Return receiver-side transfer for a parameter that is not TP-sharded.

_determine_source_ranks_for_dst_param

Route to dimension-specific planner based on parameter sharding type.

build_centralized_reshard_plan

Centralized planning: Rank 0 builds complete plan for all ranks, then scatters.

Data#

API#

core.resharding.planner.logger#

‘getLogger(…)’

core.resharding.planner._sort_ops_by_dst_offset(ops, dim)#

Sort transfer ops by destination offset on the sharded dimension.

core.resharding.planner._build_descriptors_for_param(
src_metadata: core.resharding.utils.ParameterMetadata,
dst_metadata: core.resharding.utils.ParameterMetadata,
) list[core.resharding.utils.ShardingDescriptor]#

Construct sharding descriptors (currently TP) for this parameter based on actual layout. Guard TP descriptor with size conservation so we don’t mis-classify replicated tensors.

core.resharding.planner._emit_lcm_block_ops(
*,
param_name: str,
src_shape: tuple[int, ...],
dst_shape: tuple[int, ...],
dim: int,
src_world: int,
dst_world: int,
src_stride: int,
dst_stride: int,
full_block_len: int,
dst_local_rank: int,
src_dim_ranks: list[int],
src_block_offset: int,
dst_block_offset: int,
block_label: str,
ops: list,
) None#

Emit (src_rank, src_slice, dst_slice) ops for one LCM-tiled block.

Used both by the single-block stride-aware TP planner and by the per-block loop of the block-interleaved planner.

core.resharding.planner._tp_block_layout(
param_name: str,
src_metadata: core.resharding.utils.ParameterMetadata,
dst_metadata: core.resharding.utils.ParameterMetadata,
descriptor: core.resharding.utils.ShardingDescriptor,
src_shape: tuple[int, ...],
dst_shape: tuple[int, ...],
) list[tuple[int, int, int, int, int, str]]#

Compute the per-block layout for a TP transfer.

Returns a list of (src_offset, dst_offset, full_block_len, src_stride, dst_stride, label) tuples that the LCM micro-tiler iterates.

  • Plain TP (no partition_sizes): single block covering the full partition dim with the descriptor’s strides.

  • Block-interleaved TP (partition_sizes present, e.g. Mamba in_proj): one block per packed component, each independently sharded with stride=1.

core.resharding.planner._plan_tp(
param_name: str,
src_metadata: core.resharding.utils.ParameterMetadata,
dst_metadata: core.resharding.utils.ParameterMetadata,
descriptors: list[core.resharding.utils.ShardingDescriptor],
my_global_rank: int,
) list[tuple[int, tuple[slice, ...], tuple[slice, ...]]]#

Plan TP transfers via LCM tiling, supporting both plain and block-interleaved TP.

The block layout is derived once by _tp_block_layout — the inner LCM micro-tile math (_emit_lcm_block_ops) is identical for both cases, so the single-block plain-TP path is just a special case of the multi-block partitioned path.

core.resharding.planner._finalize_dp_transfers(
param_name: str,
src_metadata: core.resharding.utils.ParameterMetadata,
dst_metadata: core.resharding.utils.ParameterMetadata,
my_global_rank: int,
) list[tuple[int, tuple[slice, ...], tuple[slice, ...]]]#

Return receiver-side transfer for a parameter that is not TP-sharded.

This is reached when we cannot build a TP sharding descriptor for the parameter (i.e., it is effectively replicated with respect to sharding). We use this when the destination and source mode have no TP or the parameter is replicted on all ranks such as layernorm. If the source and destination DP groups match, we return a local full-tensor copy; otherwise we pick a source rank from the source DP group in a deterministic round-robin manner based on the receiver’s global rank for better load distribution.

core.resharding.planner._determine_source_ranks_for_dst_param(
param_name: str,
src_metadata: core.resharding.utils.ParameterMetadata,
dst_metadata: core.resharding.utils.ParameterMetadata,
my_global_rank: int,
) list[tuple[int, tuple[slice, ...], tuple[slice, ...]]]#

Route to dimension-specific planner based on parameter sharding type.

core.resharding.planner.build_centralized_reshard_plan(
src_module: torch.nn.Module,
dst_module: torch.nn.Module,
num_experts: int = None,
group=None,
src_rank_offset: int = 0,
dst_rank_offset: int = 0,
) core.resharding.utils.ReshardPlan#

Centralized planning: Rank 0 builds complete plan for all ranks, then scatters.

Supports None for src_module and/or dst_module to enable non-collocated mode:

  • src_module=None: Rank doesn’t have source model (destination-only)

  • dst_module=None: Rank doesn’t have destination model (source-only)

  • Both provided: Rank has both models (collocated mode)

Each rank provides metadata only for the models it owns, including parallel group membership (tensor_parallel_group_ranks, expert_parallel_group_ranks, etc.). This metadata is sufficient for rank 0 to build correct transfer plans without requiring dummy models.