# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Union
import ray
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from nemo_rl.distributed.batched_data_dict import SlicedDataDict
from nemo_rl.distributed.virtual_cluster import RayVirtualCluster
from nemo_rl.utils.venvs import create_local_venv
[docs]
@dataclass
class MultiWorkerFuture:
"""Container for Ray futures with associated worker information."""
futures: List[ray.ObjectRef]
used_workers: List[int]
respect_tied_workers: bool = True
[docs]
def get_results(self, worker_group):
"""Get results from the futures, optionally respecting tied workers.
When respect_tied_workers is True, this method deduplicates results by returning
only one result per tied worker group.
The method uses worker_group.worker_to_tied_group_index to identify which tied
worker group each worker belongs to, then selects only the first result from each group.
Args:
worker_group: The RayWorkerGroup that created this bundle
Returns:
List of results, deduplicated by tied workers if respect_tied_workers is True
"""
# Basic case: Get all results
all_results = ray.get(self.futures)
# If we don't need to deduplicate by tied workers, return all results
if not self.respect_tied_workers:
return all_results
if not self.used_workers:
return all_results
# Create tied worker sets based on used workers
active_tied_workers = {}
for i, worker_idx in enumerate(self.used_workers):
tied_worker_idx = worker_group.worker_to_tied_group_index.get(worker_idx)
if tied_worker_idx is None:
continue
if tied_worker_idx not in active_tied_workers:
active_tied_workers[tied_worker_idx] = []
active_tied_workers[tied_worker_idx].append(i)
# Take the first result from each tied worker group
tied_worker_results = []
for tied_worker_idx in sorted(active_tied_workers.keys()):
if active_tied_workers[tied_worker_idx]:
result_idx = active_tied_workers[tied_worker_idx][0]
tied_worker_results.append(all_results[result_idx])
return tied_worker_results
[docs]
class RayWorkerBuilder:
def __init__(self, ray_actor_class: type, *args, **kwargs):
self.ray_actor_class = ray_actor_class
self.args = args
self.kwargs = kwargs
[docs]
def __call__(
self,
placement_group: PlacementGroup,
placement_group_bundle_index: int,
num_gpus: int,
bundle_indices: Optional[tuple] = None,
**extra_options: Dict[str, Any],
):
"""Create a Ray worker with the specified configuration.
Order of precedence for worker options configuration (from lowest to highest):
1. Options passed by the user to __call__ (extra_options)
2. Options required by the worker via configure_worker (may override user options with warning)
3. Options set by the RayWorkerBuilder.__call__ (specifically scheduling strategy)
If the worker needs to override user-provided options, it should log a warning
to inform the user about the change and the reason for it.
Args:
placement_group: Ray placement group for resource allocation
placement_group_bundle_index: Index of the bundle in the placement group
num_gpus: Number of GPUs to allocate to this worker
bundle_indices: Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable)
extra_options: Additional options to pass to the Ray actor (may be overridden by actor's configure_worker(...) method)
Returns:
A Ray actor reference to the created worker
"""
# Set up worker arguments and resources
worker_class = self.ray_actor_class
worker_kwargs = dict(self.kwargs)
options = deepcopy(extra_options)
# Use the worker's configuration interface if available
if hasattr(worker_class, "configure_worker"):
# Get complete worker configuration from the worker class
resources, env_vars, init_kwargs = worker_class.configure_worker(
num_gpus=num_gpus,
bundle_indices=bundle_indices,
)
# Apply resource configuration
if resources and "num_gpus" in resources:
num_gpus = resources["num_gpus"]
# Apply environment variables if provided
if env_vars:
if "runtime_env" not in options:
options["runtime_env"] = {}
for k, v in env_vars.items():
options["runtime_env"]["env_vars"][k] = v
# Apply initialization parameters
if init_kwargs:
worker_kwargs.update(init_kwargs)
# Create options for Ray actor
options["scheduling_strategy"] = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=placement_group_bundle_index,
placement_group_capture_child_tasks=True,
)
options["num_gpus"] = num_gpus
# If the user hasn't specified a py_executable, use the worker class's default
if not options.get("runtime_env", {}).get("py_executable", None) and hasattr(
worker_class, "DEFAULT_PY_EXECUTABLE"
):
if "runtime_env" not in options:
options["runtime_env"] = {}
options["runtime_env"]["py_executable"] = worker_class.DEFAULT_PY_EXECUTABLE
if options.get("runtime_env", {}).get("py_executable", "n/a").startswith("uv"):
# If the py_executable begins with uv it signals that we need to create a
# local venv first and then replace the py_executable with the local venv's python.
# The directory the venv will be created in is controlled by the env var
# NEMO_RL_VENV_DIR and defaults to $GIT_ROOT/venvs/.
unwrapped_cls = worker_class.__ray_actor_class__
venv_python = create_local_venv(
py_executable=options["runtime_env"]["py_executable"],
venv_name=f"{unwrapped_cls.__module__}.{unwrapped_cls.__name__}",
)
options["runtime_env"]["py_executable"] = venv_python
return worker_class.options(**options).remote(*self.args, **worker_kwargs)
[docs]
class RayWorkerGroup:
"""Manages a group of distributed Ray worker/actor processes that execute tasks in parallel.
This class creates and manages Ray actor instances that run on resources
allocated by a RayVirtualCluster. It handles:
- Worker creation and placement on specific GPU resources
- Setting up distributed training environment variables (rank, world size, etc.)
- Executing methods across all workers in parallel
- Collecting and aggregating results
- Support for tied worker groups where multiple workers process the same data
"""
def __init__(
self,
cluster: RayVirtualCluster,
remote_worker_builder: RayWorkerBuilder,
workers_per_node: Optional[Union[int, List[int]]] = None,
name_prefix: str = "",
bundle_indices_list: Optional[List[tuple]] = None,
):
"""Initialize a group of distributed Ray workers.
Args:
cluster: RayVirtualCluster
remote_worker_builder: Callable that launches a ray worker and has updatable options
workers_per_node: Defaults to launch one worker per bundle in the cluster.
Alternatively specify an int or list to launch a different number of workers per node.
name_prefix: Optional prefix for the names of the workers
bundle_indices_list: Explicit list of (node_idx, [local_bundle_indices]) tuples.
Each tuple defines a tied group of workers placed on the same node.
If provided, workers_per_node is ignored.
"""
self._workers = []
self._worker_metadata = []
self.cluster = cluster
self.name_prefix = name_prefix
self.tied_workers_groups = []
# Maps worker indices to their corresponding tied group index
# For example, if worker with index 3 belongs to tied worker group 1,
# then worker_to_tied_group_index[3] = 1
self.worker_to_tied_group_index = {}
# If explicit bundle indices are provided, use those
if bundle_indices_list is None:
# Create bundle_indices_list from workers_per_node specification
# In this case, each worker is its own group (no tied workers)
bundle_indices_list = []
# Determine how many workers per node
if workers_per_node is None:
workers_per_node = [
pg.bundle_count for pg in self.cluster.get_placement_groups()
]
elif isinstance(workers_per_node, int):
workers_per_node = [workers_per_node] * self.cluster.node_count()
elif not isinstance(workers_per_node, list):
raise ValueError(
"workers_per_node must be None(for default node distribution), an int, or a list"
)
# Validate workers_per_node
assert len(workers_per_node) == self.cluster.node_count(), (
"workers_per_node_list must be the same length as the number of nodes in the virtual cluster"
)
assert all(
[
workers_per_node[i] <= pg.bundle_count
for i, pg in enumerate(self.cluster.get_placement_groups())
]
), (
"workers_per_node must be less than or equal to the number of bundles in the placement groups"
)
# Create bundle_indices_list where each worker is its own group
for node_idx, worker_count in enumerate(workers_per_node):
for local_idx in range(worker_count):
# Each worker is its own single-element group
bundle_indices_list.append((node_idx, [local_idx]))
# Create workers based on the bundle_indices_list
self._create_workers_from_bundle_indices(
remote_worker_builder, bundle_indices_list
)
[docs]
def _create_workers_from_bundle_indices(
self, remote_worker_builder, bundle_indices_list
):
"""Create workers based on explicit bundle indices for tied worker groups.
Args:
remote_worker_builder: Builder function for Ray actors
bundle_indices_list: List of (node_idx, local_bundle_indices) tuples, where each tuple
specifies a tied group with its node and local bundle indices.
"""
self.master_address, self.master_port = (
self.cluster.get_master_address_and_port()
)
# Count total workers
self.world_size = sum(len(indices) for _, indices in bundle_indices_list)
global_rank = 0
for group_idx, (node_idx, local_bundle_indices) in enumerate(
bundle_indices_list
):
current_group = []
# Get the placement group for this node
pg = self.cluster.get_placement_groups()[node_idx]
is_tp_group = len(local_bundle_indices) > 1
for local_rank, bundle_idx in enumerate(local_bundle_indices):
# Set up basic distributed environment variables
env_vars = dict(
os.environ
) # Pass thru all user environment variables (at the lowest precendence)
env_vars.update(
{
"RANK": str(global_rank),
"LOCAL_RANK": str(bundle_idx),
"WORLD_SIZE": str(self.world_size),
"MASTER_ADDR": self.master_address,
"MASTER_PORT": str(self.master_port),
"NODE_RANK": str(node_idx),
}
)
# For tensor parallel groups, only the first worker gets bundle_indices
worker_bundle_indices = (
(node_idx, local_bundle_indices) if local_rank == 0 else None
)
# Create a descriptive name based on group structure
name = (
f"{self.name_prefix}-grp{group_idx}-{local_rank}"
if is_tp_group
else f"{self.name_prefix}-{node_idx}-{bundle_idx}"
)
# Calculate GPU resources
num_gpus = (
1 / self.cluster.max_colocated_worker_groups
if self.cluster.use_gpus
else 0
)
# Pass these options to the remote_worker_builder
runtime_env = {"env_vars": env_vars}
extra_options = {"runtime_env": runtime_env, "name": name}
# Create the worker
worker = remote_worker_builder(
placement_group=pg,
placement_group_bundle_index=bundle_idx,
num_gpus=num_gpus,
bundle_indices=worker_bundle_indices,
**extra_options,
)
# Store worker metadata
worker_idx = len(self._workers)
current_group.append(worker_idx)
self.worker_to_tied_group_index[worker_idx] = group_idx
self._workers.append(worker)
self._worker_metadata.append(
{
"node_idx": node_idx,
"local_rank": local_rank,
"global_rank": global_rank,
"name": name,
"bundle_indices": worker_bundle_indices,
"tied_group_idx": group_idx,
}
)
global_rank += 1
# Add this tied group to our list
self.tied_workers_groups.append(current_group)
@property
def workers(self):
return self._workers
@property
def worker_metadata(self):
return self._worker_metadata
@property
def group_count(self):
"""Number of tied worker groups."""
return len(self.tied_workers_groups)
[docs]
def run_all_workers_multiple_data(
self,
method_name: str,
data: List[SlicedDataDict],
common_kwargs: Optional[Dict[str, Any]] = None,
only_on: Literal["all", "tied_leader", "all_tied_workers"] = "all",
):
"""Run a method on all workers in parallel with different data.
Args:
method_name: Name of the method to call on each worker
data: List of data slices to pass to workers/groups
common_kwargs: Additional keyword arguments to pass to all workers
only_on: Determines which workers receive data and execute the method:
- "all": Each worker gets its own data slice
- "tied_leader": Only the first worker in each tied group receives data
- "all_tied_workers": All workers in each tied group receive the same data slice
Returns:
MultiWorkerFuture: Object containing futures and their associated worker information
"""
# Verify that the data is a list of SlicedDataDict objects
if not all(isinstance(d, SlicedDataDict) for d in data):
warnings.warn(
f"Expected all elements in 'data' to be of type SlicedDataDict, but got "
f"{[type(d).__name__ for d in data]}. This may cause unexpected behavior. "
f"Please use make sure you're passing in Sharded Data to this function (and not replicated data)",
UserWarning,
)
if common_kwargs is None:
common_kwargs = {}
futures = []
used_workers = []
respect_tied_workers = only_on in {"tied_leader", "all_tied_workers"}
if only_on == "all":
# Regular case - each worker gets its own data slice
for worker_id, worker in enumerate(self.workers):
if worker_id >= len(data):
break
method = getattr(worker, method_name)
futures.append(method.remote(data[worker_id], **common_kwargs))
used_workers.append(worker_id)
elif respect_tied_workers:
# If there are fewer data slices than tied worker groups, use only the first N tied worker groups
active_tied_worker_count = min(len(data), len(self.tied_workers_groups))
if active_tied_worker_count < len(self.tied_workers_groups):
print(
f"Warning: Using only {active_tied_worker_count} of {len(self.tied_workers_groups)} tied worker groups due to limited data slices"
)
# For each tied worker group, all workers in the group get the same data slice
for tied_worker_idx in range(active_tied_worker_count):
tied_worker_group = self.tied_workers_groups[tied_worker_idx]
tied_worker_data = data[tied_worker_idx]
if only_on == "all_tied_workers":
# Running on all workers in the non-vllm case
for worker_idx in tied_worker_group:
futures.append(
getattr(self._workers[worker_idx], method_name).remote(
tied_worker_data, **common_kwargs
)
)
used_workers.append(worker_idx)
else:
# Running only on the leader of the tied worker group for vllm case
futures.append(
getattr(
self._workers[tied_worker_group[0]], method_name
).remote(tied_worker_data, **common_kwargs)
)
used_workers.append(tied_worker_group[0])
else:
raise ValueError(f"Invalid value for only_on: {only_on}")
# Return a MultiWorkerFuture containing both futures and worker information
return MultiWorkerFuture(
futures=futures,
used_workers=used_workers,
respect_tied_workers=respect_tied_workers,
)
[docs]
def run_all_workers_single_data(
self,
method_name: str,
*args,
only_on: Literal["all", "tied_leader", "all_tied_workers"] = "all",
**kwargs,
):
"""Run a method on all workers in parallel with the same data.
Args:
method_name: Name of the method to call on each worker
only_on: Determines which workers to run the method on:
- "all": Run on all workers
- "tied_leader": Run only on the first worker of each tied worker group
- "all_tied_workers": Run on all workers in each tied worker group
*args, **kwargs: Arguments to pass to the method
Returns:
List[ray.ObjectRef]: A list of ray futures
"""
futures = []
respect_tied_workers = only_on in {"tied_leader", "all_tied_workers"}
if only_on == "all":
for worker in self.workers:
method = getattr(worker, method_name)
futures.append(method.remote(*args, **kwargs))
elif respect_tied_workers:
for tied_worker_group in self.tied_workers_groups:
if only_on == "all_tied_workers":
# Running on all workers in the non-vllm case
for worker_idx in tied_worker_group:
futures.append(
getattr(self._workers[worker_idx], method_name).remote(
*args, **kwargs
)
)
else:
futures.append(
getattr(
self._workers[tied_worker_group[0]], method_name
).remote(*args, **kwargs)
)
else:
raise ValueError(f"Invalid value for only_on: {only_on}")
return futures
[docs]
def get_all_worker_results(self, future_bundle):
"""Get results from all workers, optionally filtering to get just one result per tied worker group.
Args:
future_bundle: MultiWorkerFuture containing futures and worker information.
When future_bundle.respect_tied_workers is True, only results from
the leaders of tied worker groups are returned.
Returns:
List of results, deduplicated as specified in the future_bundle
"""
return future_bundle.get_results(self)
[docs]
def shutdown(
self,
cleanup_method: Optional[str] = None,
timeout: Optional[float] = 30.0,
force: bool = False,
):
"""Shutdown all workers in the worker group.
Args:
cleanup_method: Optional method name to call on each worker before termination.
If provided, this method will be called on each worker to allow
for graceful cleanup.
timeout: Timeout in seconds for graceful shutdown. Only applicable if cleanup_method is provided.
If None, wait indefinitely for workers to complete their cleanup.
force: If True, forcefully terminate workers with ray.kill() even if cleanup_method is provided.
If cleanup_method is None, workers are always forcefully terminated.
Returns:
bool: True if all workers were successfully shut down
"""
if not self._workers:
return True
success = True
# First attempt graceful shutdown if cleanup method is provided and force=False
if cleanup_method is not None and not force:
try:
# Call cleanup method on all workers
futures = self.run_all_workers_single_data(cleanup_method)
# Wait for all cleanup operations to complete with timeout
if timeout is not None:
ray.get(futures, timeout=timeout)
else:
ray.get(futures)
except (ray.exceptions.RayTaskError, ray.exceptions.GetTimeoutError) as e:
success = False
print(
f"Error during graceful shutdown: {e}. Falling back to force termination."
)
force = True
# Force kill any remaining workers
if force or cleanup_method is None:
for worker in self._workers:
try:
ray.kill(worker)
except Exception as e:
success = False
print(f"Error killing worker: {e}")
# Clear worker lists
self._workers = []
self._worker_metadata = []
self.tied_workers_groups = []
self.worker_to_tied_group_index = {}
return success
[docs]
def print_worker_layout(self):
"""Prints a visual representation of the worker layout across the virtual cluster.
This shows which workers are assigned to which nodes and GPUs.
"""
self.cluster.print_cluster_grid(self)