core.resharding.planner#
Module Contents#
Functions#
Construct sharding descriptors (currently TP) for this parameter based on actual layout. Guard TP descriptor with size conservation so we don’t mis-classify replicated tensors. |
|
TP-only planner using LCM tiling to support strides on source/destination. |
|
Block-interleaved TP planner for parameters with |
|
Return receiver-side transfer for a parameter that is not TP-sharded. |
|
Route to dimension-specific planner based on parameter sharding type. |
|
Centralized planning: Rank 0 builds complete plan for all ranks, then scatters. |
Data#
API#
- core.resharding.planner.logger#
‘getLogger(…)’
- core.resharding.planner._build_descriptors_for_param(
- src_metadata: core.resharding.utils.ParameterMetadata,
- dst_metadata: core.resharding.utils.ParameterMetadata,
Construct sharding descriptors (currently TP) for this parameter based on actual layout. Guard TP descriptor with size conservation so we don’t mis-classify replicated tensors.
- core.resharding.planner._plan_multi_dim_lcm(
- param_name: str,
- src_metadata: core.resharding.utils.ParameterMetadata,
- dst_metadata: core.resharding.utils.ParameterMetadata,
- descriptors: list[core.resharding.utils.ShardingDescriptor],
- my_global_rank: int,
TP-only planner using LCM tiling to support strides on source/destination.
Requires exactly one TP descriptor
Supports arbitrary integer strides (contiguous micro-tiles)
- core.resharding.planner._plan_block_interleaved(
- param_name: str,
- src_metadata: core.resharding.utils.ParameterMetadata,
- dst_metadata: core.resharding.utils.ParameterMetadata,
- descriptors: list[core.resharding.utils.ShardingDescriptor],
- my_global_rank: int,
Block-interleaved TP planner for parameters with
partition_sizes.When a parameter packs multiple independently-sharded components of different sizes (e.g. Mamba in_proj packs z, x, B, C, dt), a simple contiguous concat produces the wrong layout. This function treats each block independently: it gathers (or scatters) each block across TP ranks before moving to the next block.
partition_sizeslists the per-TP-rank block sizes along the partition dim. Block i occupies[sum(sizes[:i]), sum(sizes[:i+1]))in the local tensor on every TP rank. In the full (TP-gathered) tensor, block i occupies[sum(full_sizes[:i]), sum(full_sizes[:i+1]))wherefull_sizes[i] = sizes[i] * src_tp_world.
- core.resharding.planner._finalize_dp_transfers(
- param_name: str,
- src_metadata: core.resharding.utils.ParameterMetadata,
- dst_metadata: core.resharding.utils.ParameterMetadata,
- my_global_rank: int,
Return receiver-side transfer for a parameter that is not TP-sharded.
This is reached when we cannot build a TP sharding descriptor for the parameter (i.e., it is effectively replicated with respect to sharding). We use this when the destination and source mode have no TP or the parameter is replicted on all ranks such as layernorm. If the source and destination DP groups match, we return a local full-tensor copy; otherwise we pick a source rank from the source DP group in a deterministic round-robin manner based on the receiver’s global rank for better load distribution.
- core.resharding.planner._determine_source_ranks_for_dst_param(
- param_name: str,
- src_metadata: core.resharding.utils.ParameterMetadata,
- dst_metadata: core.resharding.utils.ParameterMetadata,
- my_global_rank: int,
Route to dimension-specific planner based on parameter sharding type.
- core.resharding.planner.build_centralized_reshard_plan(
- src_module: torch.nn.Module,
- dst_module: torch.nn.Module,
- num_experts: int = None,
- group=None,
- src_rank_offset: int = 0,
- dst_rank_offset: int = 0,
Centralized planning: Rank 0 builds complete plan for all ranks, then scatters.
Supports None for src_module and/or dst_module to enable non-collocated mode:
src_module=None: Rank doesn’t have source model (destination-only)
dst_module=None: Rank doesn’t have destination model (source-only)
Both provided: Rank has both models (collocated mode)
Each rank provides metadata only for the models it owns, including parallel group membership (tensor_parallel_group_ranks, expert_parallel_group_ranks, etc.). This metadata is sufficient for rank 0 to build correct transfer plans without requiring dummy models.