nemo_rl.models.value.lm_value#
Module Contents#
Classes#
Value function model for PPO using distributed training with Ray workers. |
Data#
API#
- nemo_rl.models.value.lm_value.PathLike#
None
- class nemo_rl.models.value.lm_value.Value(
- cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster,
- config: nemo_rl.models.value.config.ValueConfig,
- tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
- name_prefix: str = 'lm_value',
- workers_per_node: Optional[Union[int, list[int]]] = None,
- init_optimizer: bool = True,
- weights_path: Optional[nemo_rl.models.value.lm_value.PathLike] = None,
- optimizer_path: Optional[nemo_rl.models.value.lm_value.PathLike] = None,
Bases:
nemo_rl.models.value.interfaces.ValueInterfaceValue function model for PPO using distributed training with Ray workers.
Initialization
Initialize the Value model.
- Parameters:
cluster – Ray virtual cluster for distributed training
config – Configuration for the value model
tokenizer – Tokenizer for the model
name_prefix – Prefix for worker names
workers_per_node – Number of workers per node
init_optimizer – Whether to initialize the optimizer
weights_path – Path to load model weights from
optimizer_path – Path to load optimizer state from
- get_values(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- timer: Optional[nemo_rl.utils.timer.Timer] = None,
Get value predictions for a batch of data.
- Parameters:
data – BatchedDataDict containing input sequences
timer – Optional timer for profiling
- Returns:
BatchedDataDict containing value predictions [batch_size, sequence_length]
- train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- eval_mode: bool = False,
- *,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
- timer: Optional[nemo_rl.utils.timer.Timer] = None,
Train the value function on a batch of data with a given loss function.
- Parameters:
data – BatchedDataDict containing training data
loss_fn – Loss function to use for training
eval_mode – Whether to run in evaluation mode (no gradient updates)
gbs – Global batch size override (if None, uses config default)
mbs – Micro batch size override (if None, uses config default)
timer – Optional timer for profiling
- Returns:
Dictionary containing training metrics (loss, grad_norm, etc.)
- prepare_for_training() None#
Prepare the value model for training (load to GPU).
- prepare_for_inference() None#
Prepare the value model for inference (offload gradients, set eval mode).
- finish_inference() None#
Offload value model to CPU after inference.
- finish_training() None#
Offload value model to CPU after training.
- save_checkpoint(
- weights_path: str,
- optimizer_path: Optional[str] = None,
- tokenizer_path: Optional[str] = None,
Save a checkpoint of the value model.
- shutdown() bool#
Shut down all value workers and clean up resources.
- __del__() None#
Shuts down the worker groups when the object is deleted or garbage collected.