core.dist_checkpointing.state_dict_utils#

Utilities for transforming state_dict.

Module Contents#

Functions#

save_preprocess

Preprocesses the given state dictionary by applying factories, discarding non-persistent data and extracting the common state dictionary. Optionally, it can validate sharding integrity.

load_preprocess

Preprocesses the given state dictionary by applying factories and extracting non-persistent data, without modifying the original dictionary.

filter_out_empty_flatten_tensor

Filter out ShardedTensors with empty flatten_range. These tensors can cause the PyTorch check in failure.

API#

core.dist_checkpointing.state_dict_utils.save_preprocess(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
validate_access_integrity: bool = True,
preprocess_common_before_consistancy_check: Callable[[core.dist_checkpointing.mapping.CommonStateDict], core.dist_checkpointing.mapping.StateDict] = None,
)#

Preprocesses the given state dictionary by applying factories, discarding non-persistent data and extracting the common state dictionary. Optionally, it can validate sharding integrity.

Parameters:
  • sharded_state_dict (ShardedStateDict) – The initial state dictionary to be preprocessed.

  • validate_access_integrity (bool) – If True, triggers validation of sharding integrity.

  • preprocess_common_before_consistancy_check (callable, None) – A callable function that will preprocess the common state dict (i.e can be used to remove keys that we expect to be different in the state dict)

Returns:

The preprocessed sharded state dictionary and the common state dictionary.

Return type:

Tuple[ShardedStateDict, dict]

core.dist_checkpointing.state_dict_utils.load_preprocess(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
)#

Preprocesses the given state dictionary by applying factories and extracting non-persistent data, without modifying the original dictionary.

Parameters:

sharded_state_dict (ShardedStateDict) – The initial state dictionary to be processed (remains unchanged).

Returns:

  • A preprocessed copy of the sharded state dictionary.

  • A dictionary containing non-persistent state data.

  • A dictionary of ShardedTensorFactory instances.

Return type:

Tuple[ShardedStateDict, dict, dict]

core.dist_checkpointing.state_dict_utils.filter_out_empty_flatten_tensor(
sharded_state_dict: Union[dict, list],
)#

Filter out ShardedTensors with empty flatten_range. These tensors can cause the PyTorch check in failure.

Parameters:

sharded_state_dict – state dict possibly containing ShardedTensor objects