dist_checkpointing package#

A library for saving and loading the distributed checkpoints. A “distributed checkpoint” can have various underlying formats (current default format is based on Zarr) but has a distinctive property - the checkpoint saved in one parallel configuration (tensor/pipeline/data parallelism) can be loaded in a different parallel configuration.

Using the library requires defining sharded state_dict dictionaries with functions from mapping and optimizer modules. Those state dicts can be saved or loaded with a serialization module using strategies from strategies module.

Safe Checkpoint Loading#

Since PyTorch 2.6, the default behavior of torch.load is weights_only=True. This ensures that only tensors and allow-listed classes are loaded, reducing the risk of arbitrary code execution.

If you encounter an error such as:

WeightsUnpickler error: Unsupported global: GLOBAL argparse.Namespace was not an allowed global by default.

you can fix it by explicitly allow-listing the missing class in your script:

import torch, argparse

torch.serialization.add_safe_globals([argparse.Namespace])

Subpackages#

Submodules#

dist_checkpointing.serialization module#

Entrypoints for saving and loading the distributed checkpoints.

Functions load and save are equivalents of torch.load and torch.save but expect torch.Tensors to be wrapped with classes from the mapping module. Additionally, load expects the sharded state dict argument as a guidance for loading the sharded tensors.

core.dist_checkpointing.serialization.get_default_load_sharded_strategy(
checkpoint_dir: str,
) LoadShardedStrategy#

Get default load sharded strategy.

core.dist_checkpointing.serialization.get_default_save_common_strategy(
backend: str = 'torch',
version: int = 1,
) SaveCommonStrategy#

Get default save common strategy.

core.dist_checkpointing.serialization.get_default_save_sharded_strategy(
backend: str = 'torch_dist',
version: int = 1,
) SaveShardedStrategy#

Get default save sharded strategy.

core.dist_checkpointing.serialization.load(
sharded_state_dict: Dict[str, Any],
checkpoint_dir: str,
sharded_strategy: LoadShardedStrategy | Tuple[str, int] | None = None,
common_strategy: LoadCommonStrategy | Tuple[str, int] | None = None,
validate_access_integrity: bool = True,
strict: str | StrictHandling = StrictHandling.ASSUME_OK_UNEXPECTED,
) Dict[str, Any] | Tuple[Dict[str, Any], Set[str], Set[str]]#

Loading entrypoint.

In the steps below, the following verbs refer to corresponding objects: - load = load from checkpoint - extract = extract from sharded_state_dict - add = add to the final state dict Steps: 1. Load common state dict and form the base of the result state dict 2. Apply factories to sharded_state_dict 3. Extract LocalNonPersistentObject and add 4. (optional) Extract ShardedObjects, load and add 5. Extract ShardedBase, load, apply factory merges and add

Parameters:
  • sharded_state_dict (ShardedStateDict) – state dict of the existing model populated with ShardedTensors. Used as a mapping to determine which parts of global tensors stored in the checkpoint should be loaded.

  • checkpoint_dir (str) – directory with the checkpoint

  • sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional) – configures loading behavior for sharded tensors

  • common_strategy (LoadCommonStrategy, Tuple[str, int], optional) – configures loading behavior for common data

  • validate_access_integrity (bool default = True) – checks if each tensor shard is accessed exactly once (as main replica) by some process

  • strict (StrictHandling, str, optional) – determines the behavior in case of a mismatch between the requested sharded state dict and the checkpoint. See StrictHandling docs for more details. Some values affect the return value of this function (missing and unexpected keys are returned). Defaults to True (StrictHandling.ASSUME_OK_UNEXPECTED) which doesn’t incur any performance overhead. Other recommended values are: False (StrictHandling.LOG_UNEXPECTED) which logs only unexpected keys or StrictHandling.RETURN_ALL which returns all mismatch keys.

Returns:

in most cases only

the loaded state dict is returned. If strict flag was set to

Return type:

StateDict or Tuple[StateDict, Set[str], Set[str]]

core.dist_checkpointing.serialization.load_common_state_dict(
checkpoint_dir: str | Path,
) Dict[str, Any]#

Load common (non-sharded) objects state dict from the checkpoint.

Parameters:

checkpoint_dir (str) – checkpoint directory

Returns:

state dict with non-sharded objects from the checkpoint

Return type:

StateDict

core.dist_checkpointing.serialization.load_content_metadata(
checkpoint_dir: str | None = None,
*,
preloaded_state_dict: Dict[str, Any] | None = None,
) dict | None#

Load content metadata stored in the checkpoint with save(…, content_metadata=…).

Parameters:
  • checkpoint_dir (str, optional) – checkpoint directory to load the content metadata from.

  • preloaded_state_dict (StateDict, optional) – if the state dict was already loaded, can be provided to avoid double load from storage

Returns:

checkpoint content metadata None: in case there is no content metadata in the checkpoint

Return type:

dict

core.dist_checkpointing.serialization.load_plain_tensors(
checkpoint_dir: str,
) Dict[str, Any]#

Load checkpoint tensors without any sharding and plain structure.

NOTE: common state dict is NOT included.

Parameters:

checkpoint_dir (str) – checkpoint directory to load the tensors from.

Returns:

checkpoint state dict containing only torch.Tensors.

Return type:

StateDict

core.dist_checkpointing.serialization.load_sharded_metadata(
checkpoint_dir: str,
sharded_strategy: LoadShardedStrategy | None = None,
common_strategy: LoadCommonStrategy | None = None,
) Dict[str, ShardedTensor | ShardedObject]#

Load sharded metadata from the checkpoint.

Similar to load_tensors_metadata, but includes also ShardedObjects.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).

Concrete implementation depends on the loading strategy. If no strategy is given, a default for a given backend is used.

Parameters:
  • checkpoint_dir (str) – checkpoint directory to load from

  • sharded_strategy (LoadShardedStrategy, optional) – sharded strategy to load metadata. Defaults to None - in this case a default load strategy for a given checkpoint type is used.

  • common_strategy (LoadCommonStrategy, optional) – common strategy to load metadata. Defaults to None - in this case a default load strategy for a given checkpoint type is used. This strategy won’t be used unless sharded_strategy can’t handle ShardedObjects

Returns:

flat state dict without data describing ShardedTensors

and ShardedObjects in the checkpoint

Return type:

CkptShardedMetadata

core.dist_checkpointing.serialization.load_tensors_metadata(
checkpoint_dir: str,
sharded_strategy: LoadShardedStrategy | None = None,
) Dict[str, ShardedTensor | ShardedObject]#

Load tensors metadata from the checkpoint.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).

Concrete implementation depends on the loading strategy. If no strategy is given, a default for a given backend is used.

Parameters:
  • checkpoint_dir (str) – checkpoint directory to load from

  • sharded_strategy (LoadShardedStrategy, optional) – sharded strategy to load metadata. Defaults to None - in this case a default load strategy for a given checkpoint type is used.

Returns:

flat state dict without data describing ShardedTensors

in the checkpoint

Return type:

CkptShardedMetadata

core.dist_checkpointing.serialization.remove_sharded_tensors(checkpoint_dir: str, key_prefix: str)#

determine the appropriate sharding strategy and delegate removal to the sharded strategy

core.dist_checkpointing.serialization.save(
sharded_state_dict: Dict[str, Any],
checkpoint_dir: str,
sharded_strategy: SaveShardedStrategy | Tuple[str, int] | None = None,
common_strategy: SaveCommonStrategy | Tuple[str, int] | None = None,
validate_access_integrity: bool = True,
async_sharded_save: bool = False,
preprocess_common_before_consistancy_check: Callable[[Dict[str, Any]], Dict[str, Any]] | None = None,
content_metadata: dict | None = None,
) AsyncRequest | None#

Saving entrypoint.

Extracts ShardedTensors from the given state dict. Rank 0 saves the “regular” part of the checkpoint to common torch file. The ShardedTensors are saved according to a strategy specified by the config.

Steps: 1. Apply factories 2. Extract and discard LocalNonPersistentObject 3. Extract all ShardedBase object 4. Save all other objects to common.pt 5. (optional) Extract and save ShardedObjects 6. Save all ShardedBase objects 7. Write metadata.json file with backend and version metadata.

Step (6) can be performed asynchronously (see async_sharded_save), in this case the actual save is embodied in the returned async request and can be scheduled by the external caller. For async request, step (7) is added as one of the finalization functions, so that metadata.json is written only if the checkpoint is complete.

Parameters:
  • sharded_state_dict (ShardedStateDict) – state dict of the populated with ShardedTensors. Used as a mapping to determine how local tensors should be saved as global tensors in the checkpoint.

  • checkpoint_dir (str) – directory to save the checkpoint to

  • sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional) – configures sharded tensors saving behavior and backend

  • common_strategy (SaveCommonStrategy, Tuple[str, int], optional) – configures common data saving behavior and backend

  • validate_access_integrity (bool default = True) – checks if each tensor shard is accessed exactly once (as main replica) by some process. It also makes sure the common state dict is consistant across all ranks

  • async_sharded_save (bool, optional) – if True, for the sharded state dict part an async save implementation will be called, with the AsyncRequest being returned to the caller. Note that it is the caller responsibility to actually schedule the async save. Defaults to False.

  • preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None) – A callable function that will preprocess the common state dict (i.e can be used to remove keys that we expect to be different in the state dict). The function must not modify the original state dict

  • content_metadata (dict, optional) – metadata to identify the checkpoint content. Useful for framework specific versioning.

Returns:

if async_sharded_save is True, returns

async request that should be scheduled by the caller of this function. None otherwise.

Return type:

AsyncRequest (optional)

dist_checkpointing.mapping module#

Core library classes for representing sharding of tensors and objects.

The main expected usage is wrapping torch.Tensors in state dicts with ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod).

class core.dist_checkpointing.mapping.LocalNonpersistentObject(obj)#

Bases: object

Object that should not be stored in a checkpoint, but restored locally.

Wrapping any object inside the state dict with LocalNonpersistentObject will result in: - during saving, this object will not be stored in the checkpoint - during loading, a local version of this object will be placed in a state dict

unwrap()#

Returns the original object.

class core.dist_checkpointing.mapping.ShardedBase#

Bases: ABC

Base class for ShardedTensor and ShardedStateDict.

data: object#
key: str#
replica_id: int | Tuple[int, ...]#
abstract validate_metadata_integrity()#

Codifies the constraints on metadata attributes.

abstract without_data() ShardedBase#

Returns a new ShardedBase instance with data=None.

class core.dist_checkpointing.mapping.ShardedObject(
key: str,
data: object,
global_shape: Tuple[int, ...],
global_offset: Tuple[int, ...],
replica_id: int | Tuple[int, ...] = 0,
)#

Bases: ShardedBase

Represents a mapping between a local object and a global object.

Global object is assumed to consist of many local objects distributed between different processes.

NOTE: Contrary to ShardedTensor, it’s impossible to change global object sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor with atomic arbitrary typed elements.

Parameters:
  • key – unique identifier of a global tensor

  • data – local object data. Can be None only for consistency validation

  • global_shape – global object shape

  • global_offset – offset of a local object in a global object, specified in number of shards

  • replica_id – indicates local object replication wrt. local objects in different processes

data: object#
classmethod empty_from_unique_key(
unique_key,
replica_id: int | Tuple[int, ...] = 0,
) ShardedObject#

Instantiates a ShardedObject from a unique key.

Parameters:
  • unique_key – a string of the form <key>/shard_<global_offset>_<global_shape>

  • replica_id – indicates local object replication wrt. local objects in different processes

Returns:

a ShardedObject with data=None

global_offset: Tuple[int, ...]#
global_shape: Tuple[int, ...]#
key: str#
replica_id: int | Tuple[int, ...] = 0#
property unique_key#

returns a unique key for this object

validate_metadata_integrity()#

Codifies the constraints on metadata attributes.

without_data()#

Returns a new ShardedBase instance with data=None.

class core.dist_checkpointing.mapping.ShardedTensor(
key: str,
data: torch.Tensor | None,
dtype: torch.dtype,
local_shape: Tuple[int, ...],
global_shape: Tuple[int, ...],
global_offset: Tuple[int, ...],
axis_fragmentations: Tuple[int, ...] | None,
replica_id: int | Tuple[int, ...] = 0,
prepend_axis_num: int = 0,
allow_shape_mismatch: bool = False,
flattened_range: slice | None = None,
)#

Bases: ShardedBase

Represents a mapping between a local tensor and a global tensor.

Global tensor is assumed to consist of many local tensors distributed between different processes.

Parameters:
  • key – unique identifier of a global tensor

  • data – local tensor data. Can be None only for consistency validation

  • dtype – tensor dtype

  • local_shape – local tensor shape

  • global_shape – global tensor shape

  • global_offset – offset of a local tensor in a global tensor, specified in number of tensor elements

  • axis_fragmentations – global tensor fragmentation of each axis

  • replica_id – indicates given local tensor’s replication wrt. local tensors in different processes

  • prepend_axis_num – number of axes prepended to the local tensor to reflect global tensor shape. The behavior is similar to unsqueezing the local tensor.

  • allow_shape_mismatch – if True, during loading, the global shape of a stored tensor does not have to match the expected global shape. Useful for representing tensors with flexible shape, e.g. padded.

  • flattened_range – specifies a slice that should be applied to a flattened tensor with local_shape in order to get the tensor stored as data

allow_shape_mismatch: bool = False#
axis_fragmentations: Tuple[int, ...] | None#
data: torch.Tensor | None#
dtype: torch.dtype#
flattened_range: slice | None = None#
classmethod from_rank_offsets(
key: str,
data: torch.Tensor,
*rank_offsets: Tuple[int, int, int],
replica_id: int | Tuple[int, ...] = 0,
prepend_axis_num: int = 0,
flattened_range: None = None,
**init_kwargs,
)#

Allows to construct the ShardedTensor given offset specified in process ranks.

Parameters:
  • key (str) – unique key

  • data (torch.Tensor) – local tensor data

  • rank_offsets (Tuple[int, int, int]) – each tuple (axis, axis_rank_offset, axis_fragm) says that if global tensor is divided into axis_fragm fragment along axis axis, then local tensor data corresponds to the axis_rank_offset chunk.

  • replica_id (ReplicaId) – see ShardedTensor

  • prepend_axis_num (int) – see ShardedTensor

  • flattened_range (None) – must be None when using this constructor

  • init_kwargs – passed to ShardedTensor.__init__

classmethod from_rank_offsets_flat(
key: str,
data: torch.Tensor,
non_flat_local_shape: Tuple[int, ...],
*args,
flattened_range: slice | None = None,
**kwargs,
)#

Allows to construct a flattened ShardedTensor given offset specified in process ranks.

Parameters:
  • key (str)

  • data (torch.Tensor) – this should be a flattened data tensor

  • non_flat_local_shape (Tuple[int, ...]) – expected local shape of a non-flat chunk

  • *args – passed unchanged to the from_rank_offsets constructor

  • flattened_range (slice) – see ShardedTensor. Defaults to None, but must be set to a non-None slice.

  • **kwargs

Returns:

constructed ShardedTensor instance

Return type:

ShardedTensor

global_coordinates() Tuple[numpy.ndarray, ...]#

Returns a tuple of np.ndarrays representing the coordinates of the global tensor that this ShardedTensor corresponds to.

global_offset: Tuple[int, ...]#
global_shape: Tuple[int, ...]#
global_slice() Tuple[int | slice, ...]#

Returns a tuple of int and slice objects representing a slice of the global tensor that this ShardedTensor corresponds to.

property has_regular_grid#

Alias for having a regular sharding grid.

init_data(
device: str | torch.device,
init_fn=torch.empty,
)#

Initialize the tensor data of this ShardedTensor.

Only called if data attribute is None.

Parameters:
  • device (Union[str, torch.device]) – device to place the tensor on

  • init_fn (Callable, optional) – function to use to initialize the tensor. Defaults to torch.empty.

key: str#
local_chunk_offset_in_global() Tuple[int, ...]#

Offset of a local chunk in a global array of chunks.

Returns:

the offset of the whole local chunk in a global array of chunks.

Return type:

Tuple[int, …]

local_coordinates() Tuple[numpy.ndarray, ...]#

Returns a tuple of np.ndarrays representing the coordinates of the local tensor that this ShardedTensor corresponds to.

local_shape: Tuple[int, ...]#
max_allowed_chunks() Tuple[int, ...]#

Returns the maximum allowed chunks for this ShardedTensor.

narrow(
dim: int,
start: int,
length: int,
) List[ShardedTensor]#

This is an analogue of torch.narrow for ShardedTensors.

Narrowing assumes that we narrow a local tensor on each rank. This has consequences on local_shape, global_shape, global_offset, etc.

Parameters:
  • dim (int) – dimension to narrow. Doesn’t include prepended axes.

  • start (int) – start element

  • length (int) – length of the slice

Returns:

narrowed ShardedTensors. For non-flat tensors,

the list will always have 1 element. For flat ShardedTensors the number of elements varies depending on dim and on overlap, because flat tensors must be contiguous. In particular the list can be empty.

Return type:

List[ShardedTensor]

prepend_axis_num: int = 0#
replica_id: int | Tuple[int, ...] = 0#
validate_metadata_integrity() None#

Codifies the constraints on metadata attributes.

Meeting those constraints is guaranteed when instantiating a ShardedTensor class with from_rank_offsets or from_rank_offsets_flat constructors.

Returns:

None

without_data()#

Returns a new ShardedBase instance with data=None.

class core.dist_checkpointing.mapping.ShardedTensorFactory(
key: str,
data: torch.Tensor,
build_fn: Callable[[str, torch.Tensor, int | Tuple[int, ...], slice | None], Dict[str, Any]],
merge_fn: Callable[[Dict[str, Any]], torch.Tensor],
replica_id: int | Tuple[int, ...] = 0,
flattened_range: slice | None = None,
)#

Bases: ShardedBase

Allows to apply transformations to tensors before/after serialization.

The essence of those transformations is that they can be applied to optimizer states the same way they are applied to the model params. The ultimate state dict with sharded tensors must depend functionally on build_fn arguments (key, data, replica_id, flattened_range), which will be provided by the optimizer.

Builder creates a sub-state-dict out of a tensor before saving, and merger merges the corresponding state dict after loading.

Parameters:
  • key (str) – unique identifier of the factory

  • data (torch.Tensor) – original model parameter that will be further transformed by this factory

  • build_fn (callable) – function that transforms the original tensor to a sharded state dict

  • merge_fn (callable) – function that transforms loaded subtree back into a single tensor (inverse of build_fn)

  • replica_id (ReplicaId) – indicates factory replication wrt. factories in different processes

  • flattened_range (slice, optional) – indicates additional flattening applied to the ShardedTensors produced by the factory

build()#

Builds a ShardedStateDict from the original tensor

build_fn: Callable[[str, torch.Tensor, int | Tuple[int, ...], slice | None], Dict[str, Any]]#
data: torch.Tensor#
flattened_range: slice | None = None#
key: str#
merge_fn: Callable[[Dict[str, Any]], torch.Tensor]#
replica_id: int | Tuple[int, ...] = 0#
validate_metadata_integrity()#

No reasonable checks can be applied

without_data()#

Returns a new ShardedBase instance with data=None.

core.dist_checkpointing.mapping.apply_factories(sharded_state_dict: Dict[str, Any])#

Turn ShardedTensorFactories into ShardedTensors in-place.

Parameters:

sharded_state_dict (ShardedStateDict) – state dict possibly containing ShardedTensorFactory objects

Returns:

state dict is modified in place

Return type:

None

core.dist_checkpointing.mapping.apply_factory_merges(
x1: Dict[str, Any],
x2: Dict[str, Any],
key: Tuple[str, ...] = (),
) Dict[str, Any]#

Apply merges defined by ShardedTensorFactories in-place.

Parameters:
  • x1 (StateDict) – state dict loaded from the checkpoint

  • x2 (ShardedStateDict) – subset of x1 (in terms of dict keys) with ShardedTensorFactory as (possibly nested) values that define how to merge objects from the x1 state dict

  • key (Tuple[str, ...]) – current key in a recursive call. Used only for reporting meaningful errors

Returns:

x1 modified in-place

Return type:

StateDict

core.dist_checkpointing.mapping.is_main_replica(replica_id: int | Tuple[int, ...])#

Checks if given replica_id is considered as main.

“Main” replica is: - integer 0 - or an iterable with all 0 elements

It is the application responsibility to set correct replicas for sharded tensors.

Parameters:

replica_id (Union[int, Tuple[int, ...]]) – replica id

Returns:

True for a “main” replica

Return type:

(bool)

dist_checkpointing.optimizer module#

Helpers for defining sharding for optimizer states based on existing sharding for model parameters.

core.dist_checkpointing.optimizer.get_optim_param_to_id_map(
optim_params_iter: Iterable[torch.nn.Parameter],
) Dict[int, int]#

Generate mapping from optimizer param to optimizer state id.

core.dist_checkpointing.optimizer.get_param_id_to_sharded_param_map(
model_sharded_state_dict: Dict[str, Any],
optim_params_iter: Iterable[torch.nn.Parameter],
) Dict[int, ShardedTensor | ShardedTensorFactory]#

Generate mapping from optimizer state ids to model sharded parameters.

Parameters:
  • model_sharded_state_dict – sharded state dict with all model sharded tensors (can have any structure)

  • optim_params_iter – iterable which iterates over model parameters tracked by the optimizer. The iteration must be in the same order as in the optimizer parameters.

Returns:

mapping from optimizer state ids

to model sharded parameters.

Return type:

Dict[int, Union[ShardedTensor, ShardedTensorFactory]]

core.dist_checkpointing.optimizer.make_sharded_optimizer_tensor(
model_param: ShardedTensor | ShardedTensorFactory,
optim_param: torch.Tensor,
prefix: str,
) ShardedTensor | ShardedTensorFactory#

Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param

Parameters:
  • model_param (Union[ShardedTensor, ShardedTensorFactory]) – model param

  • optim_param (torch.Tensor) – corresponding optimizer param

  • prefix (str) – optimizer prefix for the ShardedTensor or ShardedTensorFactory

Returns:

wrapped optimizer parameter

Return type:

Union[ShardedTensor, ShardedTensorFactory]

core.dist_checkpointing.optimizer.optim_state_to_sharding_state(
optim_state_dict: Dict[str, Any],
id_to_sharded_param_map: Dict[int, ShardedTensor],
exclude_keys: Tuple[str] = (),
)#

Turn optimizer state dict to sharded state dict based on model state dict in-place.

Can be used to add sharding information to most common optimizer state dict. Creates separate ShardedTensors for each key in optim_state_dict[‘state’] (e.g. for torch.optim.Adam there will be separate tensors for exp_avg and exp_avg_sq)

Parameters:
  • optim_state_dict (StateDict) – optimizer state dict with state parameters under state key and group hyperparameters under param_groups -> params key.

  • id_to_sharded_param_map (Dict[int, ShardedTensor]) – mapping from optimizer param ids to model sharded tensors. Can be generated with get_param_id_to_sharded_param_map function.

  • exclude_keys (Tuple[str]) – optimizer state keys to exclude from the final state dict.

Returns:

state dict is modified in place

Return type:

None

dist_checkpointing.core module#

Module for managing distributed checkpoints metadata.

class core.dist_checkpointing.core.CheckpointingConfig(
sharded_backend: str,
sharded_backend_version: int = 1,
common_backend: str = 'torch',
common_backend_version: int = 1,
)#

Bases: object

Documents backends used in the checkpoint.

Checkpoint config keeps track of formats used for storing the sharded tensors (sharded_backend) and other objects (common_backend).

Note that versioning is not for the checkpoint content (which is application specific), but for the checkpoint format itself.

common_backend: str = 'torch'#
common_backend_version: int = 1#
sharded_backend: str#
sharded_backend_version: int = 1#
exception core.dist_checkpointing.core.CheckpointingException#

Bases: Exception

Base checkpointing related exception

core.dist_checkpointing.core.check_is_distributed_checkpoint(checkpoint_dir)#

Checks if metadata.json exists in the checkpoint and is a valid config.

Parameters:

checkpoint_dir – checkpoint directory

Returns:

True if metadata.json exists in the checkpoint and is a valid config.

Return type:

bool

core.dist_checkpointing.core.maybe_load_config(
checkpoint_dir: str,
) CheckpointingConfig | None#

Returns checkpoint config if checkpoint_dir is a distributed checkpoint and None otherwise

Parameters:

checkpoint_dir – checkpoint directory

Returns:

None if checkpoint is not a valid distributed checkpoint

Return type:

CheckpointingConfig (optional)

core.dist_checkpointing.core.save_config(
config: CheckpointingConfig,
checkpoint_dir: str,
)#

Save given config to checkpoint directory.

Parameters:
  • config – checkpoint config

  • checkpoint_dir – checkpoint directory

Returns:

None

dist_checkpointing.dict_utils module#

Utilities for operating with dicts and lists.

All functions in this module handle nesting of dicts and lists. Other objects (e.g. tuples) are treated as atomic leaf types that cannot be traversed.

core.dist_checkpointing.dict_utils.dict_list_map_inplace(
f: Callable[[U], V],
x: Dict | List | U,
)#

Maps dicts and lists in-place with a given function.

core.dist_checkpointing.dict_utils.dict_list_map_outplace(
f: Callable[[U], V],
x: Dict | List | U,
) Dict | List | V#

Maps dicts and lists out-of-place with a given function.

core.dist_checkpointing.dict_utils.dict_map(f: Callable, d: dict)#

map equivalent for dicts.

core.dist_checkpointing.dict_utils.dict_map_with_key(f: Callable, d: dict)#

map equivalent for dicts with a function that accepts tuple (key, value).

core.dist_checkpointing.dict_utils.diff(
x1: Any,
x2: Any,
prefix: Tuple = (),
) Tuple[list, list, list]#

Recursive diff of dicts.

Parameters:
  • x1 (object) – left dict

  • x2 (object) – right dict

  • prefix (tuple) – tracks recursive calls. Used for reporting differing keys.

Returns:

tuple of:
  • only_left: Prefixes present only in left dict

  • only_right: Prefixes present only in right dict

  • mismatch: values present in both dicts but not equal across dicts.

    For tensors equality of all elems is checked. Each element is a tuple (prefix, type of left value, type of right value).

Return type:

Tuple[list, list, list]

core.dist_checkpointing.dict_utils.extract_matching_values(
x: dict | list,
predicate: Callable[[Any], bool],
return_lists_as_dicts: bool = False,
) Tuple[dict | list, dict | list]#

Return matching and nonmatching values. Keeps hierarchy.

Parameters:
  • x (Union[dict, list]) – state dict to process. Top-level argument must be a dict or list

  • predicate (object -> bool) – determines matching values

  • return_lists_as_dicts (bool) – if True, matching lists will be turned into dicts, with keys indicating the indices of original elements. Useful for reconstructing the original hierarchy.

core.dist_checkpointing.dict_utils.inspect_types(
x: Any,
prefix: Tuple = (),
indent: int = 4,
)#

Helper to print types of (nested) dict values.

core.dist_checkpointing.dict_utils.map_reduce(
xs: ~typing.Iterable,
key_fn: ~typing.Callable = <function <lambda>>,
value_fn: ~typing.Callable = <function <lambda>>,
reduce_fn: ~typing.Callable = <function <lambda>>,
) dict#

Simple map-reduce implementation following more_itertools.map_reduce interface.

core.dist_checkpointing.dict_utils.merge(
x1: dict | list,
x2: dict | list,
key: Tuple[int | str, ...] = (),
)#

Merges dicts and lists recursively.

core.dist_checkpointing.dict_utils.nested_items_iter(x: dict | list)#

Returns iterator over (nested) tuples (container, key, value) of a given dict or list.

core.dist_checkpointing.dict_utils.nested_values(x: dict | list)#

Returns iterator over (nested) values of a given dict or list.

dist_checkpointing.utils module#

Helpers for manipulating sharded tensors and sharded state dicts.

core.dist_checkpointing.utils.add_prefix_for_sharding(
sharded_state_dict: Dict[str, Any],
prefix: str,
)#

Prepend a given prefix to all ShardedBase objects in a given state dict in-place.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict

  • prefix (str) – prefix to be prepended

Returns:

state dict is modified in-place

Return type:

None

core.dist_checkpointing.utils.apply_prefix_mapping(
sharded_state_dict: Dict[str, Any],
prefix_map: Dict[str, str],
)#

Replaces prefixes only in keys matching with one of prefixes in the map.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict to replace keys in

  • prefix_map (Dict[str, str]) – map of old->new prefixes. The first matching prefix for each key is used

Returns:

state dict is modified in place

Return type:

None

core.dist_checkpointing.utils.debug_msg(msg: str)#

Logs a debug message using the current logger stack.

This function formats and logs a debug message with the current logger and name stack, preserving context from the logger_stack context manager.

Parameters:

msg (str) – The message to be logged at the debug level.

Example

debug_msg(“Checkpoint initialized”) # Logs: “scope_name Checkpoint initialized” if called within logger_stack(“scope_name”)

core.dist_checkpointing.utils.debug_time(
name: str,
logger: Logger | None = None,
threshold: float = -inf,
level=None,
)#

Simple context manager for timing functions/code blocks.

Parameters:
  • name (str) – Label describing the code being measured.

  • logger (logging.Logger, optional) – Logger for output. Defaults to the lowest logger.

  • threshold (float, optional) – Minimum time (seconds) to log. Skips logging if faster.

  • level (int, optional) – Logging level. Defaults to DEBUG if threshold is unset; WARNING otherwise.

core.dist_checkpointing.utils.extract_nonpersistent(
sharded_state_dict: Dict[str, Any],
) Tuple[Dict[str, Any], Dict[str, Any]]#

Extract a dict consisting of only LocalNonpersistentObjects from a given state dict.

Parameters:

sharded_state_dict – state dict possibly containing LocalNonpersistentObjects

Returns:

tuple of:
  • state dict with all LocalNonpersistentObjects (keeping the original state dict structure)

  • state dict with all other objects (keeping the original state dict structure)

Return type:

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_sharded_base(
sharded_state_dict: Dict[str, Any],
) Tuple[Dict[str, Any], Dict[str, Any]]#

Extract a dict consisting of only ShardedBase from a given state dict with any objects.

Parameters:

sharded_state_dict – state dict possibly containing ShardedBase objects

Returns:

tuple of:
  • state dict with all ShardedBase objects (keeping the original state dict structure)

  • state dict with all other objects (keeping the original state dict structure)

Return type:

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_sharded_tensors(
sharded_state_dict: Dict[str, Any],
) Tuple[Dict[str, Any], Dict[str, Any]]#

Extract a dict consisting of only ShardedTensor objects from a given state dict with any objects.

Parameters:

sharded_state_dict – state dict possibly containing ShardedTensor objects

Returns:

tuple of:
  • state dict with all ShardedTensor (keeping the original state dict structure)

  • state dict with all objects other than ShardedTensor (keeping the original state dict structure)

Return type:

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_sharded_tensors_and_factories(
sharded_state_dict: Dict[str, Any],
) Tuple[Dict[str, Any], Dict[str, Any]]#

Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects from a given state dict with any objects.

Parameters:

sharded_state_dict – state dict possibly containing ShardedTensor and ShardedTensorFactory objects

Returns:

tuple of:
  • state dict with all ShardedTensor and ShardedTensorFactory (keeping the original state dict structure)

  • state dict with all other objects (keeping the original state dict structure)

Return type:

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_sharded_tensors_or_nonpersistent(
sharded_state_dict: Dict[str, Any],
) Tuple[Dict[str, Any], Dict[str, Any]]#

Extract a dict consisting of only ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject objects from a given state dict with any objects.

Parameters:
  • sharded_state_dict – state dict possibly containing ShardedTensor, ShardedTensorFactory

  • objects (and LocalNonpersistentObject)

Returns:

tuple of:
  • state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject (keeping the original state dict structure)

  • state dict with all other objects (keeping the original state dict structure)

Return type:

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.force_all_tensors_to_non_fp8(
sharded_state_dict: Dict[str, Any],
)#

Force all tensors in state dict to be non-fp8.

Parameters:

sharded_state_dict (ShardedStateDict) – sharded state dict.

core.dist_checkpointing.utils.logger_stack(
name: str | None = None,
current_logger: Logger | None = None,
)#

Context manager for managing logger and name stack.

Temporarily pushes a logger and/or name onto their respective stacks, allowing hierarchical logging and contextual logger usage. Ensures the logger stack is restored afterward.

Parameters:
  • name (str, optional) – Name to add to the logger stack. Defaults to None.

  • current_logger (logging.Logger, optional) – Logger to use. Defaults to the last logger in the stack or a fallback if none exist.

Yields:

Tuple[str, logging.Logger]

A tuple with the concatenated logger name stack and

the current logger for the block.

Example

with logger_stack(“scope”, logger):

logger.info(“Log within ‘scope’”)

core.dist_checkpointing.utils.replace_prefix_for_sharding(
sharded_state_dict: Dict[str, Any],
old_prefix: str,
new_prefix: str,
)#

Replaces the given prefix in all sharded keys in a given state dict.

Errors out if some key does not begin with a given prefix.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict to replace keys in

  • old_prefix (str) – prefix to be replaced in each key

  • new_prefix (str) – new prefix

Returns:

state dict is modified in place

Return type:

None

core.dist_checkpointing.utils.zip_strict(*args)#

Alternative to Python’s builtin zip(…, strict=True) (available in 3.10+). Apart from providing functionality in earlier versions of Python is also more verbose. (Python’s zip does not print lengths, only which iterable has finished earlier)

Module contents#