Source code for nemo_rl.distributed.worker_groups

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Optional, Union

import ray
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from nemo_rl.distributed.named_sharding import NamedSharding
from nemo_rl.distributed.ray_actor_environment_registry import (
    get_actor_python_env,
)
from nemo_rl.distributed.virtual_cluster import RayVirtualCluster
from nemo_rl.distributed.worker_group_utils import recursive_merge_options
from nemo_rl.utils.venvs import create_local_venv_on_each_node


[docs] @dataclass class MultiWorkerFuture: """Container for Ray futures with associated worker information.""" futures: list[ray.ObjectRef] return_from_workers: Optional[list[int]] = None called_workers: Optional[list[int]] = None
[docs] def get_results( self, worker_group: "RayWorkerGroup", return_generators_as_proxies: bool = False ) -> list[Any]: """Get results from the futures, optionally respecting tied workers. The method uses worker_group.worker_to_tied_group_index to identify which tied worker group each worker belongs to, then selects only the first result from each group. Args: worker_group: The RayWorkerGroup that spawned the futures. The mapping contained in worker_group.worker_to_tied_group_index is required for the deduplication path. return_generators_as_proxies: If True, and a future is an ObjectRefGenerator, return the ObjectRefGenerator itself instead of consuming it. Returns: List of results """ from ray import ObjectRef, ObjectRefGenerator if return_generators_as_proxies: # Directly return the futures, which are expected to be ObjectRefGenerators (or other proxies). # No ray.get() is called on them. The consumer is responsible for handling the proxies. if self.return_from_workers is None: return self.futures if self.called_workers is not None: map_called_worker_to_future_idx = { global_idx: i for i, global_idx in enumerate(self.called_workers) } final_proxies = [] for global_worker_to_return in self.return_from_workers: if global_worker_to_return in map_called_worker_to_future_idx: future_idx = map_called_worker_to_future_idx[ global_worker_to_return ] if future_idx < len(self.futures): final_proxies.append(self.futures[future_idx]) return final_proxies else: return [ self.futures[i] for i in self.return_from_workers if i < len(self.futures) ] object_refs: list[ObjectRef] = [] has_generator = False for idx, fut in enumerate(self.futures): if isinstance(fut, ObjectRefGenerator): # ray.get cannot be called directly on the generator object – it must be iterated to obtain the individual ObjectRef instances first. for generated_ref in fut: object_refs.append(generated_ref) has_generator = True else: object_refs.append(fut) # Retrieve the concrete results. all_results = ray.get(object_refs) # If expanded generator was present we are in streaming mode. # Every ObjectRef now corresponds to a unique, ordered chunk of data if has_generator: return all_results if self.return_from_workers is not None: if self.called_workers is not None: # Create a mapping from global worker indices to local indices in all_results worker_to_result_idx = { worker: idx for idx, worker in enumerate(self.called_workers) } # # Filter return_from_workers to only include workers that were actually called valid_return_workers = [ w for w in self.return_from_workers if w in worker_to_result_idx ] # Map global worker indices to local result indices and get results return [ all_results[worker_to_result_idx[worker]] for worker in valid_return_workers ] else: return [all_results[worker] for worker in self.return_from_workers] return all_results
[docs] class RayWorkerBuilder:
[docs] @ray.remote class IsolatedWorkerInitializer: def __init__(self, ray_actor_class_fqn: str, *init_args, **init_kwargs): self.ray_actor_class_fqn = ray_actor_class_fqn self.init_args = init_args self.init_kwargs = init_kwargs
[docs] def create_worker( self, placement_group: PlacementGroup, placement_group_bundle_index: int, num_gpus: int, bundle_indices: Optional[tuple] = None, **extra_options: Optional[dict[str, Any]], ): """Create a Ray worker with the specified configuration. Order of precedence for worker options configuration (from lowest to highest): 1. Options passed by the user to __call__ (extra_options) 2. Options required by the worker via configure_worker (may override user options with warning) 3. Options set by the RayWorkerBuilder.__call__ (specifically scheduling strategy) If the worker needs to override user-provided options, it should log a warning to inform the user about the change and the reason for it. Args: placement_group: Ray placement group for resource allocation placement_group_bundle_index: Index of the bundle in the placement group num_gpus: Number of GPUs to allocate to this worker bundle_indices: Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable) extra_options: Additional options to pass to the Ray actor (may be overridden by actor's configure_worker(...) method) Returns: A Ray actor reference to the created worker """ # Set up worker arguments and resources module_name, class_name = self.ray_actor_class_fqn.rsplit(".", 1) module = importlib.import_module(module_name) worker_class = getattr(module, class_name) worker_kwargs = dict(self.init_kwargs) default_options = getattr(worker_class, "_default_options", {}) options = recursive_merge_options(default_options, extra_options) # Use the worker's configuration interface if available if hasattr(worker_class, "configure_worker"): # Get complete worker configuration from the worker class resources, env_vars, init_kwargs = worker_class.configure_worker( num_gpus=num_gpus, bundle_indices=bundle_indices, ) # Apply resource configuration if resources and "num_gpus" in resources: num_gpus = resources["num_gpus"] # Apply environment variables if provided if env_vars: if "runtime_env" not in options: options["runtime_env"] = {"env_vars": {}} if "env_vars" not in options["runtime_env"]: # type: ignore options["runtime_env"]["env_vars"] = {} # type: ignore for k, v in env_vars.items(): options["runtime_env"]["env_vars"][k] = v # type: ignore # Apply initialization parameters if init_kwargs: worker_kwargs.update(init_kwargs) # Create options for Ray actor options["scheduling_strategy"] = PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=True, ) options["num_gpus"] = num_gpus worker = worker_class.options(**options).remote( *self.init_args, **worker_kwargs ) return worker
def __init__(self, ray_actor_class_fqn: str, *args, **kwargs): self.ray_actor_class_fqn = ray_actor_class_fqn self.args = args self.kwargs = kwargs
[docs] def create_worker_async( self, placement_group: PlacementGroup, placement_group_bundle_index: int, num_gpus: float | int, bundle_indices: Optional[tuple[int, list[int]]] = None, **extra_options: Any, ) -> tuple[ray.ObjectRef, ray.actor.ActorHandle]: """Create a Ray worker asynchronously, returning futures. This method returns immediately with futures that can be awaited later. Args: placement_group: Ray placement group for resource allocation placement_group_bundle_index: Index of the bundle in the placement group num_gpus: Number of GPUs to allocate to this worker (can be fractional) bundle_indices: Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable) extra_options: Additional options to pass to the Ray actor Returns: Tuple of (worker_future, initializer_actor): - worker_future: A Ray ObjectRef that will resolve to the worker actor - initializer_actor: The initializer actor (needed to prevent GC) """ # Set up worker arguments and resources options = deepcopy(extra_options) initializer_options = {"runtime_env": options["runtime_env"]} isolated_initializer = self.IsolatedWorkerInitializer.options( # type: ignore # @ray.remote call **initializer_options ).remote(self.ray_actor_class_fqn, *self.args, **self.kwargs) # Return the future and the initializer actor worker_future = isolated_initializer.create_worker.remote( placement_group, placement_group_bundle_index, num_gpus, bundle_indices, **options, ) return worker_future, isolated_initializer
[docs] def __call__( self, placement_group: PlacementGroup, placement_group_bundle_index: int, num_gpus: float | int, bundle_indices: Optional[tuple[int, list[int]]] = None, **extra_options: Any, ) -> ray.actor.ActorHandle: """Create a Ray worker with the specified configuration. Order of precedence for worker options configuration (from lowest to highest): 1. Options passed by the user to __call__ (extra_options) 2. Options required by the worker via configure_worker (may override user options with warning) 3. Options set by the RayWorkerBuilder.__call__ (specifically scheduling strategy) If the worker needs to override user-provided options, it should log a warning to inform the user about the change and the reason for it. Args: placement_group: Ray placement group for resource allocation placement_group_bundle_index: Index of the bundle in the placement group num_gpus: Number of GPUs to allocate to this worker (can be fractional) bundle_indices: Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable) extra_options: Additional options to pass to the Ray actor (may be overridden by actor's configure_worker(...) method) Returns: A Ray actor reference to the created worker """ # Use the async method and then block on the result worker_future, isolated_initializer = self.create_worker_async( placement_group, placement_group_bundle_index, num_gpus, bundle_indices, **extra_options, ) # Block to get the worker worker = ray.get(worker_future) # We hold onto a reference to the initializer actor to avoid gc (would kill the child, 'real' actor) worker._RAY_INITIALIZER_ACTOR_REF_TO_AVOID_GC = isolated_initializer return worker
[docs] class RayWorkerGroup: """Manages a group of distributed Ray worker/actor processes that execute tasks in parallel. This class creates and manages Ray actor instances that run on resources allocated by a RayVirtualCluster. It handles: - Worker creation and placement on specific GPU resources - Setting up distributed training environment variables (rank, world size, etc.) - Executing methods across all workers in parallel - Collecting and aggregating results - Support for tied worker groups where multiple workers process the same data """ def __init__( self, cluster: RayVirtualCluster, remote_worker_builder: RayWorkerBuilder, workers_per_node: Optional[Union[int, list[int]]] = None, name_prefix: str = "", bundle_indices_list: Optional[list[tuple[int, list[int]]]] = None, sharding_annotations: Optional[NamedSharding] = None, env_vars: dict[str, str] = {}, ): """Initialize a group of distributed Ray workers. Args: cluster: RayVirtualCluster remote_worker_builder: Callable that launches a ray worker and has updatable options workers_per_node: Defaults to launch one worker per bundle in the cluster. Alternatively specify an int or list to launch a different number of workers per node. name_prefix: Optional prefix for the names of the workers bundle_indices_list: Explicit list of (node_idx, [local_bundle_indices]) tuples. Each tuple defines a tied group of workers placed on the same node. If provided, workers_per_node is ignored. sharding_annotations: NamedSharding object representing mapping of named axes to ranks (i.e. for TP, PP, etc.) """ self._workers: list[ray.actor.ActorHandle] = [] self._worker_metadata: list[dict[str, Any]] = [] self.cluster = cluster self.name_prefix = name_prefix self.sharding_annotations = sharding_annotations self.dp_leader_worker_indices: list[int] = [] # If explicit bundle indices are provided, use those if bundle_indices_list is None: # Create bundle_indices_list from workers_per_node specification # In this case, each worker is its own group (no tied workers) bundle_indices_list = [] # Get placement groups placement_groups = self.cluster.get_placement_groups() if len(placement_groups) == 1: # Single unified placement group pg = placement_groups[0] workers_per_group = [pg.bundle_count] else: # Multiple per-node placement groups workers_per_group = [pg.bundle_count for pg in placement_groups] # Determine how many workers per node/placement group if workers_per_node is None: workers_per_group = [pg.bundle_count for pg in placement_groups] elif isinstance(workers_per_node, int): workers_per_group = [workers_per_node] * len(placement_groups) elif isinstance(workers_per_node, list): if len(workers_per_node) == 1 and len(placement_groups) == 1: workers_per_group = workers_per_node elif len(workers_per_node) != len(placement_groups): raise ValueError( f"workers_per_node list length ({len(workers_per_node)}) must match " f"number of placement groups ({len(placement_groups)})" ) else: workers_per_group = workers_per_node else: raise ValueError( "workers_per_node must be None (for default distribution), an int, or a list" ) # Validate workers_per_group for i, (pg, worker_count) in enumerate( zip(placement_groups, workers_per_group) ): if worker_count > pg.bundle_count: raise ValueError( f"Placement group {i} has {pg.bundle_count} bundles, " f"but {worker_count} workers were requested" ) for bundle_idx in range(worker_count): # Each worker is its own single-element group # The first element is the PG index (node_idx in the context of tied workers) bundle_indices_list.append((i, [bundle_idx])) # Create workers based on the bundle_indices_list self._create_workers_from_bundle_indices( remote_worker_builder, bundle_indices_list, env_vars=env_vars, )
[docs] def get_dp_leader_worker_idx(self, dp_shard_idx: int) -> int: """Returns the index of the primary worker for a given data parallel shard.""" if not 0 <= dp_shard_idx < len(self.dp_leader_worker_indices): raise IndexError( f"Data parallel shard index {dp_shard_idx} is out of range. " f"Valid range is [0, {len(self.dp_leader_worker_indices) - 1}]" ) return self.dp_leader_worker_indices[dp_shard_idx]
[docs] def _create_workers_from_bundle_indices( self, remote_worker_builder: RayWorkerBuilder, bundle_indices_list: list[tuple[int, list[int]]], env_vars: dict[str, str] = {}, ) -> None: """Create workers based on explicit bundle indices for tied worker groups. Args: remote_worker_builder: Builder function for Ray actors bundle_indices_list: List of (node_idx, local_bundle_indices) tuples, where each tuple specifies a tied group with its node and local bundle indices. If the local_bundle_indices spans multiple nodes, the node_idx will be the first node's index in the tied group. """ self.master_address, self.master_port = ( self.cluster.get_master_address_and_port() ) # Update env_vars with the current environment variables for k, v in os.environ.items(): if k not in env_vars: env_vars[k] = v # Get the python environment for the actor actor_python_env = get_actor_python_env( remote_worker_builder.ray_actor_class_fqn ) if actor_python_env.startswith("uv"): # If the py_executable begins with uv it signals that we need to create a # local venv first and then replace the py_executable with the local venv's python. # The directory the venv will be created in is controlled by the env var # NEMO_RL_VENV_DIR and defaults to $GIT_ROOT/venvs/. py_executable = create_local_venv_on_each_node( py_executable=actor_python_env, venv_name=remote_worker_builder.ray_actor_class_fqn, ) else: py_executable = actor_python_env # Count total workers self.world_size = sum(len(indices) for _, indices in bundle_indices_list) global_rank = 0 # Collect all async creation calls worker_futures = [] worker_info = [] # Store metadata for each worker # Get all placement groups placement_groups = self.cluster.get_placement_groups() for group_idx, (pg_idx, local_bundle_indices) in enumerate(bundle_indices_list): current_group = [] if len(placement_groups) == 1: pg = placement_groups[0] else: pg = placement_groups[pg_idx] is_parallel_group = len(local_bundle_indices) > 1 for local_rank, bundle_idx in enumerate(local_bundle_indices): # Set up basic distributed environment variables worker_env_vars = deepcopy(env_vars) worker_env_vars.update( { "RANK": str(global_rank), "LOCAL_RANK": str(bundle_idx), "WORLD_SIZE": str(self.world_size), "MASTER_ADDR": self.master_address, "MASTER_PORT": str(self.master_port), "NODE_RANK": str(pg_idx), } ) worker_env_vars.pop("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", None) # Only the first worker in each group gets bundle_indices # This ensures only one worker per group is the model owner worker_bundle_indices = None if local_rank == 0: worker_bundle_indices = (pg_idx, local_bundle_indices) self.dp_leader_worker_indices.append(global_rank) # Create a descriptive name based on group structure name = ( f"{self.name_prefix}-grp{group_idx}-{local_rank}" if is_parallel_group else f"{self.name_prefix}-{pg_idx}-{bundle_idx}" ) # Calculate GPU resources num_gpus = ( 1 / self.cluster.max_colocated_worker_groups if self.cluster.use_gpus else 0 ) # Pass these options to the remote_worker_builder runtime_env = { "env_vars": worker_env_vars, "py_executable": py_executable, } runtime_env["env_vars"]["VIRTUAL_ENV"] = py_executable runtime_env["env_vars"]["UV_PROJECT_ENVIRONMENT"] = py_executable extra_options = {"runtime_env": runtime_env, "name": name} # start worker creation asynchronously worker_future, initializer = remote_worker_builder.create_worker_async( placement_group=pg, placement_group_bundle_index=bundle_idx, num_gpus=num_gpus, bundle_indices=worker_bundle_indices, **extra_options, ) # Store the future and metadata worker_idx = len(worker_futures) worker_futures.append((worker_future, initializer)) worker_info.append( { "group_idx": group_idx, "worker_idx": worker_idx, "node_idx": pg_idx, "local_rank": local_rank, "global_rank": global_rank, "name": name, "bundle_indices": worker_bundle_indices, "dp_shard_idx": group_idx, } ) current_group.append(worker_idx) global_rank += 1 print( f"Waiting for {len(worker_futures)} workers to finish initializing...", flush=True, ) worker_refs = [future for future, _ in worker_futures] workers = ray.get(worker_refs) for idx, (worker, (_, initializer)) in enumerate(zip(workers, worker_futures)): worker._RAY_INITIALIZER_ACTOR_REF_TO_AVOID_GC = initializer self._workers.append(worker) # Get the corresponding metadata info = worker_info[idx] self._worker_metadata.append( { "node_idx": info["node_idx"], "local_rank": info["local_rank"], "global_rank": info["global_rank"], "name": info["name"], "bundle_indices": info["bundle_indices"], "dp_shard_idx": info["group_idx"], } )
@property def workers(self) -> list[ray.actor.ActorHandle]: return self._workers @property def worker_metadata(self) -> list[dict[str, Any]]: return self._worker_metadata @property def dp_size(self) -> int: """Number of data parallel shards.""" return len(self.dp_leader_worker_indices)
[docs] def run_single_worker_single_data( self, method_name: str, worker_idx: int, *args, **kwargs, ) -> ray.ObjectRef: """Run a method on a single, specific worker. Args: method_name: Name of the method to call on the worker. worker_idx: The index of the worker to run the method on. *args, **kwargs: Arguments to pass to the method. Returns: ray.ObjectRef: A Ray future for the result. """ assert len(args) == 0, ( "run_single_worker_single_data will fail with args under certain circumstances. " "Please use kwargs instead. " "See https://github.com/NVIDIA-NeMo/RL/issues/582 for more details." ) worker = self.workers[worker_idx] method = getattr(worker, method_name) return method.remote(*args, **kwargs)
[docs] def run_all_workers_multiple_data( self, method_name: str, *args, run_rank_0_only_axes: list[str] | None = None, common_kwargs: Optional[dict[str, Any]] = None, **kwargs, ) -> list[ray.ObjectRef]: """Run a method on all workers in parallel with different data. Args: method_name: Name of the method to call on each worker *args: List of arguments to pass to workers/groups e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]] run_rank_0_only_axes: List of named axes for which only rank 0 should run the method. common_kwargs: Keyword arguments to pass to all workers **kwargs: Keyword arguments to pass to workers/groups e.g. {"key1": [value_for_worker_1, value_for_worker_2], "key2": [value_for_worker_1, value_for_worker_2]} Returns: list[ray.ObjectRef]: A list of ray futures """ assert len(args) == 0, ( "run_all_workers_multiple_data will fail with args under certain circumstances. " "Please use kwargs instead. " "See https://github.com/NVIDIA-NeMo/RL/issues/582 for more details." ) # Check at least one arg or kwarg is provided assert len(args) > 0 or len(kwargs) > 0, ( "At least one args (positional arguments) or kwargs (keyword arguments) must be provided in run_all_workers_multiple_data. " "Otherwise, please use run_all_workers_single_data." ) # Check all args and kwargs have the same length args_count = [len(arg) for arg in args] assert all(count == args_count[0] for count in args_count), ( "All args must have the same length" ) args_count = args_count[0] if len(args_count) > 0 else 0 kwargs_count = [len(value) for value in kwargs.values()] assert all(count == kwargs_count[0] for count in kwargs_count), ( "All kwargs must have the same length" ) kwargs_count = kwargs_count[0] if len(kwargs_count) > 0 else 0 if args_count > 0 and kwargs_count > 0: assert args_count == kwargs_count, ( "The number of args and kwargs must be the same in run_all_workers_multiple_data. " f"args length = {args_count}, kwargs length = {kwargs_count}" ) data_count = max(args_count, kwargs_count) # Check the data length is equal to the number of workers if run_rank_0_only_axes is None: assert data_count == len(self.workers), ( "data length should be equal to the number of workers: " f"data length = {data_count}, number of workers = {len(self.workers)}" ) futures = [] if run_rank_0_only_axes is None: run_rank_0_only_axes = [] if common_kwargs is None: common_kwargs = {} data_idx = 0 for worker_idx, worker in enumerate(self.workers): worker_coords = self.sharding_annotations.get_worker_coords(worker_idx) # Determine if this worker should receive data should_run = True for axis in self.sharding_annotations.names: if axis not in worker_coords: continue if axis in run_rank_0_only_axes and worker_coords[axis] != 0: should_run = False break if should_run: method = getattr(worker, method_name) worker_args = [arg[data_idx] for arg in args] worker_kwargs = {key: value[data_idx] for key, value in kwargs.items()} futures.append( method.remote(*worker_args, **worker_kwargs, **common_kwargs) ) data_idx += 1 assert data_idx == data_count, ( "data length should be equal to the number of workers started: " f"data length = {data_count}, number of workers started = {data_idx}" ) return futures
[docs] def run_all_workers_single_data( self, method_name: str, *args, run_rank_0_only_axes: list[str] | None = None, **kwargs, ) -> list[ray.ObjectRef]: """Run a method on all workers in parallel with the same data. Args: method_name: Name of the method to call on each worker *args, **kwargs: Arguments to pass to the method run_rank_0_only_axes: List of named axes for which only rank 0 should run the method. Returns: list[ray.ObjectRef]: A list of ray futures """ assert len(args) == 0, ( "run_all_workers_single_data will fail with args under certain circumstances. " "Please use kwargs instead. " "See https://github.com/NVIDIA-NeMo/RL/issues/582 for more details." ) futures = [] if run_rank_0_only_axes is None: run_rank_0_only_axes = [] for worker_idx, worker in enumerate(self.workers): worker_coords = self.sharding_annotations.get_worker_coords(worker_idx) # Determine if this worker should receive data should_run = True for axis in self.sharding_annotations.names: if axis not in worker_coords: continue if axis in run_rank_0_only_axes and worker_coords[axis] != 0: should_run = False break if should_run: method = getattr(worker, method_name) futures.append(method.remote(*args, **kwargs)) return futures
[docs] def run_all_workers_sharded_data( self, method_name: str, *args, in_sharded_axes: list[str] | None = None, replicate_on_axes: list[str] | None = None, output_is_replicated: list[str] | None = None, make_dummy_calls_to_free_axes: bool = False, common_kwargs: Optional[dict[str, Any]] = None, **kwargs, ) -> MultiWorkerFuture: """Run a method on all workers in parallel with sharded data. Axes in in_sharded_axes: Data is already split across these axes, so we just send the appropriate slice to each worker (along this axis) Axes in replicate_on_axes: Data is replicated to all workers along these dimensions Free axes (axes not in either list): Data is only sent to workers at index 0 of these axes Args: method_name: Name of the method to call on each worker *args: List of arguments to pass to workers/groups e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]] in_sharded_axes: List of axes that are sharded replicate_on_axes: List of axes that are to be replicated output_is_replicated: List of axes along which the output is replicated (and we should just return the first result). We also just return from rank 0 of free axes. make_dummy_calls_to_free_axes: Whether to make dummy calls (with None) to workers that aren't rank 0 on 'free axes' (axes not in in_sharded_axes or replicate_on_axes). common_kwargs: Keyword arguments to pass to all workers **kwargs: Keyword arguments to pass to workers/groups e.g. {"key1": [value_for_worker_1, value_for_worker_2], "key2": [value_for_worker_1, value_for_worker_2]} Returns: MultiWorkerFuture: Object containing futures and their associated worker information """ assert len(args) == 0, ( "run_all_workers_sharded_data will fail with args under certain circumstances. " "Please use kwargs instead. " "See https://github.com/NVIDIA-NeMo/RL/issues/582 for more details." ) if self.sharding_annotations is None: raise ValueError( "Sharding annotations must be provided to use sharded data distribution" ) if common_kwargs is None: common_kwargs = {} if in_sharded_axes is None: in_sharded_axes = [] if replicate_on_axes is None: replicate_on_axes = [] if output_is_replicated is None: output_is_replicated = [] futures = [] # Validate axes for axis in in_sharded_axes + replicate_on_axes: if axis not in self.sharding_annotations.names: raise ValueError( f"Axis '{axis}' not found in sharding annotations. Valid axes: {self.sharding_annotations.names}" ) # Check for overlapping axes overlap = set(in_sharded_axes).intersection(set(replicate_on_axes)) if overlap: raise ValueError(f"Axes cannot be both sharded and replicated: {overlap}") called_workers = [] return_from_workers = [] # For each worker, determine what data it should receive for worker_idx, worker in enumerate(self._workers): # Get the worker's coordinates in the sharding space worker_coords = self.sharding_annotations.get_worker_coords(worker_idx) # Determine if this worker should receive data should_receive_data = True return_from_this_worker = True for axis in self.sharding_annotations.names: if axis not in worker_coords: continue # We call axes not in in_sharded_axes or replicate_on_axes free axes. if ( axis not in in_sharded_axes and axis not in replicate_on_axes and worker_coords[axis] != 0 ): # For free axes, only workers at index 0 receive data should_receive_data = False return_from_this_worker = False break if axis in output_is_replicated: if worker_coords[axis] != 0: return_from_this_worker = False if return_from_this_worker: return_from_workers.append(worker_idx) if should_receive_data: # Find the appropriate data slice for this worker worker_args = args worker_kwargs = kwargs for axis in in_sharded_axes: if axis in worker_coords: # Select the appropriate slice for this axis worker_args = [arg[worker_coords[axis]] for arg in worker_args] worker_kwargs = { key: value[worker_coords[axis]] for key, value in worker_kwargs.items() } # Call the method on the worker with its data slice future = getattr(worker, method_name).remote( *worker_args, **worker_kwargs, **common_kwargs ) futures.append(future) called_workers.append(worker_idx) else: # If this worker doesn't need data: if make_dummy_calls_to_free_axes: # If make_dummy_calls_to_free_axes is True, just call the method with None worker_args = [None] * len(args) worker_kwargs = {key: None for key in kwargs.keys()} future = getattr(worker, method_name).remote( *worker_args, **worker_kwargs, **common_kwargs ) futures.append(future) called_workers.append(worker_idx) else: # Else, don't call the method at all pass return MultiWorkerFuture( futures=futures, called_workers=called_workers, return_from_workers=return_from_workers, )
[docs] def get_all_worker_results( self, future_bundle: MultiWorkerFuture, return_generators_as_proxies: bool = False, ) -> list[Any]: """Get results from all workers, optionally filtering to get just one result per tied worker group. Args: future_bundle: MultiWorkerFuture containing futures and worker information. return_generators_as_proxies: If True, and a future in the bundle is an ObjectRefGenerator, return the ObjectRefGenerator itself instead of consuming it. Returns: List of results, deduplicated as specified in the future_bundle """ return future_bundle.get_results( self, return_generators_as_proxies=return_generators_as_proxies )
[docs] def shutdown( self, cleanup_method: Optional[str] = None, timeout: Optional[float] = 30.0, force: bool = False, ) -> bool: """Shutdown all workers in the worker group. Args: cleanup_method: Optional method name to call on each worker before termination. If provided, this method will be called on each worker to allow for graceful cleanup. timeout: Timeout in seconds for graceful shutdown. Only applicable if cleanup_method is provided. If None, wait indefinitely for workers to complete their cleanup. force: If True, forcefully terminate workers with ray.kill() even if cleanup_method is provided. If cleanup_method is None, workers are always forcefully terminated. Returns: bool: True if all workers were successfully shut down """ if not self._workers: return True success = True # First attempt graceful shutdown if cleanup method is provided and force=False if cleanup_method is not None and not force: try: # Call cleanup method on all workers futures = self.run_all_workers_single_data(cleanup_method) # Wait for all cleanup operations to complete with timeout if timeout is not None: ray.get(futures, timeout=timeout) else: ray.get(futures) except (ray.exceptions.RayTaskError, ray.exceptions.GetTimeoutError) as e: success = False print( f"Error during graceful shutdown: {e}. Falling back to force termination." ) force = True # Force kill any remaining workers if force or cleanup_method is None: initializers_to_kill = [] for worker in self._workers: if hasattr(worker, "_RAY_INITIALIZER_ACTOR_REF_TO_AVOID_GC"): # Store the initializer ref before the main worker is killed, # as killing the worker might affect accessibility of this attribute later. initializer = getattr( worker, "_RAY_INITIALIZER_ACTOR_REF_TO_AVOID_GC", None ) if initializer: initializers_to_kill.append(initializer) try: ray.kill(worker) except Exception as e: success = False print(f"Error killing worker: {e}") # Now, explicitly kill the initializer actors # This makes their termination more deterministic than relying solely on Ray's GC. for initializer in initializers_to_kill: try: ray.kill(initializer) except Exception as e: print(f"Error killing initializer actor for a worker: {e}") # Clear worker lists self._workers = [] self._worker_metadata = [] return success