# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Union
import ray
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from nemo_rl.distributed.batched_data_dict import SlicedDataDict
from nemo_rl.distributed.ray_actor_environment_registry import (
get_actor_python_env,
)
from nemo_rl.distributed.virtual_cluster import RayVirtualCluster
from nemo_rl.utils.venvs import create_local_venv
[docs]
@dataclass
class MultiWorkerFuture:
"""Container for Ray futures with associated worker information."""
futures: List[ray.ObjectRef]
used_workers: List[int]
respect_tied_workers: bool = True
[docs]
def get_results(self, worker_group):
"""Get results from the futures, optionally respecting tied workers.
When respect_tied_workers is True, this method deduplicates results by returning
only one result per tied worker group.
The method uses worker_group.worker_to_tied_group_index to identify which tied
worker group each worker belongs to, then selects only the first result from each group.
Args:
worker_group: The RayWorkerGroup that created this bundle
Returns:
List of results, deduplicated by tied workers if respect_tied_workers is True
"""
# Basic case: Get all results
all_results = ray.get(self.futures)
# If we don't need to deduplicate by tied workers, return all results
if not self.respect_tied_workers:
return all_results
if not self.used_workers:
return all_results
# Create tied worker sets based on used workers
active_tied_workers = {}
for i, worker_idx in enumerate(self.used_workers):
tied_worker_idx = worker_group.worker_to_tied_group_index.get(worker_idx)
if tied_worker_idx is None:
continue
if tied_worker_idx not in active_tied_workers:
active_tied_workers[tied_worker_idx] = []
active_tied_workers[tied_worker_idx].append(i)
# Take the first result from each tied worker group
tied_worker_results = []
for tied_worker_idx in sorted(active_tied_workers.keys()):
if active_tied_workers[tied_worker_idx]:
result_idx = active_tied_workers[tied_worker_idx][0]
tied_worker_results.append(all_results[result_idx])
return tied_worker_results
[docs]
class RayWorkerBuilder:
[docs]
@ray.remote
class IsolatedWorkerInitializer:
def __init__(self, ray_actor_class_fqn: str, *init_args, **init_kwargs):
self.ray_actor_class_fqn = ray_actor_class_fqn
self.init_args = init_args
self.init_kwargs = init_kwargs
[docs]
def create_worker(
self,
placement_group: PlacementGroup,
placement_group_bundle_index: int,
num_gpus: int,
bundle_indices: Optional[tuple] = None,
**extra_options: Optional[Dict[str, Any]],
):
"""Create a Ray worker with the specified configuration.
Order of precedence for worker options configuration (from lowest to highest):
1. Options passed by the user to __call__ (extra_options)
2. Options required by the worker via configure_worker (may override user options with warning)
3. Options set by the RayWorkerBuilder.__call__ (specifically scheduling strategy)
If the worker needs to override user-provided options, it should log a warning
to inform the user about the change and the reason for it.
Args:
placement_group: Ray placement group for resource allocation
placement_group_bundle_index: Index of the bundle in the placement group
num_gpus: Number of GPUs to allocate to this worker
bundle_indices: Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable)
extra_options: Additional options to pass to the Ray actor (may be overridden by actor's configure_worker(...) method)
Returns:
A Ray actor reference to the created worker
"""
# Set up worker arguments and resources
module_name, class_name = self.ray_actor_class_fqn.rsplit(".", 1)
module = importlib.import_module(module_name)
worker_class = getattr(module, class_name)
worker_kwargs = dict(self.init_kwargs)
options = deepcopy(extra_options)
# Use the worker's configuration interface if available
if hasattr(worker_class, "configure_worker"):
# Get complete worker configuration from the worker class
resources, env_vars, init_kwargs = worker_class.configure_worker(
num_gpus=num_gpus,
bundle_indices=bundle_indices,
)
# Apply resource configuration
if resources and "num_gpus" in resources:
num_gpus = resources["num_gpus"]
# Apply environment variables if provided
if env_vars:
if "runtime_env" not in options:
options["runtime_env"] = {"env_vars": {}}
if "env_vars" not in options["runtime_env"]:
options["runtime_env"]["env_vars"] = {}
for k, v in env_vars.items():
options["runtime_env"]["env_vars"][k] = v
# Apply initialization parameters
if init_kwargs:
worker_kwargs.update(init_kwargs)
# Create options for Ray actor
options["scheduling_strategy"] = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=placement_group_bundle_index,
placement_group_capture_child_tasks=True,
)
options["num_gpus"] = num_gpus
# If the user hasn't specified a py_executable, use the worker class's default
if not options.get("runtime_env", {}).get("py_executable", None):
if "runtime_env" not in options:
options["runtime_env"] = {}
options["runtime_env"]["py_executable"] = get_actor_python_env(
self.ray_actor_class_fqn
)
if (
options.get("runtime_env", {})
.get("py_executable", "n/a")
.startswith("uv")
):
# If the py_executable begins with uv it signals that we need to create a
# local venv first and then replace the py_executable with the local venv's python.
# The directory the venv will be created in is controlled by the env var
# NEMO_RL_VENV_DIR and defaults to $GIT_ROOT/venvs/.
venv_python = create_local_venv(
py_executable=options["runtime_env"]["py_executable"],
venv_name=self.ray_actor_class_fqn,
)
options["runtime_env"]["py_executable"] = venv_python
worker = worker_class.options(**options).remote(
*self.init_args, **worker_kwargs
)
return worker
def __init__(self, ray_actor_class_fqn: str, *args, **kwargs):
self.ray_actor_class_fqn = ray_actor_class_fqn
self.args = args
self.kwargs = kwargs
[docs]
def __call__(
self,
placement_group: PlacementGroup,
placement_group_bundle_index: int,
num_gpus: int,
bundle_indices: Optional[tuple] = None,
**extra_options: Dict[str, Any],
):
"""Create a Ray worker with the specified configuration.
Order of precedence for worker options configuration (from lowest to highest):
1. Options passed by the user to __call__ (extra_options)
2. Options required by the worker via configure_worker (may override user options with warning)
3. Options set by the RayWorkerBuilder.__call__ (specifically scheduling strategy)
If the worker needs to override user-provided options, it should log a warning
to inform the user about the change and the reason for it.
Args:
placement_group: Ray placement group for resource allocation
placement_group_bundle_index: Index of the bundle in the placement group
num_gpus: Number of GPUs to allocate to this worker
bundle_indices: Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable)
extra_options: Additional options to pass to the Ray actor (may be overridden by actor's configure_worker(...) method)
Returns:
A Ray actor reference to the created worker
"""
# Set up worker arguments and resources
options = deepcopy(extra_options)
# If the user hasn't specified a py_executable, use the worker class's default
initializer_options = {}
if not options.get("runtime_env", {}).get("py_executable", None):
if "runtime_env" not in options:
options["runtime_env"] = {}
options["runtime_env"]["py_executable"] = get_actor_python_env(
self.ray_actor_class_fqn
)
if options.get("runtime_env", {}).get("py_executable", "n/a").startswith("uv"):
# If the py_executable begins with uv it signals that we need to create a
# local venv first and then replace the py_executable with the local venv's python.
# The directory the venv will be created in is controlled by the env var
# NEMO_RL_VENV_DIR and defaults to $GIT_ROOT/venvs/.
venv_python = create_local_venv(
py_executable=options["runtime_env"]["py_executable"],
venv_name=self.ray_actor_class_fqn,
)
options["runtime_env"]["py_executable"] = venv_python
initializer_options = {"runtime_env": options["runtime_env"]}
isolated_initializer = self.IsolatedWorkerInitializer.options(
**initializer_options
).remote(self.ray_actor_class_fqn, *self.args, **self.kwargs)
worker = ray.get(
isolated_initializer.create_worker.remote(
placement_group,
placement_group_bundle_index,
num_gpus,
bundle_indices,
**options,
)
)
# We hold onto a reference to the initializer actor to avoid gc (would kill the child, 'real' actor)
worker._RAY_INITIALIZER_ACTOR_REF_TO_AVOID_GC = isolated_initializer
return worker
[docs]
class RayWorkerGroup:
"""Manages a group of distributed Ray worker/actor processes that execute tasks in parallel.
This class creates and manages Ray actor instances that run on resources
allocated by a RayVirtualCluster. It handles:
- Worker creation and placement on specific GPU resources
- Setting up distributed training environment variables (rank, world size, etc.)
- Executing methods across all workers in parallel
- Collecting and aggregating results
- Support for tied worker groups where multiple workers process the same data
"""
def __init__(
self,
cluster: RayVirtualCluster,
remote_worker_builder: RayWorkerBuilder,
workers_per_node: Optional[Union[int, List[int]]] = None,
name_prefix: str = "",
bundle_indices_list: Optional[List[tuple]] = None,
):
"""Initialize a group of distributed Ray workers.
Args:
cluster: RayVirtualCluster
remote_worker_builder: Callable that launches a ray worker and has updatable options
workers_per_node: Defaults to launch one worker per bundle in the cluster.
Alternatively specify an int or list to launch a different number of workers per node.
name_prefix: Optional prefix for the names of the workers
bundle_indices_list: Explicit list of (node_idx, [local_bundle_indices]) tuples.
Each tuple defines a tied group of workers placed on the same node.
If provided, workers_per_node is ignored.
"""
self._workers = []
self._worker_metadata = []
self.cluster = cluster
self.name_prefix = name_prefix
self.tied_workers_groups = []
# Maps worker indices to their corresponding tied group index
# For example, if worker with index 3 belongs to tied worker group 1,
# then worker_to_tied_group_index[3] = 1
self.worker_to_tied_group_index = {}
# If explicit bundle indices are provided, use those
if bundle_indices_list is None:
# Create bundle_indices_list from workers_per_node specification
# In this case, each worker is its own group (no tied workers)
bundle_indices_list = []
# Determine how many workers per node
if workers_per_node is None:
workers_per_node = [
pg.bundle_count for pg in self.cluster.get_placement_groups()
]
elif isinstance(workers_per_node, int):
workers_per_node = [workers_per_node] * self.cluster.node_count()
elif not isinstance(workers_per_node, list):
raise ValueError(
"workers_per_node must be None(for default node distribution), an int, or a list"
)
# Validate workers_per_node
assert len(workers_per_node) == self.cluster.node_count(), (
"workers_per_node_list must be the same length as the number of nodes in the virtual cluster"
)
assert all(
[
workers_per_node[i] <= pg.bundle_count
for i, pg in enumerate(self.cluster.get_placement_groups())
]
), (
"workers_per_node must be less than or equal to the number of bundles in the placement groups"
)
# Create bundle_indices_list where each worker is its own group
for node_idx, worker_count in enumerate(workers_per_node):
for local_idx in range(worker_count):
# Each worker is its own single-element group
bundle_indices_list.append((node_idx, [local_idx]))
# Create workers based on the bundle_indices_list
self._create_workers_from_bundle_indices(
remote_worker_builder, bundle_indices_list
)
[docs]
def _create_workers_from_bundle_indices(
self, remote_worker_builder, bundle_indices_list
):
"""Create workers based on explicit bundle indices for tied worker groups.
Args:
remote_worker_builder: Builder function for Ray actors
bundle_indices_list: List of (node_idx, local_bundle_indices) tuples, where each tuple
specifies a tied group with its node and local bundle indices.
"""
self.master_address, self.master_port = (
self.cluster.get_master_address_and_port()
)
# Count total workers
self.world_size = sum(len(indices) for _, indices in bundle_indices_list)
global_rank = 0
for group_idx, (node_idx, local_bundle_indices) in enumerate(
bundle_indices_list
):
current_group = []
# Get the placement group for this node
pg = self.cluster.get_placement_groups()[node_idx]
is_tp_group = len(local_bundle_indices) > 1
for local_rank, bundle_idx in enumerate(local_bundle_indices):
# Set up basic distributed environment variables
env_vars = dict(
os.environ
) # Pass thru all user environment variables (at the lowest precendence)
env_vars.update(
{
"RANK": str(global_rank),
"LOCAL_RANK": str(bundle_idx),
"WORLD_SIZE": str(self.world_size),
"MASTER_ADDR": self.master_address,
"MASTER_PORT": str(self.master_port),
"NODE_RANK": str(node_idx),
}
)
# For tensor parallel groups, only the first worker gets bundle_indices
worker_bundle_indices = (
(node_idx, local_bundle_indices) if local_rank == 0 else None
)
# Create a descriptive name based on group structure
name = (
f"{self.name_prefix}-grp{group_idx}-{local_rank}"
if is_tp_group
else f"{self.name_prefix}-{node_idx}-{bundle_idx}"
)
# Calculate GPU resources
num_gpus = (
1 / self.cluster.max_colocated_worker_groups
if self.cluster.use_gpus
else 0
)
# Pass these options to the remote_worker_builder
runtime_env = {"env_vars": env_vars}
extra_options = {"runtime_env": runtime_env, "name": name}
# Create the worker
worker = remote_worker_builder(
placement_group=pg,
placement_group_bundle_index=bundle_idx,
num_gpus=num_gpus,
bundle_indices=worker_bundle_indices,
**extra_options,
)
# Store worker metadata
worker_idx = len(self._workers)
current_group.append(worker_idx)
self.worker_to_tied_group_index[worker_idx] = group_idx
self._workers.append(worker)
self._worker_metadata.append(
{
"node_idx": node_idx,
"local_rank": local_rank,
"global_rank": global_rank,
"name": name,
"bundle_indices": worker_bundle_indices,
"tied_group_idx": group_idx,
}
)
global_rank += 1
# Add this tied group to our list
self.tied_workers_groups.append(current_group)
@property
def workers(self):
return self._workers
@property
def worker_metadata(self):
return self._worker_metadata
@property
def group_count(self):
"""Number of tied worker groups."""
return len(self.tied_workers_groups)
[docs]
def run_all_workers_multiple_data(
self,
method_name: str,
data: List[SlicedDataDict],
common_kwargs: Optional[Dict[str, Any]] = None,
only_on: Literal["all", "tied_leader", "all_tied_workers"] = "all",
):
"""Run a method on all workers in parallel with different data.
Args:
method_name: Name of the method to call on each worker
data: List of data slices to pass to workers/groups
common_kwargs: Additional keyword arguments to pass to all workers
only_on: Determines which workers receive data and execute the method:
- "all": Each worker gets its own data slice
- "tied_leader": Only the first worker in each tied group receives data
- "all_tied_workers": All workers in each tied group receive the same data slice
Returns:
MultiWorkerFuture: Object containing futures and their associated worker information
"""
# Verify that the data is a list of SlicedDataDict objects
if not all(isinstance(d, SlicedDataDict) for d in data):
warnings.warn(
f"Expected all elements in 'data' to be of type SlicedDataDict, but got "
f"{[type(d).__name__ for d in data]}. This may cause unexpected behavior. "
f"Please use make sure you're passing in Sharded Data to this function (and not replicated data)",
UserWarning,
)
if common_kwargs is None:
common_kwargs = {}
futures = []
used_workers = []
respect_tied_workers = only_on in {"tied_leader", "all_tied_workers"}
if only_on == "all":
# Regular case - each worker gets its own data slice
for worker_id, worker in enumerate(self.workers):
if worker_id >= len(data):
break
method = getattr(worker, method_name)
futures.append(method.remote(data[worker_id], **common_kwargs))
used_workers.append(worker_id)
elif respect_tied_workers:
# If there are fewer data slices than tied worker groups, use only the first N tied worker groups
active_tied_worker_count = min(len(data), len(self.tied_workers_groups))
if active_tied_worker_count < len(self.tied_workers_groups):
print(
f"Warning: Using only {active_tied_worker_count} of {len(self.tied_workers_groups)} tied worker groups due to limited data slices"
)
# For each tied worker group, all workers in the group get the same data slice
for tied_worker_idx in range(active_tied_worker_count):
tied_worker_group = self.tied_workers_groups[tied_worker_idx]
tied_worker_data = data[tied_worker_idx]
if only_on == "all_tied_workers":
# Running on all workers in the non-vllm case
for worker_idx in tied_worker_group:
futures.append(
getattr(self._workers[worker_idx], method_name).remote(
tied_worker_data, **common_kwargs
)
)
used_workers.append(worker_idx)
else:
# Running only on the leader of the tied worker group for vllm case
futures.append(
getattr(
self._workers[tied_worker_group[0]], method_name
).remote(tied_worker_data, **common_kwargs)
)
used_workers.append(tied_worker_group[0])
else:
raise ValueError(f"Invalid value for only_on: {only_on}")
# Return a MultiWorkerFuture containing both futures and worker information
return MultiWorkerFuture(
futures=futures,
used_workers=used_workers,
respect_tied_workers=respect_tied_workers,
)
[docs]
def run_all_workers_single_data(
self,
method_name: str,
*args,
only_on: Literal["all", "tied_leader", "all_tied_workers"] = "all",
**kwargs,
):
"""Run a method on all workers in parallel with the same data.
Args:
method_name: Name of the method to call on each worker
only_on: Determines which workers to run the method on:
- "all": Run on all workers
- "tied_leader": Run only on the first worker of each tied worker group
- "all_tied_workers": Run on all workers in each tied worker group
*args, **kwargs: Arguments to pass to the method
Returns:
List[ray.ObjectRef]: A list of ray futures
"""
futures = []
respect_tied_workers = only_on in {"tied_leader", "all_tied_workers"}
if only_on == "all":
for worker in self.workers:
method = getattr(worker, method_name)
futures.append(method.remote(*args, **kwargs))
elif respect_tied_workers:
for tied_worker_group in self.tied_workers_groups:
if only_on == "all_tied_workers":
# Running on all workers in the non-vllm case
for worker_idx in tied_worker_group:
futures.append(
getattr(self._workers[worker_idx], method_name).remote(
*args, **kwargs
)
)
else:
futures.append(
getattr(
self._workers[tied_worker_group[0]], method_name
).remote(*args, **kwargs)
)
else:
raise ValueError(f"Invalid value for only_on: {only_on}")
return futures
[docs]
def get_all_worker_results(self, future_bundle):
"""Get results from all workers, optionally filtering to get just one result per tied worker group.
Args:
future_bundle: MultiWorkerFuture containing futures and worker information.
When future_bundle.respect_tied_workers is True, only results from
the leaders of tied worker groups are returned.
Returns:
List of results, deduplicated as specified in the future_bundle
"""
return future_bundle.get_results(self)
[docs]
def shutdown(
self,
cleanup_method: Optional[str] = None,
timeout: Optional[float] = 30.0,
force: bool = False,
):
"""Shutdown all workers in the worker group.
Args:
cleanup_method: Optional method name to call on each worker before termination.
If provided, this method will be called on each worker to allow
for graceful cleanup.
timeout: Timeout in seconds for graceful shutdown. Only applicable if cleanup_method is provided.
If None, wait indefinitely for workers to complete their cleanup.
force: If True, forcefully terminate workers with ray.kill() even if cleanup_method is provided.
If cleanup_method is None, workers are always forcefully terminated.
Returns:
bool: True if all workers were successfully shut down
"""
if not self._workers:
return True
success = True
# First attempt graceful shutdown if cleanup method is provided and force=False
if cleanup_method is not None and not force:
try:
# Call cleanup method on all workers
futures = self.run_all_workers_single_data(cleanup_method)
# Wait for all cleanup operations to complete with timeout
if timeout is not None:
ray.get(futures, timeout=timeout)
else:
ray.get(futures)
except (ray.exceptions.RayTaskError, ray.exceptions.GetTimeoutError) as e:
success = False
print(
f"Error during graceful shutdown: {e}. Falling back to force termination."
)
force = True
# Force kill any remaining workers
if force or cleanup_method is None:
initializers_to_kill = []
for worker in self._workers:
if hasattr(worker, "_RAY_INITIALIZER_ACTOR_REF_TO_AVOID_GC"):
# Store the initializer ref before the main worker is killed,
# as killing the worker might affect accessibility of this attribute later.
initializer = getattr(
worker, "_RAY_INITIALIZER_ACTOR_REF_TO_AVOID_GC", None
)
if initializer:
initializers_to_kill.append(initializer)
try:
ray.kill(worker)
except Exception as e:
success = False
print(f"Error killing worker: {e}")
# Now, explicitly kill the initializer actors
# This makes their termination more deterministic than relying solely on Ray's GC.
for initializer in initializers_to_kill:
try:
ray.kill(initializer)
except Exception as e:
print(f"Error killing initializer actor for a worker: {e}")
# Clear worker lists
self._workers = []
self._worker_metadata = []
self.tied_workers_groups = []
self.worker_to_tied_group_index = {}
return success
[docs]
def print_worker_layout(self):
"""Prints a visual representation of the worker layout across the virtual cluster.
This shows which workers are assigned to which nodes and GPUs.
"""
self.cluster.print_cluster_grid(self)