core.dist_checkpointing.serialization#

Entrypoints for saving and loading the distributed checkpoints.

Functions load and save are equivalents of torch.load and torch.save but expect torch.Tensors to be wrapped with classes from the mapping module. Additionally, load expects the sharded state dict argument as a guidance for loading the sharded tensors.

Module Contents#

Functions#

load

Loading entrypoint.

load_common_state_dict

Load common (non-sharded) objects state dict from the checkpoint.

load_tensors_metadata

Load tensors metadata from the checkpoint.

load_sharded_metadata

Load sharded metadata from the checkpoint.

load_plain_tensors

Load checkpoint tensors without any sharding and plain structure.

load_content_metadata

Load content metadata stored in the checkpoint with save(..., content_metadata=...).

remove_sharded_tensors

determine the appropriate sharding strategy and delegate removal to the sharded strategy

save

Saving entrypoint.

get_default_save_sharded_strategy

Get default save sharded strategy.

get_default_save_common_strategy

Get default save common strategy.

get_default_load_sharded_strategy

Get default load sharded strategy.

Data#

API#

core.dist_checkpointing.serialization.logger#

‘getLogger(…)’

core.dist_checkpointing.serialization.CkptShardedMetadata#

None

core.dist_checkpointing.serialization._CONTENT_METADATA_KEY#

‘content_metadata’

core.dist_checkpointing.serialization.load(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[core.dist_checkpointing.strategies.base.LoadShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[core.dist_checkpointing.strategies.base.LoadCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
strict: Union[str, core.dist_checkpointing.validation.StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED,
) Union[core.dist_checkpointing.mapping.StateDict, Tuple[core.dist_checkpointing.mapping.StateDict, Set[str], Set[str]]]#

Loading entrypoint.

In the steps below, the following verbs refer to corresponding objects:

  • load = load from checkpoint

  • extract = extract from sharded_state_dict

  • add = add to the final state dict Steps:

  1. Load common state dict and form the base of the result state dict

  2. Apply factories to sharded_state_dict

  3. Extract LocalNonPersistentObject and add

  4. (optional) Extract ShardedObjects, load and add

  5. Extract ShardedBase, load, apply factory merges and add

Parameters:
  • sharded_state_dict (ShardedStateDict) – state dict of the existing model populated with ShardedTensors. Used as a mapping to determine which parts of global tensors stored in the checkpoint should be loaded.

  • checkpoint_dir (str) – directory with the checkpoint

  • sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional) – configures loading behavior for sharded tensors

  • common_strategy (LoadCommonStrategy, Tuple[str, int], optional) – configures loading behavior for common data

  • validate_access_integrity (bool default = True) – checks if each tensor shard is accessed exactly once (as main replica) by some process

  • strict (StrictHandling, str, optional) – determines the behavior in case of a mismatch between the requested sharded state dict and the checkpoint. See StrictHandling docs for more details. Some values affect the return value of this function (missing and unexpected keys are returned). Defaults to True (StrictHandling.ASSUME_OK_UNEXPECTED) which doesn’t incur any performance overhead. Other recommended values are: False (StrictHandling.LOG_UNEXPECTED) which logs only unexpected keys or StrictHandling.RETURN_ALL which returns all mismatch keys.

Returns:

in most cases only the loaded state dict is returned. If strict flag was set to

Return type:

StateDict or Tuple[StateDict, Set[str], Set[str]]

core.dist_checkpointing.serialization.load_common_state_dict(
checkpoint_dir: Union[str, pathlib.Path],
) core.dist_checkpointing.mapping.StateDict#

Load common (non-sharded) objects state dict from the checkpoint.

Parameters:

checkpoint_dir (str) – checkpoint directory

Returns:

state dict with non-sharded objects from the checkpoint

Return type:

StateDict

core.dist_checkpointing.serialization.load_tensors_metadata(
checkpoint_dir: str,
sharded_strategy: Union[core.dist_checkpointing.strategies.base.LoadShardedStrategy, None] = None,
) core.dist_checkpointing.serialization.CkptShardedMetadata#

Load tensors metadata from the checkpoint.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).

Concrete implementation depends on the loading strategy. If no strategy is given, a default for a given backend is used.

Parameters:
  • checkpoint_dir (str) – checkpoint directory to load from

  • sharded_strategy (LoadShardedStrategy, optional) – sharded strategy to load metadata. Defaults to None - in this case a default load strategy for a given checkpoint type is used.

Returns:

flat state dict without data describing ShardedTensors in the checkpoint

Return type:

CkptShardedMetadata

core.dist_checkpointing.serialization.load_sharded_metadata(
checkpoint_dir: str,
sharded_strategy: Union[core.dist_checkpointing.strategies.base.LoadShardedStrategy, None] = None,
common_strategy: Union[core.dist_checkpointing.strategies.base.LoadCommonStrategy, None] = None,
) core.dist_checkpointing.serialization.CkptShardedMetadata#

Load sharded metadata from the checkpoint.

Similar to load_tensors_metadata, but includes also ShardedObjects.

Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).

Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).

Concrete implementation depends on the loading strategy. If no strategy is given, a default for a given backend is used.

Parameters:
  • checkpoint_dir (str) – checkpoint directory to load from

  • sharded_strategy (LoadShardedStrategy, optional) – sharded strategy to load metadata. Defaults to None - in this case a default load strategy for a given checkpoint type is used.

  • common_strategy (LoadCommonStrategy, optional) – common strategy to load metadata. Defaults to None - in this case a default load strategy for a given checkpoint type is used. This strategy won’t be used unless sharded_strategy can’t handle ShardedObjects

Returns:

flat state dict without data describing ShardedTensors and ShardedObjects in the checkpoint

Return type:

CkptShardedMetadata

core.dist_checkpointing.serialization.load_plain_tensors(
checkpoint_dir: str,
) core.dist_checkpointing.mapping.StateDict#

Load checkpoint tensors without any sharding and plain structure.

NOTE: common state dict is NOT included.

Parameters:

checkpoint_dir (str) – checkpoint directory to load the tensors from.

Returns:

checkpoint state dict containing only torch.Tensors.

Return type:

StateDict

core.dist_checkpointing.serialization.load_content_metadata(
checkpoint_dir: Optional[str] = None,
*,
preloaded_state_dict: Optional[core.dist_checkpointing.mapping.StateDict] = None,
) Optional[dict]#

Load content metadata stored in the checkpoint with save(..., content_metadata=...).

Parameters:
  • checkpoint_dir (str, optional) – checkpoint directory to load the content metadata from.

  • preloaded_state_dict (StateDict, optional) – if the state dict was already loaded, can be provided to avoid double load from storage

Returns:

checkpoint content metadata None: in case there is no content metadata in the checkpoint

Return type:

dict

core.dist_checkpointing.serialization.remove_sharded_tensors(checkpoint_dir: str, key_prefix: str)#

determine the appropriate sharding strategy and delegate removal to the sharded strategy

core.dist_checkpointing.serialization.save(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[core.dist_checkpointing.strategies.base.SaveShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[core.dist_checkpointing.strategies.base.SaveCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
async_sharded_save: bool = False,
preprocess_common_before_consistancy_check: Optional[Callable[[core.dist_checkpointing.mapping.CommonStateDict], core.dist_checkpointing.mapping.StateDict]] = None,
content_metadata: Optional[dict] = None,
) Optional[core.dist_checkpointing.strategies.async_utils.AsyncRequest]#

Saving entrypoint.

Extracts ShardedTensors from the given state dict. Rank 0 saves the “regular” part of the checkpoint to common torch file. The ShardedTensors are saved according to a strategy specified by the config.

Steps:

  1. Apply factories

  2. Extract and discard LocalNonPersistentObject

  3. Extract all ShardedBase object

  4. Save all other objects to common.pt

  5. (optional) Extract and save ShardedObjects

  6. Save all ShardedBase objects

  7. Write metadata.json file with backend and version metadata.

Step (6) can be performed asynchronously (see async_sharded_save), in this case the actual save is embodied in the returned async request and can be scheduled by the external caller. For async request, step (7) is added as one of the finalization functions, so that metadata.json is written only if the checkpoint is complete.

Parameters:
  • sharded_state_dict (ShardedStateDict) – state dict of the populated with ShardedTensors. Used as a mapping to determine how local tensors should be saved as global tensors in the checkpoint.

  • checkpoint_dir (str) – directory to save the checkpoint to

  • sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional) – configures sharded tensors saving behavior and backend

  • common_strategy (SaveCommonStrategy, Tuple[str, int], optional) – configures common data saving behavior and backend

  • validate_access_integrity (bool default = True) – checks if each tensor shard is accessed exactly once (as main replica) by some process. It also makes sure the common state dict is consistant across all ranks

  • async_sharded_save (bool, optional) – if True, for the sharded state dict part an async save implementation will be called, with the AsyncRequest being returned to the caller. Note that it is the caller responsibility to actually schedule the async save. Defaults to False.

  • preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None) – A callable function that will preprocess the common state dict (i.e can be used to remove keys that we expect to be different in the state dict). The function must not modify the original state dict

  • content_metadata (dict, optional) – metadata to identify the checkpoint content. Useful for framework specific versioning.

Returns:

if async_sharded_save is True, returns async request that should be scheduled by the caller of this function. None otherwise.

Return type:

AsyncRequest (optional)

core.dist_checkpointing.serialization.get_default_save_sharded_strategy(
backend: str = 'torch_dist',
version: int = 1,
) core.dist_checkpointing.strategies.base.SaveShardedStrategy#

Get default save sharded strategy.

core.dist_checkpointing.serialization.get_default_save_common_strategy(
backend: str = 'torch',
version: int = 1,
) core.dist_checkpointing.strategies.base.SaveCommonStrategy#

Get default save common strategy.

core.dist_checkpointing.serialization.get_default_load_sharded_strategy(
checkpoint_dir: str,
) core.dist_checkpointing.strategies.base.LoadShardedStrategy#

Get default load sharded strategy.