core.dist_checkpointing.strategies.torch#

Strategies using PyTorch distributed.checkpoint as an underlying format.

Module Contents#

Classes#

MCoreMetadata

MCoreSavePlan

MCoreSavePlanner

Differs with the default planner by saving BytesIO objects on all ranks.

MCoreLoadPlanner

Adds global shape validation to the default planner.

TorchDistSaveShardedStrategy

Async save strategy for the PyT Distributed format.

TorchDistLoadShardedStrategy

Basic load strategy for the PyT Distributed format.

Functions#

register_default_torch_strategies

Register default strategies related to PyT Distributed backend.

flatten_state_dict

Flattens state dict into a single level dict.

sharded_tensor_to_torch_sharded_tensor

Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks.

mcore_to_pyt_state_dict

Convert state dict with ShardedTensors and ShardedObjects to state dict compatible with PyT Dist format.

_unwrap_pyt_sharded_tensor

Unwrap tensor from PyT ShardedTensor instance.

_replace_state_dict_keys_with_sharded_keys

Group ShardedBase objects by keys and return mappings required for recreating the original dict.

_replace_sharded_keys_with_state_dict_keys

Inverse of _replace_state_dict_keys_with_sharded_keys.

_restore_dict_types

Recursively update x keys, based on keys_template.

_get_filesystem_reader

Data#

API#

core.dist_checkpointing.strategies.torch.MSC_PREFIX#

‘msc://’

core.dist_checkpointing.strategies.torch._metadata_fn: str#

‘.metadata’

class core.dist_checkpointing.strategies.torch.MCoreMetadata#
class core.dist_checkpointing.strategies.torch.MCoreSavePlan#
core.dist_checkpointing.strategies.torch.register_default_torch_strategies()#

Register default strategies related to PyT Distributed backend.

core.dist_checkpointing.strategies.torch.logger#

‘getLogger(…)’

core.dist_checkpointing.strategies.torch.flatten_state_dict(
state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
) Tuple[core.dist_checkpointing.mapping.ShardedStateDict, Dict[str, torch.distributed.checkpoint._traverse.OBJ_PATH]]#

Flattens state dict into a single level dict.

It’s a copy of torch.distributed.checkpoint._nested_dict.flatten_state_dict which also accepts ShardedBase tensors as terminal objects

Parameters:

state_dict (ShardedStateDict) – state dict to be flattened

Returns (tuple): flattened state dict and a mapping allowing to recreate the original one

core.dist_checkpointing.strategies.torch.sharded_tensor_to_torch_sharded_tensor(
sh_tens: List[core.dist_checkpointing.mapping.ShardedTensor],
rank: Optional[int] = None,
load_legacy_1d_flatten_tensors: bool = False,
) torch.distributed._shard.sharded_tensor.ShardedTensor#

Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks.

On high-level, this function follows the logic of torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor. Additionally, it saves prepend_axis_num and has_flattened_range (specific to MCore) as attributes for further restoration in _unwrap_pyt_sharded_tensor.

NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor. The only local irregularities could be introduced with a flattened_range attribute.

This function handles 2 different type of ShardedTensors:

  1. Non-flat regular ShardedTensors (not has_flattened_range)

  2. N-D flattened ShardedTensors (has_flattened_range)

(1) type are saved according to their original shape. Type (2) however requires global shape adjustment for efficiency: we treat [X, Y, Z] global shape tensor with local shape [x, y, z] as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis partitioned according to flattened_range slices. This will need special handling while resharding.

Parameters:
  • sh_tens (List[ShardedTensor]) – list of sharded tensors to convert

  • rank (int, optional) – current process rank passed to PyT ShardedTensor. If None, assumes rank in the default pg.

  • load_legacy_1d_flatten_tensors (bool, optional) – flag indicating if 1-D flattened tensors should be loaded in a legacy way. Defaults to False.

Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards.

core.dist_checkpointing.strategies.torch.mcore_to_pyt_state_dict(
state_dict: Dict[str, List[core.dist_checkpointing.mapping.ShardedBase]],
is_loading: bool = False,
init_device: torch.device = torch.device('cpu'),
load_legacy_1d_flatten_tensors: bool = False,
) Dict[str, Union[torch.distributed._shard.sharded_tensor.ShardedTensor, io.BytesIO]]#

Convert state dict with ShardedTensors and ShardedObjects to state dict compatible with PyT Dist format.

Operates in-place and returns the original state dict.

Parameters:
  • state_dict (Dict[str, List[ShardedBase]]) – flattened state dict, where values are lists of either ShardedTensor or ShardedObjects.

  • is_loading (bool, optional) – flag indicating if loading or saving. Defaults to False.

  • init_device (torch.device, optional) – device to initialize potentially missing tensors during loading. Defaults to ‘cpu’.

Returns (Dict[str, Union[TorchShardedTensor, io.BytesIO]]): original dictionary with values converted either into PyT ShardedTensors or io.BytesIO.

core.dist_checkpointing.strategies.torch._unwrap_pyt_sharded_tensor(
sh_ten: Union[torch.distributed._shard.sharded_tensor.ShardedTensor, core.dist_checkpointing.strategies.checkpointable.CheckpointableShardedTensor, core.dist_checkpointing.strategies.checkpointable.LocalShardsContainer, Any],
) Union[List[torch.Tensor], Any]#

Unwrap tensor from PyT ShardedTensor instance.

If prepend_axis_num was non-zero (which is specific to MCore ShardedTensor) then the tensor has additional singleton dimensions which should be squeezed.

core.dist_checkpointing.strategies.torch._replace_state_dict_keys_with_sharded_keys(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
keep_only_main_replica: bool = False,
) Tuple[Dict[str, List[core.dist_checkpointing.mapping.ShardedBase]], torch.distributed.checkpoint._nested_dict.FLATTEN_MAPPING, Dict[str, List[str]]]#

Group ShardedBase objects by keys and return mappings required for recreating the original dict.

core.dist_checkpointing.strategies.torch._replace_sharded_keys_with_state_dict_keys(
state_dict: Dict[str, List[Union[torch.Tensor, io.BytesIO]]],
flat_mapping: torch.distributed.checkpoint._nested_dict.FLATTEN_MAPPING,
rename_mapping: Dict[str, List[str]],
)#

Inverse of _replace_state_dict_keys_with_sharded_keys.

core.dist_checkpointing.strategies.torch._restore_dict_types(
x: Union[dict, list, Any],
keys_template: Union[dict, list, Any],
)#

Recursively update x keys, based on keys_template.

class core.dist_checkpointing.strategies.torch.MCoreSavePlanner(
*args,
dedup_replicated_tensors: Optional[bool] = None,
can_run_decentralized_global_plan: bool = True,
**kwargs,
)#

Bases: torch.distributed.checkpoint.DefaultSavePlanner

Differs with the default planner by saving BytesIO objects on all ranks.

In the integration of MCore with PyT Distributed format, BytesIO objects come from ShardedObjects, which should be treated as separate objects on each rank (not common on all ranks).

Also, the objects are already packed in io.BytesIO, so no need to redo it in transform_object.

Initialization

create_local_plan() torch.distributed.checkpoint.SavePlan#

Adds IOBytes write request on non-coordinator ranks.

create_decentralized_global_plan(
local_plan: torch.distributed.checkpoint.SavePlan,
) torch.distributed.checkpoint.SavePlan#

Nothing to do, just some checks.

Parameters:

local_plan (SavePlan) – local plan to turn to a global plan (without interactions with other ranks)

Returns:

SavePlan - locally transformed plan equivalent to the plan that would be created by the coordinator

transform_object(
write_item: torch.distributed.checkpoint.WriteItem,
object: Any,
)#

Make no transformations - bytes objects are already serialized.

class core.dist_checkpointing.strategies.torch.MCoreLoadPlanner(
*args,
shapes_validation_sharded_tensors: Iterable[core.dist_checkpointing.mapping.ShardedTensor] = (),
allow_shape_mismatch_sharded_tensors: Optional[Dict[str, core.dist_checkpointing.mapping.ShardedTensor]] = None,
**kwargs,
)#

Bases: torch.distributed.checkpoint.DefaultLoadPlanner

Adds global shape validation to the default planner.

If global shape validation can be ignored (shouldn’t!), the default load planner can be used.

Initialization

_validate_global_shapes(metadata, sharded_tensors)#
_temporarily_bypass_shape_validation()#

Temporarily set the size of tensors to their expected shapes to bypass DCP shape validation. This is used when validating the shapes during local plan creation.

create_local_plan() torch.distributed.checkpoint.LoadPlan#

Runs additional shapes validation.

resolve_tensor(read_item: torch.distributed.checkpoint.ReadItem)#

Override to add FP8 support.

Narrowing the Float8Tensor can create incontiguous tensors and there are no copy kernels for such cases. This method creates a contiguous FP8 tensors so that the subsequent copy_ in FileSystemReader succeeds. Note that this requires tracking the original tensor (as self._intermediate_read_item_and_target attribute) and restoring it in commit_tensor method.

commit_tensor(
read_item: torch.distributed.checkpoint.ReadItem,
tensor: torch.Tensor,
) None#

Restores the original FP8 tensor saved in resolve_tensor.

class core.dist_checkpointing.strategies.torch.TorchDistSaveShardedStrategy(
backend: str,
version: int,
keep_only_main_replica: bool = True,
thread_count: int = 2,
cached_metadata: bool = False,
separation_hint: Optional[str] = None,
)#

Bases: core.dist_checkpointing.strategies.base.AsyncSaveShardedStrategy

Async save strategy for the PyT Distributed format.

The idea is to translate MCore ShardedTensors into PyT ShardedTensors and use the async-adjusted torch.distributed.checkpoint saving mechanism provided by the FileSystemWriterAsync writer.

Initialization

Adds parameters specific to PyT Distributed format

Parameters:
  • backend (str) – format backend string

  • version (int) – format version

  • keep_only_main_replica (bool, optional) – PyT Distributed has a mechanism for deduplication, but replica_id aware deduplication is more coherent. Default is True (recommended to keep it).

  • thread_count (int, optional) – threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).

  • cached_metadata (bool, optional) – Enables using cached global metadata to avoid gathering local metadata every checkpointing invocation

  • separation_hint (str, optional) – If provided, all tensors whose keys have this prefix will be saved to a separate file.

async_save(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: pathlib.Path,
) core.dist_checkpointing.strategies.async_utils.AsyncRequest#

Translates MCore ShardedTensors to PyT ShardedTensors & saves in PyT Distributed format.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict to save

  • checkpoint_dir (Path) – checkpoint directory

Returns: None

_get_save_and_finalize_callbacks(
writer,
save_state_dict_ret,
) core.dist_checkpointing.strategies.async_utils.AsyncRequest#
can_handle_sharded_objects()#
core.dist_checkpointing.strategies.torch._get_filesystem_reader(
checkpoint_dir: Union[str, pathlib.Path],
cache_metadata: bool = False,
) torch.distributed.checkpoint.FileSystemReader#
class core.dist_checkpointing.strategies.torch.TorchDistLoadShardedStrategy#

Bases: core.dist_checkpointing.strategies.base.LoadShardedStrategy

Basic load strategy for the PyT Distributed format.

Initialization

load(
sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
checkpoint_dir: pathlib.Path,
) core.dist_checkpointing.mapping.StateDict#

Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict with mapping information to instruct loading

  • checkpoint_dir (Path) – checkpoint directory

Returns: loaded state dict

load_tensors_metadata(
checkpoint_dir: pathlib.Path,
metadata: torch.distributed.checkpoint.metadata.Metadata = None,
)#

Uses tensors metadata stored in the metadata file.

load_sharded_metadata(
checkpoint_dir: pathlib.Path,
) core.dist_checkpointing.mapping.ShardedStateDict#

Uses tensors and objects metadata stored in the metadata file.

remove_sharded_tensors(checkpoint_dir: str, key_prefix: str)#

Removes checkpoint files whose keys have the given prefix.

Performs the following steps:

  1. checks whether there are files that start with the key_prefix

  2. loads metadata

  3. removes all entries from the metadata that start with the key_prefix

  4. resaves the new metadata and removes the old metadata

  5. removes the relevant files

can_handle_sharded_objects()#
check_backend_compatibility(loaded_version)#
check_version_compatibility(loaded_version)#