core.dist_checkpointing.strategies.torch#
Strategies using PyTorch distributed.checkpoint as an underlying format.
Module Contents#
Classes#
Differs with the default planner by saving BytesIO objects on all ranks. |
|
Adds global shape validation to the default planner. |
|
Async save strategy for the PyT Distributed format. |
|
Basic load strategy for the PyT Distributed format. |
Functions#
Register default strategies related to PyT Distributed backend. |
|
Flattens state dict into a single level dict. |
|
Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks. |
|
Convert state dict with ShardedTensors and ShardedObjects to state dict compatible with PyT Dist format. |
|
Unwrap tensor from PyT ShardedTensor instance. |
|
Group ShardedBase objects by keys and return mappings required for recreating the original dict. |
|
Inverse of _replace_state_dict_keys_with_sharded_keys. |
|
Recursively update |
|
Data#
API#
- core.dist_checkpointing.strategies.torch.MSC_PREFIX#
‘msc://’
- core.dist_checkpointing.strategies.torch._metadata_fn: str#
‘.metadata’
- class core.dist_checkpointing.strategies.torch.MCoreMetadata#
- class core.dist_checkpointing.strategies.torch.MCoreSavePlan#
- core.dist_checkpointing.strategies.torch.register_default_torch_strategies()#
Register default strategies related to PyT Distributed backend.
- core.dist_checkpointing.strategies.torch.logger#
‘getLogger(…)’
- core.dist_checkpointing.strategies.torch.flatten_state_dict(
- state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
Flattens state dict into a single level dict.
It’s a copy of torch.distributed.checkpoint._nested_dict.flatten_state_dict which also accepts ShardedBase tensors as terminal objects
- Parameters:
state_dict (ShardedStateDict) – state dict to be flattened
Returns (tuple): flattened state dict and a mapping allowing to recreate the original one
- core.dist_checkpointing.strategies.torch.sharded_tensor_to_torch_sharded_tensor(
- sh_tens: List[core.dist_checkpointing.mapping.ShardedTensor],
- rank: Optional[int] = None,
- load_legacy_1d_flatten_tensors: bool = False,
Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks.
On high-level, this function follows the logic of torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor. Additionally, it saves
prepend_axis_numandhas_flattened_range(specific to MCore) as attributes for further restoration in_unwrap_pyt_sharded_tensor.NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor. The only local irregularities could be introduced with a
flattened_rangeattribute.This function handles 2 different type of ShardedTensors:
Non-flat regular ShardedTensors (
not has_flattened_range)N-D flattened ShardedTensors (
has_flattened_range)
(1) type are saved according to their original shape. Type (2) however requires global shape adjustment for efficiency: we treat [X, Y, Z] global shape tensor with local shape [x, y, z] as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis partitioned according to
flattened_rangeslices. This will need special handling while resharding.- Parameters:
sh_tens (List[ShardedTensor]) – list of sharded tensors to convert
rank (int, optional) – current process rank passed to PyT ShardedTensor. If None, assumes rank in the default pg.
load_legacy_1d_flatten_tensors (bool, optional) – flag indicating if 1-D flattened tensors should be loaded in a legacy way. Defaults to False.
Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards.
- core.dist_checkpointing.strategies.torch.mcore_to_pyt_state_dict(
- state_dict: Dict[str, List[core.dist_checkpointing.mapping.ShardedBase]],
- is_loading: bool = False,
- init_device: torch.device = torch.device('cpu'),
- load_legacy_1d_flatten_tensors: bool = False,
Convert state dict with ShardedTensors and ShardedObjects to state dict compatible with PyT Dist format.
Operates in-place and returns the original state dict.
- Parameters:
state_dict (Dict[str, List[ShardedBase]]) – flattened state dict, where values are lists of either ShardedTensor or ShardedObjects.
is_loading (bool, optional) – flag indicating if loading or saving. Defaults to False.
init_device (torch.device, optional) – device to initialize potentially missing tensors during loading. Defaults to ‘cpu’.
Returns (Dict[str, Union[TorchShardedTensor, io.BytesIO]]): original dictionary with values converted either into PyT ShardedTensors or io.BytesIO.
- core.dist_checkpointing.strategies.torch._unwrap_pyt_sharded_tensor(
- sh_ten: Union[torch.distributed._shard.sharded_tensor.ShardedTensor, core.dist_checkpointing.strategies.checkpointable.CheckpointableShardedTensor, core.dist_checkpointing.strategies.checkpointable.LocalShardsContainer, Any],
Unwrap tensor from PyT ShardedTensor instance.
If
prepend_axis_numwas non-zero (which is specific to MCore ShardedTensor) then the tensor has additional singleton dimensions which should be squeezed.
- core.dist_checkpointing.strategies.torch._replace_state_dict_keys_with_sharded_keys(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- keep_only_main_replica: bool = False,
Group ShardedBase objects by keys and return mappings required for recreating the original dict.
- core.dist_checkpointing.strategies.torch._replace_sharded_keys_with_state_dict_keys(
- state_dict: Dict[str, List[Union[torch.Tensor, io.BytesIO]]],
- flat_mapping: torch.distributed.checkpoint._nested_dict.FLATTEN_MAPPING,
- rename_mapping: Dict[str, List[str]],
Inverse of _replace_state_dict_keys_with_sharded_keys.
- core.dist_checkpointing.strategies.torch._restore_dict_types(
- x: Union[dict, list, Any],
- keys_template: Union[dict, list, Any],
Recursively update
xkeys, based onkeys_template.
- class core.dist_checkpointing.strategies.torch.MCoreSavePlanner(
- *args,
- dedup_replicated_tensors: Optional[bool] = None,
- can_run_decentralized_global_plan: bool = True,
- **kwargs,
Bases:
torch.distributed.checkpoint.DefaultSavePlannerDiffers with the default planner by saving BytesIO objects on all ranks.
In the integration of MCore with PyT Distributed format, BytesIO objects come from ShardedObjects, which should be treated as separate objects on each rank (not common on all ranks).
Also, the objects are already packed in io.BytesIO, so no need to redo it in transform_object.
Initialization
- create_local_plan() torch.distributed.checkpoint.SavePlan#
Adds IOBytes write request on non-coordinator ranks.
- create_decentralized_global_plan(
- local_plan: torch.distributed.checkpoint.SavePlan,
Nothing to do, just some checks.
- Parameters:
local_plan (SavePlan) – local plan to turn to a global plan (without interactions with other ranks)
- Returns:
SavePlan - locally transformed plan equivalent to the plan that would be created by the coordinator
- transform_object(
- write_item: torch.distributed.checkpoint.WriteItem,
- object: Any,
Make no transformations - bytes objects are already serialized.
- class core.dist_checkpointing.strategies.torch.MCoreLoadPlanner(
- *args,
- shapes_validation_sharded_tensors: Iterable[core.dist_checkpointing.mapping.ShardedTensor] = (),
- allow_shape_mismatch_sharded_tensors: Optional[Dict[str, core.dist_checkpointing.mapping.ShardedTensor]] = None,
- **kwargs,
Bases:
torch.distributed.checkpoint.DefaultLoadPlannerAdds global shape validation to the default planner.
If global shape validation can be ignored (shouldn’t!), the default load planner can be used.
Initialization
- _validate_global_shapes(metadata, sharded_tensors)#
- _temporarily_bypass_shape_validation()#
Temporarily set the size of tensors to their expected shapes to bypass DCP shape validation. This is used when validating the shapes during local plan creation.
- create_local_plan() torch.distributed.checkpoint.LoadPlan#
Runs additional shapes validation.
- resolve_tensor(read_item: torch.distributed.checkpoint.ReadItem)#
Override to add FP8 support.
Narrowing the Float8Tensor can create incontiguous tensors and there are no
copykernels for such cases. This method creates a contiguous FP8 tensors so that the subsequentcopy_in FileSystemReader succeeds. Note that this requires tracking the original tensor (asself._intermediate_read_item_and_targetattribute) and restoring it incommit_tensormethod.
- commit_tensor(
- read_item: torch.distributed.checkpoint.ReadItem,
- tensor: torch.Tensor,
Restores the original FP8 tensor saved in
resolve_tensor.
- class core.dist_checkpointing.strategies.torch.TorchDistSaveShardedStrategy(
- backend: str,
- version: int,
- keep_only_main_replica: bool = True,
- thread_count: int = 2,
- cached_metadata: bool = False,
- separation_hint: Optional[str] = None,
Bases:
core.dist_checkpointing.strategies.base.AsyncSaveShardedStrategyAsync save strategy for the PyT Distributed format.
The idea is to translate MCore ShardedTensors into PyT ShardedTensors and use the async-adjusted torch.distributed.checkpoint saving mechanism provided by the FileSystemWriterAsync writer.
Initialization
Adds parameters specific to PyT Distributed format
- Parameters:
backend (str) – format backend string
version (int) – format version
keep_only_main_replica (bool, optional) – PyT Distributed has a mechanism for deduplication, but replica_id aware deduplication is more coherent. Default is True (recommended to keep it).
thread_count (int, optional) – threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).
cached_metadata (bool, optional) – Enables using cached global metadata to avoid gathering local metadata every checkpointing invocation
separation_hint (str, optional) – If provided, all tensors whose keys have this prefix will be saved to a separate file.
- async_save(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: pathlib.Path,
Translates MCore ShardedTensors to PyT ShardedTensors & saves in PyT Distributed format.
- Parameters:
sharded_state_dict (ShardedStateDict) – sharded state dict to save
checkpoint_dir (Path) – checkpoint directory
Returns: None
- _get_save_and_finalize_callbacks(
- writer,
- save_state_dict_ret,
- can_handle_sharded_objects()#
- core.dist_checkpointing.strategies.torch._get_filesystem_reader(
- checkpoint_dir: Union[str, pathlib.Path],
- cache_metadata: bool = False,
- class core.dist_checkpointing.strategies.torch.TorchDistLoadShardedStrategy#
Bases:
core.dist_checkpointing.strategies.base.LoadShardedStrategyBasic load strategy for the PyT Distributed format.
Initialization
- load(
- sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- checkpoint_dir: pathlib.Path,
Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt.
- Parameters:
sharded_state_dict (ShardedStateDict) – sharded state dict with mapping information to instruct loading
checkpoint_dir (Path) – checkpoint directory
Returns: loaded state dict
- load_tensors_metadata(
- checkpoint_dir: pathlib.Path,
- metadata: torch.distributed.checkpoint.metadata.Metadata = None,
Uses tensors metadata stored in the metadata file.
- load_sharded_metadata(
- checkpoint_dir: pathlib.Path,
Uses tensors and objects metadata stored in the metadata file.
- remove_sharded_tensors(checkpoint_dir: str, key_prefix: str)#
Removes checkpoint files whose keys have the given prefix.
Performs the following steps:
checks whether there are files that start with the key_prefix
loads metadata
removes all entries from the metadata that start with the key_prefix
resaves the new metadata and removes the old metadata
removes the relevant files
- can_handle_sharded_objects()#
- check_backend_compatibility(loaded_version)#
- check_version_compatibility(loaded_version)#