nemo_automodel.utils.sig_utils
#
Module Contents#
Classes#
Context manager to handle signals gracefully in a distributed setting. |
Functions#
Get the appropriate torch device based on the distributed backend. |
|
Perform an all_gather operation on a single Python object. |
API#
- nemo_automodel.utils.sig_utils.get_device(local_rank: Optional[int] = None) torch.device [source]#
Get the appropriate torch device based on the distributed backend.
- Parameters:
local_rank – The local rank, used to specify the CUDA device index for NCCL. If None, uses the default CUDA device.
- Returns:
The torch.device (‘cuda’ for NCCL, ‘cpu’ for Gloo).
- Raises:
RuntimeError – If the distributed backend is neither ‘nccl’ nor ‘gloo’.
- nemo_automodel.utils.sig_utils.all_gather_item(
- item: Any,
- dtype: torch.dtype,
- group: Optional[torch.distributed.ProcessGroup] = None,
- async_op: bool = False,
- local_rank: Optional[int] = None,
Perform an all_gather operation on a single Python object.
Converts the item to a tensor, performs all_gather, and converts back to a list of Python objects from all ranks.
- Parameters:
item (Any) – The Python object to gather.
dtype (torch.dtype) – The torch dtype to use for the intermediate tensor.
group (Optional[torch.distributed.ProcessGroup]) – The process group to gather within (defaults to the global group).
async_op (bool) – Whether the operation should be asynchronous.
local_rank (Optional[int]) – The local rank to determine the device.
- Returns:
A list containing the gathered items (of type Any) from all ranks in the group.
- Return type:
list[Any]
- class nemo_automodel.utils.sig_utils.DistributedSignalHandler(sig: int = signal.SIGTERM)[source]#
Context manager to handle signals gracefully in a distributed setting.
Installs a signal handler upon entering the context that sets a flag when the specified signal is received. The
signals_received
method can be used to check if any rank received the signal (using all_gather). The original signal handler is restored upon exiting the context.- Parameters:
sig – The signal number to handle (e.g., signal.SIGTERM). Defaults to signal.SIGTERM.
Initialization
Constructor for the DistributedSignalHandler.
- Parameters:
sig (int, optional) – The signal to handle. Defaults to signal.SIGTERM.
- signals_received() list[bool] [source]#
Check if any rank in the default group received the signal.
Uses all_gather to collect the signal status from all ranks.
- Returns:
A list of booleans, where each element indicates if the corresponding rank received the signal.
- __enter__() nemo_automodel.utils.sig_utils.DistributedSignalHandler [source]#
Enters the signal-managed area.
- Returns:
returns self.
- Return type: