core.resharding.refit#

Module Contents#

Functions#

get_or_create_service

Get or create a cached CopyService instance for the given backend.

clear_service_cache

Clear the cached refit services.

swap_model_weights

Orchestrate weight swap/refit.

reshard_model_weights

Reshard and copy model weights from src_model to target_model using service.

Data#

API#

core.resharding.refit.RefitBackendName#

None

core.resharding.refit._service_cache: dict[str, core.resharding.copy_services.base.CopyService]#

None

core.resharding.refit.get_or_create_service(
backend: core.resharding.refit.RefitBackendName,
) core.resharding.copy_services.base.CopyService#

Get or create a cached CopyService instance for the given backend.

This avoids expensive repeated allocations (especially for NVSHMEM buffers) when swap_model_weights is called multiple times with the same backend.

core.resharding.refit.clear_service_cache()#

Clear the cached refit services.

Call this if you need to invalidate the cache, for example when reinitializing distributed state.

core.resharding.refit.swap_model_weights(
src_model: megatron.core.models.common.language_module.language_module.LanguageModule,
target_model: megatron.core.models.common.language_module.language_module.LanguageModule,
refit_method: Union[core.resharding.refit.RefitBackendName, core.resharding.copy_services.base.CopyService],
)#

Orchestrate weight swap/refit.

  • refit_method can be:

    • a string backend name (one of the supported refit backends), or

    • a CopyService instance.

core.resharding.refit.reshard_model_weights(
src_model: megatron.core.models.common.language_module.language_module.LanguageModule,
target_model: megatron.core.models.common.language_module.language_module.LanguageModule,
service: core.resharding.copy_services.base.CopyService,
)#

Reshard and copy model weights from src_model to target_model using service.