dist_checkpointing.strategies package
Package defining different checkpoint formats (backends) and saving/loading algorithms (strategies).
Strategies can be used for implementing new checkpoint formats or implementing new (more optimal for a given use case) ways of saving/loading of existing formats. Strategies are passed to dist_checkpointing.load and dist_checkpointing.save functions and control the actual saving/loading procedure.
Strategies base interfaces.
- class core.dist_checkpointing.strategies.base.AsyncSaveShardedStrategy(backend: str, version: int)
Bases:
core.dist_checkpointing.strategies.base.SaveShardedStrategy
Save strategy suitable for async save.
- abstract async_save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path) → core.dist_checkpointing.strategies.async_utils.AsyncRequest
Perform preparation and return an AsyncRequest to the external caller.
- Parameters
sharded_state_dict (ShardedStateDict) – sharded state dict to save
checkpoint_dir (Path) – checkpoint target directory
- Returns
- represents the async save function and finalization function.
It is the caller responsibility to actually schedule the async save.
- Return type
AsyncRequest
- save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
Each async strategy can be trivially used as a sync strategy.
- class core.dist_checkpointing.strategies.base.LoadCommonStrategy
Bases:
core.dist_checkpointing.strategies.base.LoadStrategyBase
Load strategy for common (non-sharded) objects
- abstract load_common(checkpoint_dir: pathlib.Path)
Load common part of the checkpoint.
- load_sharded_metadata(checkpoint_dir: pathlib.Path) → Dict[str, Any]
Load just the metadata from the checkpoint.
- abstract load_sharded_objects(sharded_objects_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
Load sharded objects from the checkpoint.
- class core.dist_checkpointing.strategies.base.LoadShardedStrategy
Bases:
core.dist_checkpointing.strategies.base.LoadStrategyBase
Load strategy for sharded tensors
- abstract load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
Load the sharded part of the checkpoint.
- load_sharded_metadata(checkpoint_dir: pathlib.Path)
Load sharded metadata from the checkpoint for ShardedTensors and ShardedObjects.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply sharded keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors or ShardedObjects without any data and sharding.
- abstract load_tensors_metadata(checkpoint_dir: pathlib.Path)
Load tensors metadata from the checkpoint for ShardedTensors.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any data and sharding (so, the only useful information is tensors global shape and dtype).
- remove_sharded_tensors(checkpoint_dir: str, key_prefix: str)
Remove all tensors whose key starts with key_prefix
- class core.dist_checkpointing.strategies.base.LoadStrategyBase
Bases:
abc.ABC
Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version.
- property can_handle_sharded_objects
Returns whether or not this strategy can handle loading ShardedObjects.
- abstract check_backend_compatibility(loaded_backend)
Verifies if this strategy is compatible with loaded_backend.
- abstract check_version_compatibility(loaded_version)
Verifies if this strategy is compatible with loaded_version.
- class core.dist_checkpointing.strategies.base.SaveCommonStrategy(backend: str, version: int)
Bases:
core.dist_checkpointing.strategies.base.SaveStrategyBase
Save strategy for common (non-sharded) objects
- abstract save_common(common_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
Save common part of the state dict.
- save_sharded_objects(sharded_objects_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
Save sharded objects from the state dict.
- class core.dist_checkpointing.strategies.base.SaveShardedStrategy(backend: str, version: int)
Bases:
core.dist_checkpointing.strategies.base.SaveStrategyBase
Save strategy for sharded tensors
- abstract save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
Save the sharded part of the state dict.
- class core.dist_checkpointing.strategies.base.SaveStrategyBase(backend: str, version: int)
Bases:
abc.ABC
Base class for a save strategy. Requires defining a backend type and version of the saved format.
- property can_handle_sharded_objects
Returns whether or not this strategy can handle saving ShardedObjects.
- class core.dist_checkpointing.strategies.base.StrategyAction(value)
Bases:
enum.Enum
Specifies save vs load and sharded vs common action.
- LOAD_COMMON = 'load_common'
- LOAD_SHARDED = 'load_sharded'
- SAVE_COMMON = 'save_common'
- SAVE_SHARDED = 'save_sharded'
- core.dist_checkpointing.strategies.base.get_default_strategy(action: core.dist_checkpointing.strategies.base.StrategyAction, backend: str, version: int)
Retrieves a default strategy for a given action, backend and version.
- core.dist_checkpointing.strategies.base.register_default_strategy(action: core.dist_checkpointing.strategies.base.StrategyAction, backend: str, version: int, strategy: Union[core.dist_checkpointing.strategies.base.SaveStrategyBase, core.dist_checkpointing.strategies.base.LoadStrategyBase])
Adds a given strategy to the registry of default strategies.
- Parameters
action (StrategyAction) – specifies save/load and sharded/common
backend (str) – backend that the strategy becomes a default for
version (int) – version that the strategy becomes a default for
strategy (SaveStrategyBase, LoadStrategyBase) – strategy to register
Strategies using TensorStore to load and save Zarr arrays.
- class core.dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device: bool = False)
Bases:
core.dist_checkpointing.strategies.base.LoadShardedStrategy
Load strategy for Zarr backend using tensorstore for loading.
- check_backend_compatibility(loaded_version)
Verifies if this strategy is compatible with loaded_backend.
- check_version_compatibility(loaded_version)
Verifies if this strategy is compatible with loaded_version.
- load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
Load the sharded part of the checkpoint.
- load_tensors_metadata(checkpoint_dir: pathlib.Path)
Load tensors metadata from the checkpoint for ShardedTensors.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any data and sharding (so, the only useful information is tensors global shape and dtype).
- core.dist_checkpointing.strategies.tensorstore.merge_global_slice_with_shape(global_slice, actual_shape, key)
Intersects the global slice with the actual shape (prevent overflow).
- core.dist_checkpointing.strategies.tensorstore.open_ts_array(arr_path: pathlib.Path)
Opens a Zarr file array with Tensorstore with basic setting.
- Parameters
arr_path (Path) – path to a Zarr (Tensorstore) array
- core.dist_checkpointing.strategies.tensorstore.register_default_tensorstore_strategies()
Register default strategies leveraging tensorstore.
2-stage checkpoint loading.
- class core.dist_checkpointing.strategies.two_stage.TwoStageDataParallelLoadShardedStrategy(data_parallel_group, cpu_transfer=True)
Bases:
core.dist_checkpointing.strategies.base.LoadShardedStrategy
Loads one checkpoint replica from storage and broadcasts to other nodes.
This strategy loads checkpoint from storage on minimal set of nodes and distributes the checkpoint to other nodes with torch.distributed. Loading is performed with tensorstore.
Steps: 0. (optional) create Gloo distributed groups 1. Exchange ShardedTensors metadata between all nodes 2. Align needed tensors within DP groups 3. For each globally unique tensor: 3.a) on one of the ranks load it from storage to CPU and move to CUDA 3.b) allocate CUDA tensor on other ranks 3.c) broadcast within DP group 3.d) copy tensor content to the model param location 3.e) free tensor buffers from a) and b)
Notes: 1. Loading and broadcasting is done sequentially to avoid both host and device OOMs 2. There is a lot of overlap potential between all three steps done for each tensor: 2.a) loading from storage to numpy 2.b) moving CPU tensors to CUDA 2.c) broadcast
- check_backend_compatibility(loaded_version)
Verifies if this strategy is compatible with loaded_backend.
- check_version_compatibility(loaded_version)
Verifies if this strategy is compatible with loaded_version.
- deduplicate_chunks(ten_metas: List[core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata])
Group tensors by chunk and then pick the tensor with the lowest rank.
- NOTE: with proper loading overlap, loading from randomized ranks
(instead of the smallest one) could be beneficial here.
- load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
Load the sharded part of the checkpoint.
- load_tensor_from_storage(checkpoint_dir, ten_meta: core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata)
- load_tensors_metadata(checkpoint_dir: pathlib.Path)
Load tensors metadata from the checkpoint for ShardedTensors.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any data and sharding (so, the only useful information is tensors global shape and dtype).
- maybe_init_gloo_group()
- summarize_load_times()
- core.dist_checkpointing.strategies.two_stage.sharded_tensor_chunk_id(sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor)
- core.dist_checkpointing.strategies.two_stage.timed(verbose=True)
Strategies using Zarr as an underlying format.
- class core.dist_checkpointing.strategies.zarr.ZarrLoadShardedStrategy
Bases:
core.dist_checkpointing.strategies.base.LoadShardedStrategy
Load strategy for the Zarr backend.
- check_backend_compatibility(loaded_version)
Verifies if this strategy is compatible with loaded_backend.
- check_version_compatibility(loaded_version)
Verifies if this strategy is compatible with loaded_version.
- load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
Load the sharded part of the checkpoint.
- load_tensors_metadata(checkpoint_dir: pathlib.Path)
Load tensors metadata from the checkpoint for ShardedTensors.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any data and sharding (so, the only useful information is tensors global shape and dtype).
- class core.dist_checkpointing.strategies.zarr.ZarrSaveShardedStrategy(backend: str, version: int)
Bases:
core.dist_checkpointing.strategies.base.SaveShardedStrategy
Save strategy for Zarr backend.
- save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
Save the sharded part of the state dict.
- core.dist_checkpointing.strategies.zarr.flatten_range(sharded_tensor, x)
Apply flattened range to a tensor.
- core.dist_checkpointing.strategies.zarr.load_zarr_based_sharded_metadata(checkpoint_dir: pathlib.Path, get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], numpy.dtype]]) → Dict[str, Any]
Load metadata of Zarr arrays.
- Parameters
checkpoint_dir (str) – checkpoint root directory
get_shape_dtype_fn (str -> ((int, ...), np.dtype)) – a function returning an array shape and dtype for a given Zarr array path
- core.dist_checkpointing.strategies.zarr.pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: core.dist_checkpointing.mapping.ShardedTensor)
Pad tensor to the expected shape.
- core.dist_checkpointing.strategies.zarr.postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True)
Turn numpy array to torch tensor.
- core.dist_checkpointing.strategies.zarr.register_default_zarr_strategies()
Register default strategies related to Zarr backend.
Various loading and saving strategies