nemo_rl.models.policy.dtensor_policy_worker#
Module Contents#
Classes#
Functions#
Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference. |
|
Copy the state dict generator to CPU memory. |
API#
- nemo_rl.models.policy.dtensor_policy_worker.unshard_fsdp2_model(
- model: torch.nn.Module,
Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference.
- nemo_rl.models.policy.dtensor_policy_worker.get_cpu_state_dict(
- state_generator: Iterable[tuple[str, Union[torch.Tensor, torch.distributed.tensor.DTensor]]],
- pin_memory: bool = False,
Copy the state dict generator to CPU memory.
- Parameters:
state_generator (Iterable[tuple[str, Union[torch.Tensor, DTensor]]]) – An iterable that yields (key, tensor) pairs from a model state.
pin_memory (bool, optional) – Whether to allocate the CPU tensors in pinned memory for faster GPU transfer. Defaults to False.
- Returns:
A dictionary mapping parameter names to CPU tensors.
- Return type:
dict[str, torch.Tensor]
- class nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker(
- config: nemo_rl.models.policy.PolicyConfig,
- tokenizer: transformers.AutoTokenizer,
- processor: Optional[transformers.AutoProcessor] = None,
- weights_path: Optional[str] = None,
- optimizer_path: Optional[str] = None,
- init_optimizer: bool = True,
- init_reference_model: bool = True,
- **kwargs: Any,
Initialization
Initialize the DTensorPolicyWorker.
- __repr__() str#
Customizes the actor’s prefix in the Ray logs.
This makes it easier to identify which worker is producing specific log messages.
- static create_context_parallel_ctx(
- cp_mesh: torch.distributed.device_mesh.DeviceMesh,
- cp_buffers: list[torch.Tensor],
- cp_seq_dims: list[int],
- cp_no_restore_buffers: Set[torch.Tensor],
- cp_rotate_method: Optional[str] = None,
Create a context parallel context.
- Parameters:
cp_mesh (DeviceMesh) – The device mesh for context parallel.
cp_buffers (list[torch.Tensor]) – The buffers for context parallel.
cp_seq_dims (list[int]) – The sequence dimensions for context parallel.
cp_no_restore_buffers (Set[torch.Tensor]) – The no restore buffers for context parallel.
cp_rotate_method (str) – The rotation method for context parallel, such as “allgather” or “addtoall”.
- _apply_temperature_scaling(logits: torch.Tensor) torch.Tensor#
- static train_context(
- cp_context: Optional[Generator[None, None, None]] = None,
- init_collective(
- ip: str,
- port: int,
- world_size: int,
- *,
- train_world_size: int,
Initialize the collective communication.
- is_alive() bool#
- reset_peak_memory_stats() None#
- get_gpu_info() dict[str, Any]#
Return information about the GPU being used by this worker.
- train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- eval_mode: bool = False,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
Train the policy on a batch of data with a given loss function.
- get_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- micro_batch_size: Optional[int] = None,
Get the logprobs of the model for a batch of data.
Uses the configured logprob_batch_size to do microbatching.
Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format.
- Returns:
a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- score( ) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ScoreOutputSpec]#
- get_topk_logits(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- k: int,
- micro_batch_size: Optional[int] = None,
Return per-position top-k logits and corresponding global indices.
Notes:
Return shapes are [B, S, k].
Computes top-k over the full sequence (no trimming of the last position).
If alignment with next-token targets is required, the caller should handle it.
If logits are TP-sharded DTensor, performs distributed global top-k across TP.
Supports context parallelism with proper CP gather.
Otherwise, computes local top-k on full-vocab tensor.
- use_reference_model() Generator[None, None, None]#
Context manager that temporarily swaps the reference model and active model.
On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references On exit: Restores original references and re-flips cuda/cpu
- get_reference_policy_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- micro_batch_size: Optional[int] = None,
Get the logprobs from the reference policy for a batch of data.
- Returns:
a BatchedDataDict with key “reference_logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- _add_noise_to_weights() None#
Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only.
- return_state_dict()#
- return_model_config() dict[str, Any]#
Return the model configuration as a dictionary.
- Returns:
Model configuration dictionary
- Return type:
dict
- report_device_id() str#
Report the UUID of the current CUDA device using NVML.
- Returns:
UUID of the device in the format “GPU-xxxxx”
- Return type:
str
- get_zmq_address()#
Get the ZMQ address for the current device.
- maybe_init_zmq()#
Initialize the ZMQ socket if it doesn’t exist.
- prepare_refit_info() Optional[dict[str, Any]]#
Prepare state dict metadata for weight refitting and IPC streaming.
- get_free_memory_bytes() int#
Get the available free memory.
- stream_weights_via_ipc_zmq(buffer_size_bytes: int = 0) None#
Stream model weights to peer process via ZMQ IPC socket.
- broadcast_weights_for_collective() None#
Broadcast the weights for collective communication.
- prepare_for_lp_inference() None#
- prepare_for_training(*args, **kwargs) None#
- offload_before_refit() None#
Offload the optimizer to the CPU.
- offload_after_refit() None#
Offload as much as possible on the CPU.
- move_to_device(
- model: torch.nn.Module,
- device: str | torch.device,
- move_buffer_to_device(
- model: torch.nn.Module,
- device: str | torch.device,
- move_to_cuda(model: torch.nn.Module) torch.nn.Module#
- move_to_cpu(model: torch.nn.Module) torch.nn.Module#
- save_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
- tokenizer_path: Optional[str] = None,
Save a checkpoint of the model.
the optimizer states are saved only if
optimizerandoptimizer_pathare provided.
- load_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
Load a checkpoint into the model.
- shutdown() None#
Shutdown the policy.
- start_gpu_profiling() None#
Start GPU profiling.
- stop_gpu_profiling() None#
Stop GPU profiling.
- report_node_ip_and_gpu_id() list[tuple[str, int]]#
Report the node IP and GPU ID of the current worker.