core.dist_checkpointing.utils#

Helpers for manipulating sharded tensors and sharded state dicts.

Module Contents#

Functions#

zip_strict

Alternative to Python’s builtin zip(…, strict=True) (available in 3.10+). Apart from providing functionality in earlier versions of Python is also more verbose. (Python’s zip does not print lengths, only which iterable has finished earlier)

_sharded_tensor_shard_id

Unique id of the sharded tensor data.

_sharded_object_id

Unique id of the sharded object data.

extract_sharded_tensors

Extract a dict consisting of only ShardedTensor objects from a given state dict with any objects.

extract_sharded_tensors_and_factories

Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects from a given state dict with any objects.

extract_sharded_tensors_or_nonpersistent

Extract a dict consisting of only ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject objects from a given state dict with any objects.

extract_sharded_base

Extract a dict consisting of only ShardedBase from a given state dict with any objects.

extract_nonpersistent

Extract a dict consisting of only LocalNonpersistentObjects from a given state dict.

add_prefix_for_sharding

Prepend a given prefix to all ShardedBase objects in a given state dict in-place.

replace_prefix_for_sharding

Replaces the given prefix in all sharded keys in a given state dict.

apply_prefix_mapping

Replaces prefixes only in keys matching with one of prefixes in the map.

force_all_tensors_to_non_fp8

Force all tensors in state dict to be non-fp8.

logger_stack

Context manager for managing logger and name stack.

debug_time

Simple context manager for timing functions/code blocks.

debug_msg

Logs a debug message using the current logger stack.

_clean_metadata_for_serialization

Create a clean copy of metadata for serialization by removing non-serializable objects.

Data#

API#

core.dist_checkpointing.utils._ShardId#

None

core.dist_checkpointing.utils.zip_strict(*args)#

Alternative to Python’s builtin zip(…, strict=True) (available in 3.10+). Apart from providing functionality in earlier versions of Python is also more verbose. (Python’s zip does not print lengths, only which iterable has finished earlier)

core.dist_checkpointing.utils._sharded_tensor_shard_id(
sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor,
) core.dist_checkpointing.utils._ShardId#

Unique id of the sharded tensor data.

Should yield the same value for same data replicated on different ranks.

Parameters:

sharded_tensor (ShardedTensor) – sharded tensor representing the data shard

Returns (tuple): unique id of a data shard

core.dist_checkpointing.utils._sharded_object_id(
sharded_object: core.dist_checkpointing.mapping.ShardedObject,
) core.dist_checkpointing.utils._ShardId#

Unique id of the sharded object data.

Should yield the same value for same data replicated on different ranks.

Parameters:

sharded_object (ShardedObject) – sharded object representing the data shard

Returns (tuple): unique id of a data shard

core.dist_checkpointing.utils.extract_sharded_tensors(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
) Tuple[core.dist_checkpointing.mapping.ShardedStateDict, core.dist_checkpointing.mapping.StateDict]#

Extract a dict consisting of only ShardedTensor objects from a given state dict with any objects.

Parameters:

sharded_state_dict – state dict possibly containing ShardedTensor objects

Returns:

tuple of: - state dict with all ShardedTensor (keeping the original state dict structure) - state dict with all objects other than ShardedTensor (keeping the original state dict structure)

Return type:

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_sharded_tensors_and_factories(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
) Tuple[core.dist_checkpointing.mapping.ShardedStateDict, core.dist_checkpointing.mapping.StateDict]#

Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects from a given state dict with any objects.

Parameters:

sharded_state_dict – state dict possibly containing ShardedTensor and ShardedTensorFactory objects

Returns:

tuple of: - state dict with all ShardedTensor and ShardedTensorFactory (keeping the original state dict structure) - state dict with all other objects (keeping the original state dict structure)

Return type:

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_sharded_tensors_or_nonpersistent(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
) Tuple[core.dist_checkpointing.mapping.ShardedStateDict, core.dist_checkpointing.mapping.StateDict]#

Extract a dict consisting of only ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject objects from a given state dict with any objects.

Parameters:
  • sharded_state_dict – state dict possibly containing ShardedTensor, ShardedTensorFactory

  • objects (and LocalNonpersistentObject)

Returns:

tuple of: - state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject (keeping the original state dict structure) - state dict with all other objects (keeping the original state dict structure)

Return type:

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_sharded_base(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
) Tuple[core.dist_checkpointing.mapping.ShardedStateDict, core.dist_checkpointing.mapping.StateDict]#

Extract a dict consisting of only ShardedBase from a given state dict with any objects.

Parameters:

sharded_state_dict – state dict possibly containing ShardedBase objects

Returns:

tuple of: - state dict with all ShardedBase objects (keeping the original state dict structure) - state dict with all other objects (keeping the original state dict structure)

Return type:

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_nonpersistent(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
) Tuple[core.dist_checkpointing.mapping.ShardedStateDict, core.dist_checkpointing.mapping.StateDict]#

Extract a dict consisting of only LocalNonpersistentObjects from a given state dict.

Parameters:

sharded_state_dict – state dict possibly containing LocalNonpersistentObjects

Returns:

tuple of: - state dict with all LocalNonpersistentObjects (keeping the original state dict structure) - state dict with all other objects (keeping the original state dict structure)

Return type:

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.add_prefix_for_sharding(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
prefix: str,
)#

Prepend a given prefix to all ShardedBase objects in a given state dict in-place.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict

  • prefix (str) – prefix to be prepended

Returns:

state dict is modified in-place

Return type:

None

core.dist_checkpointing.utils.replace_prefix_for_sharding(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
old_prefix: str,
new_prefix: str,
)#

Replaces the given prefix in all sharded keys in a given state dict.

Errors out if some key does not begin with a given prefix.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict to replace keys in

  • old_prefix (str) – prefix to be replaced in each key

  • new_prefix (str) – new prefix

Returns:

state dict is modified in place

Return type:

None

core.dist_checkpointing.utils.apply_prefix_mapping(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
prefix_map: Dict[str, str],
)#

Replaces prefixes only in keys matching with one of prefixes in the map.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict to replace keys in

  • prefix_map (Dict[str, str]) – map of old->new prefixes. The first matching prefix for each key is used

Returns:

state dict is modified in place

Return type:

None

core.dist_checkpointing.utils.force_all_tensors_to_non_fp8(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
)#

Force all tensors in state dict to be non-fp8.

Parameters:

sharded_state_dict (ShardedStateDict) – sharded state dict.

core.dist_checkpointing.utils.fallback_logger#

‘getLogger(…)’

core.dist_checkpointing.utils.__LOGGER_NAME_STACK#

[]

core.dist_checkpointing.utils.__LOGGER_STACK#

[]

core.dist_checkpointing.utils.logger_stack(
name: Optional[str] = None,
current_logger: Optional[logging.Logger] = None,
)#

Context manager for managing logger and name stack.

Temporarily pushes a logger and/or name onto their respective stacks, allowing hierarchical logging and contextual logger usage. Ensures the logger stack is restored afterward.

Parameters:
  • name (str, optional) – Name to add to the logger stack. Defaults to None.

  • current_logger (logging.Logger, optional) – Logger to use. Defaults to the last logger in the stack or a fallback if none exist.

Yields:

Tuple[str, logging.Logger]

A tuple with the concatenated logger name stack and the current logger for the block.

.. rubric:: Example

with logger_stack(“scope”, logger): logger.info(“Log within ‘scope’”)

core.dist_checkpointing.utils.debug_time(
name: str,
logger: Optional[logging.Logger] = None,
threshold: float = float('-inf'),
level=None,
)#

Simple context manager for timing functions/code blocks.

Parameters:
  • name (str) – Label describing the code being measured.

  • logger (logging.Logger, optional) – Logger for output. Defaults to the lowest logger.

  • threshold (float, optional) – Minimum time (seconds) to log. Skips logging if faster.

  • level (int, optional) – Logging level. Defaults to DEBUG if threshold is unset; WARNING otherwise.

core.dist_checkpointing.utils.debug_msg(msg: str)#

Logs a debug message using the current logger stack.

This function formats and logs a debug message with the current logger and name stack, preserving context from the logger_stack context manager.

Parameters:

msg (str) – The message to be logged at the debug level.

.. rubric:: Example

debug_msg(“Checkpoint initialized”)

Logs: “scope_name Checkpoint initialized” if called within logger_stack(“scope_name”)#

core.dist_checkpointing.utils._clean_metadata_for_serialization(metadata: dict) dict#

Create a clean copy of metadata for serialization by removing non-serializable objects.

Parameters:

metadata – Original metadata dict

Returns:

Clean metadata dict suitable for serialization