A library for saving and loading the distributed checkpoints. A “distributed checkpoint” can have various underlying formats (current default format is based on Zarr) but has a distinctive property - the checkpoint saved in one parallel configuration (tensor/pipeline/data parallelism) can be loaded in a different parallel configuration.
Using the library requires defining sharded state_dict dictionaries with functions from mapping and optimizer modules. Those state dicts can be saved or loaded with a serialization module using strategies from strategies module.
Entrypoints for saving and loading the distributed checkpoints.
Functions load and save are equivalents of torch.load and torch.save but expect torch.Tensors to be wrapped with classes from the mapping module. Additionally, load expects the sharded state dict argument as a guidance for loading the sharded tensors.
- core.dist_checkpointing.serialization.load(sharded_state_dict: Dict[str, Any], checkpoint_dir: str, sharded_strategy: Optional[Union[core.dist_checkpointing.strategies.base.LoadShardedStrategy, Tuple[str, int]]] = None, common_strategy: Optional[Union[core.dist_checkpointing.strategies.base.LoadCommonStrategy, Tuple[str, int]]] = None, validate_access_integrity: bool = True) → Dict[str, Any]
Loading entrypoint.
In the steps below, the following verbs refer to corresponding objects: - load = load from checkpoint - extract = extract from sharded_state_dict - add = add to the final state dict Steps: 1. Load common state dict and form the base of the result state dict 2. Apply factories to sharded_state_dict 3. Extract LocalNonPersistentObject and add 4. (optional) Extract ShardedObjects, load and add 5. Extract ShardedBase, load, apply factory merges and add
- Parameters
sharded_state_dict (ShardedStateDict) – state dict of the existing model populated with ShardedTensors. Used as a mapping to determine which parts of global tensors stored in the checkpoint should be loaded.
checkpoint_dir (str) – directory with the checkpoint
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional) – configures loading behavior for sharded tensors
common_strategy (LoadCommonStrategy, Tuple[str, int], optional) – configures loading behavior for common data
validate_access_integrity (bool default = True) – checks if each tensor shard is accessed exactly once (as main replica) by some process
- core.dist_checkpointing.serialization.load_common_state_dict(checkpoint_dir: pathlib.Path) → Dict[str, Any]
Load common (non-sharded) objects state dict from the checkpoint.
- Parameters
- Returns
- Return type
checkpoint_dir (Path) – checkpoint directory
state dict with non-sharded objects from the checkpoint
StateDict
- core.dist_checkpointing.serialization.load_plain_tensors(checkpoint_dir: str)
Load checkpoint tensors without any sharding.
NOTE: common state dict is NOT included.
- core.dist_checkpointing.serialization.load_sharded_objects(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
Replaces all ShardedObject from a given state dict with values loaded from the checkpoint.
- Parameters
sharded_state_dict (ShardedStateDict) – sharded state dict defining what objects should be loaded.
checkpoint_dir (Path) – checkpoint directory
- Returns
- Return type
state dict is modified in place
None
- core.dist_checkpointing.serialization.load_tensors_metadata(checkpoint_dir: str, sharded_strategy: Optional[core.dist_checkpointing.strategies.base.LoadShardedStrategy] = None) → Dict[str, Any]
Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is given, a default for a given backend is used.
- core.dist_checkpointing.serialization.save(sharded_state_dict: Dict[str, Any], checkpoint_dir: str, sharded_strategy: Optional[Union[core.dist_checkpointing.strategies.base.SaveShardedStrategy, Tuple[str, int]]] = None, common_strategy: Optional[Union[core.dist_checkpointing.strategies.base.SaveCommonStrategy, Tuple[str, int]]] = None, validate_access_integrity: bool = True) → None
Saving entrypoint.
Extracts ShardedTensors from the given state dict. Rank 0 saves the “regular” part of the checkpoint to common torch file. The ShardedTensors are saved according to a strategy specified by the config.
Steps: 1. Apply factories 2. Extract and discard LocalNonPersistentObject 3. Extract all ShardedBase object 4. Save all other objects to common.pt 5. (optional) Extract and save ShardedObjects 6. Save all ShardedBase objects
- Parameters
sharded_state_dict (ShardedStateDict) – state dict of the populated with ShardedTensors. Used as a mapping to determine how local tensors should be saved as global tensors in the checkpoint.
checkpoint_dir (str) – directory to save the checkpoint to
sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional) – configures sharded tensors saving behavior and backend
common_strategy (SaveCommonStrategy, Tuple[str, int], optional) – configures common data saving behavior and backend
validate_access_integrity (bool default = True) – checks if each tensor shard is accessed exactly once (as main replica) by some process
- core.dist_checkpointing.serialization.validate_sharding_integrity(sharded_tensors: Iterable[core.dist_checkpointing.mapping.ShardedTensor])
Validate if the ShardedTensors from multiple processes define correct sharding of a global tensor.
Local ShardedTensors metadata is exchanged with torch.distributed.all_gather_object and then process with global rank 0 checks if main replicas of the shards: - cover the whole global tensors - don’t overlap
- Parameters
- Returns
- Raises
sharded_tensors (Iterable[ShardedTensor]) – sharded tensors local to this process
None
CheckpointingException for invalid access pattern –
Core library classes for representing sharding of tensors and objects.
The main expected usage is wrapping torch.Tensors in state dicts with ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod).
- class core.dist_checkpointing.mapping.LocalNonpersitentObject(obj)
Bases:
object
Object that should not be stored in a checkpoint, but restored locally.
Wrapping any object inside the state dict with LocalNonpersitentObject will result in: - during saving, this object will not be stored in the checkpoint - during loading, a local version of this object will be placed in a state dict
- unwrap()
- class core.dist_checkpointing.mapping.ShardedBase
Bases:
abc.ABC
- data: object
- key: str
- replica_id: Union[int, Tuple[int, ...]]
- class core.dist_checkpointing.mapping.ShardedObject(key: str, data: object, global_shape: Tuple[int, ...], global_offset: Tuple[int, ...], replica_id: Union[int, Tuple[int, ...]] = 0)
Bases:
core.dist_checkpointing.mapping.ShardedBase
Represents a mapping between a local object and a global object.
Global object is assumed to consist of many local objects distributed between different processes.
NOTE: Contrary to ShardedTensor, it’s impossible to change global object sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor with atomic arbitrary typed elements.
- Parameters
key – unique identifier of a global tensor
data – local object data. Can be None only for consistency validation
global_shape – global object shape
global_offset – offset of a local object in a global object, specified in number of shards
replica_id – indicates local object replication wrt. local objects in different processes
- data: object
- global_offset: Tuple[int, ...]
- global_shape: Tuple[int, ...]
- key: str
- replica_id: Union[int, Tuple[int, ...]] = 0
- property unique_key
- without_data()
- class core.dist_checkpointing.mapping.ShardedTensor(key: str, data: Optional[torch.Tensor], dtype: torch.dtype, local_shape: Tuple[int, ...], global_shape: Tuple[int, ...], global_offset: Tuple[int, ...], axis_fragmentations: Optional[Tuple[int, ...]], replica_id: Union[int, Tuple[int, ...]] = 0, prepend_axis_num: int = 0, allow_shape_mismatch: bool = False, flattened_range: Optional[slice] = None)
Bases:
core.dist_checkpointing.mapping.ShardedBase
Represents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed between different processes.
- Parameters
key – unique identifier of a global tensor
data – local tensor data. Can be None only for consistency validation
dtype – tensor dtype
local_shape – local tensor shape
global_shape – global tensor shape
global_offset – offset of a local tensor in a global tensor, specified in number of tensor elements
axis_fragmentations – global tensor fragmentation of each axis
replica_id – indicates given local tensor’s replication wrt. local tensors in different processes
prepend_axis_num – number of axes prepended to the local tensor to reflect global tensor shape. The behavior is similar to unsqueezing the local tensor.
allow_shape_mismatch – if True, during loading, the global shape of a stored tensor does not have to match the expected global shape. Useful for representing tensors with flexible shape, e.g. padded.
flattened_range – specifies a slice that should be applied to a flattened tensor with local_shape in order to get the tensor stored as data
- allow_shape_mismatch: bool = False
- axis_fragmentations: Optional[Tuple[int, ...]]
- data: Optional[torch.Tensor]
- dtype: torch.dtype
- flattened_range: Optional[slice] = None
- classmethod from_rank_offsets(key: str, data: torch.Tensor, *rank_offsets: Tuple[int, int, int], replica_id: Union[int, Tuple[int, ...]] = 0, prepend_axis_num: int = 0, allow_shape_mismatch: bool = False)
Allows to construct the ShardedTensor given offset specified in process ranks.
- Parameters
key – unique key
data – local tensor data
rank_offsets – each tuple (axis, axis_rank_offset, axis_fragm) says that if global tensor is divided into axis_fragm fragment along axis axis, then local tensor data corresponds to the axis_rank_offset chunk.
replica_id – see ShardedTensor
prepend_axis_num – see ShardedTensor
allow_shape_mismatch – see ShardedTensor
- global_coordinates() → Tuple[numpy.ndarray, ...]
- global_offset: Tuple[int, ...]
- global_shape: Tuple[int, ...]
- global_slice() → Tuple[Union[int, slice], ...]
- init_data(device: torch.device, init_fn=torch.empty)
- key: str
- local_coordinates() → Tuple[numpy.ndarray, ...]
- local_shape: Tuple[int, ...]
- max_allowed_chunks() → Tuple[int, ...]
- prepend_axis_num: int = 0
- replica_id: Union[int, Tuple[int, ...]] = 0
- without_data()
- class core.dist_checkpointing.mapping.ShardedTensorFactory(key: str, data: torch.Tensor, build_fn: Callable[[str, torch.Tensor, Union[int, Tuple[int, ...]]], Dict[str, Any]], merge_fn: Callable[[Dict[str, Any]], torch.Tensor], replica_id: Union[int, Tuple[int, ...]] = 0)
Bases:
core.dist_checkpointing.mapping.ShardedBase
Allows to apply transformations to tensors before/after serialization.
The essence of those transformations is that they can be applied to optimizer states the same way they are applied to the model params.
Builder creates a sub-state-dict out of a tensor before saving, and merger merges the corresponding state dict after loading.
- Parameters
key (str) – unique identifier of the factory
data (torch.Tensor) – original model parameter that will be further transformed by this factory
build_fn (callable) – function that transforms the original tensor to a sharded state dict
merge_fn (callable) – function that transforms loaded subtree back into a single tensor (inverse of build_fn)
replica_id (ReplicaId) – indicates factory replication wrt. factories in different processes
- build()
- build_fn: Callable[[str, torch.Tensor, Union[int, Tuple[int, ...]]], Dict[str, Any]]
- data: torch.Tensor
- key: str
- merge_fn: Callable[[Dict[str, Any]], torch.Tensor]
- replica_id: Union[int, Tuple[int, ...]] = 0
- core.dist_checkpointing.mapping.apply_factories(sharded_state_dict: Dict[str, Any])
Turn ShardedTensorFactories into ShardedTensors in-place.
- Parameters
- Returns
- Return type
sharded_state_dict (ShardedStateDict) – state dict possibly containing ShardedTensorFactory objects
state dict is modified in place
None
- core.dist_checkpointing.mapping.apply_factory_merges(x1: Dict[str, Any], x2: Dict[str, Any], key: Tuple[str, ...] = ()) → Dict[str, Any]
Apply merges defined by ShardedTensorFactories in-place.
- Parameters
x1 (StateDict) – state dict loaded from the checkpoint
x2 (ShardedStateDict) – subset of x1 (in terms of dict keys) with ShardedTensorFactory as (possibly nested) values that define how to merge objects from the x1 state dict
key (Tuple[str, ...]) – current key in a recursive call. Used only for reporting meaningful errors
- Returns
- Return type
x1 modified in-place
StateDict
- core.dist_checkpointing.mapping.is_main_replica(replica_id: Union[int, Tuple[int, ...]])
Checks if given replica_id is considered as main.
“Main” replica is: - integer 0 - or an iterable with all 0 elements
It is the application responsibility to set correct replicas for sharded tensors.
- Parameters
- Returns
- Return type
replica_id (Union[int, Tuple[int, ...]]) – replica id
True for a “main” replica
(bool)
Helpers for defining sharding for optimizer states based on existing sharding for model parameters.
- core.dist_checkpointing.optimizer.get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) → Dict[int, int]
- core.dist_checkpointing.optimizer.get_param_id_to_sharded_param_map(model_sharded_state_dict: Dict[str, Any], optim_params_iter: Iterable[torch.nn.Parameter]) → Dict[int, Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory]]
Generate mapping from optimizer state ids to model sharded parameters.
- Parameters
model_sharded_state_dict – sharded state dict with all model sharded tensors (can have any structure)
optim_params_iter – iterable which iterates over model parameters tracked by the optimizer. The iteration must be in the same order as in the optimizer parameters.
- Returns
- mapping from optimizer state ids
to model sharded parameters.
- Return type
Dict[int, Union[ShardedTensor, ShardedTensorFactory]]
- core.dist_checkpointing.optimizer.make_sharded_optimizer_tensor(model_param: Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory], optim_param: torch.Tensor, prefix: str) → Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory]
Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param
- Parameters
model_param (Union[ShardedTensor, ShardedTensorFactory]) – model param
optim_param (torch.Tensor) – corresponding optimizer param
prefix (str) – optimizer prefix for the ShardedTensor or ShardedTensorFactory
- Returns
- Return type
wrapped optimizer parameter
Union[ShardedTensor, ShardedTensorFactory]
- core.dist_checkpointing.optimizer.optim_state_to_sharding_state(optim_state_dict: Dict[str, Any], id_to_sharded_param_map: Dict[int, core.dist_checkpointing.mapping.ShardedTensor], exclude_keys: Tuple[str] = ())
Turn optimizer state dict to sharded state dict based on model state dict in-place.
Can be used to add sharding information to most common optimizer state dict. Creates separate ShardedTensors for each key in optim_state_dict[‘state’] (e.g. for torch.optim.Adam there will be separate tensors for exp_avg and exp_avg_sq)
- Parameters
optim_state_dict (StateDict) – optimizer state dict with state parameters under state key and group hyperparameters under param_groups -> params key.
id_to_sharded_param_map (Dict[int, ShardedTensor]) – mapping from optimizer param ids to model sharded tensors. Can be generated with get_param_id_to_sharded_param_map function
exclude_keys (Tuple[str]) – optimizer state keys to exclude from the final state dict.
- Returns
- Return type
state dict is modified in place
None
Module for managing distributed checkpoints metadata.
- class core.dist_checkpointing.core.CheckpointingConfig(sharded_backend: str, sharded_backend_version: int = 1, common_backend: str = 'torch', common_backend_version: int = 1)
Bases:
object
Documents backends used in the checkpoint.
Checkpoint config keeps track of formats used for storing the sharded tensors (sharded_backend) and other objects (common_backend).
Note that versioning is not for the checkpoint content (which is application specific), but for the checkpoint format itself.
- common_backend: str = 'torch'
- common_backend_version: int = 1
- sharded_backend: str
- sharded_backend_version: int = 1
- exception core.dist_checkpointing.core.CheckpointingException
Bases: Exception
Base checkpointing related exception
- core.dist_checkpointing.core.check_is_distributed_checkpoint(checkpoint_dir)
Checks if metadata.json exists in the checkpoint and is a valid config.
- Parameters
- Returns
- Return type
checkpoint_dir – checkpoint directory
True if metadata.json exists in the checkpoint and is a valid config.
bool
- core.dist_checkpointing.core.maybe_load_config(checkpoint_dir: str) → Optional[core.dist_checkpointing.core.CheckpointingConfig]
Returns checkpoint config if checkpoint_dir is a distributed checkpoint and None otherwise
- Parameters
- Returns
- Return type
checkpoint_dir – checkpoint directory
None if checkpoint is not a valid distributed checkpoint
CheckpointingConfig (optional)
- core.dist_checkpointing.core.save_config(config: core.dist_checkpointing.core.CheckpointingConfig, checkpoint_dir: str)
Save given config to checkpoint directory.
- Parameters
config – checkpoint config
checkpoint_dir – checkpoint directory
- Returns
None
Utilities for operating with dicts and lists.
All functions in this module handle nesting of dicts and lists. Other objects (e.g. tuples) are treated as atomic leaf types that cannot be traversed.
- core.dist_checkpointing.dict_utils.dict_list_map_inplace(f: Callable, x: Union[dict, list])
Maps dicts and lists in-place with a given function.
- core.dist_checkpointing.dict_utils.dict_list_map_outplace(f: Callable, x: Union[dict, list])
Maps dicts and lists out-of-place with a given function.
- core.dist_checkpointing.dict_utils.dict_map(f: Callable, d: dict)
map equivalent for dicts.
- core.dist_checkpointing.dict_utils.dict_map_with_key(f: Callable, d: dict)
map equivalent for dicts with a function that accepts tuple (key, value).
- core.dist_checkpointing.dict_utils.diff(x1: Any, x2: Any, prefix: Tuple = ()) → Tuple[list, list, list]
Recursive diff of dicts.
- Parameters
x1 (object) – left dict
x2 (object) – right dict
prefix (tuple) – tracks recursive calls. Used for reporting differing keys.
- Returns
- tuple of:
only_left: Prefixes present only in left dict
only_right: Prefixes present only in right dict
- mismatch: values present in both dicts but not equal across dicts.
For tensors equality of all elems is checked. Each element is a tuple (prefix, type of left value, type of right value).
- Return type
Tuple[list, list, list]
- core.dist_checkpointing.dict_utils.extract_matching_values(x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False) → Tuple[Union[dict, list], Union[dict, list]]
Return matching and nonmatching values. Keeps hierarchy.
- Parameters
x (Union[dict, list]) – state dict to process. Top-level argument must be a dict or list
predicate (object -> bool) – determines matching values
return_lists_as_dicts (bool) – if True, matching lists will be turned into dicts, with keys indicating the indices of original elements. Useful for reconstructing the original hierarchy.
- core.dist_checkpointing.dict_utils.inspect_types(x: Any, prefix: Tuple = (), indent: int = 4)
Helper to print types of (nested) dict values.
- core.dist_checkpointing.dict_utils.map_reduce(xs: typing.Iterable, key_fn: typing.Callable = <function <lambda>>, value_fn: typing.Callable = <function <lambda>>, reduce_fn: typing.Callable = <function <lambda>>) → dict
Simple map-reduce implementation following more_itertools.map_reduce interface.
- core.dist_checkpointing.dict_utils.merge(x1: dict, x2: dict, key: Tuple[str, ...] = ())
Merges dicts and lists recursively.
- core.dist_checkpointing.dict_utils.nested_items_iter(x: Union[dict, list])
Returns iterator over (nested) tuples (container, key, value) of a given dict or list.
- core.dist_checkpointing.dict_utils.nested_values(x: Union[dict, list])
Returns iterator over (nested) values of a given dict or list.
Helpers for manipulating sharded tensors and sharded state dicts.
- core.dist_checkpointing.utils.add_prefix_for_sharding(sharded_state_dict: Dict[str, Any], prefix: str)
Prepend a given prefix to all ShardedBase objects in a given state dict in-place.
- Parameters
sharded_state_dict (ShardedStateDict) – sharded state dict
prefix (str) – prefix to be prepended
- Returns
- Return type
state dict is modified in-place
None
- core.dist_checkpointing.utils.apply_prefix_mapping(sharded_state_dict: Dict[str, Any], prefix_map: Dict[str, str])
Replaces prefixes only in keys matching with one of prefixes in the map.
- Parameters
sharded_state_dict (ShardedStateDict) – sharded state dict to replace keys in
prefix_map (Dict[str, str]) – map of old->new prefixes. The first matching prefix for each key is used
- Returns
- Return type
state dict is modified in place
None
- core.dist_checkpointing.utils.extract_nonpersistent(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
- core.dist_checkpointing.utils.extract_sharded_base(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
- core.dist_checkpointing.utils.extract_sharded_tensors(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
Extract a dict consisting of only ShardedTensor objects from a given state dict with any objects.
- Parameters
- Returns
- tuple of:
state dict with all ShardedTensor (keeping the original state dict structure)
state dict with all objects other than ShardedTensor (keeping the original state dict structure)
- Return type
sharded_state_dict – state dict possibly containing ShardedTensor objects
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_sharded_tensors_and_factories(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects from a given state dict with any objects.
- Parameters
- Returns
- tuple of:
state dict with all ShardedTensor and ShardedTensorFactory (keeping the original state dict structure)
state dict with all other objects (keeping the original state dict structure)
- Return type
sharded_state_dict – state dict possibly containing ShardedTensor and ShardedTensorFactory objects
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_sharded_tensors_or_nonpersistent(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
Extract a dict consisting of only ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject objects from a given state dict with any objects.
- Parameters
- Returns
- tuple of:
state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject (keeping the original state dict structure)
state dict with all other objects (keeping the original state dict structure)
- Return type
sharded_state_dict – state dict possibly containing ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject objects
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.replace_prefix_for_sharding(sharded_state_dict: Dict[str, Any], old_prefix: str, new_prefix: str)
Replaces the given prefix in all sharded keys in a given state dict.
Errors out if some key does not begin with a given prefix.
- Parameters
sharded_state_dict (ShardedStateDict) – sharded state dict to replace keys in
old_prefix (str) – prefix to be replaced in each key
new_prefix (str) – new prefix
- Returns
- Return type
state dict is modified in place
None