core.dist_checkpointing.state_dict_utils#
Utilities for transforming state_dict.
Module Contents#
Functions#
Preprocesses the given state dictionary by applying factories, discarding non-persistent data and extracting the common state dictionary. Optionally, it can validate sharding integrity. |
|
Preprocesses the given state dictionary by applying factories and extracting non-persistent data, without modifying the original dictionary. |
|
Filter out ShardedTensors with empty flatten_range. These tensors can cause the PyTorch check in failure. |
API#
- core.dist_checkpointing.state_dict_utils.save_preprocess(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- validate_access_integrity: bool = True,
- preprocess_common_before_consistancy_check: Callable[[core.dist_checkpointing.mapping.CommonStateDict], core.dist_checkpointing.mapping.StateDict] = None,
Preprocesses the given state dictionary by applying factories, discarding non-persistent data and extracting the common state dictionary. Optionally, it can validate sharding integrity.
- Parameters:
sharded_state_dict (ShardedStateDict) – The initial state dictionary to be preprocessed.
validate_access_integrity (bool) – If True, triggers validation of sharding integrity.
preprocess_common_before_consistancy_check (callable, None) – A callable function that will preprocess the common state dict (i.e can be used to remove keys that we expect to be different in the state dict)
- Returns:
The preprocessed sharded state dictionary and the common state dictionary.
- Return type:
Tuple[ShardedStateDict, dict]
- core.dist_checkpointing.state_dict_utils.load_preprocess(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
Preprocesses the given state dictionary by applying factories and extracting non-persistent data, without modifying the original dictionary.
- Parameters:
sharded_state_dict (ShardedStateDict) – The initial state dictionary to be processed (remains unchanged).
- Returns:
A preprocessed copy of the sharded state dictionary.
A dictionary containing non-persistent state data.
A dictionary of
ShardedTensorFactoryinstances.
- Return type:
Tuple[ShardedStateDict, dict, dict]
- core.dist_checkpointing.state_dict_utils.filter_out_empty_flatten_tensor(
- sharded_state_dict: Union[dict, list],
Filter out ShardedTensors with empty flatten_range. These tensors can cause the PyTorch check in failure.
- Parameters:
sharded_state_dict – state dict possibly containing ShardedTensor objects