core.dist_checkpointing.serialization#
Entrypoints for saving and loading the distributed checkpoints.
Functions load and save are equivalents of torch.load and torch.save
but expect torch.Tensors to be wrapped with classes from the mapping module.
Additionally, load expects the sharded state dict argument as a guidance for
loading the sharded tensors.
Module Contents#
Functions#
Loading entrypoint. |
|
Load common (non-sharded) objects state dict from the checkpoint. |
|
Load tensors metadata from the checkpoint. |
|
Load sharded metadata from the checkpoint. |
|
Load checkpoint tensors without any sharding and plain structure. |
|
Load content metadata stored in the checkpoint with |
|
determine the appropriate sharding strategy and delegate removal to the sharded strategy |
|
Saving entrypoint. |
|
Get default save sharded strategy. |
|
Get default save common strategy. |
|
Get default load sharded strategy. |
Data#
API#
- core.dist_checkpointing.serialization.logger#
‘getLogger(…)’
- core.dist_checkpointing.serialization.CkptShardedMetadata#
None
- core.dist_checkpointing.serialization._CONTENT_METADATA_KEY#
‘content_metadata’
- core.dist_checkpointing.serialization.load(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: str,
- sharded_strategy: Union[core.dist_checkpointing.strategies.base.LoadShardedStrategy, Tuple[str, int], None] = None,
- common_strategy: Union[core.dist_checkpointing.strategies.base.LoadCommonStrategy, Tuple[str, int], None] = None,
- validate_access_integrity: bool = True,
- strict: Union[str, core.dist_checkpointing.validation.StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED,
Loading entrypoint.
In the steps below, the following verbs refer to corresponding objects:
load = load from checkpoint
extract = extract from sharded_state_dict
add = add to the final state dict Steps:
Load common state dict and form the base of the result state dict
Apply factories to sharded_state_dict
Extract LocalNonPersistentObject and add
(optional) Extract ShardedObjects, load and add
Extract ShardedBase, load, apply factory merges and add
- Parameters:
sharded_state_dict (ShardedStateDict) – state dict of the existing model populated with ShardedTensors. Used as a mapping to determine which parts of global tensors stored in the checkpoint should be loaded.
checkpoint_dir (str) – directory with the checkpoint
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional) – configures loading behavior for sharded tensors
common_strategy (LoadCommonStrategy, Tuple[str, int], optional) – configures loading behavior for common data
validate_access_integrity (bool default = True) – checks if each tensor shard is accessed exactly once (as main replica) by some process
strict (StrictHandling, str, optional) – determines the behavior in case of a mismatch between the requested sharded state dict and the checkpoint. See
StrictHandlingdocs for more details. Some values affect the return value of this function (missing and unexpected keys are returned). Defaults toTrue(StrictHandling.ASSUME_OK_UNEXPECTED) which doesn’t incur any performance overhead. Other recommended values are:False(StrictHandling.LOG_UNEXPECTED) which logs only unexpected keys orStrictHandling.RETURN_ALLwhich returns all mismatch keys.
- Returns:
in most cases only the loaded state dict is returned. If
strictflag was set to- Return type:
StateDict or Tuple[StateDict, Set[str], Set[str]]
- core.dist_checkpointing.serialization.load_common_state_dict(
- checkpoint_dir: Union[str, pathlib.Path],
Load common (non-sharded) objects state dict from the checkpoint.
- Parameters:
checkpoint_dir (str) – checkpoint directory
- Returns:
state dict with non-sharded objects from the checkpoint
- Return type:
StateDict
- core.dist_checkpointing.serialization.load_tensors_metadata(
- checkpoint_dir: str,
- sharded_strategy: Union[core.dist_checkpointing.strategies.base.LoadShardedStrategy, None] = None,
Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is given, a default for a given backend is used.
- Parameters:
checkpoint_dir (str) – checkpoint directory to load from
sharded_strategy (LoadShardedStrategy, optional) – sharded strategy to load metadata. Defaults to None - in this case a default load strategy for a given checkpoint type is used.
- Returns:
flat state dict without data describing ShardedTensors in the checkpoint
- Return type:
CkptShardedMetadata
- core.dist_checkpointing.serialization.load_sharded_metadata(
- checkpoint_dir: str,
- sharded_strategy: Union[core.dist_checkpointing.strategies.base.LoadShardedStrategy, None] = None,
- common_strategy: Union[core.dist_checkpointing.strategies.base.LoadCommonStrategy, None] = None,
Load sharded metadata from the checkpoint.
Similar to
load_tensors_metadata, but includes also ShardedObjects.Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is given, a default for a given backend is used.
- Parameters:
checkpoint_dir (str) – checkpoint directory to load from
sharded_strategy (LoadShardedStrategy, optional) – sharded strategy to load metadata. Defaults to None - in this case a default load strategy for a given checkpoint type is used.
common_strategy (LoadCommonStrategy, optional) – common strategy to load metadata. Defaults to None - in this case a default load strategy for a given checkpoint type is used. This strategy won’t be used unless
sharded_strategycan’t handle ShardedObjects
- Returns:
flat state dict without data describing ShardedTensors and ShardedObjects in the checkpoint
- Return type:
CkptShardedMetadata
- core.dist_checkpointing.serialization.load_plain_tensors(
- checkpoint_dir: str,
Load checkpoint tensors without any sharding and plain structure.
NOTE: common state dict is NOT included.
- Parameters:
checkpoint_dir (str) – checkpoint directory to load the tensors from.
- Returns:
checkpoint state dict containing only torch.Tensors.
- Return type:
StateDict
- core.dist_checkpointing.serialization.load_content_metadata(
- checkpoint_dir: Optional[str] = None,
- *,
- preloaded_state_dict: Optional[core.dist_checkpointing.mapping.StateDict] = None,
Load content metadata stored in the checkpoint with
save(..., content_metadata=...).- Parameters:
checkpoint_dir (str, optional) – checkpoint directory to load the content metadata from.
preloaded_state_dict (StateDict, optional) – if the state dict was already loaded, can be provided to avoid double load from storage
- Returns:
checkpoint content metadata None: in case there is no content metadata in the checkpoint
- Return type:
dict
- core.dist_checkpointing.serialization.remove_sharded_tensors(checkpoint_dir: str, key_prefix: str)#
determine the appropriate sharding strategy and delegate removal to the sharded strategy
- core.dist_checkpointing.serialization.save(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: str,
- sharded_strategy: Union[core.dist_checkpointing.strategies.base.SaveShardedStrategy, Tuple[str, int], None] = None,
- common_strategy: Union[core.dist_checkpointing.strategies.base.SaveCommonStrategy, Tuple[str, int], None] = None,
- validate_access_integrity: bool = True,
- async_sharded_save: bool = False,
- preprocess_common_before_consistancy_check: Optional[Callable[[core.dist_checkpointing.mapping.CommonStateDict], core.dist_checkpointing.mapping.StateDict]] = None,
- content_metadata: Optional[dict] = None,
Saving entrypoint.
Extracts ShardedTensors from the given state dict. Rank 0 saves the “regular” part of the checkpoint to common torch file. The ShardedTensors are saved according to a strategy specified by the config.
Steps:
Apply factories
Extract and discard LocalNonPersistentObject
Extract all ShardedBase object
Save all other objects to common.pt
(optional) Extract and save ShardedObjects
Save all ShardedBase objects
Write metadata.json file with backend and version metadata.
Step (6) can be performed asynchronously (see
async_sharded_save), in this case the actual save is embodied in the returned async request and can be scheduled by the external caller. For async request, step (7) is added as one of the finalization functions, so that metadata.json is written only if the checkpoint is complete.- Parameters:
sharded_state_dict (ShardedStateDict) – state dict of the populated with ShardedTensors. Used as a mapping to determine how local tensors should be saved as global tensors in the checkpoint.
checkpoint_dir (str) – directory to save the checkpoint to
sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional) – configures sharded tensors saving behavior and backend
common_strategy (SaveCommonStrategy, Tuple[str, int], optional) – configures common data saving behavior and backend
validate_access_integrity (bool default = True) – checks if each tensor shard is accessed exactly once (as main replica) by some process. It also makes sure the common state dict is consistant across all ranks
async_sharded_save (bool, optional) – if True, for the sharded state dict part an async save implementation will be called, with the AsyncRequest being returned to the caller. Note that it is the caller responsibility to actually schedule the async save. Defaults to False.
preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None) – A callable function that will preprocess the common state dict (i.e can be used to remove keys that we expect to be different in the state dict). The function must not modify the original state dict
content_metadata (dict, optional) – metadata to identify the checkpoint content. Useful for framework specific versioning.
- Returns:
if
async_sharded_saveis True, returns async request that should be scheduled by the caller of this function. None otherwise.- Return type:
AsyncRequest (optional)
- core.dist_checkpointing.serialization.get_default_save_sharded_strategy(
- backend: str = 'torch_dist',
- version: int = 1,
Get default save sharded strategy.
- core.dist_checkpointing.serialization.get_default_save_common_strategy(
- backend: str = 'torch',
- version: int = 1,
Get default save common strategy.
- core.dist_checkpointing.serialization.get_default_load_sharded_strategy(
- checkpoint_dir: str,
Get default load sharded strategy.