core.dist_checkpointing.tensor_aware_state_dict#
Utilities for transforming state_dict, including a tensor-aware implementation.
Module Contents#
Classes#
MCore-specific class defining the interface between the MCore state dict and checkpoint manager. |
Data#
API#
- core.dist_checkpointing.tensor_aware_state_dict.logger#
‘getLogger(…)’
- class core.dist_checkpointing.tensor_aware_state_dict.MCoreTensorAwareStateDict#
Bases:
nvidia_resiliency_ext.checkpointing.local.base_state_dict.TensorAwareStateDictMCore-specific class defining the interface between the MCore state dict and checkpoint manager.
This class distinguishes between raw objects, the common state dict, and sharded state dicts (tensor parts). It also handles optional metadata needed for fully parallel save/load.
- common: core.dist_checkpointing.mapping.StateDict#
None
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict#
None
- _is_hollow: bool#
False
- static _validate_params(algo)#
- static _get_distribution(
- fully_parallel,
- sharded_part,
- parallelization_group,
- cached_distribution=None,
- static _remove_redundant_data(
- fully_parallel,
- sharded_part,
- shard_to_saving_rank,
- parallelization_group,
- classmethod from_state_dict(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- algo: str = 'fully_parallel',
- parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
- cached_metadata: core.dist_checkpointing.exchange_utils.ShardDistribution = None,
Constructs a TensorAwareStateDict from a sharded state dictionary.
This method preprocesses the input
sharded_state_dict, validates parameters, and extracts the necessary data to create an instance ofMCoreTensorAwareStateDict.- Parameters:
sharded_state_dict – The input sharded state dictionary to be converted.
algo (str, optional) –
Initialization algorithm. Defaults to ‘fully_parallel’.
’fully_parallel’ enables fully parallel initialization.
parallelization_group (Optional) – A distributed process group for parallelization.
cached_metadata (Optional) –
Precomputed metadata from previous saves.
Reuses data that doesn’t need recalculation, optimizing the creation process.
- Returns:
An instance initialized with the provided sharded state dictionary and optional cached metadata.
The metadata is stored in memory to speed up future saves.
- Return type:
TensorAwareStateDict
- property is_hollow#
True iff tensors had been extracted and have not been inserted back yet.
- property _sharded_tensors#
- property tensors: Iterator[torch.Tensor]#
Get the tensor data from the state dict.
- property common_state_dict: Dict#
Get the common state dict from the state dict.
- pop_tensors() List[torch.Tensor]#
Extracts the tensor data from the wrapped state dict, preserving metadata.
Replaces the tensor data in sharded_tensors with device type of extracted tensors. After this operation, the state dictionary is “hollow”, containing no tensor data. Further calls to
pop_tensorwill raise an error.@return List of extracted tensors
- insert_tensors(tensor_data: Iterable[torch.Tensor])#
Reverse of
pop_tensors. Replaces device type in sharded_tensors with actual values Value ofselfis considered to be the same after:self.insert_tensors(self.pop_tensors())
- init_tensors()#
Initializes empty tensors with the same properties as the original tensors.
This function should only be called after the original tensors have been popped. It ensures that the newly created empty tensors match the shape, dtype, and device of the originals, but contain no data.
- copy_tensors_to_cpu(non_blocking=False)#
Stores CPU copies of tensors in the state_dict, replacing the originals, but without destroying them. The original devices are remembered for restoration with restore_tensor_device(). Using non_blocking=True allows for asynchronous copying.
- restore_tensor_device(non_blocking=True)#
Restores all tensors to their original devices, if a move is required. Using non_blocking=True allows for asynchronous copying.
- _insert_sharded_data(
- fully_parallel,
- sharded_part,
- parallelization_group,
- exchange_algo,
- to_state_dict(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- algo: str = 'atomic',
- exchange_algo: str = 'broadcast',
- validate_access_integrity: bool = True,
- parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
- strict: core.dist_checkpointing.validation.StrictHandling = StrictHandling.ASSUME_OK_UNEXPECTED,
- return_mismatch_keys: bool = False,
Convert tensor-aware dict back to the original state_dict