core.dist_checkpointing.strategies.async_utils#
This module provides an async utilities which allow to start a checkpoint save process in the background.
Module Contents#
Classes#
Represents an async request that needs to be scheduled for execution. |
|
Wrapper around mp.Process that ensures correct semantic of distributed finalization. |
|
Wrapper around mp.Process that ensures correct semantic of distributed finalization. |
|
Wrapper around mp.Process that ensures correct semantic of distributed finalization. |
|
Helper to represent an active async call. |
|
Manages a queue of async calls. |
Functions#
Temporarily disables GC. |
Data#
API#
- core.dist_checkpointing.strategies.async_utils.logger#
‘getLogger(…)’
- core.dist_checkpointing.strategies.async_utils._disable_gc()#
Temporarily disables GC.
- class core.dist_checkpointing.strategies.async_utils.AsyncRequest#
Bases:
typing.NamedTupleRepresents an async request that needs to be scheduled for execution.
- Parameters:
async_fn (Callable, optional) – async function to call. None represents noop.
async_fn_args (Tuple) – args to pass to
async_fn.finalize_fns (List[Callable]) – list of functions to call to finalize the request. These functions will be called synchronously after
async_fnis done on all ranks.async_fn_kwargs (Tuple) – kwargs to pass to
async_fn.preload_fn (Callable) – preload function to stage tensors from GPU to Host. This should be self-contained with a proper list of arguments with
partial.is_frozen (Bool) – a flag to indicate this async request can be modified or not.
call_idx (int) – index variable used to order async requests for synchronization in preloading and writing tensors on the async caller
- async_fn: Optional[Callable]#
None
- async_fn_args: Tuple#
None
- finalize_fns: List[Callable]#
None
- async_fn_kwargs: Dict#
None
- preload_fn: Callable#
None
- is_frozen: bool#
False
- call_idx: int#
0
- add_finalize_fn(fn: Callable) None#
Adds a new finalize function to the request.
- Parameters:
fn (Callable) – function to add to the async request. This function will be called after existing finalization functions.
- Returns:
None
- execute_sync() None#
Helper to synchronously execute the request.
This logic is equivalent to what should happen in case of the async call.
- freeze() core.dist_checkpointing.strategies.async_utils.AsyncRequest#
Freezes the async request, disallowing adding new finalization functions.
- Returns:
new async request with all same fields except for the
is_frozenflag.- Return type:
- class core.dist_checkpointing.strategies.async_utils.AsyncCaller#
Bases:
abc.ABCWrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
- abstractmethod schedule_async_call( ) None#
Schedule
async_reqwith some process forking or reusing persistent workerThis method must be called on all ranks.
- Parameters:
async_req (AsyncRequest) –
AsyncRequestobject containing to start async process
- abstractmethod is_current_async_call_done(blocking: bool, no_dist: bool) bool#
Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check. This method must be called on all ranks.
- Parameters:
blocking (bool, optional) – if True, will wait until the call is done on all ranks. Otherwise, returns immediately if at least one rank is still active. Defaults to False.
no_dist (bool, Optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.
- Returns:
True if all ranks are done (immediately of after active wait if
blockingis True), False if at least one rank is still active.- Return type:
bool
- sync_all_async_calls(is_alive: int) bool#
Check if all ranks have completed async checkpoint writing
- Parameters:
is_alive (bool) – if True, the current async request is not completed
- Returns:
True if all ranks are done, False if at least one rank is still active.
- Return type:
bool
- abstractmethod close(abort=False)#
Terminate the async caller at exit of an application or some termination conditions
- abstractmethod __del__()#
- class core.dist_checkpointing.strategies.async_utils.TemporalAsyncCaller#
Bases:
core.dist_checkpointing.strategies.async_utils.AsyncCallerWrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
Initialization
- schedule_async_call( ) None#
Spawn a process with
async_fnas the target.This method must be called on all ranks.
- Parameters:
async_fn (Callable, optional) – async function to call. If None, no process will be started.
async_req (AsyncRequest) –
AsyncRequestobject containing to start async process
- is_current_async_call_done(
- blocking: bool = False,
- no_dist: bool = False,
Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check. This method must be called on all ranks.
- Parameters:
blocking (bool, optional) – if True, will wait until the call is done on all ranks. Otherwise, returns immediately if at least one rank is still active. Defaults to False.
no_dist (bool, Optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.
- Returns:
True if all ranks are done (immediately of after active wait if
blockingis True), False if at least one rank is still active.- Return type:
bool
- close(abort=False)#
For TemporalAsyncCaller, this method is called explictly in
is_current_async_calls_doneThis method make sure the TemporalAsyncCaller terminated with all its assigned async request completed
- Parameters:
abort (bool, optional) – Default to False. Needs to be manually set to true when the checkpoint async process needs to be aborted.
- __del__()#
- class core.dist_checkpointing.strategies.async_utils.PersistentAsyncCaller#
Bases:
core.dist_checkpointing.strategies.async_utils.AsyncCallerWrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
Initialization
- schedule_async_call( ) None#
Put
AsyncRequestto the Persistent Async CallerThis method must be called on all ranks.
- Parameters:
async_fn (Callable, optional) – async function to call. If None, no process will be started.
async_req (AsyncRequest) –
AsyncRequestobject containing to schedule a checkpointing request
- is_current_async_call_done(
- blocking: bool = False,
- no_dist: bool = False,
Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check. This method must be called on all ranks.
- Parameters:
blocking (bool, optional) – if True, will wait until the call is done on all ranks. Otherwise, returns immediately if at least one rank is still active. Defaults to False.
no_dist (bool, Optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.
- Returns:
True if all ranks are done (immediately of after active wait if
blockingis True), False if at least one rank is still active.- Return type:
bool
- close(abort=False)#
Wait on the left async requests and terminate the PersistentAsyncCaller
Signals the PersistentAsyncCaller by sending a ‘DONE’ message to make it terminated
- Parameters:
abort (bool, optional) – Default to False. Needs to be manually set to true when the checkpoint async process needs to be aborted.
- __del__()#
- static async_loop(
- rank: int,
- queue: torch.multiprocessing.JoinableQueue,
- preload_q: torch.multiprocessing.JoinableQueue,
- comp_q: torch.multiprocessing.Queue,
- log_level: int = logging.INFO,
Main function for the persistent checkpoint worker
The persisent worker is created once and terminated at exit or when application calls
close()explictilyThis routine receives
AsyncRequestand doespreload_fnfirst and put the integer value inpreload_qto inform the trainer to proceed. When theasync_fnfrom the requestis completed (background saving is done), it puts a integer value tocomp_q` to notify the trainer the completion.- Parameters:
rank (int) – the rank of the trainer where the persistent worker is created.
queue (mp.JoinableQueue) – the main queue used to receive `AsyncRequest from the training rank
preload_q (mp.JoinableQueue) – a queue to inform trainer that preloading of tensors from GPU to Host or dedicated location is completed
comp_q (mp.Queue) – a queue to inform the training rank the completion of scheduled async checkpoint request
log_level (int, Optional) – an integer to set log-level in this spawned process to get aligned with the training rank’s logging level
- class core.dist_checkpointing.strategies.async_utils._ActiveAsyncRequest#
Bases:
typing.NamedTupleHelper to represent an active async call.
- Parameters:
idx (int) – index of the call (starting from 0)
async_caller (DistributedAsyncCaller) – async caller instance that represents the async process handling the async request
async_request (AsyncRequest) – async request that is being called
- idx: int#
None
- async_caller: core.dist_checkpointing.strategies.async_utils.AsyncCaller#
None
- async_request: core.dist_checkpointing.strategies.async_utils.AsyncRequest#
None
- class core.dist_checkpointing.strategies.async_utils.AsyncCallsQueue(persistent: bool = False)#
Manages a queue of async calls.
Allows adding a new async call with
schedule_async_requestand finalizing active calls withmaybe_finalize_async_calls.Initialization
- _get_async_caller()#
- schedule_async_request(
- async_request: core.dist_checkpointing.strategies.async_utils.AsyncRequest,
Start a new async call and add it to a queue of active async calls.
This method must be called on all ranks.
- Parameters:
async_request (AsyncRequest) – async request to start.
- Returns:
index of the async call that was started. This can help the user keep track of the async calls.
- Return type:
int
- maybe_finalize_async_calls(
- blocking=False,
- no_dist=False,
Finalizes all available calls.
This method must be called on all ranks.
- Parameters:
blocking (bool, optional) – if True, will wait until all active requests are done. Otherwise, finalizes only the async request that already finished. Defaults to False.
no_dist (bool, Optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.
- Returns:
list of indices (as returned by
schedule_async_request) of async calls that have been successfully finalized.- Return type:
List[int]
- Raises:
CheckpointException – if any rank(s) raised an exception during checkpoint writing, the exceptions are wrapped and raised on all ranks.
- get_num_unfinalized_calls()#
Get the number of active async calls.
- close(abort=False)#
Finalize all calls upon closing.
- Parameters:
abort (bool, optional) – Default to False. Needs to be manually set to true when the checkpoint async process needs to be aborted.