core.dist_checkpointing.strategies.filesystem_async#
Storage writer for PyT Distributed format allowing asynchronous save.
Module Contents#
Classes#
Async-enabled implementation of FileSystemWriter using file I/O. |
Functions#
Splits write items according to item size into close to uniform bins. |
|
Splits buckets into those whose keys begin with the separation_hint and those whose keys do not |
|
Calculates size (in bytes) of a single write item. |
|
Get memory used by current process. |
Data#
API#
- core.dist_checkpointing.strategies.filesystem_async.logger#
‘getLogger(…)’
- core.dist_checkpointing.strategies.filesystem_async.WriteBucket#
None
- core.dist_checkpointing.strategies.filesystem_async._results_queue#
None
- core.dist_checkpointing.strategies.filesystem_async._get_write_results_queue()#
- class core.dist_checkpointing.strategies.filesystem_async.FileSystemWriterAsync(
- path: Union[str, os.PathLike],
- *args,
- separation_hint: Optional[str] = None,
- use_msc: bool = False,
- **kwargs,
Bases:
torch.distributed.checkpoint.FileSystemWriterAsync-enabled implementation of FileSystemWriter using file I/O.
This class does not spawn the async process itself but relies on an external async mechanism.
Flow:
Call
write_dataExternally start an async process with
get_save_function_and_argsand its arguments.The async function
writer_proxy_funccallswrite_preloaded_dataacross multiple processes.Once saving is finalized on all ranks, call
super().finishwith the results stored inself.writer_result.
Note: Step (3) can also be executed synchronously.
Currently, it is assumed that a separate writer is created for each ckpt save (intermediate state is stored as writer attributes).
Initialization
- prepare_write_data(
- plan: torch.distributed.checkpoint.planner.SavePlan,
- planner: torch.distributed.checkpoint.planner.SavePlanner,
First stage of async saving. Copy data to CPU and plan the local saving.
- Parameters:
plan (SavePlan) – save plan generated by the PyT Distributed compatible planner
planner (SavePlanner) – save planner used to resolve the bytes and tensor data
Returns: None, but stores the save plan in
self.write_buckets
- get_save_function_and_args() Tuple[Optional[Callable], Optional[Callable], List]#
Get function that saves the data to storage along with its arguments. Allows the external caller to apply the save function synchronously or asynchronously.
Returns: None (if there is nothing to write on this rank) or a tuple of: 1) the function that saves the data. 2) the function that stages the GPU tensors to a destination for async checkpointing. This function should be self-contained. 3) arguments to that function in 1).
- static preload_tensors(
- write_buckets: List[core.dist_checkpointing.strategies.filesystem_async.WriteBucket],
- non_blocking=True,
Preloads tensors in
state_dictto host memory via CPU memory.- Parameters:
write_buckets (List) – List of
WriteBucketobjects that define what to save in a checkpoint.non_blocking (bool, optional) – knob to enable pinned D2H memcpy. Default is True.
- static write_preloaded_data_multiproc(
- transform_list: List[torch.distributed.checkpoint.filesystem._StorageWriterTransforms],
- use_msc: bool,
- rank: int,
- write_buckets: List[core.dist_checkpointing.strategies.filesystem_async.WriteBucket],
- global_results_queue: torch.multiprocessing.Queue,
Performs saving data to storage with multiple processes.
Starts predefined number of processes and uses 2 queues to make sure the results are complete:
local_results_queue - to send the actual results
count_queue - small queue to mark worker as completed
Using just one queue disallowed proper exception handling.
This method is meant to be run in a forked subprocess. Triggering GC during execution leads to CUDA errors (cleaning up tensors owned by the parent process). To prevent this, we disable the GC explicitly for this function with _disable_gc.
- Parameters:
write_buckets (List[WriteBucket]) – write plan
global_results_queue (mp.Queue) – mp.Queue to collect Dict[List[WriteResults]] (or an Exception) from parallel write processes to the main training process
Returns: None
- static write_preloaded_data(
- transform_list: List[torch.distributed.checkpoint.filesystem._StorageWriterTransforms],
- local_proc_idx: int,
- write_bucket: core.dist_checkpointing.strategies.filesystem_async.WriteBucket,
- results_queue: torch.multiprocessing.SimpleQueue,
- count_queue: torch.multiprocessing.JoinableQueue,
- use_fsync: bool,
- **kwargs,
Performs actual data saving to storage.
- Parameters:
local_proc_idx (int) – index of a local process that performs writing
write_bucket (WriteBucket) – data to write to storage
results_queue (mp.Queue) – queue to return the write results to the proxy checkpoint process.
count_queue (mp.JoinableQueue) – queue to marks worker task as completed
use_fsync (bool) – if True, calls os.fsync at the end of saving
Returns: None, the write result are put into the
queue
- abstractmethod write_data(
- plan: torch.distributed.checkpoint.planner.SavePlan,
- planner: torch.distributed.checkpoint.planner.SavePlanner,
Write all items from
plan.
- retrieve_write_results() Union[List[torch.distributed.checkpoint.storage.WriteResult], torch.distributed.checkpoint.api.WRAPPED_EXCEPTION]#
Turn the latest dict including write results from
self.results_queueinto a single results lists. Includes error check.Returns (Union(List[WriteResult], WRAPPED_EXCEPTION): the list of write results from all local processes performing the save, or a WRAPPED_EXCEPTION if an exception was raised during the writing process.
- prepare_decentralized_global_plan(
- local_plan: torch.distributed.checkpoint.planner.SavePlan,
Instead of assigning indices by plan order, uses PyT rank (same outcome).
- Parameters:
local_plan (SavePlan) – local plan to turn to a global plan (without interactions with other ranks)
- Returns:
SavePlan - locally transformed plan equivalent to the plan that would be created by the coordinator
- finish(
- metadata: torch.distributed.checkpoint.metadata.Metadata,
- results: List[List[torch.distributed.checkpoint.storage.WriteResult]],
Finish the checkpointing process.
- Parameters:
metadata (Metadata) – metadata to save
results (List[List[WriteResult]]) – results to save
- prepare_local_plan(
- plan: torch.distributed.checkpoint.planner.SavePlan,
Prepare the local plan for the checkpointing process.
- property checkpoint_id: Union[str, os.PathLike]#
return the checkpoint_id that will be used to save the checkpoint.
- classmethod validate_checkpoint_id(
- checkpoint_id: Union[str, os.PathLike],
Validate the checkpoint_id that will be used to save the checkpoint.
This method is available in PyTorch 2.3 and above.
- core.dist_checkpointing.strategies.filesystem_async._split_by_size_and_type(
- bins: int,
- items: List[torch.distributed.checkpoint.planner.WriteItem],
Splits write items according to item size into close to uniform bins.
Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type, but with a fixed _item_size function.
- Parameters:
bins (int) – numbers of bins to split to
items (List[WriteItem]) – list of write items
Returns (List[List[WriteItem]]): write items split to bins
- core.dist_checkpointing.strategies.filesystem_async._split_by_separation_hint(
- buckets: List[List[torch.distributed.checkpoint.planner.WriteItem]],
- separation_hint: Optional[str] = None,
Splits buckets into those whose keys begin with the separation_hint and those whose keys do not
- Parameters:
buckets (List[List[WriteItem]]) – buckets to split
separation_hint (Optional[str]) – optional prefix to split on
Returns (Dict[str, List[List[WriteItem]]]): a dictionary mapping the prefix to the relevant buckets
- core.dist_checkpointing.strategies.filesystem_async._item_size(
- item: torch.distributed.checkpoint.planner.WriteItem,
Calculates size (in bytes) of a single write item.
Same as torch.distributed.checkpoint.filesystem._item_size, but fixes computing chunk size (with item.tensor_data.chunk.sizes)
- Parameters:
item (WriteItem) – write item to compute the size of
Returns (int): size of an item in bytes
- core.dist_checkpointing.strategies.filesystem_async._process_memory() int#
Get memory used by current process.
Returns (int): memory used by current process