nemo_rl.models.policy.dtensor_policy_worker
#
Module Contents#
Classes#
Functions#
Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference. |
|
Copy the state dict generator to CPU memory. |
API#
- nemo_rl.models.policy.dtensor_policy_worker.unshard_fsdp2_model(model: torch.nn.Module)[source]#
Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference.
- nemo_rl.models.policy.dtensor_policy_worker.get_cpu_state_dict(
- state_generator: Iterable[Tuple[str, Union[torch.Tensor, torch.distributed.tensor.DTensor]]],
- pin_memory: bool = False,
Copy the state dict generator to CPU memory.
- Parameters:
state_generator (Iterable[Tuple[str, Union[torch.Tensor, DTensor]]]) – An iterable that yields (key, tensor) pairs from a model state.
pin_memory (bool, optional) – Whether to allocate the CPU tensors in pinned memory for faster GPU transfer. Defaults to False.
- Returns:
A dictionary mapping parameter names to CPU tensors.
- Return type:
Dict[str, torch.Tensor]
- class nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker(
- config: nemo_rl.models.policy.PolicyConfig,
- tokenizer: transformers.AutoTokenizer,
- weights_path: Optional[str] = None,
- optimizer_path: Optional[str] = None,
- init_optimizer: bool = True,
- init_reference_model: bool = True,
Initialization
- __repr__()[source]#
Customizes the actor’s prefix in the Ray logs.
This makes it easier to identify which worker is producing specific log messages.
- train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- eval_mode: bool = False,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
Train the policy on a batch of data with a given loss function.
- get_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- micro_batch_size: int = None,
Get the logprobs of the model for a batch of data.
Uses the configured logprob_batch_size to do microbatching.
Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format.
- Returns:
a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- use_reference_model()[source]#
Context manager that temporarily swaps the reference model and active model.
On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references On exit: Restores original references and re-flips cuda/cpu
- get_reference_policy_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- micro_batch_size: int = None,
Get the logprobs from the reference policy for a batch of data.
- Returns:
a BatchedDataDict with key “reference_logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.
- _add_noise_to_weights()[source]#
Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only.
- report_device_id() str [source]#
Report the UUID of the current CUDA device using NVML.
- Returns:
UUID of the device in the format “GPU-xxxxx”
- Return type:
str
- save_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
- tokenizer_path: Optional[str] = None,
Save a checkpoint of the model.
the optimizer states are saved only if
optimizer
andoptimizer_path
are provided.