nemo_rl.models.value.lm_value#

Module Contents#

Classes#

Value

Value function model for PPO using distributed training with Ray workers.

Data#

API#

nemo_rl.models.value.lm_value.PathLike#

None

class nemo_rl.models.value.lm_value.Value(
cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
config: nemo_rl.models.value.config.ValueConfig,
tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
name_prefix: str = 'lm_value',
workers_per_node: Optional[Union[int, list[int]]] = None,
init_optimizer: bool = True,
weights_path: Optional[nemo_rl.models.value.lm_value.PathLike] = None,
optimizer_path: Optional[nemo_rl.models.value.lm_value.PathLike] = None,
)#

Bases: nemo_rl.models.value.interfaces.ValueInterface

Value function model for PPO using distributed training with Ray workers.

Initialization

Initialize the Value model.

Parameters:
  • cluster – Ray virtual cluster for distributed training

  • config – Configuration for the value model

  • tokenizer – Tokenizer for the model

  • name_prefix – Prefix for worker names

  • workers_per_node – Number of workers per node

  • init_optimizer – Whether to initialize the optimizer

  • weights_path – Path to load model weights from

  • optimizer_path – Path to load optimizer state from

get_values(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
timer: Optional[nemo_rl.utils.timer.Timer] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.value.interfaces.ValueOutputSpec]#

Get value predictions for a batch of data.

Parameters:
  • data – BatchedDataDict containing input sequences

  • timer – Optional timer for profiling

Returns:

BatchedDataDict containing value predictions [batch_size, sequence_length]

train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
eval_mode: bool = False,
*,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
timer: Optional[nemo_rl.utils.timer.Timer] = None,
) dict[str, Any]#

Train the value function on a batch of data with a given loss function.

Parameters:
  • data – BatchedDataDict containing training data

  • loss_fn – Loss function to use for training

  • eval_mode – Whether to run in evaluation mode (no gradient updates)

  • gbs – Global batch size override (if None, uses config default)

  • mbs – Micro batch size override (if None, uses config default)

  • timer – Optional timer for profiling

Returns:

Dictionary containing training metrics (loss, grad_norm, etc.)

prepare_for_training() None#

Prepare the value model for training (load to GPU).

prepare_for_inference() None#

Prepare the value model for inference (offload gradients, set eval mode).

finish_inference() None#

Offload value model to CPU after inference.

finish_training() None#

Offload value model to CPU after training.

save_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
) None#

Save a checkpoint of the value model.

shutdown() bool#

Shut down all value workers and clean up resources.

__del__() None#

Shuts down the worker groups when the object is deleted or garbage collected.