core.resharding.refit#

Module Contents#

Classes#

_PlanCacheKey

Cache key for reshard plans.

Functions#

_get_config_tuple

Extract (TP, PP, EP, DP, expt_tp) sizes from a model core, memoized on the core.

_build_plan_cache_key

Build cache key for reshard plan.

get_or_create_service

Get or create a cached CopyService instance for the given backend.

clear_service_cache

Clear the cached refit services.

clear_plan_cache

Clear the cached refit plans.

clear_all_caches

Clear both service and plan caches.

_unwrap_model_cores

Extract (src_core, tgt_core, num_experts) from model arguments.

_build_or_get_plan

Return the cached reshard plan, building it (collectively) if not yet cached.

_needs_mxfp8_conversion

Check if a model uses FlashInfer MXFP8 inference and needs weight conversion.

_setup_mxfp8_transform_on_plan

Detect MXFP8 needs and attach a transform to the plan if required.

prepare_swap_model_weights

Pre-build and cache the reshard plan and any format-conversion transforms.

swap_model_weights

Orchestrate weight swap/refit.

_harmonize_buffer_dtypes

Bring destination persistent-buffer dtypes into agreement with source.

reshard_model_weights

Reshard and copy model weights from src_model to target_model using service.

Data#

API#

core.resharding.refit.RefitBackendName#

None

class core.resharding.refit._PlanCacheKey#

Cache key for reshard plans.

rank: int#

None

src_config: Optional[Tuple[int, int, int, int, int]]#

None

dst_config: Optional[Tuple[int, int, int, int, int]]#

None

num_experts: Optional[int]#

None

src_rank_offset: int#

0

dst_rank_offset: int#

0

core.resharding.refit._get_config_tuple(
core,
) Optional[Tuple[int, int, int, int, int]]#

Extract (TP, PP, EP, DP, expt_tp) sizes from a model core, memoized on the core.

Process-group sizes don’t change after init, so the result is cached on the core object itself to avoid repeated get_process_group_ranks calls on the hot path (each refit looks the key up 2-3x).

core.resharding.refit._build_plan_cache_key(
src_core,
tgt_core,
num_experts: Optional[int],
group=None,
src_rank_offset: int = 0,
dst_rank_offset: int = 0,
) core.resharding.refit._PlanCacheKey#

Build cache key for reshard plan.

core.resharding.refit._service_cache: dict[str, core.resharding.copy_services.base.CopyService]#

None

core.resharding.refit._plan_cache: dict[core.resharding.refit._PlanCacheKey, Any]#

None

core.resharding.refit.get_or_create_service(
backend: core.resharding.refit.RefitBackendName,
group=None,
) core.resharding.copy_services.base.CopyService#

Get or create a cached CopyService instance for the given backend.

This avoids expensive repeated allocations (especially for NVSHMEM buffers) when swap_model_weights is called multiple times with the same backend.

Parameters:
  • backend – Backend name (“nccl”, “gloo”, or “nvshmem”).

  • group – Optional process group for NCCL backend.

core.resharding.refit.clear_service_cache()#

Clear the cached refit services.

Call this if you need to invalidate the cache, for example when reinitializing distributed state. Services are close()-d first so backends owning GPU buffers (NVSHMEM) release them cleanly.

core.resharding.refit.clear_plan_cache()#

Clear the cached refit plans.

core.resharding.refit.clear_all_caches()#

Clear both service and plan caches.

core.resharding.refit._unwrap_model_cores(src_model, target_model)#

Extract (src_core, tgt_core, num_experts) from model arguments.

Handles list-wrapped modules and None (non-collocated) models. Fills in missing DP groups from Megatron’s parallel state on the source.

Returns:

(src_core, tgt_core, num_experts)

core.resharding.refit._build_or_get_plan(
src_core,
tgt_core,
num_experts,
group,
src_rank_offset,
dst_rank_offset,
)#

Return the cached reshard plan, building it (collectively) if not yet cached.

All participating ranks must call this simultaneously when the plan is not yet cached, because build_centralized_reshard_plan uses collective communication.

core.resharding.refit._needs_mxfp8_conversion(model) bool#

Check if a model uses FlashInfer MXFP8 inference and needs weight conversion.

core.resharding.refit._setup_mxfp8_transform_on_plan(plan, target_model) None#

Detect MXFP8 needs and attach a transform to the plan if required.

If the target_model uses an inference-optimized layer spec with MXFP8, this function:

  1. Computes which params are eligible for MXFP8 conversion.

  2. Quantizes the target model’s decoder weights to FlashInfer MXFP8Tensor (creating persistent buffers whose addresses are later captured by CUDA graphs).

  3. Builds an MXFP8ReshardTransform and attaches it to plan.transform.

Idempotent: skips re-setup if plan.transform is already populated.

core.resharding.refit.prepare_swap_model_weights(
src_model: megatron.core.models.common.language_module.language_module.LanguageModule,
target_model: megatron.core.models.common.language_module.language_module.LanguageModule,
group=None,
src_rank_offset: int = 0,
dst_rank_offset: int = 0,
)#

Pre-build and cache the reshard plan and any format-conversion transforms.

Call this during initialization while models are in their native (BF16) format, before any weight format conversion (e.g., MXFP8). The plan is stored in the same module-level cache as swap_model_weights, so subsequent calls reuse it without needing to inspect named_parameters() again.

If the target_model uses an inference-optimized layer spec with MXFP8 (config.transformer_impl == 'inference_optimized' and config.fp8_recipe == 'mxfp8'), this function also:

  • computes which parameters are eligible for MXFP8 conversion,

  • quantizes the target decoder weights to persistent FlashInfer MXFP8Tensor buffers (whose addresses are later baked into CUDA graphs),

  • creates an MXFP8ReshardTransform that subsequent swap_model_weights calls use automatically.

Callers do not need to know about MXFP8 — the transform is created and cached transparently.

All participating ranks must call this simultaneously — the plan builder uses collective communication internally.

Parameters:
  • src_model – Source model, or None if this rank only receives weights.

  • target_model – Target model, or None if this rank only sends weights.

  • group – Optional process group for collective communication.

  • src_rank_offset – Rank offset for source (training) workers.

  • dst_rank_offset – Rank offset for destination (inference) workers.

core.resharding.refit.swap_model_weights(
src_model: megatron.core.models.common.language_module.language_module.LanguageModule,
target_model: megatron.core.models.common.language_module.language_module.LanguageModule,
refit_method: Union[core.resharding.refit.RefitBackendName, core.resharding.copy_services.base.CopyService],
group=None,
src_rank_offset: int = 0,
dst_rank_offset: int = 0,
transform: Optional[core.resharding.transforms.ReshardTransform] = None,
)#

Orchestrate weight swap/refit.

If transform is not explicitly provided, the function automatically uses any MXFP8ReshardTransform that was created and cached by a prior prepare_swap_model_weights call for the same model pair. This makes MXFP8 handling transparent to callers.

Parameters:
  • refit_method – a string backend name (one of the supported refit backends) or a CopyService instance.

  • group – Optional process group for communication.

  • dst_rank_offset (src_rank_offset /) – Offsets applied to local process group ranks so that metadata contains globally unique rank IDs across independent torch.distributed worlds.

  • transform – Optional ReshardTransform for custom format conversion. If None, the cached transform (from prepare_swap_model_weights) is used automatically when the receiver needs MXFP8 conversion.

core.resharding.refit._harmonize_buffer_dtypes(plan, src_core, tgt_core, group=None)#

Bring destination persistent-buffer dtypes into agreement with source.

Some buffers (notably the MoE router expert_bias) are upcast to fp32 inside the trainer on first forward by _maintain_float32_expert_bias, while the freshly-built inference model still holds them in bf16 from the Float16Module wrap. The reshard send/recv path is dtype-strict — sending fp32 bytes into a bf16 receive buffer corrupts the data — so dst’s buffer must match src’s dtype before the transfer.

The canonical dtype map is collected once via all_gather_object and cached on the plan. Subsequent refits reuse the cached map and only do the per-buffer dtype check / replacement (no collective).

core.resharding.refit.reshard_model_weights(
src_model: megatron.core.models.common.language_module.language_module.LanguageModule,
target_model: megatron.core.models.common.language_module.language_module.LanguageModule,
service: core.resharding.copy_services.base.CopyService,
group=None,
src_rank_offset: int = 0,
dst_rank_offset: int = 0,
transform: Optional[core.resharding.transforms.ReshardTransform] = None,
)#

Reshard and copy model weights from src_model to target_model using service.

Supports None for src_model and/or target_model to enable non-collocated mode:

  • (src_model, target_model): Both models present (collocated mode)

  • (src_model, None): Source rank - only sends data (non-collocated)

  • (None, target_model): Destination rank - only receives data (non-collocated)

  • (None, None): Idle rank - participates in collectives but has no transfers (non-collocated)

Parameters:
  • group – Optional process group for collective communication.

  • dst_rank_offset (src_rank_offset /) – Offsets for mapping local ranks to global ranks in independent torch.distributed worlds.

  • transform – Optional ReshardTransform for custom format conversion.