nemo_rl.models.policy.workers.base_policy_worker#
Module Contents#
Classes#
Base class for policy workers with shared functionality. |
API#
- class nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker#
Base class for policy workers with shared functionality.
- init_collective(
- ip: str,
- port: int,
- world_size: int,
- *,
- train_world_size: int,
Initialize the collective communication.
- Parameters:
ip – IP address for the process group
port – Port for the process group
world_size – Total world size (train_world_size + inference_world_size)
train_world_size – Number of training workers (used in inference cluster)
- is_alive() bool#
Check if the worker is alive.
- reset_peak_memory_stats() None#
Reset peak memory statistics.
- get_gpu_info() dict[str, Any]#
Return information about the GPU being used by this worker.
- report_device_id() str#
Report the UUID of the current CUDA device using NVML.
- Returns:
UUID of the device in the format “GPU-xxxxx”
- Return type:
str
- get_zmq_address() str#
Get the ZMQ address for the current device.
- maybe_init_zmq() None#
Initialize the ZMQ socket if it doesn’t exist.
- get_free_memory_bytes() int#
Get the available free memory.
- shutdown() bool#
Shutdown the policy.
- start_gpu_profiling() None#
Start GPU profiling.
- stop_gpu_profiling() None#
Stop GPU profiling.
- report_node_ip_and_gpu_id() tuple[str, int]#
Report the node IP and GPU ID of the current worker.
- get_reference_policy_logprobs(
- *,
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- micro_batch_size: Optional[int] = None,
Get the logprobs from the reference policy for a batch of data.
If micro_batch_size is provided, it will be used instead of the configured logprob_batch_size.
- Returns:
a BatchedDataDict with key “reference_logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- finish_training(*args: Any, **kwargs: Any) None#