dist_checkpointing.strategies package

Package defining different checkpoint formats (backends) and saving/loading algorithms (strategies).

Strategies can be used for implementing new checkpoint formats or implementing new (more optimal for a given use case) ways of saving/loading of existing formats. Strategies are passed to dist_checkpointing.load and dist_checkpointing.save functions and control the actual saving/loading procedure.

Strategies base interfaces.

class core.dist_checkpointing.strategies.base.LoadCommonStrategy

Bases: core.dist_checkpointing.strategies.base.LoadStrategyBase

Load strategy for common (non-sharded) objects

abstract load(checkpoint_dir: pathlib.Path)

class core.dist_checkpointing.strategies.base.LoadShardedStrategy

Bases: core.dist_checkpointing.strategies.base.LoadStrategyBase

Load strategy for sharded tensors

abstract load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

abstract load_tensors_metadata(checkpoint_dir: pathlib.Path)

Load tensors metadata from the checkpoint.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).

class core.dist_checkpointing.strategies.base.LoadStrategyBase

Bases: abc.ABC

Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version.

property can_handle_sharded_objects

Returns whether or not this strategy can handle loading ShardedObjects.

abstract check_backend_compatibility(loaded_version)

abstract check_version_compatibility(loaded_version)

class core.dist_checkpointing.strategies.base.SaveCommonStrategy(backend: str, version: int)

Bases: core.dist_checkpointing.strategies.base.SaveStrategyBase

Save strategy for common (non-sharded) objects

abstract save(common_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

class core.dist_checkpointing.strategies.base.SaveShardedStrategy(backend: str, version: int)

Bases: core.dist_checkpointing.strategies.base.SaveStrategyBase

Save strategy for sharded tensors

abstract save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

class core.dist_checkpointing.strategies.base.SaveStrategyBase(backend: str, version: int)

Bases: abc.ABC

Base class for a save strategy. Requires defining a backend type and version of the saved format.

property can_handle_sharded_objects

Returns whether or not this strategy can handle saving ShardedObjects.

class core.dist_checkpointing.strategies.base.StrategyAction(value)

Bases: enum.Enum

An enumeration.

LOAD_COMMON = 'load_common'

LOAD_SHARDED = 'load_sharded'

SAVE_COMMON = 'save_common'

SAVE_SHARDED = 'save_sharded'

core.dist_checkpointing.strategies.base.get_default_strategy(action: core.dist_checkpointing.strategies.base.StrategyAction, backend: str, version: int)

Retrieves a default strategy for a given action, backend and version.

Strategies using TensorStore to load and save Zarr arrays.

class core.dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device: bool = False)

Bases: core.dist_checkpointing.strategies.base.LoadShardedStrategy

check_backend_compatibility(loaded_version)

check_version_compatibility(loaded_version)

load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

load_tensors_metadata(checkpoint_dir: pathlib.Path)

Load tensors metadata from the checkpoint.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).

core.dist_checkpointing.strategies.tensorstore.merge_global_slice_with_shape(global_slice, actual_shape, key)

core.dist_checkpointing.strategies.tensorstore.open_ts_array(arr_path: pathlib.Path)

Opens a Zarr file array with Tensorstore with basic setting.

Parameters

arr_path (Path) – path to a Zarr (Tensorstore) array

2-stage checkpoint loading.

class core.dist_checkpointing.strategies.two_stage.TwoStageDataParallelLoadShardedStrategy(data_parallel_group, cpu_transfer=True)

Bases: core.dist_checkpointing.strategies.base.LoadShardedStrategy

Loads one checkpoint replica from storage and broadcasts to other nodes.

This strategy loads checkpoint from storage on minimal set of nodes and distributes the checkpoint to other nodes with torch.distributed. Loading is performed with tensorstore.

Steps: 0. (optional) create Gloo distributed groups 1. Exchange ShardedTensors metadata between all nodes 2. Align needed tensors within DP groups 3. For each globally unique tensor: 3.a) on one of the ranks load it from storage to CPU and move to CUDA 3.b) allocate CUDA tensor on other ranks 3.c) broadcast within DP group 3.d) copy tensor content to the model param location 3.e) free tensor buffers from a) and b)

Notes: 1. Loading and broadcasting is done sequentially to avoid both host and device OOMs 2. There is a lot of overlap potential between all three steps done for each tensor: 2.a) loading from storage to numpy 2.b) moving CPU tensors to CUDA 2.c) broadcast

check_backend_compatibility(loaded_version)

check_version_compatibility(loaded_version)

deduplicate_chunks(ten_metas: List[core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata])

Group tensors by chunk and then pick the tensor with the lowest rank.

NOTE: with proper loading overlap, loading from randomized ranks

(instead of the smallest one) could be beneficial here.

load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

load_tensor_from_storage(checkpoint_dir, ten_meta: core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata)

load_tensors_metadata(checkpoint_dir: pathlib.Path)

Load tensors metadata from the checkpoint.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).

maybe_init_gloo_group()

summarize_load_times()

core.dist_checkpointing.strategies.two_stage.sharded_tensor_chunk_id(sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor)

core.dist_checkpointing.strategies.two_stage.timed(verbose=True)

Strategies using Zarr as an underlying format.

class core.dist_checkpointing.strategies.zarr.ZarrLoadShardedStrategy

Bases: core.dist_checkpointing.strategies.base.LoadShardedStrategy

check_backend_compatibility(loaded_version)

check_version_compatibility(loaded_version)

load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

load_tensors_metadata(checkpoint_dir: pathlib.Path)

Load tensors metadata from the checkpoint.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).

class core.dist_checkpointing.strategies.zarr.ZarrSaveShardedStrategy(backend: str, version: int)

Bases: core.dist_checkpointing.strategies.base.SaveShardedStrategy

save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

core.dist_checkpointing.strategies.zarr.flatten_range(sharded_tensor, x)

core.dist_checkpointing.strategies.zarr.load_zarr_based_sharded_metadata(checkpoint_dir: pathlib.Path, get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], numpy.dtype]]) → Dict[str, Any]

Load metadata of Zarr arrays.

Parameters
  • checkpoint_dir (str) – checkpoint root directory

  • get_shape_dtype_fn (str -> ((int, ...), np.dtype)) – a function returning an array shape and dtype for a given Zarr array path

core.dist_checkpointing.strategies.zarr.pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: core.dist_checkpointing.mapping.ShardedTensor)

core.dist_checkpointing.strategies.zarr.postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True)

Various loading and saving strategies

Previous dist_checkpointing package
Next distributed package
© Copyright 2022-2024, NVIDIA. Last updated on Mar 16, 2024.