core.dist_checkpointing.strategies.base#

Strategies base interfaces.

Module Contents#

Classes#

StrategyAction

Specifies save vs load and sharded vs common action.

LoadStrategyBase

Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version.

SaveStrategyBase

Base class for a save strategy. Requires defining a backend type and version of the saved format.

LoadCommonStrategy

Load strategy for common (non-sharded) objects

LoadShardedStrategy

Load strategy for sharded tensors

SaveCommonStrategy

Save strategy for common (non-sharded) objects

SaveShardedStrategy

Save strategy for sharded tensors

AsyncSaveShardedStrategy

Save strategy suitable for async save.

Functions#

get_default_strategy

Retrieves a default strategy for a given action, backend and version.

register_default_strategy

Adds a given strategy to the registry of default strategies.

Data#

API#

class core.dist_checkpointing.strategies.base.StrategyAction(*args, **kwds)#

Bases: enum.Enum

Specifies save vs load and sharded vs common action.

Initialization

LOAD_COMMON#

‘load_common’

LOAD_SHARDED#

‘load_sharded’

SAVE_COMMON#

‘save_common’

SAVE_SHARDED#

‘save_sharded’

core.dist_checkpointing.strategies.base.default_strategies: DefaultDict[str, dict[tuple, Any]]#

‘defaultdict(…)’

core.dist_checkpointing.strategies.base.async_calls#

‘AsyncCallsQueue(…)’

core.dist_checkpointing.strategies.base.get_default_strategy(
action: core.dist_checkpointing.strategies.base.StrategyAction,
backend: str,
version: int,
)#

Retrieves a default strategy for a given action, backend and version.

core.dist_checkpointing.strategies.base.register_default_strategy(
action: core.dist_checkpointing.strategies.base.StrategyAction,
backend: str,
version: int,
strategy: Union[SaveStrategyBase, LoadStrategyBase],
)#

Adds a given strategy to the registry of default strategies.

Parameters:
  • action (StrategyAction) – specifies save/load and sharded/common

  • backend (str) – backend that the strategy becomes a default for

  • version (int) – version that the strategy becomes a default for

  • strategy (SaveStrategyBase, LoadStrategyBase) – strategy to register

class core.dist_checkpointing.strategies.base.LoadStrategyBase#

Bases: abc.ABC

Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version.

abstractmethod check_backend_compatibility(loaded_backend)#

Verifies if this strategy is compatible with loaded_backend.

abstractmethod check_version_compatibility(loaded_version)#

Verifies if this strategy is compatible with loaded_version.

property can_handle_sharded_objects#

Returns whether or not this strategy can handle loading ShardedObjects.

class core.dist_checkpointing.strategies.base.SaveStrategyBase(backend: str, version: int)#

Bases: abc.ABC

Base class for a save strategy. Requires defining a backend type and version of the saved format.

Initialization

property can_handle_sharded_objects#

Returns whether or not this strategy can handle saving ShardedObjects.

__str__()#
class core.dist_checkpointing.strategies.base.LoadCommonStrategy#

Bases: core.dist_checkpointing.strategies.base.LoadStrategyBase

Load strategy for common (non-sharded) objects

abstractmethod load_common(checkpoint_dir: Union[str, pathlib.Path])#

Load common part of the checkpoint.

abstractmethod load_sharded_objects(
sharded_objects_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#

Load sharded objects from the checkpoint.

load_sharded_metadata(
checkpoint_dir: Union[str, pathlib.Path],
) core.dist_checkpointing.mapping.ShardedStateDict#

Load just the metadata from the checkpoint.

class core.dist_checkpointing.strategies.base.LoadShardedStrategy#

Bases: core.dist_checkpointing.strategies.base.LoadStrategyBase

Load strategy for sharded tensors

abstractmethod load(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#

Load the sharded part of the checkpoint.

abstractmethod load_tensors_metadata(checkpoint_dir: Union[str, pathlib.Path])#

Load tensors metadata from the checkpoint for ShardedTensors.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any data and sharding (so, the only useful information is tensors global shape and dtype).

load_sharded_metadata(checkpoint_dir: Union[str, pathlib.Path])#

Load sharded metadata from the checkpoint for ShardedTensors and ShardedObjects.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply sharded keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors or ShardedObjects without any data and sharding.

abstractmethod remove_sharded_tensors(
checkpoint_dir: Union[str, pathlib.Path],
key_prefix: str,
)#

Remove all tensors whose key starts with key_prefix

class core.dist_checkpointing.strategies.base.SaveCommonStrategy(backend: str, version: int)#

Bases: core.dist_checkpointing.strategies.base.SaveStrategyBase

Save strategy for common (non-sharded) objects

Initialization

abstractmethod save_common(
common_state_dict: core.dist_checkpointing.mapping.StateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#

Save common part of the state dict.

abstractmethod save_sharded_objects(
sharded_objects_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#

Save sharded objects from the state dict.

class core.dist_checkpointing.strategies.base.SaveShardedStrategy(backend: str, version: int)#

Bases: core.dist_checkpointing.strategies.base.SaveStrategyBase

Save strategy for sharded tensors

Initialization

abstractmethod save(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#

Save the sharded part of the state dict.

class core.dist_checkpointing.strategies.base.AsyncSaveShardedStrategy(backend: str, version: int)#

Bases: core.dist_checkpointing.strategies.base.SaveShardedStrategy

Save strategy suitable for async save.

Initialization

abstractmethod async_save(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
) core.dist_checkpointing.strategies.async_utils.AsyncRequest#

Perform preparation and return an AsyncRequest to the external caller.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict to save

  • checkpoint_dir (Path) – checkpoint target directory

Returns:

represents the async save function and finalization function. It is the caller responsibility to actually schedule the async save.

Return type:

AsyncRequest

save(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: Union[str, pathlib.Path],
)#

Each async strategy can be trivially used as a sync strategy.