core.dist_checkpointing.strategies.filesystem_async#

Storage writer for PyT Distributed format allowing asynchronous save.

Module Contents#

Classes#

FileSystemWriterAsync

Async-enabled implementation of FileSystemWriter using file I/O.

Functions#

_get_write_results_queue

_split_by_size_and_type

Splits write items according to item size into close to uniform bins.

_split_by_separation_hint

Splits buckets into those whose keys begin with the separation_hint and those whose keys do not

_item_size

Calculates size (in bytes) of a single write item.

_process_memory

Get memory used by current process.

Data#

API#

core.dist_checkpointing.strategies.filesystem_async.logger#

‘getLogger(…)’

core.dist_checkpointing.strategies.filesystem_async.WriteBucket#

None

core.dist_checkpointing.strategies.filesystem_async._results_queue#

None

core.dist_checkpointing.strategies.filesystem_async._get_write_results_queue()#
class core.dist_checkpointing.strategies.filesystem_async.FileSystemWriterAsync(
path: Union[str, os.PathLike],
*args,
separation_hint: Optional[str] = None,
use_msc: bool = False,
**kwargs,
)#

Bases: torch.distributed.checkpoint.FileSystemWriter

Async-enabled implementation of FileSystemWriter using file I/O.

This class does not spawn the async process itself but relies on an external async mechanism.

Flow:

  1. Call write_data

  2. Externally start an async process with get_save_function_and_args and its arguments.

  3. The async function writer_proxy_func calls write_preloaded_data across multiple processes.

  4. Once saving is finalized on all ranks, call super().finish with the results stored in self.writer_result.

Note: Step (3) can also be executed synchronously.

Currently, it is assumed that a separate writer is created for each ckpt save (intermediate state is stored as writer attributes).

Initialization

prepare_write_data(
plan: torch.distributed.checkpoint.planner.SavePlan,
planner: torch.distributed.checkpoint.planner.SavePlanner,
) None#

First stage of async saving. Copy data to CPU and plan the local saving.

Parameters:
  • plan (SavePlan) – save plan generated by the PyT Distributed compatible planner

  • planner (SavePlanner) – save planner used to resolve the bytes and tensor data

Returns: None, but stores the save plan in self.write_buckets

get_save_function_and_args() Tuple[Optional[Callable], Optional[Callable], List]#

Get function that saves the data to storage along with its arguments. Allows the external caller to apply the save function synchronously or asynchronously.

Returns: None (if there is nothing to write on this rank) or a tuple of: 1) the function that saves the data. 2) the function that stages the GPU tensors to a destination for async checkpointing. This function should be self-contained. 3) arguments to that function in 1).

static preload_tensors(
write_buckets: List[core.dist_checkpointing.strategies.filesystem_async.WriteBucket],
non_blocking=True,
) List[core.dist_checkpointing.strategies.filesystem_async.WriteBucket]#

Preloads tensors in state_dict to host memory via CPU memory.

Parameters:
  • write_buckets (List) – List of WriteBucket objects that define what to save in a checkpoint.

  • non_blocking (bool, optional) – knob to enable pinned D2H memcpy. Default is True.

static write_preloaded_data_multiproc(
transform_list: List[torch.distributed.checkpoint.filesystem._StorageWriterTransforms],
use_msc: bool,
rank: int,
write_buckets: List[core.dist_checkpointing.strategies.filesystem_async.WriteBucket],
global_results_queue: torch.multiprocessing.Queue,
) None#

Performs saving data to storage with multiple processes.

Starts predefined number of processes and uses 2 queues to make sure the results are complete:

  • local_results_queue - to send the actual results

  • count_queue - small queue to mark worker as completed

Using just one queue disallowed proper exception handling.

This method is meant to be run in a forked subprocess. Triggering GC during execution leads to CUDA errors (cleaning up tensors owned by the parent process). To prevent this, we disable the GC explicitly for this function with _disable_gc.

Parameters:
  • write_buckets (List[WriteBucket]) – write plan

  • global_results_queue (mp.Queue) – mp.Queue to collect Dict[List[WriteResults]] (or an Exception) from parallel write processes to the main training process

Returns: None

static write_preloaded_data(
transform_list: List[torch.distributed.checkpoint.filesystem._StorageWriterTransforms],
local_proc_idx: int,
write_bucket: core.dist_checkpointing.strategies.filesystem_async.WriteBucket,
results_queue: torch.multiprocessing.SimpleQueue,
count_queue: torch.multiprocessing.JoinableQueue,
use_fsync: bool,
**kwargs,
) None#

Performs actual data saving to storage.

Parameters:
  • local_proc_idx (int) – index of a local process that performs writing

  • write_bucket (WriteBucket) – data to write to storage

  • results_queue (mp.Queue) – queue to return the write results to the proxy checkpoint process.

  • count_queue (mp.JoinableQueue) – queue to marks worker task as completed

  • use_fsync (bool) – if True, calls os.fsync at the end of saving

Returns: None, the write result are put into the queue

abstractmethod write_data(
plan: torch.distributed.checkpoint.planner.SavePlan,
planner: torch.distributed.checkpoint.planner.SavePlanner,
) torch.futures.Future[List[torch.distributed.checkpoint.storage.WriteResult]]#

Write all items from plan.

retrieve_write_results() Union[List[torch.distributed.checkpoint.storage.WriteResult], torch.distributed.checkpoint.api.WRAPPED_EXCEPTION]#

Turn the latest dict including write results from self.results_queue into a single results lists. Includes error check.

Returns (Union(List[WriteResult], WRAPPED_EXCEPTION): the list of write results from all local processes performing the save, or a WRAPPED_EXCEPTION if an exception was raised during the writing process.

prepare_decentralized_global_plan(
local_plan: torch.distributed.checkpoint.planner.SavePlan,
) torch.distributed.checkpoint.planner.SavePlan#

Instead of assigning indices by plan order, uses PyT rank (same outcome).

Parameters:

local_plan (SavePlan) – local plan to turn to a global plan (without interactions with other ranks)

Returns:

SavePlan - locally transformed plan equivalent to the plan that would be created by the coordinator

finish(
metadata: torch.distributed.checkpoint.metadata.Metadata,
results: List[List[torch.distributed.checkpoint.storage.WriteResult]],
) None#

Finish the checkpointing process.

Parameters:
  • metadata (Metadata) – metadata to save

  • results (List[List[WriteResult]]) – results to save

prepare_local_plan(
plan: torch.distributed.checkpoint.planner.SavePlan,
) torch.distributed.checkpoint.planner.SavePlan#

Prepare the local plan for the checkpointing process.

property checkpoint_id: Union[str, os.PathLike]#

return the checkpoint_id that will be used to save the checkpoint.

classmethod validate_checkpoint_id(
checkpoint_id: Union[str, os.PathLike],
) bool#

Validate the checkpoint_id that will be used to save the checkpoint.

This method is available in PyTorch 2.3 and above.

core.dist_checkpointing.strategies.filesystem_async._split_by_size_and_type(
bins: int,
items: List[torch.distributed.checkpoint.planner.WriteItem],
) List[List[torch.distributed.checkpoint.planner.WriteItem]]#

Splits write items according to item size into close to uniform bins.

Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type, but with a fixed _item_size function.

Parameters:
  • bins (int) – numbers of bins to split to

  • items (List[WriteItem]) – list of write items

Returns (List[List[WriteItem]]): write items split to bins

core.dist_checkpointing.strategies.filesystem_async._split_by_separation_hint(
buckets: List[List[torch.distributed.checkpoint.planner.WriteItem]],
separation_hint: Optional[str] = None,
) Dict[str, List[List[torch.distributed.checkpoint.planner.WriteItem]]]#

Splits buckets into those whose keys begin with the separation_hint and those whose keys do not

Parameters:
  • buckets (List[List[WriteItem]]) – buckets to split

  • separation_hint (Optional[str]) – optional prefix to split on

Returns (Dict[str, List[List[WriteItem]]]): a dictionary mapping the prefix to the relevant buckets

core.dist_checkpointing.strategies.filesystem_async._item_size(
item: torch.distributed.checkpoint.planner.WriteItem,
) int#

Calculates size (in bytes) of a single write item.

Same as torch.distributed.checkpoint.filesystem._item_size, but fixes computing chunk size (with item.tensor_data.chunk.sizes)

Parameters:

item (WriteItem) – write item to compute the size of

Returns (int): size of an item in bytes

core.dist_checkpointing.strategies.filesystem_async._process_memory() int#

Get memory used by current process.

Returns (int): memory used by current process