dist_checkpointing package
A library for saving and loading the distributed checkpoints. A “distributed checkpoint” can have various underlying formats (current default format is based on Zarr) but has a distinctive property - the checkpoint saved in one parallel configuration (tensor/pipeline/data parallelism) can be loaded in a different parallel configuration.
Using the library requires defining sharded state_dict dictionaries with functions from mapping and optimizer modules. Those state dicts can be saved or loaded with a serialization module using strategies from strategies module.
Entrypoints for saving and loading the distributed checkpoints.
Functions load and save are equivalents of torch.load and torch.save but expect torch.Tensors to be wrapped with classes from the mapping module. Additionally, load expects the sharded state dict argument as a guidance for loading the sharded tensors.
- core.dist_checkpointing.serialization.get_default_load_sharded_strategy(checkpoint_dir: str) → core.dist_checkpointing.strategies.base.LoadShardedStrategy
Get default load sharded strategy.
- core.dist_checkpointing.serialization.get_default_save_common_strategy(backend: str = 'torch', version: int = 1) → core.dist_checkpointing.strategies.base.SaveCommonStrategy
Get default save common strategy.
- core.dist_checkpointing.serialization.get_default_save_sharded_strategy(backend: str = 'torch_dist', version: int = 1) → core.dist_checkpointing.strategies.base.SaveShardedStrategy
Get default save sharded strategy.
- core.dist_checkpointing.serialization.load(sharded_state_dict: Dict[str, Any], checkpoint_dir: str, sharded_strategy: Optional[Union[core.dist_checkpointing.strategies.base.LoadShardedStrategy, Tuple[str, int]]] = None, common_strategy: Optional[Union[core.dist_checkpointing.strategies.base.LoadCommonStrategy, Tuple[str, int]]] = None, validate_access_integrity: bool = True, strict: Union[str, core.dist_checkpointing.validation.StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED) → Union[Dict[str, Any], Tuple[Dict[str, Any], Set[str], Set[str]]]
Loading entrypoint.
In the steps below, the following verbs refer to corresponding objects: - load = load from checkpoint - extract = extract from sharded_state_dict - add = add to the final state dict Steps: 1. Load common state dict and form the base of the result state dict 2. Apply factories to sharded_state_dict 3. Extract LocalNonPersistentObject and add 4. (optional) Extract ShardedObjects, load and add 5. Extract ShardedBase, load, apply factory merges and add
- Parameters
sharded_state_dict (ShardedStateDict) – state dict of the existing model populated with ShardedTensors. Used as a mapping to determine which parts of global tensors stored in the checkpoint should be loaded.
checkpoint_dir (str) – directory with the checkpoint
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional) – configures loading behavior for sharded tensors
common_strategy (LoadCommonStrategy, Tuple[str, int], optional) – configures loading behavior for common data
validate_access_integrity (bool default = True) – checks if each tensor shard is accessed exactly once (as main replica) by some process
strict (StrictHandling, str, optional) – determines the behavior in case of a mismatch between the requested sharded state dict and the checkpoint. See StrictHandling docs for more details. Some values affect the return value of this function (missing and unexpected keys are returned). Defaults to True (StrictHandling.ASSUME_OK_UNEXPECTED) which doesn’t incur any performance overhead. Other recommended values are: False (StrictHandling.LOG_UNEXPECTED) which logs only unexpected keys or StrictHandling.RETURN_ALL which returns all mismatch keys.
- Returns
- in most cases only
the loaded state dict is returned. If strict flag was set to
- Return type
StateDict or Tuple[StateDict, Set[str], Set[str]]
- core.dist_checkpointing.serialization.load_common_state_dict(checkpoint_dir: pathlib.Path) → Dict[str, Any]
Load common (non-sharded) objects state dict from the checkpoint.
- Parameters
checkpoint_dir (Path) – checkpoint directory
- Returns
state dict with non-sharded objects from the checkpoint
- Return type
StateDict
- core.dist_checkpointing.serialization.load_plain_tensors(checkpoint_dir: str) → Dict[str, Any]
Load checkpoint tensors without any sharding and plain structure.
NOTE: common state dict is NOT included.
- Parameters
checkpoint_dir (str) – checkpoint directory to load the tensors from.
- Returns
checkpoint state dict containing only torch.Tensors.
- Return type
StateDict
- core.dist_checkpointing.serialization.load_sharded_metadata(checkpoint_dir: str, sharded_strategy: Optional[core.dist_checkpointing.strategies.base.LoadShardedStrategy] = None, common_strategy: Optional[core.dist_checkpointing.strategies.base.LoadCommonStrategy] = None) → Dict[str, Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedObject]]
Load sharded metadata from the checkpoint.
Similar to load_tensors_metadata, but includes also ShardedObjects.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is given, a default for a given backend is used.
- Parameters
checkpoint_dir (str) – checkpoint directory to load from
sharded_strategy (LoadShardedStrategy, optional) – sharded strategy to load metadata. Defaults to None - in this case a default load strategy for a given checkpoint type is used.
common_strategy (LoadCommonStrategy, optional) – common strategy to load metadata. Defaults to None - in this case a default load strategy for a given checkpoint type is used. This strategy won’t be used unless sharded_strategy can’t handle ShardedObjects
- Returns
- flat state dict without data describing ShardedTensors
and ShardedObjects in the checkpoint
- Return type
CkptShardedMetadata
- core.dist_checkpointing.serialization.load_tensors_metadata(checkpoint_dir: str, sharded_strategy: Optional[core.dist_checkpointing.strategies.base.LoadShardedStrategy] = None) → Dict[str, Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedObject]]
Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is given, a default for a given backend is used.
- Parameters
checkpoint_dir (str) – checkpoint directory to load from
sharded_strategy (LoadShardedStrategy, optional) – sharded strategy to load metadata. Defaults to None - in this case a default load strategy for a given checkpoint type is used.
- Returns
- flat state dict without data describing ShardedTensors
in the checkpoint
- Return type
CkptShardedMetadata
- core.dist_checkpointing.serialization.remove_sharded_tensors(checkpoint_dir: str, key_prefix: str)
determine the appropriate sharding strategy and delegate removal to the sharded strategy
- core.dist_checkpointing.serialization.save(sharded_state_dict: Dict[str, Any], checkpoint_dir: str, sharded_strategy: Optional[Union[core.dist_checkpointing.strategies.base.SaveShardedStrategy, Tuple[str, int]]] = None, common_strategy: Optional[Union[core.dist_checkpointing.strategies.base.SaveCommonStrategy, Tuple[str, int]]] = None, validate_access_integrity: bool = True, async_sharded_save: bool = False, preprocess_common_before_consistancy_check: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None) → Optional[core.dist_checkpointing.strategies.async_utils.AsyncRequest]
Saving entrypoint.
Extracts ShardedTensors from the given state dict. Rank 0 saves the “regular” part of the checkpoint to common torch file. The ShardedTensors are saved according to a strategy specified by the config.
Steps: 1. Apply factories 2. Extract and discard LocalNonPersistentObject 3. Extract all ShardedBase object 4. Save all other objects to common.pt 5. (optional) Extract and save ShardedObjects 6. Save all ShardedBase objects 7. Write metadata.json file with backend and version metadata.
Step (6) can be performed asynchronously (see async_sharded_save), in this case the actual save is embodied in the returned async request and can be scheduled by the external caller. For async request, step (7) is added as one of the finalization functions, so that metadata.json is written only if the checkpoint is complete.
- Parameters
sharded_state_dict (ShardedStateDict) – state dict of the populated with ShardedTensors. Used as a mapping to determine how local tensors should be saved as global tensors in the checkpoint.
checkpoint_dir (str) – directory to save the checkpoint to
sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional) – configures sharded tensors saving behavior and backend
common_strategy (SaveCommonStrategy, Tuple[str, int], optional) – configures common data saving behavior and backend
validate_access_integrity (bool default = True) – checks if each tensor shard is accessed exactly once (as main replica) by some process. It also makes sure the common state dict is consistant across all ranks
async_sharded_save (bool, optional) – if True, for the sharded state dict part an async save implementation will be called, with the AsyncRequest being returned to the caller. Note that it is the caller responsibility to actually schedule the async save. Defaults to False.
preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None) – A callable function that will preprocess the common state dict (i.e can be used to remove keys that we expect to be different in the state dict). The function must not modify the original state dict
- Returns
- if async_sharded_save is True, returns
async request that should be scheduled by the caller of this function. None otherwise.
- Return type
AsyncRequest (optional)
Core library classes for representing sharding of tensors and objects.
The main expected usage is wrapping torch.Tensors in state dicts with ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod).
- class core.dist_checkpointing.mapping.LocalNonpersistentObject(obj)
Bases:
object
Object that should not be stored in a checkpoint, but restored locally.
Wrapping any object inside the state dict with LocalNonpersistentObject will result in: - during saving, this object will not be stored in the checkpoint - during loading, a local version of this object will be placed in a state dict
- unwrap()
Returns the original object.
- core.dist_checkpointing.mapping.LocalNonpersitentObject
alias of
core.dist_checkpointing.mapping.LocalNonpersistentObject
- class core.dist_checkpointing.mapping.ShardedBase
Bases:
abc.ABC
Base class for ShardedTensor and ShardedStateDict.
- data: object
- key: str
- replica_id: Union[int, Tuple[int, ...]]
- abstract validate_metadata_integrity()
Codifies the constraints on metadata attributes.
- abstract without_data() → core.dist_checkpointing.mapping.ShardedBase
Returns a new ShardedBase instance with data=None.
- class core.dist_checkpointing.mapping.ShardedObject(key: str, data: object, global_shape: Tuple[int, ...], global_offset: Tuple[int, ...], replica_id: Union[int, Tuple[int, ...]] = 0)
Bases:
core.dist_checkpointing.mapping.ShardedBase
Represents a mapping between a local object and a global object.
Global object is assumed to consist of many local objects distributed between different processes.
NOTE: Contrary to ShardedTensor, it’s impossible to change global object sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor with atomic arbitrary typed elements.
- Parameters
key – unique identifier of a global tensor
data – local object data. Can be None only for consistency validation
global_shape – global object shape
global_offset – offset of a local object in a global object, specified in number of shards
replica_id – indicates local object replication wrt. local objects in different processes
- data: object
- classmethod empty_from_unique_key(unique_key, replica_id: Union[int, Tuple[int, ...]] = 0) → core.dist_checkpointing.mapping.ShardedObject
Instantiates a ShardedObject from a unique key.
- Parameters
unique_key – a string of the form <key>/shard_<global_offset>_<global_shape>
replica_id – indicates local object replication wrt. local objects in different processes
- Returns
a ShardedObject with data=None
- global_offset: Tuple[int, ...]
- global_shape: Tuple[int, ...]
- key: str
- replica_id: Union[int, Tuple[int, ...]] = 0
- property unique_key
returns a unique key for this object
- validate_metadata_integrity()
Codifies the constraints on metadata attributes.
- without_data()
Returns a new ShardedBase instance with data=None.
- class core.dist_checkpointing.mapping.ShardedTensor(key: str, data: Optional[torch.Tensor], dtype: torch.dtype, local_shape: Tuple[int, ...], global_shape: Tuple[int, ...], global_offset: Tuple[int, ...], axis_fragmentations: Optional[Tuple[int, ...]], replica_id: Union[int, Tuple[int, ...]] = 0, prepend_axis_num: int = 0, allow_shape_mismatch: bool = False, flattened_range: Optional[slice] = None)
Bases:
core.dist_checkpointing.mapping.ShardedBase
Represents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed between different processes.
- Parameters
key – unique identifier of a global tensor
data – local tensor data. Can be None only for consistency validation
dtype – tensor dtype
local_shape – local tensor shape
global_shape – global tensor shape
global_offset – offset of a local tensor in a global tensor, specified in number of tensor elements
axis_fragmentations – global tensor fragmentation of each axis
replica_id – indicates given local tensor’s replication wrt. local tensors in different processes
prepend_axis_num – number of axes prepended to the local tensor to reflect global tensor shape. The behavior is similar to unsqueezing the local tensor.
allow_shape_mismatch – if True, during loading, the global shape of a stored tensor does not have to match the expected global shape. Useful for representing tensors with flexible shape, e.g. padded.
flattened_range – specifies a slice that should be applied to a flattened tensor with local_shape in order to get the tensor stored as data
- allow_shape_mismatch: bool = False
- axis_fragmentations: Optional[Tuple[int, ...]]
- data: Optional[torch.Tensor]
- dtype: torch.dtype
- flattened_range: Optional[slice] = None
- classmethod from_rank_offsets(key: str, data: torch.Tensor, *rank_offsets: Tuple[int, int, int], replica_id: Union[int, Tuple[int, ...]] = 0, prepend_axis_num: int = 0, flattened_range: None = None, **init_kwargs)
Allows to construct the ShardedTensor given offset specified in process ranks.
- Parameters
key (str) – unique key
data (torch.Tensor) – local tensor data
rank_offsets (Tuple[int, int, int]) – each tuple (axis, axis_rank_offset, axis_fragm) says that if global tensor is divided into axis_fragm fragment along axis axis, then local tensor data corresponds to the axis_rank_offset chunk.
replica_id (ReplicaId) – see ShardedTensor
prepend_axis_num (int) – see ShardedTensor
flattened_range (None) – must be None when using this constructor
init_kwargs – passed to ShardedTensor.__init__
- classmethod from_rank_offsets_flat(key: str, data: torch.Tensor, non_flat_local_shape: Tuple[int, ...], *args, flattened_range: Optional[slice] = None, **kwargs)
Allows to construct a flattened ShardedTensor given offset specified in process ranks.
- Parameters
key (str) –
data (torch.Tensor) – this should be a flattened data tensor
non_flat_local_shape (Tuple[int, ...]) – expected local shape of a non-flat chunk
*args – passed unchanged to the from_rank_offsets constructor
flattened_range (slice) – see ShardedTensor. Defaults to None, but must be set to a non-None slice.
**kwargs –
- Returns
constructed ShardedTensor instance
- Return type
- global_coordinates() → Tuple[numpy.ndarray, ...]
Returns a tuple of np.ndarrays representing the coordinates of the global tensor that this ShardedTensor corresponds to.
- global_offset: Tuple[int, ...]
- global_shape: Tuple[int, ...]
- global_slice() → Tuple[Union[int, slice], ...]
Returns a tuple of int and slice objects representing a slice of the global tensor that this ShardedTensor corresponds to.
- init_data(device: Union[str, torch.device], init_fn=torch.empty)
Initialize the tensor data of this ShardedTensor.
Only called if data attribute is None.
- Parameters
device (Union[str, torch.device]) – device to place the tensor on
init_fn (Callable, optional) – function to use to initialize the tensor. Defaults to torch.empty.
- key: str
- local_chunk_offset_in_global() → Tuple[int, ...]
Offset of a local chunk in a global array of chunks.
- Returns
the offset of the whole local chunk in a global array of chunks.
- Return type
Tuple[int, …]
- local_coordinates() → Tuple[numpy.ndarray, ...]
Returns a tuple of np.ndarrays representing the coordinates of the local tensor that this ShardedTensor corresponds to.
- local_shape: Tuple[int, ...]
- max_allowed_chunks() → Tuple[int, ...]
Returns the maximum allowed chunks for this ShardedTensor.
- narrow(dim: int, start: int, length: int) → List[core.dist_checkpointing.mapping.ShardedTensor]
This is an analogue of torch.narrow for ShardedTensors.
Narrowing assumes that we narrow a local tensor on each rank. This has consequences on local_shape, global_shape, global_offset, etc.
- Parameters
dim (int) – dimension to narrow. Doesn’t include prepended axes.
start (int) – start element
length (int) – length of the slice
- Returns
- narrowed ShardedTensors. For non-flat tensors,
the list will always have 1 element. For flat ShardedTensors the number of elements varies depending on dim and on overlap, because flat tensors must be contiguous. In particular the list can be empty.
- Return type
List[ShardedTensor]
- prepend_axis_num: int = 0
- replica_id: Union[int, Tuple[int, ...]] = 0
- validate_metadata_integrity() → None
Codifies the constraints on metadata attributes.
Meeting those constraints is guaranteed when instantiating a ShardedTensor class with from_rank_offsets or from_rank_offsets_flat constructors.
- Returns
None
- without_data()
Returns a new ShardedBase instance with data=None.
- class core.dist_checkpointing.mapping.ShardedTensorFactory(key: str, data: torch.Tensor, build_fn: Callable[[str, torch.Tensor, Union[int, Tuple[int, ...]], Optional[slice]], Dict[str, Any]], merge_fn: Callable[[Dict[str, Any]], torch.Tensor], replica_id: Union[int, Tuple[int, ...]] = 0, flattened_range: Optional[slice] = None)
Bases:
core.dist_checkpointing.mapping.ShardedBase
Allows to apply transformations to tensors before/after serialization.
The essence of those transformations is that they can be applied to optimizer states the same way they are applied to the model params. The ultimate state dict with sharded tensors must depend functionally on build_fn arguments (key, data, replica_id, flattened_range), which will be provided by the optimizer.
Builder creates a sub-state-dict out of a tensor before saving, and merger merges the corresponding state dict after loading.
- Parameters
key (str) – unique identifier of the factory
data (torch.Tensor) – original model parameter that will be further transformed by this factory
build_fn (callable) – function that transforms the original tensor to a sharded state dict
merge_fn (callable) – function that transforms loaded subtree back into a single tensor (inverse of build_fn)
replica_id (ReplicaId) – indicates factory replication wrt. factories in different processes
flattened_range (slice, optional) – indicates additional flattening applied to the ShardedTensors produced by the factory
- build()
Builds a ShardedStateDict from the original tensor
- build_fn: Callable[[str, torch.Tensor, Union[int, Tuple[int, ...]], Optional[slice]], Dict[str, Any]]
- data: torch.Tensor
- flattened_range: Optional[slice] = None
- key: str
- merge_fn: Callable[[Dict[str, Any]], torch.Tensor]
- replica_id: Union[int, Tuple[int, ...]] = 0
- validate_metadata_integrity()
No reasonable checks can be applied
- without_data()
Returns a new ShardedBase instance with data=None.
- core.dist_checkpointing.mapping.apply_factories(sharded_state_dict: Dict[str, Any])
Turn ShardedTensorFactories into ShardedTensors in-place.
- Parameters
sharded_state_dict (ShardedStateDict) – state dict possibly containing ShardedTensorFactory objects
- Returns
state dict is modified in place
- Return type
None
- core.dist_checkpointing.mapping.apply_factory_merges(x1: Dict[str, Any], x2: Dict[str, Any], key: Tuple[str, ...] = ()) → Dict[str, Any]
Apply merges defined by ShardedTensorFactories in-place.
- Parameters
x1 (StateDict) – state dict loaded from the checkpoint
x2 (ShardedStateDict) – subset of x1 (in terms of dict keys) with ShardedTensorFactory as (possibly nested) values that define how to merge objects from the x1 state dict
key (Tuple[str, ...]) – current key in a recursive call. Used only for reporting meaningful errors
- Returns
x1 modified in-place
- Return type
StateDict
- core.dist_checkpointing.mapping.is_main_replica(replica_id: Union[int, Tuple[int, ...]])
Checks if given replica_id is considered as main.
“Main” replica is: - integer 0 - or an iterable with all 0 elements
It is the application responsibility to set correct replicas for sharded tensors.
- Parameters
replica_id (Union[int, Tuple[int, ...]]) – replica id
- Returns
True for a “main” replica
- Return type
(bool)
Helpers for defining sharding for optimizer states based on existing sharding for model parameters.
- core.dist_checkpointing.optimizer.get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) → Dict[int, int]
Generate mapping from optimizer param to optimizer state id.
- core.dist_checkpointing.optimizer.get_param_id_to_sharded_param_map(model_sharded_state_dict: Dict[str, Any], optim_params_iter: Iterable[torch.nn.Parameter]) → Dict[int, Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory]]
Generate mapping from optimizer state ids to model sharded parameters.
- Parameters
model_sharded_state_dict – sharded state dict with all model sharded tensors (can have any structure)
optim_params_iter – iterable which iterates over model parameters tracked by the optimizer. The iteration must be in the same order as in the optimizer parameters.
- Returns
- mapping from optimizer state ids
to model sharded parameters.
- Return type
Dict[int, Union[ShardedTensor, ShardedTensorFactory]]
- core.dist_checkpointing.optimizer.make_sharded_optimizer_tensor(model_param: Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory], optim_param: torch.Tensor, prefix: str) → Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory]
Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param
- Parameters
model_param (Union[ShardedTensor, ShardedTensorFactory]) – model param
optim_param (torch.Tensor) – corresponding optimizer param
prefix (str) – optimizer prefix for the ShardedTensor or ShardedTensorFactory
- Returns
wrapped optimizer parameter
- Return type
Union[ShardedTensor, ShardedTensorFactory]
- core.dist_checkpointing.optimizer.optim_state_to_sharding_state(optim_state_dict: Dict[str, Any], id_to_sharded_param_map: Dict[int, core.dist_checkpointing.mapping.ShardedTensor], exclude_keys: Tuple[str] = ())
Turn optimizer state dict to sharded state dict based on model state dict in-place.
Can be used to add sharding information to most common optimizer state dict. Creates separate ShardedTensors for each key in optim_state_dict[‘state’] (e.g. for torch.optim.Adam there will be separate tensors for exp_avg and exp_avg_sq)
- Parameters
optim_state_dict (StateDict) – optimizer state dict with state parameters under state key and group hyperparameters under param_groups -> params key.
id_to_sharded_param_map (Dict[int, ShardedTensor]) – mapping from optimizer param ids to model sharded tensors. Can be generated with get_param_id_to_sharded_param_map function.
exclude_keys (Tuple[str]) – optimizer state keys to exclude from the final state dict.
- Returns
state dict is modified in place
- Return type
None
Module for managing distributed checkpoints metadata.
- class core.dist_checkpointing.core.CheckpointingConfig(sharded_backend: str, sharded_backend_version: int = 1, common_backend: str = 'torch', common_backend_version: int = 1)
Bases:
object
Documents backends used in the checkpoint.
Checkpoint config keeps track of formats used for storing the sharded tensors (sharded_backend) and other objects (common_backend).
Note that versioning is not for the checkpoint content (which is application specific), but for the checkpoint format itself.
- common_backend: str = 'torch'
- common_backend_version: int = 1
- sharded_backend: str
- sharded_backend_version: int = 1
- exception core.dist_checkpointing.core.CheckpointingException
Bases:
Exception
Base checkpointing related exception
- core.dist_checkpointing.core.check_is_distributed_checkpoint(checkpoint_dir)
Checks if metadata.json exists in the checkpoint and is a valid config.
- Parameters
checkpoint_dir – checkpoint directory
- Returns
True if metadata.json exists in the checkpoint and is a valid config.
- Return type
bool
- core.dist_checkpointing.core.maybe_load_config(checkpoint_dir: str) → Optional[core.dist_checkpointing.core.CheckpointingConfig]
Returns checkpoint config if checkpoint_dir is a distributed checkpoint and None otherwise
- Parameters
checkpoint_dir – checkpoint directory
- Returns
None if checkpoint is not a valid distributed checkpoint
- Return type
CheckpointingConfig (optional)
- core.dist_checkpointing.core.save_config(config: core.dist_checkpointing.core.CheckpointingConfig, checkpoint_dir: str)
Save given config to checkpoint directory.
- Parameters
config – checkpoint config
checkpoint_dir – checkpoint directory
- Returns
None
Utilities for operating with dicts and lists.
All functions in this module handle nesting of dicts and lists. Other objects (e.g. tuples) are treated as atomic leaf types that cannot be traversed.
- core.dist_checkpointing.dict_utils.dict_list_map_inplace(f: Callable[[core.dist_checkpointing.dict_utils.U], core.dist_checkpointing.dict_utils.V], x: Union[Dict, List, core.dist_checkpointing.dict_utils.U])
Maps dicts and lists in-place with a given function.
- core.dist_checkpointing.dict_utils.dict_list_map_outplace(f: Callable[[core.dist_checkpointing.dict_utils.U], core.dist_checkpointing.dict_utils.V], x: Union[Dict, List, core.dist_checkpointing.dict_utils.U]) → Union[Dict, List, core.dist_checkpointing.dict_utils.V]
Maps dicts and lists out-of-place with a given function.
- core.dist_checkpointing.dict_utils.dict_map(f: Callable, d: dict)
map equivalent for dicts.
- core.dist_checkpointing.dict_utils.dict_map_with_key(f: Callable, d: dict)
map equivalent for dicts with a function that accepts tuple (key, value).
- core.dist_checkpointing.dict_utils.diff(x1: Any, x2: Any, prefix: Tuple = ()) → Tuple[list, list, list]
Recursive diff of dicts.
- Parameters
x1 (object) – left dict
x2 (object) – right dict
prefix (tuple) – tracks recursive calls. Used for reporting differing keys.
- Returns
- tuple of:
only_left: Prefixes present only in left dict
only_right: Prefixes present only in right dict
- mismatch: values present in both dicts but not equal across dicts.
For tensors equality of all elems is checked. Each element is a tuple (prefix, type of left value, type of right value).
- Return type
Tuple[list, list, list]
- core.dist_checkpointing.dict_utils.extract_matching_values(x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False) → Tuple[Union[dict, list], Union[dict, list]]
Return matching and nonmatching values. Keeps hierarchy.
- Parameters
x (Union[dict, list]) – state dict to process. Top-level argument must be a dict or list
predicate (object -> bool) – determines matching values
return_lists_as_dicts (bool) – if True, matching lists will be turned into dicts, with keys indicating the indices of original elements. Useful for reconstructing the original hierarchy.
- core.dist_checkpointing.dict_utils.inspect_types(x: Any, prefix: Tuple = (), indent: int = 4)
Helper to print types of (nested) dict values.
- core.dist_checkpointing.dict_utils.map_reduce(xs: typing.Iterable, key_fn: typing.Callable = <function <lambda>>, value_fn: typing.Callable = <function <lambda>>, reduce_fn: typing.Callable = <function <lambda>>) → dict
Simple map-reduce implementation following more_itertools.map_reduce interface.
- core.dist_checkpointing.dict_utils.merge(x1: Union[dict, list], x2: Union[dict, list], key: Tuple[Union[str, int], ...] = ())
Merges dicts and lists recursively.
- core.dist_checkpointing.dict_utils.nested_items_iter(x: Union[dict, list])
Returns iterator over (nested) tuples (container, key, value) of a given dict or list.
- core.dist_checkpointing.dict_utils.nested_values(x: Union[dict, list])
Returns iterator over (nested) values of a given dict or list.
Helpers for manipulating sharded tensors and sharded state dicts.
- core.dist_checkpointing.utils.add_prefix_for_sharding(sharded_state_dict: Dict[str, Any], prefix: str)
Prepend a given prefix to all ShardedBase objects in a given state dict in-place.
- Parameters
sharded_state_dict (ShardedStateDict) – sharded state dict
prefix (str) – prefix to be prepended
- Returns
state dict is modified in-place
- Return type
None
- core.dist_checkpointing.utils.apply_prefix_mapping(sharded_state_dict: Dict[str, Any], prefix_map: Dict[str, str])
Replaces prefixes only in keys matching with one of prefixes in the map.
- Parameters
sharded_state_dict (ShardedStateDict) – sharded state dict to replace keys in
prefix_map (Dict[str, str]) – map of old->new prefixes. The first matching prefix for each key is used
- Returns
state dict is modified in place
- Return type
None
- core.dist_checkpointing.utils.extract_nonpersistent(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
Extract a dict consisting of only LocalNonpersistentObjects from a given state dict.
- Parameters
sharded_state_dict – state dict possibly containing LocalNonpersistentObjects
- Returns
- tuple of:
state dict with all LocalNonpersistentObjects (keeping the original state dict structure)
state dict with all other objects (keeping the original state dict structure)
- Return type
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_sharded_base(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
Extract a dict consisting of only ShardedBase from a given state dict with any objects.
- Parameters
sharded_state_dict – state dict possibly containing ShardedBase objects
- Returns
- tuple of:
state dict with all ShardedBase objects (keeping the original state dict structure)
state dict with all other objects (keeping the original state dict structure)
- Return type
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_sharded_tensors(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
Extract a dict consisting of only ShardedTensor objects from a given state dict with any objects.
- Parameters
sharded_state_dict – state dict possibly containing ShardedTensor objects
- Returns
- tuple of:
state dict with all ShardedTensor (keeping the original state dict structure)
state dict with all objects other than ShardedTensor (keeping the original state dict structure)
- Return type
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_sharded_tensors_and_factories(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects from a given state dict with any objects.
- Parameters
sharded_state_dict – state dict possibly containing ShardedTensor and ShardedTensorFactory objects
- Returns
- tuple of:
state dict with all ShardedTensor and ShardedTensorFactory (keeping the original state dict structure)
state dict with all other objects (keeping the original state dict structure)
- Return type
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_sharded_tensors_or_nonpersistent(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
Extract a dict consisting of only ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject objects from a given state dict with any objects.
- Parameters
sharded_state_dict – state dict possibly containing ShardedTensor, ShardedTensorFactory
objects (and LocalNonpersistentObject) –
- Returns
- tuple of:
state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject (keeping the original state dict structure)
state dict with all other objects (keeping the original state dict structure)
- Return type
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.replace_prefix_for_sharding(sharded_state_dict: Dict[str, Any], old_prefix: str, new_prefix: str)
Replaces the given prefix in all sharded keys in a given state dict.
Errors out if some key does not begin with a given prefix.
- Parameters
sharded_state_dict (ShardedStateDict) – sharded state dict to replace keys in
old_prefix (str) – prefix to be replaced in each key
new_prefix (str) – new prefix
- Returns
state dict is modified in place
- Return type
None