core.dist_checkpointing.strategies.state_dict_saver#

State dict saver for PyT Distributed format allowing asynchronous save.

Module Contents#

Functions#

_compare_dataclasses

save_state_dict_async_plan

First stage of saving a state dict to storage.

verify_global_md_reuse

Verifies that global metadata reuse is possible by checking the loaded plans from the checkpoint are consistent, which means we have the same settings when resuming training.

save_state_dict_async_finalize

Finalization of save_state_dict_async_plan.

Data#

API#

core.dist_checkpointing.strategies.state_dict_saver.logger#

‘getLogger(…)’

core.dist_checkpointing.strategies.state_dict_saver._compare_dataclasses(obj1, obj2)#
core.dist_checkpointing.strategies.state_dict_saver.save_state_dict_async_plan(
state_dict: torch.distributed.checkpoint.metadata.STATE_DICT_TYPE,
storage_writer: core.dist_checkpointing.strategies.filesystem_async.FileSystemWriterAsync,
process_group: Optional[torch.distributed.ProcessGroup] = None,
coordinator_rank: int = 0,
planner: Optional[Union[torch.distributed.checkpoint.planner.SavePlanner, core.dist_checkpointing.strategies.torch.MCoreSavePlanner]] = None,
cached_ckpt_structure: Optional[Tuple[torch.distributed.checkpoint.planner.SavePlan, torch.distributed.checkpoint.planner.SavePlan, bool]] = None,
loaded_all_plans: Optional[List[torch.distributed.checkpoint.planner.SavePlan]] = None,
) Tuple[Tuple[core.dist_checkpointing.strategies.filesystem_async.FileSystemWriterAsync, Union[torch.distributed.checkpoint.metadata.Metadata, None], torch.distributed.checkpoint.utils._DistWrapper], torch.distributed.checkpoint.planner.SavePlan, bool]#

First stage of saving a state dict to storage.

This is an async adjustment of torch.distributed.checkpoint.state_dict_saver. In order to support async save, saving should be split into three parts:

  1. Planning

  2. Actual saving

  3. Finalization

Out of these, step (2) must happen asynchronously. The first step is realized with this function.

The planning part consists of several steps, described here: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner

Parameters:
  • state_dict (STATE_DICT_TYPE) – state dict to save

  • storage_writer (FileSystemWriterAsync) – in current version only an instance of FileSystemWriterAsync

  • process_group (dist.ProcessGroup, optional) – process group used for save planning

  • coordinator_rank (int, optional) – coordinator rank for planning. Defaults to 0.

  • planner (SavePlanner, optional) – save planner for torch.distributed.checkpoint format

  • cached_ckpt_structure (Tuple[SavePlan, SavePlan, bool], Optional) – Each object of this tuple will be used in the order as following cached_central_plan (SavePlan): a globally coordinated save plan cached in the previous iteration cached_local_plan (SavePlan): a local plan cached in the previous iteration validated_cache_reuse (bool): boolean value to tell global_metadata and planning dict is consistent over iterations

Returns: Tuple of: - storage writer (the one passed as input) - metadata from planning (or None if we reuse cached global metadata) - distributed wrapper used for planning The return value of this function should be passed as an input to save_state_dict_async_finalize and cached_plan to skip reduce_scatter at planning.

core.dist_checkpointing.strategies.state_dict_saver.verify_global_md_reuse(
loaded_all_plans: List[torch.distributed.checkpoint.planner.SavePlan],
local_plan: torch.distributed.checkpoint.planner.SavePlan,
rank: int,
dist_wrapper: torch.distributed.checkpoint.utils._DistWrapper,
) bool#

Verifies that global metadata reuse is possible by checking the loaded plans from the checkpoint are consistent, which means we have the same settings when resuming training.

Parameters:
  • loaded_all_plans – List[SavePlan], The loaded plans from the checkpoint (stored in checkpoint metadata).

  • local_plan – SavePlan, The local save plan.

  • rank – Current process rank.

  • dist_wrapper (_DistWrapper) – distributed wrapper created during planning

Returns: True iff the global metadata reuse is possible.

core.dist_checkpointing.strategies.state_dict_saver.save_state_dict_async_finalize(
storage_writer: core.dist_checkpointing.strategies.filesystem_async.FileSystemWriterAsync,
global_metadata: torch.distributed.checkpoint.metadata.Metadata,
dist_wrapper: torch.distributed.checkpoint.utils._DistWrapper,
) None#

Finalization of save_state_dict_async_plan.

The input arguments are the same as the save_state_dict_async_plan output, the write_results are retrieved from the storage_writer.

Parameters:
  • storage_writer (FileSystemWriterAsync) – storage writer used for planning

  • global_metadata (Metadata) – metadata created during planning

  • dist_wrapper (_DistWrapper) – distributed wrapper created during planning

Returns: None