core.dist_checkpointing.strategies.async_utils#

This module provides an async utilities which allow to start a checkpoint save process in the background.

Module Contents#

Classes#

AsyncRequest

Represents an async request that needs to be scheduled for execution.

AsyncCaller

Wrapper around mp.Process that ensures correct semantic of distributed finalization.

TemporalAsyncCaller

Wrapper around mp.Process that ensures correct semantic of distributed finalization.

PersistentAsyncCaller

Wrapper around mp.Process that ensures correct semantic of distributed finalization.

_ActiveAsyncRequest

Helper to represent an active async call.

AsyncCallsQueue

Manages a queue of async calls.

Functions#

_disable_gc

Temporarily disables GC.

Data#

API#

core.dist_checkpointing.strategies.async_utils.logger#

‘getLogger(…)’

core.dist_checkpointing.strategies.async_utils._disable_gc()#

Temporarily disables GC.

class core.dist_checkpointing.strategies.async_utils.AsyncRequest#

Bases: typing.NamedTuple

Represents an async request that needs to be scheduled for execution.

Parameters:
  • async_fn (Callable, optional) – async function to call. None represents noop.

  • async_fn_args (Tuple) – args to pass to async_fn.

  • finalize_fns (List[Callable]) – list of functions to call to finalize the request. These functions will be called synchronously after async_fn is done on all ranks.

  • async_fn_kwargs (Tuple) – kwargs to pass to async_fn.

  • preload_fn (Callable) – preload function to stage tensors from GPU to Host. This should be self-contained with a proper list of arguments with partial.

  • is_frozen (Bool) – a flag to indicate this async request can be modified or not.

  • call_idx (int) – index variable used to order async requests for synchronization in preloading and writing tensors on the async caller

async_fn: Optional[Callable]#

None

async_fn_args: Tuple#

None

finalize_fns: List[Callable]#

None

async_fn_kwargs: Dict#

None

preload_fn: Callable#

None

is_frozen: bool#

False

call_idx: int#

0

add_finalize_fn(fn: Callable) None#

Adds a new finalize function to the request.

Parameters:

fn (Callable) – function to add to the async request. This function will be called after existing finalization functions.

Returns:

None

execute_sync() None#

Helper to synchronously execute the request.

This logic is equivalent to what should happen in case of the async call.

freeze() core.dist_checkpointing.strategies.async_utils.AsyncRequest#

Freezes the async request, disallowing adding new finalization functions.

Returns:

new async request with all same fields except for the is_frozen flag.

Return type:

AsyncRequest

class core.dist_checkpointing.strategies.async_utils.AsyncCaller#

Bases: abc.ABC

Wrapper around mp.Process that ensures correct semantic of distributed finalization.

Starts process asynchronously and allows checking if all processes on all ranks are done.

abstractmethod schedule_async_call(
async_req: core.dist_checkpointing.strategies.async_utils.AsyncRequest,
) None#

Schedule async_req with some process forking or reusing persistent worker

This method must be called on all ranks.

Parameters:

async_req (AsyncRequest) – AsyncRequest object containing to start async process

abstractmethod is_current_async_call_done(blocking: bool, no_dist: bool) bool#

Check if async save is finished on all ranks.

For semantic correctness, requires rank synchronization in each check. This method must be called on all ranks.

Parameters:
  • blocking (bool, optional) – if True, will wait until the call is done on all ranks. Otherwise, returns immediately if at least one rank is still active. Defaults to False.

  • no_dist (bool, Optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.

Returns:

True if all ranks are done (immediately of after active wait if blocking is True), False if at least one rank is still active.

Return type:

bool

sync_all_async_calls(is_alive: int) bool#

Check if all ranks have completed async checkpoint writing

Parameters:

is_alive (bool) – if True, the current async request is not completed

Returns:

True if all ranks are done, False if at least one rank is still active.

Return type:

bool

abstractmethod close(abort=False)#

Terminate the async caller at exit of an application or some termination conditions

abstractmethod __del__()#
class core.dist_checkpointing.strategies.async_utils.TemporalAsyncCaller#

Bases: core.dist_checkpointing.strategies.async_utils.AsyncCaller

Wrapper around mp.Process that ensures correct semantic of distributed finalization.

Starts process asynchronously and allows checking if all processes on all ranks are done.

Initialization

schedule_async_call(
async_req: core.dist_checkpointing.strategies.async_utils.AsyncRequest,
) None#

Spawn a process with async_fn as the target.

This method must be called on all ranks.

Parameters:
  • async_fn (Callable, optional) – async function to call. If None, no process will be started.

  • async_req (AsyncRequest) – AsyncRequest object containing to start async process

is_current_async_call_done(
blocking: bool = False,
no_dist: bool = False,
) bool#

Check if async save is finished on all ranks.

For semantic correctness, requires rank synchronization in each check. This method must be called on all ranks.

Parameters:
  • blocking (bool, optional) – if True, will wait until the call is done on all ranks. Otherwise, returns immediately if at least one rank is still active. Defaults to False.

  • no_dist (bool, Optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.

Returns:

True if all ranks are done (immediately of after active wait if blocking is True), False if at least one rank is still active.

Return type:

bool

close(abort=False)#

For TemporalAsyncCaller, this method is called explictly in is_current_async_calls_done

This method make sure the TemporalAsyncCaller terminated with all its assigned async request completed

Parameters:

abort (bool, optional) – Default to False. Needs to be manually set to true when the checkpoint async process needs to be aborted.

__del__()#
class core.dist_checkpointing.strategies.async_utils.PersistentAsyncCaller#

Bases: core.dist_checkpointing.strategies.async_utils.AsyncCaller

Wrapper around mp.Process that ensures correct semantic of distributed finalization.

Starts process asynchronously and allows checking if all processes on all ranks are done.

Initialization

schedule_async_call(
async_req: core.dist_checkpointing.strategies.async_utils.AsyncRequest,
) None#

Put AsyncRequest to the Persistent Async Caller

This method must be called on all ranks.

Parameters:
  • async_fn (Callable, optional) – async function to call. If None, no process will be started.

  • async_req (AsyncRequest) – AsyncRequest object containing to schedule a checkpointing request

is_current_async_call_done(
blocking: bool = False,
no_dist: bool = False,
) bool#

Check if async save is finished on all ranks.

For semantic correctness, requires rank synchronization in each check. This method must be called on all ranks.

Parameters:
  • blocking (bool, optional) – if True, will wait until the call is done on all ranks. Otherwise, returns immediately if at least one rank is still active. Defaults to False.

  • no_dist (bool, Optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.

Returns:

True if all ranks are done (immediately of after active wait if blocking is True), False if at least one rank is still active.

Return type:

bool

close(abort=False)#

Wait on the left async requests and terminate the PersistentAsyncCaller

Signals the PersistentAsyncCaller by sending a ‘DONE’ message to make it terminated

Parameters:

abort (bool, optional) – Default to False. Needs to be manually set to true when the checkpoint async process needs to be aborted.

__del__()#
static async_loop(
rank: int,
queue: torch.multiprocessing.JoinableQueue,
preload_q: torch.multiprocessing.JoinableQueue,
comp_q: torch.multiprocessing.Queue,
log_level: int = logging.INFO,
)#

Main function for the persistent checkpoint worker

The persisent worker is created once and terminated at exit or when application calls close() explictily

This routine receives AsyncRequest and does preload_fn first and put the integer value in preload_q to inform the trainer to proceed. When the async_fn from the requestis completed (background saving is done), it puts a integer value tocomp_q` to notify the trainer the completion.

Parameters:
  • rank (int) – the rank of the trainer where the persistent worker is created.

  • queue (mp.JoinableQueue) – the main queue used to receive `AsyncRequest from the training rank

  • preload_q (mp.JoinableQueue) – a queue to inform trainer that preloading of tensors from GPU to Host or dedicated location is completed

  • comp_q (mp.Queue) – a queue to inform the training rank the completion of scheduled async checkpoint request

  • log_level (int, Optional) – an integer to set log-level in this spawned process to get aligned with the training rank’s logging level

class core.dist_checkpointing.strategies.async_utils._ActiveAsyncRequest#

Bases: typing.NamedTuple

Helper to represent an active async call.

Parameters:
  • idx (int) – index of the call (starting from 0)

  • async_caller (DistributedAsyncCaller) – async caller instance that represents the async process handling the async request

  • async_request (AsyncRequest) – async request that is being called

idx: int#

None

async_caller: core.dist_checkpointing.strategies.async_utils.AsyncCaller#

None

async_request: core.dist_checkpointing.strategies.async_utils.AsyncRequest#

None

class core.dist_checkpointing.strategies.async_utils.AsyncCallsQueue(persistent: bool = False)#

Manages a queue of async calls.

Allows adding a new async call with schedule_async_request and finalizing active calls with maybe_finalize_async_calls.

Initialization

_get_async_caller()#
schedule_async_request(
async_request: core.dist_checkpointing.strategies.async_utils.AsyncRequest,
) int#

Start a new async call and add it to a queue of active async calls.

This method must be called on all ranks.

Parameters:

async_request (AsyncRequest) – async request to start.

Returns:

index of the async call that was started. This can help the user keep track of the async calls.

Return type:

int

maybe_finalize_async_calls(
blocking=False,
no_dist=False,
) List[int]#

Finalizes all available calls.

This method must be called on all ranks.

Parameters:
  • blocking (bool, optional) – if True, will wait until all active requests are done. Otherwise, finalizes only the async request that already finished. Defaults to False.

  • no_dist (bool, Optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.

Returns:

list of indices (as returned by schedule_async_request) of async calls that have been successfully finalized.

Return type:

List[int]

Raises:

CheckpointException – if any rank(s) raised an exception during checkpoint writing, the exceptions are wrapped and raised on all ranks.

get_num_unfinalized_calls()#

Get the number of active async calls.

close(abort=False)#

Finalize all calls upon closing.

Parameters:

abort (bool, optional) – Default to False. Needs to be manually set to true when the checkpoint async process needs to be aborted.