nemo_rl.distributed.worker_groups#

Module Contents#

Classes#

MultiWorkerFuture

Container for Ray futures with associated worker information.

RayWorkerBuilder

RayWorkerGroup

Manages a group of distributed Ray worker/actor processes that execute tasks in parallel.

API#

class nemo_rl.distributed.worker_groups.MultiWorkerFuture[source]#

Container for Ray futures with associated worker information.

futures: list[ray.ObjectRef]#

None

return_from_workers: Optional[list[int]]#

None

called_workers: Optional[list[int]]#

None

get_results(
worker_group: nemo_rl.distributed.worker_groups.RayWorkerGroup,
return_generators_as_proxies: bool = False,
) list[Any][source]#

Get results from the futures, optionally respecting tied workers.

The method uses worker_group.worker_to_tied_group_index to identify which tied worker group each worker belongs to, then selects only the first result from each group.

Parameters:
  • worker_group – The RayWorkerGroup that spawned the futures. The mapping contained in worker_group.worker_to_tied_group_index is required for the deduplication path.

  • return_generators_as_proxies – If True, and a future is an ObjectRefGenerator, return the ObjectRefGenerator itself instead of consuming it.

Returns:

List of results

class nemo_rl.distributed.worker_groups.RayWorkerBuilder(ray_actor_class_fqn: str, *args, **kwargs)[source]#

Initialization

class IsolatedWorkerInitializer(
ray_actor_class_fqn: str,
*init_args,
**init_kwargs,
)[source]#

Initialization

create_worker(
placement_group: ray.util.placement_group.PlacementGroup,
placement_group_bundle_index: int,
num_gpus: int,
bundle_indices: Optional[tuple] = None,
**extra_options: Optional[dict[str, Any]],
)[source]#

Create a Ray worker with the specified configuration.

Order of precedence for worker options configuration (from lowest to highest):

  1. Options passed by the user to call (extra_options)

  2. Options required by the worker via configure_worker (may override user options with warning)

  3. Options set by the RayWorkerBuilder.call (specifically scheduling strategy)

If the worker needs to override user-provided options, it should log a warning to inform the user about the change and the reason for it.

Parameters:
  • placement_group – Ray placement group for resource allocation

  • placement_group_bundle_index – Index of the bundle in the placement group

  • num_gpus – Number of GPUs to allocate to this worker

  • bundle_indices – Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable)

  • extra_options – Additional options to pass to the Ray actor (may be overridden by actor’s configure_worker(…) method)

Returns:

A Ray actor reference to the created worker

create_worker_async(
placement_group: ray.util.placement_group.PlacementGroup,
placement_group_bundle_index: int,
num_gpus: float | int,
bundle_indices: Optional[tuple[int, list[int]]] = None,
**extra_options: Any,
) tuple[ray.ObjectRef, ray.actor.ActorHandle][source]#

Create a Ray worker asynchronously, returning futures.

This method returns immediately with futures that can be awaited later.

Parameters:
  • placement_group – Ray placement group for resource allocation

  • placement_group_bundle_index – Index of the bundle in the placement group

  • num_gpus – Number of GPUs to allocate to this worker (can be fractional)

  • bundle_indices – Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable)

  • extra_options – Additional options to pass to the Ray actor

Returns:

  • worker_future: A Ray ObjectRef that will resolve to the worker actor

  • initializer_actor: The initializer actor (needed to prevent GC)

Return type:

Tuple of (worker_future, initializer_actor)

__call__(
placement_group: ray.util.placement_group.PlacementGroup,
placement_group_bundle_index: int,
num_gpus: float | int,
bundle_indices: Optional[tuple[int, list[int]]] = None,
**extra_options: Any,
) ray.actor.ActorHandle[source]#

Create a Ray worker with the specified configuration.

Order of precedence for worker options configuration (from lowest to highest):

  1. Options passed by the user to call (extra_options)

  2. Options required by the worker via configure_worker (may override user options with warning)

  3. Options set by the RayWorkerBuilder.call (specifically scheduling strategy)

If the worker needs to override user-provided options, it should log a warning to inform the user about the change and the reason for it.

Parameters:
  • placement_group – Ray placement group for resource allocation

  • placement_group_bundle_index – Index of the bundle in the placement group

  • num_gpus – Number of GPUs to allocate to this worker (can be fractional)

  • bundle_indices – Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable)

  • extra_options – Additional options to pass to the Ray actor (may be overridden by actor’s configure_worker(…) method)

Returns:

A Ray actor reference to the created worker

class nemo_rl.distributed.worker_groups.RayWorkerGroup(
cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
remote_worker_builder: nemo_rl.distributed.worker_groups.RayWorkerBuilder,
workers_per_node: Optional[Union[int, list[int]]] = None,
name_prefix: str = '',
bundle_indices_list: Optional[list[tuple[int, list[int]]]] = None,
sharding_annotations: Optional[nemo_rl.distributed.named_sharding.NamedSharding] = None,
env_vars: dict[str, str] = {},
)[source]#

Manages a group of distributed Ray worker/actor processes that execute tasks in parallel.

This class creates and manages Ray actor instances that run on resources allocated by a RayVirtualCluster. It handles:

  • Worker creation and placement on specific GPU resources

  • Setting up distributed training environment variables (rank, world size, etc.)

  • Executing methods across all workers in parallel

  • Collecting and aggregating results

  • Support for tied worker groups where multiple workers process the same data

Initialization

Initialize a group of distributed Ray workers.

Parameters:
  • cluster – RayVirtualCluster

  • remote_worker_builder – Callable that launches a ray worker and has updatable options

  • workers_per_node – Defaults to launch one worker per bundle in the cluster. Alternatively specify an int or list to launch a different number of workers per node.

  • name_prefix – Optional prefix for the names of the workers

  • bundle_indices_list – Explicit list of (node_idx, [local_bundle_indices]) tuples. Each tuple defines a tied group of workers placed on the same node. If provided, workers_per_node is ignored.

  • sharding_annotations – NamedSharding object representing mapping of named axes to ranks (i.e. for TP, PP, etc.)

get_dp_leader_worker_idx(dp_shard_idx: int) int[source]#

Returns the index of the primary worker for a given data parallel shard.

_create_workers_from_bundle_indices(
remote_worker_builder: nemo_rl.distributed.worker_groups.RayWorkerBuilder,
bundle_indices_list: list[tuple[int, list[int]]],
env_vars: dict[str, str] = {},
) None[source]#

Create workers based on explicit bundle indices for tied worker groups.

Parameters:
  • remote_worker_builder – Builder function for Ray actors

  • bundle_indices_list – List of (node_idx, local_bundle_indices) tuples, where each tuple specifies a tied group with its node and local bundle indices. If the local_bundle_indices spans multiple nodes, the node_idx will be the first node’s index in the tied group.

property workers: list[ray.actor.ActorHandle]#
property worker_metadata: list[dict[str, Any]]#
property dp_size: int#

Number of data parallel shards.

run_single_worker_single_data(
method_name: str,
worker_idx: int,
*args,
**kwargs,
) ray.ObjectRef[source]#

Run a method on a single, specific worker.

Parameters:
  • method_name – Name of the method to call on the worker.

  • worker_idx – The index of the worker to run the method on.

  • *args – Arguments to pass to the method.

  • **kwargs – Arguments to pass to the method.

Returns:

A Ray future for the result.

Return type:

ray.ObjectRef

run_all_workers_multiple_data(
method_name: str,
*args,
run_rank_0_only_axes: list[str] | None = None,
common_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
) list[ray.ObjectRef][source]#

Run a method on all workers in parallel with different data.

Parameters:
  • method_name – Name of the method to call on each worker

  • *args – List of arguments to pass to workers/groups e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]]

  • run_rank_0_only_axes – List of named axes for which only rank 0 should run the method.

  • common_kwargs – Keyword arguments to pass to all workers

  • **kwargs – Keyword arguments to pass to workers/groups e.g. {β€œkey1”: [value_for_worker_1, value_for_worker_2], β€œkey2”: [value_for_worker_1, value_for_worker_2]}

Returns:

A list of ray futures

Return type:

list[ray.ObjectRef]

run_all_workers_single_data(
method_name: str,
*args,
run_rank_0_only_axes: list[str] | None = None,
**kwargs,
) list[ray.ObjectRef][source]#

Run a method on all workers in parallel with the same data.

Parameters:
  • method_name – Name of the method to call on each worker

  • *args – Arguments to pass to the method

  • **kwargs – Arguments to pass to the method

  • run_rank_0_only_axes – List of named axes for which only rank 0 should run the method.

Returns:

A list of ray futures

Return type:

list[ray.ObjectRef]

run_all_workers_sharded_data(
method_name: str,
*args,
in_sharded_axes: list[str] | None = None,
replicate_on_axes: list[str] | None = None,
output_is_replicated: list[str] | None = None,
make_dummy_calls_to_free_axes: bool = False,
common_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
) nemo_rl.distributed.worker_groups.MultiWorkerFuture[source]#

Run a method on all workers in parallel with sharded data.

Axes in in_sharded_axes: Data is already split across these axes, so we just send the appropriate slice to each worker (along this axis) Axes in replicate_on_axes: Data is replicated to all workers along these dimensions Free axes (axes not in either list): Data is only sent to workers at index 0 of these axes

Parameters:
  • method_name – Name of the method to call on each worker

  • *args – List of arguments to pass to workers/groups e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]]

  • in_sharded_axes – List of axes that are sharded

  • replicate_on_axes – List of axes that are to be replicated

  • output_is_replicated – List of axes along which the output is replicated (and we should just return the first result). We also just return from rank 0 of free axes.

  • make_dummy_calls_to_free_axes – Whether to make dummy calls (with None) to workers that aren’t rank 0 on β€˜free axes’ (axes not in in_sharded_axes or replicate_on_axes).

  • common_kwargs – Keyword arguments to pass to all workers

  • **kwargs – Keyword arguments to pass to workers/groups e.g. {β€œkey1”: [value_for_worker_1, value_for_worker_2], β€œkey2”: [value_for_worker_1, value_for_worker_2]}

Returns:

Object containing futures and their associated worker information

Return type:

MultiWorkerFuture

get_all_worker_results(
future_bundle: nemo_rl.distributed.worker_groups.MultiWorkerFuture,
return_generators_as_proxies: bool = False,
) list[Any][source]#

Get results from all workers, optionally filtering to get just one result per tied worker group.

Parameters:
  • future_bundle – MultiWorkerFuture containing futures and worker information.

  • return_generators_as_proxies – If True, and a future in the bundle is an ObjectRefGenerator, return the ObjectRefGenerator itself instead of consuming it.

Returns:

List of results, deduplicated as specified in the future_bundle

shutdown(
cleanup_method: Optional[str] = None,
timeout: Optional[float] = 30.0,
force: bool = False,
) bool[source]#

Shutdown all workers in the worker group.

Parameters:
  • cleanup_method – Optional method name to call on each worker before termination. If provided, this method will be called on each worker to allow for graceful cleanup.

  • timeout – Timeout in seconds for graceful shutdown. Only applicable if cleanup_method is provided. If None, wait indefinitely for workers to complete their cleanup.

  • force – If True, forcefully terminate workers with ray.kill() even if cleanup_method is provided. If cleanup_method is None, workers are always forcefully terminated.

Returns:

True if all workers were successfully shut down

Return type:

bool