bridge.training.utils.sig_utils
#
Module Contents#
Classes#
Context manager to handle signals gracefully in a distributed setting. |
Functions#
Get the appropriate torch device based on the distributed backend. |
|
Perform an all_gather operation on a single Python object. |
API#
- bridge.training.utils.sig_utils.get_device(local_rank: Optional[int] = None) torch.device #
Get the appropriate torch device based on the distributed backend.
- Parameters:
local_rank – The local rank, used to specify the CUDA device index for NCCL. If None, uses the default CUDA device.
- Returns:
The torch.device (‘cuda’ for NCCL, ‘cpu’ for Gloo).
- Raises:
RuntimeError – If the distributed backend is neither ‘nccl’ nor ‘gloo’.
- bridge.training.utils.sig_utils.all_gather_item(
- item: Any,
- dtype: torch.dtype,
- group: Optional[torch.distributed.ProcessGroup] = None,
- async_op: bool = False,
- local_rank: Optional[int] = None,
Perform an all_gather operation on a single Python object.
Converts the item to a tensor, performs all_gather, and converts back to a list of Python objects from all ranks.
- Parameters:
item (Any) – The Python object to gather.
dtype (torch.dtype) – The torch dtype to use for the intermediate tensor.
group (Optional[torch.distributed.ProcessGroup]) – The process group to gather within (defaults to the global group).
async_op (bool) – Whether the operation should be asynchronous.
local_rank (Optional[int]) – The local rank to determine the device.
- Returns:
A list containing the gathered items (of type Any) from all ranks in the group.
- Return type:
list[Any]
- class bridge.training.utils.sig_utils.DistributedSignalHandler(sig: int = signal.SIGTERM)#
Context manager to handle signals gracefully in a distributed setting.
Installs a signal handler upon entering the context that sets a flag when the specified signal is received. The
signals_received
method can be used to check if any rank received the signal (using all_gather). The original signal handler is restored upon exiting the context.- Parameters:
sig – The signal number to handle (e.g., signal.SIGTERM). Defaults to signal.SIGTERM.
Initialization
- signals_received() list[bool] #
Check if any rank in the default group received the signal.
Uses all_gather to collect the signal status from all ranks.
- Returns:
A list of booleans, where each element indicates if the corresponding rank received the signal.
- __exit__(
- exc_type: Optional[type],
- exc_val: Optional[Exception],
- exc_tb: Optional[Any],
Release the signal handler and restore the original handler.
- release() bool #
Restore the original signal handler.
- Returns:
True if the handler was released, False if it was already released.