core.resharding.refit#
Module Contents#
Functions#
Get or create a cached CopyService instance for the given backend. |
|
Clear the cached refit services. |
|
Orchestrate weight swap/refit. |
|
Reshard and copy model weights from |
Data#
API#
- core.resharding.refit.RefitBackendName#
None
- core.resharding.refit._service_cache: dict[str, core.resharding.copy_services.base.CopyService]#
None
- core.resharding.refit.get_or_create_service(
- backend: core.resharding.refit.RefitBackendName,
Get or create a cached CopyService instance for the given backend.
This avoids expensive repeated allocations (especially for NVSHMEM buffers) when swap_model_weights is called multiple times with the same backend.
- core.resharding.refit.clear_service_cache()#
Clear the cached refit services.
Call this if you need to invalidate the cache, for example when reinitializing distributed state.
- core.resharding.refit.swap_model_weights(
- src_model: megatron.core.models.common.language_module.language_module.LanguageModule,
- target_model: megatron.core.models.common.language_module.language_module.LanguageModule,
- refit_method: Union[core.resharding.refit.RefitBackendName, core.resharding.copy_services.base.CopyService],
Orchestrate weight swap/refit.
refit_method can be:
a string backend name (one of the supported refit backends), or
a CopyService instance.
- core.resharding.refit.reshard_model_weights(
- src_model: megatron.core.models.common.language_module.language_module.LanguageModule,
- target_model: megatron.core.models.common.language_module.language_module.LanguageModule,
- service: core.resharding.copy_services.base.CopyService,
Reshard and copy model weights from
src_modeltotarget_modelusingservice.