nemo_rl.models.policy.dtensor_policy_worker#

Module Contents#

Classes#

Functions#

unshard_fsdp2_model

Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference.

get_cpu_state_dict

Copy the state dict generator to CPU memory.

API#

nemo_rl.models.policy.dtensor_policy_worker.unshard_fsdp2_model(
model: torch.nn.Module,
) Generator[None, None, None][source]#

Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference.

nemo_rl.models.policy.dtensor_policy_worker.get_cpu_state_dict(
state_generator: Iterable[tuple[str, Union[torch.Tensor, torch.distributed.tensor.DTensor]]],
pin_memory: bool = False,
) dict[str, torch.Tensor][source]#

Copy the state dict generator to CPU memory.

Parameters:
  • state_generator (Iterable[tuple[str, Union[torch.Tensor, DTensor]]]) – An iterable that yields (key, tensor) pairs from a model state.

  • pin_memory (bool, optional) – Whether to allocate the CPU tensors in pinned memory for faster GPU transfer. Defaults to False.

Returns:

A dictionary mapping parameter names to CPU tensors.

Return type:

dict[str, torch.Tensor]

class nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker(
config: nemo_rl.models.policy.PolicyConfig,
tokenizer: transformers.AutoTokenizer,
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
init_optimizer: bool = True,
init_reference_model: bool = True,
**kwargs: Any,
)[source]#

Initialization

__repr__() str[source]#

Customizes the actor’s prefix in the Ray logs.

This makes it easier to identify which worker is producing specific log messages.

static create_context_parallel_ctx(
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
cp_buffers: list[torch.Tensor],
cp_seq_dims: list[int],
cp_no_restore_buffers: Set[torch.Tensor],
cp_rotate_method: Optional[str] = None,
)[source]#

Create a context parallel context.

Parameters:
  • cp_mesh (DeviceMesh) – The device mesh for context parallel.

  • cp_buffers (list[torch.Tensor]) – The buffers for context parallel.

  • cp_seq_dims (list[int]) – The sequence dimensions for context parallel.

  • cp_no_restore_buffers (Set[torch.Tensor]) – The no restore buffers for context parallel.

  • cp_rotate_method (str) – The rotation method for context parallel, such as “allgather” or “addtoall”.

_apply_temperature_scaling(logits: torch.Tensor) torch.Tensor[source]#
static train_context(
cp_context: Optional[Generator[None, None, None]] = None,
)[source]#
init_collective(ip: str, port: int, world_size: int) None[source]#

Initialize the collective communication.

is_alive() bool[source]#
reset_peak_memory_stats() None[source]#
get_gpu_info() dict[str, Any][source]#

Return information about the GPU being used by this worker.

train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
eval_mode: bool = False,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) dict[str, Any][source]#

Train the policy on a batch of data with a given loss function.

get_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
micro_batch_size: Optional[int] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec][source]#

Get the logprobs of the model for a batch of data.

Uses the configured logprob_batch_size to do microbatching.

Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format.

Returns:

a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.

use_reference_model() Generator[None, None, None][source]#

Context manager that temporarily swaps the reference model and active model.

On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references On exit: Restores original references and re-flips cuda/cpu

get_reference_policy_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
micro_batch_size: Optional[int] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec][source]#

Get the logprobs from the reference policy for a batch of data.

Returns:

a BatchedDataDict with key “reference_logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.

_add_noise_to_weights() None[source]#

Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only.

return_state_dict()[source]#
report_device_id() str[source]#

Report the UUID of the current CUDA device using NVML.

Returns:

UUID of the device in the format “GPU-xxxxx”

Return type:

str

prepare_refit_info() Optional[dict[str, Any]][source]#
prepare_weights_for_ipc() tuple[list[tuple[str, int]], float][source]#

Prepare the weights for IPC.

This function:

  • Prepares the state_dict of the model.

  • Collects the info for streaming multiple tensors.

Returns:

The list of parameters sizes. float: The total available memory in bytes.

Return type:

list

get_weights_ipc_handles(
keys: Iterable[str],
) dict[str, Any][source]#
broadcast_weights_for_collective() None[source]#

Broadcast the weights for collective communication.

prepare_for_lp_inference() None[source]#
prepare_for_training(*args, **kwargs) None[source]#
offload_before_refit() None[source]#

Offload the optimizer to the CPU.

offload_after_refit() None[source]#
move_to_device(
model: torch.nn.Module,
device: str | torch.device,
) torch.nn.Module[source]#
move_buffer_to_device(
model: torch.nn.Module,
device: str | torch.device,
) torch.nn.Module[source]#
move_to_cuda(model: torch.nn.Module) torch.nn.Module[source]#
move_to_cpu(model: torch.nn.Module) torch.nn.Module[source]#
save_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
) None[source]#

Save a checkpoint of the model.

the optimizer states are saved only if optimizer and optimizer_path are provided.

load_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
) None[source]#

Load a checkpoint into the model.

shutdown() None[source]#

Shutdown the policy.

start_gpu_profiling() None[source]#

Start GPU profiling.

stop_gpu_profiling() None[source]#

Stop GPU profiling.