A library for saving and loading the distributed checkpoints. A “distributed checkpoint” can have various underlying formats (current default format is based on Zarr) but has a distinctive property - the checkpoint saved in one parallel configuration (tensor/pipeline/data parallelism) can be loaded in a different parallel configuration.

Using the library requires defining sharded state_dict dictionaries with functions from mapping and optimizer modules. Those state dicts can be saved or loaded with a serialization module using strategies from strategies module.

Entrypoints for saving and loading the distributed checkpoints.

Functions load and save are equivalents of torch.load and torch.save but expect torch.Tensors to be wrapped with classes from the mapping module. Additionally, load expects the sharded state dict argument as a guidance for loading the sharded tensors.

core.dist_checkpointing.serialization.load(sharded_state_dict: Dict[str, Any], checkpoint_dir: str, sharded_strategy: Optional[Union[core.dist_checkpointing.strategies.base.LoadShardedStrategy, Tuple[str, int]]] = None, common_strategy: Optional[Union[core.dist_checkpointing.strategies.base.LoadCommonStrategy, Tuple[str, int]]] = None, validate_access_integrity: bool = True) → Dict[str, Any]

Loading entrypoint.

In the steps below, the following verbs refer to corresponding objects: - load = load from checkpoint - extract = extract from sharded_state_dict - add = add to the final state dict Steps: 1. Load common state dict and form the base of the result state dict 2. Apply factories to sharded_state_dict 3. Extract LocalNonPersistentObject and add 4. (optional) Extract ShardedObjects, load and add 5. Extract ShardedBase, load, apply factory merges and add

  • sharded_state_dict (ShardedStateDict) – state dict of the existing model populated with ShardedTensors. Used as a mapping to determine which parts of global tensors stored in the checkpoint should be loaded.

  • checkpoint_dir (str) – directory with the checkpoint

  • sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional) – configures loading behavior for sharded tensors

  • common_strategy (LoadCommonStrategy, Tuple[str, int], optional) – configures loading behavior for common data

  • validate_access_integrity (bool default = True) – checks if each tensor shard is accessed exactly once (as main replica) by some process

core.dist_checkpointing.serialization.load_common_state_dict(checkpoint_dir: pathlib.Path) → Dict[str, Any]

Load common (non-sharded) objects state dict from the checkpoint.


checkpoint_dir (Path) – checkpoint directory


state dict with non-sharded objects from the checkpoint

Return type


core.dist_checkpointing.serialization.load_plain_tensors(checkpoint_dir: str)

Load checkpoint tensors without any sharding.

NOTE: common state dict is NOT included.

core.dist_checkpointing.serialization.load_sharded_objects(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

Replaces all ShardedObject from a given state dict with values loaded from the checkpoint.

  • sharded_state_dict (ShardedStateDict) – sharded state dict defining what objects should be loaded.

  • checkpoint_dir (Path) – checkpoint directory


state dict is modified in place

Return type


core.dist_checkpointing.serialization.load_tensors_metadata(checkpoint_dir: str, sharded_strategy: Optional[core.dist_checkpointing.strategies.base.LoadShardedStrategy] = None) → Dict[str, Any]

Load tensors metadata from the checkpoint.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).

Concrete implementation depends on the loading strategy. If no strategy is given, a default for a given backend is used.

core.dist_checkpointing.serialization.save(sharded_state_dict: Dict[str, Any], checkpoint_dir: str, sharded_strategy: Optional[Union[core.dist_checkpointing.strategies.base.SaveShardedStrategy, Tuple[str, int]]] = None, common_strategy: Optional[Union[core.dist_checkpointing.strategies.base.SaveCommonStrategy, Tuple[str, int]]] = None, validate_access_integrity: bool = True) → None

Saving entrypoint.

Extracts ShardedTensors from the given state dict. Rank 0 saves the “regular” part of the checkpoint to common torch file. The ShardedTensors are saved according to a strategy specified by the config.

Steps: 1. Apply factories 2. Extract and discard LocalNonPersistentObject 3. Extract all ShardedBase object 4. Save all other objects to common.pt 5. (optional) Extract and save ShardedObjects 6. Save all ShardedBase objects

  • sharded_state_dict (ShardedStateDict) – state dict of the populated with ShardedTensors. Used as a mapping to determine how local tensors should be saved as global tensors in the checkpoint.

  • checkpoint_dir (str) – directory to save the checkpoint to

  • sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional) – configures sharded tensors saving behavior and backend

  • common_strategy (SaveCommonStrategy, Tuple[str, int], optional) – configures common data saving behavior and backend

  • validate_access_integrity (bool default = True) – checks if each tensor shard is accessed exactly once (as main replica) by some process

core.dist_checkpointing.serialization.validate_sharding_integrity(sharded_tensors: Iterable[core.dist_checkpointing.mapping.ShardedTensor])

Validate if the ShardedTensors from multiple processes define correct sharding of a global tensor.

Local ShardedTensors metadata is exchanged with torch.distributed.all_gather_object and then process with global rank 0 checks if main replicas of the shards: - cover the whole global tensors - don’t overlap


sharded_tensors (Iterable[ShardedTensor]) – sharded tensors local to this process




CheckpointingException for invalid access pattern

Core library classes for representing sharding of tensors and objects.

The main expected usage is wrapping torch.Tensors in state dicts with ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod).

class core.dist_checkpointing.mapping.LocalNonpersitentObject(obj)

Bases: object

Object that should not be stored in a checkpoint, but restored locally.

Wrapping any object inside the state dict with LocalNonpersitentObject will result in: - during saving, this object will not be stored in the checkpoint - during loading, a local version of this object will be placed in a state dict


class core.dist_checkpointing.mapping.ShardedBase

Bases: abc.ABC

data: object

key: str

replica_id: Union[int, Tuple[int, ...]]

class core.dist_checkpointing.mapping.ShardedObject(key: str, data: object, global_shape: Tuple[int, ...], global_offset: Tuple[int, ...], replica_id: Union[int, Tuple[int, ...]] = 0)

Bases: core.dist_checkpointing.mapping.ShardedBase

Represents a mapping between a local object and a global object.

Global object is assumed to consist of many local objects distributed between different processes.

NOTE: Contrary to ShardedTensor, it’s impossible to change global object sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor with atomic arbitrary typed elements.

  • key – unique identifier of a global tensor

  • data – local object data. Can be None only for consistency validation

  • global_shape – global object shape

  • global_offset – offset of a local object in a global object, specified in number of shards

  • replica_id – indicates local object replication wrt. local objects in different processes

data: object

global_offset: Tuple[int, ...]

global_shape: Tuple[int, ...]

key: str

replica_id: Union[int, Tuple[int, ...]] = 0

property unique_key


class core.dist_checkpointing.mapping.ShardedTensor(key: str, data: Optional[torch.Tensor], dtype: torch.dtype, local_shape: Tuple[int, ...], global_shape: Tuple[int, ...], global_offset: Tuple[int, ...], axis_fragmentations: Optional[Tuple[int, ...]], replica_id: Union[int, Tuple[int, ...]] = 0, prepend_axis_num: int = 0, allow_shape_mismatch: bool = False, flattened_range: Optional[slice] = None)

Bases: core.dist_checkpointing.mapping.ShardedBase

Represents a mapping between a local tensor and a global tensor.

Global tensor is assumed to consist of many local tensors distributed between different processes.

  • key – unique identifier of a global tensor

  • data – local tensor data. Can be None only for consistency validation

  • dtype – tensor dtype

  • local_shape – local tensor shape

  • global_shape – global tensor shape

  • global_offset – offset of a local tensor in a global tensor, specified in number of tensor elements

  • axis_fragmentations – global tensor fragmentation of each axis

  • replica_id – indicates given local tensor’s replication wrt. local tensors in different processes

  • prepend_axis_num – number of axes prepended to the local tensor to reflect global tensor shape. The behavior is similar to unsqueezing the local tensor.

  • allow_shape_mismatch – if True, during loading, the global shape of a stored tensor does not have to match the expected global shape. Useful for representing tensors with flexible shape, e.g. padded.

  • flattened_range – specifies a slice that should be applied to a flattened tensor with local_shape in order to get the tensor stored as data

allow_shape_mismatch: bool = False

axis_fragmentations: Optional[Tuple[int, ...]]

data: Optional[torch.Tensor]

dtype: torch.dtype

flattened_range: Optional[slice] = None

classmethod from_rank_offsets(key: str, data: torch.Tensor, *rank_offsets: Tuple[int, int, int], replica_id: Union[int, Tuple[int, ...]] = 0, prepend_axis_num: int = 0, allow_shape_mismatch: bool = False)

Allows to construct the ShardedTensor given offset specified in process ranks.

  • key – unique key

  • data – local tensor data

  • rank_offsets – each tuple (axis, axis_rank_offset, axis_fragm) says that if global tensor is divided into axis_fragm fragment along axis axis, then local tensor data corresponds to the axis_rank_offset chunk.

  • replica_id – see ShardedTensor

  • prepend_axis_num – see ShardedTensor

  • allow_shape_mismatch – see ShardedTensor

global_coordinates() → Tuple[numpy.ndarray, ...]

global_offset: Tuple[int, ...]

global_shape: Tuple[int, ...]

global_slice() → Tuple[Union[int, slice], ...]

init_data(device: torch.device, init_fn=torch.empty)

key: str

local_coordinates() → Tuple[numpy.ndarray, ...]

local_shape: Tuple[int, ...]

max_allowed_chunks() → Tuple[int, ...]

prepend_axis_num: int = 0

replica_id: Union[int, Tuple[int, ...]] = 0


class core.dist_checkpointing.mapping.ShardedTensorFactory(key: str, data: torch.Tensor, build_fn: Callable[[str, torch.Tensor, Union[int, Tuple[int, ...]]], Dict[str, Any]], merge_fn: Callable[[Dict[str, Any]], torch.Tensor], replica_id: Union[int, Tuple[int, ...]] = 0)

Bases: core.dist_checkpointing.mapping.ShardedBase

Allows to apply transformations to tensors before/after serialization.

The essence of those transformations is that they can be applied to optimizer states the same way they are applied to the model params.

Builder creates a sub-state-dict out of a tensor before saving, and merger merges the corresponding state dict after loading.

  • key (str) – unique identifier of the factory

  • data (torch.Tensor) – original model parameter that will be further transformed by this factory

  • build_fn (callable) – function that transforms the original tensor to a sharded state dict

  • merge_fn (callable) – function that transforms loaded subtree back into a single tensor (inverse of build_fn)

  • replica_id (ReplicaId) – indicates factory replication wrt. factories in different processes


build_fn: Callable[[str, torch.Tensor, Union[int, Tuple[int, ...]]], Dict[str, Any]]

data: torch.Tensor

key: str

merge_fn: Callable[[Dict[str, Any]], torch.Tensor]

replica_id: Union[int, Tuple[int, ...]] = 0

core.dist_checkpointing.mapping.apply_factories(sharded_state_dict: Dict[str, Any])

Turn ShardedTensorFactories into ShardedTensors in-place.


sharded_state_dict (ShardedStateDict) – state dict possibly containing ShardedTensorFactory objects


state dict is modified in place

Return type


core.dist_checkpointing.mapping.apply_factory_merges(x1: Dict[str, Any], x2: Dict[str, Any], key: Tuple[str, ...] = ()) → Dict[str, Any]

Apply merges defined by ShardedTensorFactories in-place.

  • x1 (StateDict) – state dict loaded from the checkpoint

  • x2 (ShardedStateDict) – subset of x1 (in terms of dict keys) with ShardedTensorFactory as (possibly nested) values that define how to merge objects from the x1 state dict

  • key (Tuple[str, ...]) – current key in a recursive call. Used only for reporting meaningful errors


x1 modified in-place

Return type


core.dist_checkpointing.mapping.is_main_replica(replica_id: Union[int, Tuple[int, ...]])

Checks if given replica_id is considered as main.

“Main” replica is: - integer 0 - or an iterable with all 0 elements

It is the application responsibility to set correct replicas for sharded tensors.


replica_id (Union[int, Tuple[int, ...]]) – replica id


True for a “main” replica

Return type


Helpers for defining sharding for optimizer states based on existing sharding for model parameters.

core.dist_checkpointing.optimizer.get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) → Dict[int, int]

core.dist_checkpointing.optimizer.get_param_id_to_sharded_param_map(model_sharded_state_dict: Dict[str, Any], optim_params_iter: Iterable[torch.nn.Parameter]) → Dict[int, Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory]]

Generate mapping from optimizer state ids to model sharded parameters.

  • model_sharded_state_dict – sharded state dict with all model sharded tensors (can have any structure)

  • optim_params_iter – iterable which iterates over model parameters tracked by the optimizer. The iteration must be in the same order as in the optimizer parameters.


mapping from optimizer state ids

to model sharded parameters.

Return type

Dict[int, Union[ShardedTensor, ShardedTensorFactory]]

core.dist_checkpointing.optimizer.make_sharded_optimizer_tensor(model_param: Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory], optim_param: torch.Tensor, prefix: str) → Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory]

Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param

  • model_param (Union[ShardedTensor, ShardedTensorFactory]) – model param

  • optim_param (torch.Tensor) – corresponding optimizer param

  • prefix (str) – optimizer prefix for the ShardedTensor or ShardedTensorFactory


wrapped optimizer parameter

Return type

Union[ShardedTensor, ShardedTensorFactory]

core.dist_checkpointing.optimizer.optim_state_to_sharding_state(optim_state_dict: Dict[str, Any], id_to_sharded_param_map: Dict[int, core.dist_checkpointing.mapping.ShardedTensor], exclude_keys: Tuple[str] = ())

Turn optimizer state dict to sharded state dict based on model state dict in-place.

Can be used to add sharding information to most common optimizer state dict. Creates separate ShardedTensors for each key in optim_state_dict[‘state’] (e.g. for torch.optim.Adam there will be separate tensors for exp_avg and exp_avg_sq)

  • optim_state_dict (StateDict) – optimizer state dict with state parameters under state key and group hyperparameters under param_groups -> params key.

  • id_to_sharded_param_map (Dict[int, ShardedTensor]) – mapping from optimizer param ids to model sharded tensors. Can be generated with get_param_id_to_sharded_param_map function

  • exclude_keys (Tuple[str]) – optimizer state keys to exclude from the final state dict.


state dict is modified in place

Return type


Module for managing distributed checkpoints metadata.

class core.dist_checkpointing.core.CheckpointingConfig(sharded_backend: str, sharded_backend_version: int = 1, common_backend: str = 'torch', common_backend_version: int = 1)

Bases: object

Documents backends used in the checkpoint.

Checkpoint config keeps track of formats used for storing the sharded tensors (sharded_backend) and other objects (common_backend).

Note that versioning is not for the checkpoint content (which is application specific), but for the checkpoint format itself.

common_backend: str = 'torch'

common_backend_version: int = 1

sharded_backend: str

sharded_backend_version: int = 1

exception core.dist_checkpointing.core.CheckpointingException

Bases: Exception

Base checkpointing related exception


Checks if metadata.json exists in the checkpoint and is a valid config.


checkpoint_dir – checkpoint directory


True if metadata.json exists in the checkpoint and is a valid config.

Return type


core.dist_checkpointing.core.maybe_load_config(checkpoint_dir: str) → Optional[core.dist_checkpointing.core.CheckpointingConfig]

Returns checkpoint config if checkpoint_dir is a distributed checkpoint and None otherwise


checkpoint_dir – checkpoint directory


None if checkpoint is not a valid distributed checkpoint

Return type

CheckpointingConfig (optional)

core.dist_checkpointing.core.save_config(config: core.dist_checkpointing.core.CheckpointingConfig, checkpoint_dir: str)

Save given config to checkpoint directory.

  • config – checkpoint config

  • checkpoint_dir – checkpoint directory



Utilities for operating with dicts and lists.

All functions in this module handle nesting of dicts and lists. Other objects (e.g. tuples) are treated as atomic leaf types that cannot be traversed.

core.dist_checkpointing.dict_utils.dict_list_map_inplace(f: Callable, x: Union[dict, list])

Maps dicts and lists in-place with a given function.

core.dist_checkpointing.dict_utils.dict_list_map_outplace(f: Callable, x: Union[dict, list])

Maps dicts and lists out-of-place with a given function.

core.dist_checkpointing.dict_utils.dict_map(f: Callable, d: dict)

map equivalent for dicts.

core.dist_checkpointing.dict_utils.dict_map_with_key(f: Callable, d: dict)

map equivalent for dicts with a function that accepts tuple (key, value).

core.dist_checkpointing.dict_utils.diff(x1: Any, x2: Any, prefix: Tuple = ()) → Tuple[list, list, list]

Recursive diff of dicts.

  • x1 (object) – left dict

  • x2 (object) – right dict

  • prefix (tuple) – tracks recursive calls. Used for reporting differing keys.


tuple of:
  • only_left: Prefixes present only in left dict

  • only_right: Prefixes present only in right dict

  • mismatch: values present in both dicts but not equal across dicts.

    For tensors equality of all elems is checked. Each element is a tuple (prefix, type of left value, type of right value).

Return type

Tuple[list, list, list]

core.dist_checkpointing.dict_utils.extract_matching_values(x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False) → Tuple[Union[dict, list], Union[dict, list]]

Return matching and nonmatching values. Keeps hierarchy.

  • x (Union[dict, list]) – state dict to process. Top-level argument must be a dict or list

  • predicate (object -> bool) – determines matching values

  • return_lists_as_dicts (bool) – if True, matching lists will be turned into dicts, with keys indicating the indices of original elements. Useful for reconstructing the original hierarchy.

core.dist_checkpointing.dict_utils.inspect_types(x: Any, prefix: Tuple = (), indent: int = 4)

Helper to print types of (nested) dict values.

core.dist_checkpointing.dict_utils.map_reduce(xs: typing.Iterable, key_fn: typing.Callable = <function <lambda>>, value_fn: typing.Callable = <function <lambda>>, reduce_fn: typing.Callable = <function <lambda>>) → dict

Simple map-reduce implementation following more_itertools.map_reduce interface.

core.dist_checkpointing.dict_utils.merge(x1: dict, x2: dict, key: Tuple[str, ...] = ())

Merges dicts and lists recursively.

core.dist_checkpointing.dict_utils.nested_items_iter(x: Union[dict, list])

Returns iterator over (nested) tuples (container, key, value) of a given dict or list.

core.dist_checkpointing.dict_utils.nested_values(x: Union[dict, list])

Returns iterator over (nested) values of a given dict or list.

Helpers for manipulating sharded tensors and sharded state dicts.

core.dist_checkpointing.utils.add_prefix_for_sharding(sharded_state_dict: Dict[str, Any], prefix: str)

Prepend a given prefix to all ShardedBase objects in a given state dict in-place.

  • sharded_state_dict (ShardedStateDict) – sharded state dict

  • prefix (str) – prefix to be prepended


state dict is modified in-place

Return type


core.dist_checkpointing.utils.apply_prefix_mapping(sharded_state_dict: Dict[str, Any], prefix_map: Dict[str, str])

Replaces prefixes only in keys matching with one of prefixes in the map.

  • sharded_state_dict (ShardedStateDict) – sharded state dict to replace keys in

  • prefix_map (Dict[str, str]) – map of old->new prefixes. The first matching prefix for each key is used


state dict is modified in place

Return type


core.dist_checkpointing.utils.extract_nonpersistent(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]

core.dist_checkpointing.utils.extract_sharded_base(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]

core.dist_checkpointing.utils.extract_sharded_tensors(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]

Extract a dict consisting of only ShardedTensor objects from a given state dict with any objects.


sharded_state_dict – state dict possibly containing ShardedTensor objects


tuple of:
  • state dict with all ShardedTensor (keeping the original state dict structure)

  • state dict with all objects other than ShardedTensor (keeping the original state dict structure)

Return type

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_sharded_tensors_and_factories(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]

Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects from a given state dict with any objects.


sharded_state_dict – state dict possibly containing ShardedTensor and ShardedTensorFactory objects


tuple of:
  • state dict with all ShardedTensor and ShardedTensorFactory (keeping the original state dict structure)

  • state dict with all other objects (keeping the original state dict structure)

Return type

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_sharded_tensors_or_nonpersistent(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]

Extract a dict consisting of only ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject objects from a given state dict with any objects.


sharded_state_dict – state dict possibly containing ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject objects


tuple of:
  • state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject (keeping the original state dict structure)

  • state dict with all other objects (keeping the original state dict structure)

Return type

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.replace_prefix_for_sharding(sharded_state_dict: Dict[str, Any], old_prefix: str, new_prefix: str)

Replaces the given prefix in all sharded keys in a given state dict.

Errors out if some key does not begin with a given prefix.

  • sharded_state_dict (ShardedStateDict) – sharded state dict to replace keys in

  • old_prefix (str) – prefix to be replaced in each key

  • new_prefix (str) – new prefix


state dict is modified in place

Return type


