Package defining different checkpoint formats (backends) and saving/loading algorithms (strategies).
Strategies can be used for implementing new checkpoint formats or implementing new (more optimal for a given use case) ways of saving/loading of existing formats. Strategies are passed to dist_checkpointing.load and dist_checkpointing.save functions and control the actual saving/loading procedure.
Strategies base interfaces.
- class core.dist_checkpointing.strategies.base.LoadCommonStrategy
Bases:
core.dist_checkpointing.strategies.base.LoadStrategyBase
Load strategy for common (non-sharded) objects
- abstract load(checkpoint_dir: pathlib.Path)
- class core.dist_checkpointing.strategies.base.LoadShardedStrategy
Bases:
core.dist_checkpointing.strategies.base.LoadStrategyBase
Load strategy for sharded tensors
- abstract load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- abstract load_tensors_metadata(checkpoint_dir: pathlib.Path)
Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).
- class core.dist_checkpointing.strategies.base.LoadStrategyBase
Bases:
abc.ABC
Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version.
- property can_handle_sharded_objects
Returns whether or not this strategy can handle loading ShardedObjects.
- abstract check_backend_compatibility(loaded_version)
- abstract check_version_compatibility(loaded_version)
- class core.dist_checkpointing.strategies.base.SaveCommonStrategy(backend: str, version: int)
Bases:
core.dist_checkpointing.strategies.base.SaveStrategyBase
Save strategy for common (non-sharded) objects
- abstract save(common_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- class core.dist_checkpointing.strategies.base.SaveShardedStrategy(backend: str, version: int)
Bases:
core.dist_checkpointing.strategies.base.SaveStrategyBase
Save strategy for sharded tensors
- abstract save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- class core.dist_checkpointing.strategies.base.SaveStrategyBase(backend: str, version: int)
Bases:
abc.ABC
Base class for a save strategy. Requires defining a backend type and version of the saved format.
- property can_handle_sharded_objects
Returns whether or not this strategy can handle saving ShardedObjects.
- class core.dist_checkpointing.strategies.base.StrategyAction(value)
Bases:
enum.Enum
An enumeration.
- LOAD_COMMON = 'load_common'
- LOAD_SHARDED = 'load_sharded'
- SAVE_COMMON = 'save_common'
- SAVE_SHARDED = 'save_sharded'
- core.dist_checkpointing.strategies.base.get_default_strategy(action: core.dist_checkpointing.strategies.base.StrategyAction, backend: str, version: int)
Retrieves a default strategy for a given action, backend and version.
Strategies using TensorStore to load and save Zarr arrays.
- class core.dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device: bool = False)
Bases:
core.dist_checkpointing.strategies.base.LoadShardedStrategy
- check_backend_compatibility(loaded_version)
- check_version_compatibility(loaded_version)
- load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- load_tensors_metadata(checkpoint_dir: pathlib.Path)
Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).
- core.dist_checkpointing.strategies.tensorstore.merge_global_slice_with_shape(global_slice, actual_shape, key)
- core.dist_checkpointing.strategies.tensorstore.open_ts_array(arr_path: pathlib.Path)
Opens a Zarr file array with Tensorstore with basic setting.
- Parameters
arr_path (Path) – path to a Zarr (Tensorstore) array
2-stage checkpoint loading.
- class core.dist_checkpointing.strategies.two_stage.TwoStageDataParallelLoadShardedStrategy(data_parallel_group, cpu_transfer=True)
Bases:
core.dist_checkpointing.strategies.base.LoadShardedStrategy
Loads one checkpoint replica from storage and broadcasts to other nodes.
This strategy loads checkpoint from storage on minimal set of nodes and distributes the checkpoint to other nodes with torch.distributed. Loading is performed with tensorstore.
Steps: 0. (optional) create Gloo distributed groups 1. Exchange ShardedTensors metadata between all nodes 2. Align needed tensors within DP groups 3. For each globally unique tensor: 3.a) on one of the ranks load it from storage to CPU and move to CUDA 3.b) allocate CUDA tensor on other ranks 3.c) broadcast within DP group 3.d) copy tensor content to the model param location 3.e) free tensor buffers from a) and b)
Notes: 1. Loading and broadcasting is done sequentially to avoid both host and device OOMs 2. There is a lot of overlap potential between all three steps done for each tensor: 2.a) loading from storage to numpy 2.b) moving CPU tensors to CUDA 2.c) broadcast
- check_backend_compatibility(loaded_version)
- check_version_compatibility(loaded_version)
- deduplicate_chunks(ten_metas: List[core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata])
Group tensors by chunk and then pick the tensor with the lowest rank.
- NOTE: with proper loading overlap, loading from randomized ranks
(instead of the smallest one) could be beneficial here.
- load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- load_tensor_from_storage(checkpoint_dir, ten_meta: core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata)
- load_tensors_metadata(checkpoint_dir: pathlib.Path)
Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).
- maybe_init_gloo_group()
- summarize_load_times()
- core.dist_checkpointing.strategies.two_stage.sharded_tensor_chunk_id(sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor)
- core.dist_checkpointing.strategies.two_stage.timed(verbose=True)
Strategies using Zarr as an underlying format.
- class core.dist_checkpointing.strategies.zarr.ZarrLoadShardedStrategy
Bases:
core.dist_checkpointing.strategies.base.LoadShardedStrategy
- check_backend_compatibility(loaded_version)
- check_version_compatibility(loaded_version)
- load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- load_tensors_metadata(checkpoint_dir: pathlib.Path)
Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).
- class core.dist_checkpointing.strategies.zarr.ZarrSaveShardedStrategy(backend: str, version: int)
Bases:
core.dist_checkpointing.strategies.base.SaveShardedStrategy
- save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- core.dist_checkpointing.strategies.zarr.flatten_range(sharded_tensor, x)
- core.dist_checkpointing.strategies.zarr.load_zarr_based_sharded_metadata(checkpoint_dir: pathlib.Path, get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], numpy.dtype]]) → Dict[str, Any]
Load metadata of Zarr arrays.
- Parameters
checkpoint_dir (str) – checkpoint root directory
get_shape_dtype_fn (str -> ((int, ...), np.dtype)) – a function returning an array shape and dtype for a given Zarr array path
- core.dist_checkpointing.strategies.zarr.pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: core.dist_checkpointing.mapping.ShardedTensor)
- core.dist_checkpointing.strategies.zarr.postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True)
Various loading and saving strategies