nemo_rl.models.policy.dtensor_policy_worker_v2#
Module Contents#
Classes#
API#
- class nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2(
- config: nemo_rl.models.policy.PolicyConfig,
- tokenizer: transformers.AutoTokenizer,
- processor: Optional[transformers.AutoProcessor] = None,
- weights_path: Optional[str] = None,
- optimizer_path: Optional[str] = None,
- init_optimizer: bool = True,
- init_reference_model: bool = True,
- **kwargs: Any,
Initialization
Initialize the DTensorPolicyWorkerV2.
- __repr__() str#
Customizes the actor’s prefix in the Ray logs.
This makes it easier to identify which worker is producing specific log messages.
- _apply_temperature_scaling(logits: torch.Tensor) torch.Tensor#
- init_collective(
- ip: str,
- port: int,
- world_size: int,
- *,
- train_world_size: int,
- is_alive() bool#
- check_model_allow_flash_attn_args(model_config) bool#
- reset_peak_memory_stats() None#
- get_gpu_info() dict[str, Any]#
Return information about the GPU being used by this worker.
- train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- eval_mode: bool = False,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
Train the policy on a batch of data with a given loss function.
- get_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- micro_batch_size: Optional[int] = None,
Get the logprobs of the model for a batch of data.
Uses the configured logprob_batch_size to do microbatching.
Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format.
- Returns:
a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- score( ) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ScoreOutputSpec]#
- get_topk_logits(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- k: int,
- micro_batch_size: Optional[int] = None,
Return per-position top-k logits and corresponding global indices.
Notes:
Return shapes are [B, S, k].
Computes top-k over the full sequence (no trimming of the last position).
If alignment with next-token targets is required, the caller should handle it.
If logits are TP-sharded DTensor, performs distributed global top-k across TP.
Supports context parallelism with proper CP gather.
Otherwise, computes local top-k on full-vocab tensor.
- use_reference_model() Generator[None, None, None]#
Context manager that temporarily swaps the reference model and active model.
On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references On exit: Restores original references and re-flips cuda/cpu
- get_reference_policy_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- micro_batch_size: Optional[int] = None,
Get the logprobs from the reference policy for a batch of data.
- Returns:
a BatchedDataDict with key “reference_logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- _add_noise_to_weights() None#
Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only.
- return_state_dict()#
- return_model_config() dict[str, Any]#
Return the model configuration as a dictionary.
- Returns:
Model configuration dictionary
- Return type:
dict
- report_device_id() str#
Report the UUID of the current CUDA device using NVML.
- Returns:
UUID of the device in the format “GPU-xxxxx”
- Return type:
str
- get_zmq_address()#
Get the ZMQ address for the current device.
- maybe_init_zmq()#
Initialize the ZMQ socket if it doesn’t exist.
- prepare_refit_info() Optional[dict[str, Any]]#
Prepare state dict metadata for weight refitting and IPC streaming.
- get_free_memory_bytes() int#
Get the available free memory.
- stream_weights_via_ipc_zmq(buffer_size_bytes: int = 0) None#
Stream model weights to peer process via ZMQ IPC socket.
- broadcast_weights_for_collective() None#
Broadcast the weights for collective communication.
- prepare_for_lp_inference() None#
- prepare_for_training(*args, **kwargs) None#
- offload_before_refit() None#
Offload the optimizer to the CPU.
- offload_after_refit() None#
Offload as much as possible on the CPU.
- move_to_device(
- model: torch.nn.Module,
- device: str | torch.device,
- move_buffer_to_device(
- model: torch.nn.Module,
- device: str | torch.device,
- move_to_cuda(model: torch.nn.Module) torch.nn.Module#
- move_to_cpu(model: torch.nn.Module) torch.nn.Module#
- save_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
- tokenizer_path: Optional[str] = None,
- checkpointing_cfg: Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None,
Save a checkpoint of the model.
the optimizer states are saved only if
optimizerandoptimizer_pathare provided.
- load_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
Load a checkpoint into the model.
- shutdown() None#
Shutdown the policy.
- start_gpu_profiling() None#
Start GPU profiling.
- stop_gpu_profiling() None#
Stop GPU profiling.
- report_node_ip_and_gpu_id() list[tuple[str, int]]#
Report the node IP and GPU ID of the current worker.