nemo_rl.models.policy.fsdp1_policy_worker#

Module Contents#

Classes#

API#

class nemo_rl.models.policy.fsdp1_policy_worker.FSDP1PolicyWorker(
config: nemo_rl.models.policy.PolicyConfig,
tokenizer: transformers.AutoTokenizer,
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
init_optimizer: bool = True,
init_reference_model: bool = True,
)[source]#

Initialization

DEFAULT_PY_EXECUTABLE#

None

__repr__()[source]#

Customizes the actor’s prefix in the Ray logs.

This makes it easier to identify which worker is producing specific log messages.

is_alive()[source]#
reset_peak_memory_stats()[source]#
get_gpu_info()[source]#

Return information about the GPU being used by this worker.

train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
eval_mode: bool = False,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) Dict[str, Any][source]#

Train the policy on a batch of data with a given loss function.

get_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
micro_batch_size: int = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[source]#

Get the logprobs of the model for a batch of data.

If no micro-batch size is provided, uses the configured logprob_batch_size to do microbatching.

Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format.

Returns:

a BatchedDataDict with key “logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.

use_reference_model()[source]#

Context manager that temporarily swaps the reference model and active model.

On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references On exit: Restores original references and re-flips cuda/cpu

get_reference_policy_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
micro_batch_size: int = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[source]#

Get the logprobs from the reference policy for a batch of data.

Returns:

a BatchedDataDict with key “reference_logprobs” and shape [batch_size, sequence_length]. We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor.

generate(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool = False,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec][source]#

Generate a batch of data using huggingface framework generation.

Parameters:

data – BatchedDataDict containing input_ids and input_lengths tensors

Returns:

  • output_ids: input + generated token IDs

  • logprobs: Log probabilities for each token

  • generation_lengths: Lengths of each response

Return type:

BatchedDataDict conforming to GenerationOutputSpec

_add_noise_to_weights()[source]#

Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only.

report_device_id() str[source]#

Report the UUID of the current CUDA device using NVML.

Returns:

UUID of the device in the format “GPU-xxxxx”

Return type:

str

prepare_weights_for_ipc()[source]#
get_weights_ipc_handles(keys)[source]#
prepare_for_lp_inference()[source]#
prepare_for_training(*args, **kwargs)[source]#
offload_before_refit()[source]#

Offload the optimizer and buffers to the CPU.

offload_after_refit()[source]#
manual_offload_to_cpu(model)[source]#
manual_load_to_gpu(model)[source]#
save_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
)[source]#

Save a checkpoint of the model.

The checkpoint is saved in the following format:

weights_path/ __0_1.distcp __1_0.distcp … optimizer_path/ __0_0.distcp __1_0.distcp …

the optimizer states are saved only if optimizer and optimizer_path are provided.

load_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
)[source]#

Load a checkpoint into the model.

shutdown()[source]#

Shutdown the policy.