core.dist_checkpointing.strategies.base#

Strategies base interfaces.

Module Contents#

Classes#

StrategyAction

Specifies save vs load and sharded vs common action. To be removed in future releases.

LoadStrategyBase

Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version.

SaveStrategyBase

Base class for a save strategy. Requires defining a backend type and version of the saved format.

LoadShardedStrategy

Base class for load strategies to be removed in future releases.

SaveShardedStrategy

Base class for save strategies to be removed in future releases.

AsyncSaveShardedStrategy

Save strategy suitable for async save. To be removed in future releases.

Functions#

get_default_strategy

Retrieves a default strategy for a given action, backend and version.

Data#

API#

core.dist_checkpointing.strategies.base.logger#

‘getLogger(…)’

class core.dist_checkpointing.strategies.base.StrategyAction(*args, **kwds)#

Bases: enum.Enum

Specifies save vs load and sharded vs common action. To be removed in future releases.

Initialization

LOAD_COMMON#

‘load_common’

LOAD_SHARDED#

‘load_sharded’

SAVE_COMMON#

‘save_common’

SAVE_SHARDED#

‘save_sharded’

core.dist_checkpointing.strategies.base.get_default_strategy(
action: core.dist_checkpointing.strategies.base.StrategyAction,
backend: str,
version: int,
)#

Retrieves a default strategy for a given action, backend and version.

class core.dist_checkpointing.strategies.base.LoadStrategyBase#

Bases: abc.ABC

Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version.

Initialization

abstractmethod check_backend_compatibility(loaded_backend)#

Verifies if this strategy is compatible with loaded_backend.

abstractmethod check_version_compatibility(loaded_version)#

Verifies if this strategy is compatible with loaded_version.

property can_handle_sharded_objects#

Returns whether or not this strategy can handle loading ShardedObjects.

class core.dist_checkpointing.strategies.base.SaveStrategyBase(backend: str, version: int)#

Bases: abc.ABC

Base class for a save strategy. Requires defining a backend type and version of the saved format.

Initialization

property can_handle_sharded_objects#

Returns whether or not this strategy can handle saving ShardedObjects.

__str__()#
class core.dist_checkpointing.strategies.base.LoadShardedStrategy#

Bases: core.dist_checkpointing.strategies.base.LoadStrategyBase

Base class for load strategies to be removed in future releases.

Initialization

abstractmethod load(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#

Load the sharded part of the checkpoint.

abstractmethod load_tensors_metadata(checkpoint_dir: Union[str, pathlib.Path])#

Load tensors metadata from the checkpoint for ShardedTensors.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any data and sharding (so, the only useful information is tensors global shape and dtype).

load_sharded_metadata(checkpoint_dir: Union[str, pathlib.Path])#

Load sharded metadata from the checkpoint for ShardedTensors and ShardedObjects.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply sharded keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors or ShardedObjects without any data and sharding.

abstractmethod remove_sharded_tensors(
checkpoint_dir: Union[str, pathlib.Path],
key_prefix: str,
)#

Remove all tensors whose key starts with key_prefix

class core.dist_checkpointing.strategies.base.SaveShardedStrategy(backend: str, version: int)#

Bases: core.dist_checkpointing.strategies.base.SaveStrategyBase

Base class for save strategies to be removed in future releases.

Initialization

abstractmethod save(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#

Save the sharded part of the state dict.

class core.dist_checkpointing.strategies.base.AsyncSaveShardedStrategy(backend: str, version: int)#

Bases: core.dist_checkpointing.strategies.base.SaveShardedStrategy

Save strategy suitable for async save. To be removed in future releases.

Initialization

abstractmethod async_save(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
) core.dist_checkpointing.strategies.async_utils.AsyncRequest#

Perform preparation and return an AsyncRequest to the external caller.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict to save

  • checkpoint_dir (Path) – checkpoint target directory

Returns:

represents the async save function and finalization function. It is the caller responsibility to actually schedule the async save.

Return type:

AsyncRequest

save(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#

Each async strategy can be trivially used as a sync strategy.