core.dist_checkpointing.strategies.base#
Strategies base interfaces.
Module Contents#
Classes#
Specifies save vs load and sharded vs common action. |
|
Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version. |
|
Base class for a save strategy. Requires defining a backend type and version of the saved format. |
|
Load strategy for common (non-sharded) objects |
|
Load strategy for sharded tensors |
|
Save strategy for common (non-sharded) objects |
|
Save strategy for sharded tensors |
|
Save strategy suitable for async save. |
Functions#
Retrieves a default strategy for a given action, backend and version. |
|
Adds a given strategy to the registry of default strategies. |
Data#
API#
- class core.dist_checkpointing.strategies.base.StrategyAction(*args, **kwds)#
Bases:
enum.EnumSpecifies save vs load and sharded vs common action.
Initialization
- LOAD_COMMON#
‘load_common’
- LOAD_SHARDED#
‘load_sharded’
- SAVE_COMMON#
‘save_common’
- SAVE_SHARDED#
‘save_sharded’
- core.dist_checkpointing.strategies.base.default_strategies: DefaultDict[str, dict[tuple, Any]]#
‘defaultdict(…)’
- core.dist_checkpointing.strategies.base.async_calls#
‘AsyncCallsQueue(…)’
- core.dist_checkpointing.strategies.base.get_default_strategy(
- action: core.dist_checkpointing.strategies.base.StrategyAction,
- backend: str,
- version: int,
Retrieves a default strategy for a given action, backend and version.
- core.dist_checkpointing.strategies.base.register_default_strategy(
- action: core.dist_checkpointing.strategies.base.StrategyAction,
- backend: str,
- version: int,
- strategy: Union[SaveStrategyBase, LoadStrategyBase],
Adds a given strategy to the registry of default strategies.
- Parameters:
action (StrategyAction) – specifies save/load and sharded/common
backend (str) – backend that the strategy becomes a default for
version (int) – version that the strategy becomes a default for
strategy (SaveStrategyBase, LoadStrategyBase) – strategy to register
- class core.dist_checkpointing.strategies.base.LoadStrategyBase#
Bases:
abc.ABCBase class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version.
- abstractmethod check_backend_compatibility(loaded_backend)#
Verifies if this strategy is compatible with
loaded_backend.
- abstractmethod check_version_compatibility(loaded_version)#
Verifies if this strategy is compatible with
loaded_version.
- property can_handle_sharded_objects#
Returns whether or not this strategy can handle loading ShardedObjects.
- class core.dist_checkpointing.strategies.base.SaveStrategyBase(backend: str, version: int)#
Bases:
abc.ABCBase class for a save strategy. Requires defining a backend type and version of the saved format.
Initialization
- property can_handle_sharded_objects#
Returns whether or not this strategy can handle saving ShardedObjects.
- __str__()#
- class core.dist_checkpointing.strategies.base.LoadCommonStrategy#
Bases:
core.dist_checkpointing.strategies.base.LoadStrategyBaseLoad strategy for common (non-sharded) objects
- abstractmethod load_common(checkpoint_dir: Union[str, pathlib.Path])#
Load common part of the checkpoint.
- abstractmethod load_sharded_objects(
- sharded_objects_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: Union[str, pathlib.Path],
Load sharded objects from the checkpoint.
- load_sharded_metadata(
- checkpoint_dir: Union[str, pathlib.Path],
Load just the metadata from the checkpoint.
- class core.dist_checkpointing.strategies.base.LoadShardedStrategy#
Bases:
core.dist_checkpointing.strategies.base.LoadStrategyBaseLoad strategy for sharded tensors
- abstractmethod load(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: Union[str, pathlib.Path],
Load the sharded part of the checkpoint.
- abstractmethod load_tensors_metadata(checkpoint_dir: Union[str, pathlib.Path])#
Load tensors metadata from the checkpoint for ShardedTensors.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any data and sharding (so, the only useful information is tensors global shape and dtype).
- load_sharded_metadata(checkpoint_dir: Union[str, pathlib.Path])#
Load sharded metadata from the checkpoint for ShardedTensors and ShardedObjects.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply sharded keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors or ShardedObjects without any data and sharding.
- abstractmethod remove_sharded_tensors(
- checkpoint_dir: Union[str, pathlib.Path],
- key_prefix: str,
Remove all tensors whose key starts with key_prefix
- class core.dist_checkpointing.strategies.base.SaveCommonStrategy(backend: str, version: int)#
Bases:
core.dist_checkpointing.strategies.base.SaveStrategyBaseSave strategy for common (non-sharded) objects
Initialization
- abstractmethod save_common(
- common_state_dict: core.dist_checkpointing.mapping.StateDict,
- checkpoint_dir: Union[str, pathlib.Path],
Save common part of the state dict.
- abstractmethod save_sharded_objects(
- sharded_objects_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: Union[str, pathlib.Path],
Save sharded objects from the state dict.
- class core.dist_checkpointing.strategies.base.SaveShardedStrategy(backend: str, version: int)#
Bases:
core.dist_checkpointing.strategies.base.SaveStrategyBaseSave strategy for sharded tensors
Initialization
- abstractmethod save(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: Union[str, pathlib.Path],
Save the sharded part of the state dict.
- class core.dist_checkpointing.strategies.base.AsyncSaveShardedStrategy(backend: str, version: int)#
Bases:
core.dist_checkpointing.strategies.base.SaveShardedStrategySave strategy suitable for async save.
Initialization
- abstractmethod async_save(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: Union[str, pathlib.Path],
Perform preparation and return an AsyncRequest to the external caller.
- Parameters:
sharded_state_dict (ShardedStateDict) – sharded state dict to save
checkpoint_dir (Path) – checkpoint target directory
- Returns:
represents the async save function and finalization function. It is the caller responsibility to actually schedule the async save.
- Return type:
- save(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: Union[str, pathlib.Path],
Each async strategy can be trivially used as a sync strategy.