core.dist_checkpointing.utils#
Helpers for manipulating sharded tensors and sharded state dicts.
Module Contents#
Functions#
Alternative to Python’s builtin zip(…, strict=True) (available in 3.10+). Apart from providing functionality in earlier versions of Python is also more verbose. (Python’s zip does not print lengths, only which iterable has finished earlier) |
|
Unique id of the sharded tensor data. |
|
Unique id of the sharded object data. |
|
Extract a dict consisting of only ShardedTensor objects from a given state dict with any objects. |
|
Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects from a given state dict with any objects. |
|
Extract a dict consisting of only ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject objects from a given state dict with any objects. |
|
Extract a dict consisting of only ShardedBase from a given state dict with any objects. |
|
Extract a dict consisting of only LocalNonpersistentObjects from a given state dict. |
|
Prepend a given prefix to all ShardedBase objects in a given state dict in-place. |
|
Replaces the given prefix in all sharded keys in a given state dict. |
|
Replaces prefixes only in keys matching with one of prefixes in the map. |
|
Force all tensors in state dict to be non-fp8. |
|
Context manager for managing logger and name stack. |
|
Simple context manager for timing functions/code blocks. |
|
Logs a debug message using the current logger stack. |
|
Create a clean copy of metadata for serialization by removing non-serializable objects. |
Data#
API#
- core.dist_checkpointing.utils._ShardId#
None
- core.dist_checkpointing.utils.zip_strict(*args)#
Alternative to Python’s builtin zip(…, strict=True) (available in 3.10+). Apart from providing functionality in earlier versions of Python is also more verbose. (Python’s zip does not print lengths, only which iterable has finished earlier)
- core.dist_checkpointing.utils._sharded_tensor_shard_id(
- sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor,
Unique id of the sharded tensor data.
Should yield the same value for same data replicated on different ranks.
- Parameters:
sharded_tensor (ShardedTensor) – sharded tensor representing the data shard
Returns (tuple): unique id of a data shard
- core.dist_checkpointing.utils._sharded_object_id(
- sharded_object: core.dist_checkpointing.mapping.ShardedObject,
Unique id of the sharded object data.
Should yield the same value for same data replicated on different ranks.
- Parameters:
sharded_object (ShardedObject) – sharded object representing the data shard
Returns (tuple): unique id of a data shard
- core.dist_checkpointing.utils.extract_sharded_tensors(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
Extract a dict consisting of only ShardedTensor objects from a given state dict with any objects.
- Parameters:
sharded_state_dict – state dict possibly containing ShardedTensor objects
- Returns:
tuple of: - state dict with all ShardedTensor (keeping the original state dict structure) - state dict with all objects other than ShardedTensor (keeping the original state dict structure)
- Return type:
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_sharded_tensors_and_factories(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects from a given state dict with any objects.
- Parameters:
sharded_state_dict – state dict possibly containing ShardedTensor and ShardedTensorFactory objects
- Returns:
tuple of: - state dict with all ShardedTensor and ShardedTensorFactory (keeping the original state dict structure) - state dict with all other objects (keeping the original state dict structure)
- Return type:
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_sharded_tensors_or_nonpersistent(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
Extract a dict consisting of only ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject objects from a given state dict with any objects.
- Parameters:
sharded_state_dict – state dict possibly containing ShardedTensor, ShardedTensorFactory
objects (and LocalNonpersistentObject)
- Returns:
tuple of: - state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject (keeping the original state dict structure) - state dict with all other objects (keeping the original state dict structure)
- Return type:
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_sharded_base(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
Extract a dict consisting of only ShardedBase from a given state dict with any objects.
- Parameters:
sharded_state_dict – state dict possibly containing ShardedBase objects
- Returns:
tuple of: - state dict with all ShardedBase objects (keeping the original state dict structure) - state dict with all other objects (keeping the original state dict structure)
- Return type:
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_nonpersistent(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
Extract a dict consisting of only LocalNonpersistentObjects from a given state dict.
- Parameters:
sharded_state_dict – state dict possibly containing LocalNonpersistentObjects
- Returns:
tuple of: - state dict with all LocalNonpersistentObjects (keeping the original state dict structure) - state dict with all other objects (keeping the original state dict structure)
- Return type:
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.add_prefix_for_sharding(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- prefix: str,
Prepend a given prefix to all ShardedBase objects in a given state dict in-place.
- Parameters:
sharded_state_dict (ShardedStateDict) – sharded state dict
prefix (str) – prefix to be prepended
- Returns:
state dict is modified in-place
- Return type:
None
- core.dist_checkpointing.utils.replace_prefix_for_sharding(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- old_prefix: str,
- new_prefix: str,
Replaces the given prefix in all sharded keys in a given state dict.
Errors out if some key does not begin with a given prefix.
- Parameters:
sharded_state_dict (ShardedStateDict) – sharded state dict to replace keys in
old_prefix (str) – prefix to be replaced in each key
new_prefix (str) – new prefix
- Returns:
state dict is modified in place
- Return type:
None
- core.dist_checkpointing.utils.apply_prefix_mapping(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- prefix_map: Dict[str, str],
Replaces prefixes only in keys matching with one of prefixes in the map.
- Parameters:
sharded_state_dict (ShardedStateDict) – sharded state dict to replace keys in
prefix_map (Dict[str, str]) – map of old->new prefixes. The first matching prefix for each key is used
- Returns:
state dict is modified in place
- Return type:
None
- core.dist_checkpointing.utils.force_all_tensors_to_non_fp8(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
Force all tensors in state dict to be non-fp8.
- Parameters:
sharded_state_dict (ShardedStateDict) – sharded state dict.
- core.dist_checkpointing.utils.fallback_logger#
‘getLogger(…)’
- core.dist_checkpointing.utils.__LOGGER_NAME_STACK#
[]
- core.dist_checkpointing.utils.__LOGGER_STACK#
[]
- core.dist_checkpointing.utils.logger_stack(
- name: Optional[str] = None,
- current_logger: Optional[logging.Logger] = None,
Context manager for managing logger and name stack.
Temporarily pushes a logger and/or name onto their respective stacks, allowing hierarchical logging and contextual logger usage. Ensures the logger stack is restored afterward.
- Parameters:
name (str, optional) – Name to add to the logger stack. Defaults to None.
current_logger (logging.Logger, optional) – Logger to use. Defaults to the last logger in the stack or a fallback if none exist.
- Yields:
Tuple[str, logging.Logger] –
A tuple with the concatenated logger name stack and the current logger for the block.
.. rubric:: Example
with logger_stack(“scope”, logger): logger.info(“Log within ‘scope’”)
- core.dist_checkpointing.utils.debug_time(
- name: str,
- logger: Optional[logging.Logger] = None,
- threshold: float = float('-inf'),
- level=None,
Simple context manager for timing functions/code blocks.
- Parameters:
name (str) – Label describing the code being measured.
logger (logging.Logger, optional) – Logger for output. Defaults to the lowest logger.
threshold (float, optional) – Minimum time (seconds) to log. Skips logging if faster.
level (int, optional) – Logging level. Defaults to DEBUG if
thresholdis unset; WARNING otherwise.
- core.dist_checkpointing.utils.debug_msg(msg: str)#
Logs a debug message using the current logger stack.
This function formats and logs a debug message with the current logger and name stack, preserving context from the logger_stack context manager.
- Parameters:
msg (str) – The message to be logged at the debug level.
.. rubric:: Example
debug_msg(“Checkpoint initialized”)
Logs: “scope_name Checkpoint initialized” if called within logger_stack(“scope_name”)#
- core.dist_checkpointing.utils._clean_metadata_for_serialization(metadata: dict) dict#
Create a clean copy of metadata for serialization by removing non-serializable objects.
- Parameters:
metadata – Original metadata dict
- Returns:
Clean metadata dict suitable for serialization