nemo_rl.models.policy.workers.base_policy_worker#

Module Contents#

Classes#

AbstractPolicyWorker

Base class for policy workers with shared functionality.

API#

class nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker#

Base class for policy workers with shared functionality.

init_collective(
ip: str,
port: int,
world_size: int,
*,
train_world_size: int,
) None#

Initialize the collective communication.

Parameters:
  • ip – IP address for the process group

  • port – Port for the process group

  • world_size – Total world size (train_world_size + inference_world_size)

  • train_world_size – Number of training workers (used in inference cluster)

is_alive() bool#

Check if the worker is alive.

reset_peak_memory_stats() None#

Reset peak memory statistics.

get_gpu_info() dict[str, Any]#

Return information about the GPU being used by this worker.

report_device_id() str#

Report the UUID of the current CUDA device using NVML.

Returns:

UUID of the device in the format “GPU-xxxxx”

Return type:

str

get_zmq_address() str#

Get the ZMQ address for the current device.

maybe_init_zmq() None#

Initialize the ZMQ socket if it doesn’t exist.

get_free_memory_bytes() int#

Get the available free memory.

shutdown() bool#

Shutdown the policy.

start_gpu_profiling() None#

Start GPU profiling.

stop_gpu_profiling() None#

Stop GPU profiling.

report_node_ip_and_gpu_id() tuple[str, int]#

Report the node IP and GPU ID of the current worker.

get_reference_policy_logprobs(
*,
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
micro_batch_size: Optional[int] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec]#

Get the logprobs from the reference policy for a batch of data.

If micro_batch_size is provided, it will be used instead of the configured logprob_batch_size.

Returns:

a BatchedDataDict with key “reference_logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.

finish_training(*args: Any, **kwargs: Any) None#