nemo_rl.models.policy.dtensor_policy_worker_v2#

Module Contents#

Classes#

API#

class nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2(
config: nemo_rl.models.policy.PolicyConfig,
tokenizer: transformers.AutoTokenizer,
processor: Optional[transformers.AutoProcessor] = None,
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
init_optimizer: bool = True,
init_reference_model: bool = True,
**kwargs: Any,
)#

Initialization

Initialize the DTensorPolicyWorkerV2.

__repr__() str#

Customizes the actor’s prefix in the Ray logs.

This makes it easier to identify which worker is producing specific log messages.

_apply_temperature_scaling(logits: torch.Tensor) torch.Tensor#
init_collective(
ip: str,
port: int,
world_size: int,
*,
train_world_size: int,
) None#
is_alive() bool#
check_model_allow_flash_attn_args(model_config) bool#
reset_peak_memory_stats() None#
get_gpu_info() dict[str, Any]#

Return information about the GPU being used by this worker.

train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
eval_mode: bool = False,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) dict[str, Any]#

Train the policy on a batch of data with a given loss function.

get_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
micro_batch_size: Optional[int] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec]#

Get the logprobs of the model for a batch of data.

Uses the configured logprob_batch_size to do microbatching.

Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format.

Returns:

a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.

score(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ScoreOutputSpec]#
get_topk_logits(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
k: int,
micro_batch_size: Optional[int] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]#

Return per-position top-k logits and corresponding global indices.

Notes:

  • Return shapes are [B, S, k].

  • Computes top-k over the full sequence (no trimming of the last position).

  • If alignment with next-token targets is required, the caller should handle it.

  • If logits are TP-sharded DTensor, performs distributed global top-k across TP.

  • Supports context parallelism with proper CP gather.

  • Otherwise, computes local top-k on full-vocab tensor.

use_reference_model() Generator[None, None, None]#

Context manager that temporarily swaps the reference model and active model.

On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references On exit: Restores original references and re-flips cuda/cpu

get_reference_policy_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
micro_batch_size: Optional[int] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec]#

Get the logprobs from the reference policy for a batch of data.

Returns:

a BatchedDataDict with key “reference_logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.

_add_noise_to_weights() None#

Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only.

return_state_dict()#
return_model_config() dict[str, Any]#

Return the model configuration as a dictionary.

Returns:

Model configuration dictionary

Return type:

dict

report_device_id() str#

Report the UUID of the current CUDA device using NVML.

Returns:

UUID of the device in the format “GPU-xxxxx”

Return type:

str

get_zmq_address()#

Get the ZMQ address for the current device.

maybe_init_zmq()#

Initialize the ZMQ socket if it doesn’t exist.

prepare_refit_info() Optional[dict[str, Any]]#

Prepare state dict metadata for weight refitting and IPC streaming.

get_free_memory_bytes() int#

Get the available free memory.

stream_weights_via_ipc_zmq(buffer_size_bytes: int = 0) None#

Stream model weights to peer process via ZMQ IPC socket.

broadcast_weights_for_collective() None#

Broadcast the weights for collective communication.

prepare_for_lp_inference() None#
prepare_for_training(*args, **kwargs) None#
offload_before_refit() None#

Offload the optimizer to the CPU.

offload_after_refit() None#

Offload as much as possible on the CPU.

move_to_device(
model: torch.nn.Module,
device: str | torch.device,
) torch.nn.Module#
move_buffer_to_device(
model: torch.nn.Module,
device: str | torch.device,
) torch.nn.Module#
move_to_cuda(model: torch.nn.Module) torch.nn.Module#
move_to_cpu(model: torch.nn.Module) torch.nn.Module#
save_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
checkpointing_cfg: Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None,
) None#

Save a checkpoint of the model.

the optimizer states are saved only if optimizer and optimizer_path are provided.

load_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
) None#

Load a checkpoint into the model.

shutdown() None#

Shutdown the policy.

start_gpu_profiling() None#

Start GPU profiling.

stop_gpu_profiling() None#

Stop GPU profiling.

report_node_ip_and_gpu_id() list[tuple[str, int]]#

Report the node IP and GPU ID of the current worker.