nemo_automodel.utils.sig_utils#

Module Contents#

Classes#

DistributedSignalHandler

Context manager to handle signals gracefully in a distributed setting.

Functions#

get_device

Get the appropriate torch device based on the distributed backend.

all_gather_item

Perform an all_gather operation on a single Python object.

API#

nemo_automodel.utils.sig_utils.get_device(local_rank: Optional[int] = None) torch.device[source]#

Get the appropriate torch device based on the distributed backend.

Parameters:

local_rank – The local rank, used to specify the CUDA device index for NCCL. If None, uses the default CUDA device.

Returns:

The torch.device (‘cuda’ for NCCL, ‘cpu’ for Gloo).

Raises:

RuntimeError – If the distributed backend is neither ‘nccl’ nor ‘gloo’.

nemo_automodel.utils.sig_utils.all_gather_item(
item: Any,
dtype: torch.dtype,
group: Optional[torch.distributed.ProcessGroup] = None,
async_op: bool = False,
local_rank: Optional[int] = None,
) list[Any][source]#

Perform an all_gather operation on a single Python object.

Converts the item to a tensor, performs all_gather, and converts back to a list of Python objects from all ranks.

Parameters:
  • item (Any) – The Python object to gather.

  • dtype (torch.dtype) – The torch dtype to use for the intermediate tensor.

  • group (Optional[torch.distributed.ProcessGroup]) – The process group to gather within (defaults to the global group).

  • async_op (bool) – Whether the operation should be asynchronous.

  • local_rank (Optional[int]) – The local rank to determine the device.

Returns:

A list containing the gathered items (of type Any) from all ranks in the group.

Return type:

list[Any]

class nemo_automodel.utils.sig_utils.DistributedSignalHandler(sig: int = signal.SIGTERM)[source]#

Context manager to handle signals gracefully in a distributed setting.

Installs a signal handler upon entering the context that sets a flag when the specified signal is received. The signals_received method can be used to check if any rank received the signal (using all_gather). The original signal handler is restored upon exiting the context.

Parameters:

sig – The signal number to handle (e.g., signal.SIGTERM). Defaults to signal.SIGTERM.

Initialization

Constructor for the DistributedSignalHandler.

Parameters:

sig (int, optional) – The signal to handle. Defaults to signal.SIGTERM.

signals_received() list[bool][source]#

Check if any rank in the default group received the signal.

Uses all_gather to collect the signal status from all ranks.

Returns:

A list of booleans, where each element indicates if the corresponding rank received the signal.

__enter__() nemo_automodel.utils.sig_utils.DistributedSignalHandler[source]#

Enters the signal-managed area.

Returns:

returns self.

Return type:

DistributedSignalHandler

__exit__(
exc_type: Optional[type],
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) None[source]#

Release the signal handler and restore the original handler.

release() bool[source]#

Restore the original signal handler.

Returns:

True if the handler was released, False if it was already released.