core.resharding.refit#
Module Contents#
Classes#
Cache key for reshard plans. |
Functions#
Extract (TP, PP, EP, DP, expt_tp) sizes from a model core. |
|
Build cache key for reshard plan. |
|
Get or create a cached CopyService instance for the given backend. |
|
Clear the cached refit services. |
|
Clear the cached refit plans. |
|
Clear both service and plan caches. |
|
Orchestrate weight swap/refit. |
|
Reshard and copy model weights from |
Data#
API#
- core.resharding.refit.RefitBackendName#
None
- class core.resharding.refit._PlanCacheKey#
Cache key for reshard plans.
- rank: int#
None
- src_config: Optional[Tuple[int, int, int, int, int]]#
None
- dst_config: Optional[Tuple[int, int, int, int, int]]#
None
- num_experts: Optional[int]#
None
- core.resharding.refit._get_config_tuple(
- core,
Extract (TP, PP, EP, DP, expt_tp) sizes from a model core.
- Returns:
Tuple of (TP, PP, EP, DP, expt_tp) sizes, or None if core is None.
TP: Tensor parallelism
PP: Pipeline parallelism
EP: Expert parallelism
DP: Data parallelism
expt_tp: Expert tensor parallelism
- core.resharding.refit._build_plan_cache_key(
- src_core,
- tgt_core,
- num_experts: Optional[int],
- group=None,
Build cache key for reshard plan.
- Parameters:
src_core – Source model core (or None for non-collocated destination/idle ranks)
tgt_core – Target model core (or None for non-collocated source/idle ranks)
num_experts – Number of MoE experts (or None for non-MoE models)
group – Optional process group for rank query
- Returns:
Cache key that uniquely identifies this reshard configuration for this rank
- core.resharding.refit._service_cache: dict[str, core.resharding.copy_services.base.CopyService]#
None
- core.resharding.refit._plan_cache: dict[core.resharding.refit._PlanCacheKey, Any]#
None
- core.resharding.refit.get_or_create_service(
- backend: core.resharding.refit.RefitBackendName,
- group=None,
Get or create a cached CopyService instance for the given backend.
This avoids expensive repeated allocations (especially for NVSHMEM buffers) when swap_model_weights is called multiple times with the same backend.
- Parameters:
backend – Backend name (“nccl”, “gloo”, or “nvshmem”).
group – Optional process group for NCCL backend.
- core.resharding.refit.clear_service_cache()#
Clear the cached refit services.
Call this if you need to invalidate the cache, for example when reinitializing distributed state.
This properly finalizes services to free GPU buffers before clearing the cache.
- core.resharding.refit.clear_plan_cache()#
Clear the cached refit plans.
- core.resharding.refit.clear_all_caches()#
Clear both service and plan caches.
- core.resharding.refit.swap_model_weights(
- src_model: megatron.core.models.common.language_module.language_module.LanguageModule,
- target_model: megatron.core.models.common.language_module.language_module.LanguageModule,
- refit_method: Union[core.resharding.refit.RefitBackendName, core.resharding.copy_services.base.CopyService],
- group=None,
- src_rank_offset: int = 0,
- dst_rank_offset: int = 0,
Orchestrate weight swap/refit.
refit_method can be:
a string backend name (one of the supported refit backends), or
a CopyService instance.
group: Optional process group for communication.
src_rank_offset / dst_rank_offset: Offsets applied to local process group ranks so that metadata contains globally unique rank IDs across independent torch.distributed worlds (e.g., separate training and inference clusters).
- core.resharding.refit.reshard_model_weights(
- src_model: megatron.core.models.common.language_module.language_module.LanguageModule,
- target_model: megatron.core.models.common.language_module.language_module.LanguageModule,
- service: core.resharding.copy_services.base.CopyService,
- group=None,
- src_rank_offset: int = 0,
- dst_rank_offset: int = 0,
Reshard and copy model weights from
src_modeltotarget_modelusingservice.Supports None for src_model and/or target_model to enable non-collocated mode:
(src_model, target_model): Both models present (collocated mode)
(src_model, None): Source rank - only sends data (non-collocated)
(None, target_model): Destination rank - only receives data (non-collocated)
(None, None): Idle rank - participates in collectives but has no transfers (non-collocated)
In non-collocated mode, metadata includes local rank positions within parallel groups, allowing the planner to correctly map between different process group configurations without requiring dummy models on every rank.
- Parameters:
group – Optional process group for collective communication.
dst_rank_offset (src_rank_offset /) – Offsets for mapping local ranks to global ranks in independent torch.distributed worlds.