core.dist_checkpointing.strategies.common#

Common strategies.

Module Contents#

Classes#

TorchCommonSaveStrategy

Common save strategy leveraging native torch save/load.

TorchCommonLoadStrategy

Common load strategy leveraging native torch save/load.

Functions#

register_default_common_strategies

Register default common strategies.

Data#

API#

core.dist_checkpointing.strategies.common.COMMON_STATE_FNAME#

‘common.pt’

core.dist_checkpointing.strategies.common.logger#

‘getLogger(…)’

core.dist_checkpointing.strategies.common.register_default_common_strategies()#

Register default common strategies.

class core.dist_checkpointing.strategies.common.TorchCommonSaveStrategy(backend: str, version: int)#

Bases: megatron.core.dist_checkpointing.strategies.base.SaveCommonStrategy

Common save strategy leveraging native torch save/load.

Initialization

save_common(
common_state_dict: megatron.core.dist_checkpointing.mapping.StateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#

Save common part of the state dict.

save_sharded_objects(
sharded_objects_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#

Save sharded objects from the state dict.

can_handle_sharded_objects()#

This strategy can handle ShardedObjects.

class core.dist_checkpointing.strategies.common.TorchCommonLoadStrategy#

Bases: core.dist_checkpointing.strategies.base.LoadCommonStrategy

Common load strategy leveraging native torch save/load.

load_common(checkpoint_dir: Union[str, pathlib.Path])#

Load common (non-sharded) objects state dict from the checkpoint.

Parameters:

checkpoint_dir (Union[str, Path]) – checkpoint directory

Returns:

state dict with non-sharded objects from the checkpoint

Return type:

StateDict

load_sharded_objects(
sharded_objects_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#

Replaces all ShardedObject from a given state dict with values loaded from the checkpoint.

Parameters:
  • sharded_objects_state_dict (ShardedStateDict) – sharded state dict defining what objects should be loaded.

  • checkpoint_dir (Union[str, Path]) – checkpoint directory

Returns:

sharded state dict is modified in place

Return type:

None

load_sharded_metadata(
checkpoint_dir: Union[str, pathlib.Path],
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#
property can_handle_sharded_objects#

This strategy can handle ShardedObjects.

check_backend_compatibility(loaded_version)#
check_version_compatibility(loaded_version)#