Source code for nemo_automodel.utils.sig_utils

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import signal
import types
from typing import Any, Optional
import logging

import torch
import torch.distributed


[docs] def get_device(local_rank: Optional[int] = None) -> torch.device: """ Get the appropriate torch device based on the distributed backend. Args: local_rank: The local rank, used to specify the CUDA device index for NCCL. If None, uses the default CUDA device. Returns: The torch.device ('cuda' for NCCL, 'cpu' for Gloo). Raises: RuntimeError: If the distributed backend is neither 'nccl' nor 'gloo'. """ backend = torch.distributed.get_backend() if backend == "nccl": if local_rank is None: device = torch.device("cuda") else: device = torch.device(f"cuda:{local_rank}") elif backend == "gloo": device = torch.device("cpu") else: raise RuntimeError return device
[docs] def all_gather_item( item: Any, dtype: torch.dtype, group: Optional[torch.distributed.ProcessGroup] = None, async_op: bool = False, local_rank: Optional[int] = None, ) -> list[Any]: """Perform an all_gather operation on a single Python object. Converts the item to a tensor, performs all_gather, and converts back to a list of Python objects from all ranks. Args: item (Any): The Python object to gather. dtype (torch.dtype): The torch dtype to use for the intermediate tensor. group (Optional[torch.distributed.ProcessGroup]): The process group to gather within (defaults to the global group). async_op (bool): Whether the operation should be asynchronous. local_rank (Optional[int]): The local rank to determine the device. Returns: list[Any]: A list containing the gathered items (of type Any) from all ranks in the group. """ if not torch.distributed.is_available() or not torch.distributed.is_initialized(): return [item] device = get_device(local_rank) if group is not None: group_size = group.size() else: group_size = torch.distributed.get_world_size() tensor = torch.tensor([item], device=device, dtype=dtype) output_tensors = [torch.zeros(1, dtype=tensor.dtype, device=tensor.device) for _ in range(group_size)] torch.distributed.all_gather(output_tensors, tensor, group, async_op) output = [elem.item() for elem in output_tensors] return output
[docs] class DistributedSignalHandler: """ Context manager to handle signals gracefully in a distributed setting. Installs a signal handler upon entering the context that sets a flag when the specified signal is received. The `signals_received` method can be used to check if any rank received the signal (using all_gather). The original signal handler is restored upon exiting the context. Args: sig: The signal number to handle (e.g., signal.SIGTERM). Defaults to signal.SIGTERM. """ def __init__(self, sig: int = signal.SIGTERM) -> None: """ Constructor for the DistributedSignalHandler. Args: sig (int, optional): The signal to handle. Defaults to signal.SIGTERM. """ self.sig = sig self._signal_received = False self.released = False self.original_handler = None
[docs] def signals_received(self) -> list[bool]: """ Check if any rank in the default group received the signal. Uses all_gather to collect the signal status from all ranks. Returns: A list of booleans, where each element indicates if the corresponding rank received the signal. """ all_received = all_gather_item(self._signal_received, dtype=torch.int32) return all_received
[docs] def __enter__(self) -> "DistributedSignalHandler": """ Enters the signal-managed area. Returns: DistributedSignalHandler: returns self. """ self._signal_received = False self.released = False self.original_handler = signal.getsignal(self.sig) def handler(signum: int, frame: Optional[Any]) -> None: logging.info("Received signal {}, initiating graceful stop".format(signum)) self._signal_received = True signal.signal(self.sig, handler) logging.info("Signal handler installed for {}".format(self.sig)) return self
[docs] def __exit__(self, exc_type: Optional[type], exc_val: BaseException | None, exc_tb: types.TracebackType | None ) -> None: # noqa: E501 """ Release the signal handler and restore the original handler. """ self.release()
[docs] def release(self) -> bool: """ Restore the original signal handler. Returns: True if the handler was released, False if it was already released. """ if self.released: return False signal.signal(self.sig, self.original_handler) self.released = True return True