core.dist_checkpointing.strategies.base#
Strategies base interfaces.
Module Contents#
Classes#
Specifies save vs load and sharded vs common action. To be removed in future releases. |
|
Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version. |
|
Base class for a save strategy. Requires defining a backend type and version of the saved format. |
|
Base class for load strategies to be removed in future releases. |
|
Base class for save strategies to be removed in future releases. |
|
Save strategy suitable for async save. To be removed in future releases. |
Functions#
Retrieves a default strategy for a given action, backend and version. |
Data#
API#
- core.dist_checkpointing.strategies.base.logger#
‘getLogger(…)’
- class core.dist_checkpointing.strategies.base.StrategyAction(*args, **kwds)#
Bases:
enum.EnumSpecifies save vs load and sharded vs common action. To be removed in future releases.
Initialization
- LOAD_COMMON#
‘load_common’
- LOAD_SHARDED#
‘load_sharded’
- SAVE_COMMON#
‘save_common’
- SAVE_SHARDED#
‘save_sharded’
- core.dist_checkpointing.strategies.base.get_default_strategy(
- action: core.dist_checkpointing.strategies.base.StrategyAction,
- backend: str,
- version: int,
Retrieves a default strategy for a given action, backend and version.
- class core.dist_checkpointing.strategies.base.LoadStrategyBase#
Bases:
abc.ABCBase class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version.
Initialization
- abstractmethod check_backend_compatibility(loaded_backend)#
Verifies if this strategy is compatible with
loaded_backend.
- abstractmethod check_version_compatibility(loaded_version)#
Verifies if this strategy is compatible with
loaded_version.
- property can_handle_sharded_objects#
Returns whether or not this strategy can handle loading ShardedObjects.
- class core.dist_checkpointing.strategies.base.SaveStrategyBase(backend: str, version: int)#
Bases:
abc.ABCBase class for a save strategy. Requires defining a backend type and version of the saved format.
Initialization
- property can_handle_sharded_objects#
Returns whether or not this strategy can handle saving ShardedObjects.
- __str__()#
- class core.dist_checkpointing.strategies.base.LoadShardedStrategy#
Bases:
core.dist_checkpointing.strategies.base.LoadStrategyBaseBase class for load strategies to be removed in future releases.
Initialization
- abstractmethod load(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: Union[str, pathlib.Path],
Load the sharded part of the checkpoint.
- abstractmethod load_tensors_metadata(checkpoint_dir: Union[str, pathlib.Path])#
Load tensors metadata from the checkpoint for ShardedTensors.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any data and sharding (so, the only useful information is tensors global shape and dtype).
- load_sharded_metadata(checkpoint_dir: Union[str, pathlib.Path])#
Load sharded metadata from the checkpoint for ShardedTensors and ShardedObjects.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply sharded keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors or ShardedObjects without any data and sharding.
- abstractmethod remove_sharded_tensors(
- checkpoint_dir: Union[str, pathlib.Path],
- key_prefix: str,
Remove all tensors whose key starts with key_prefix
- class core.dist_checkpointing.strategies.base.SaveShardedStrategy(backend: str, version: int)#
Bases:
core.dist_checkpointing.strategies.base.SaveStrategyBaseBase class for save strategies to be removed in future releases.
Initialization
- abstractmethod save(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: Union[str, pathlib.Path],
Save the sharded part of the state dict.
- class core.dist_checkpointing.strategies.base.AsyncSaveShardedStrategy(backend: str, version: int)#
Bases:
core.dist_checkpointing.strategies.base.SaveShardedStrategySave strategy suitable for async save. To be removed in future releases.
Initialization
- abstractmethod async_save(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: Union[str, pathlib.Path],
Perform preparation and return an AsyncRequest to the external caller.
- Parameters:
sharded_state_dict (ShardedStateDict) – sharded state dict to save
checkpoint_dir (Path) – checkpoint target directory
- Returns:
represents the async save function and finalization function. It is the caller responsibility to actually schedule the async save.
- Return type:
- save(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: Union[str, pathlib.Path],
Each async strategy can be trivially used as a sync strategy.