nemo_rl.models.value.interfaces#
Module Contents#
Classes#
values: Tensor of value predictions [batch_size, sequence_length]. |
|
Abstract base class defining the interface for value functions. |
API#
- class nemo_rl.models.value.interfaces.ValueOutputSpec#
Bases:
typing.TypedDictvalues: Tensor of value predictions [batch_size, sequence_length].
Initialization
Initialize self. See help(type(self)) for accurate signature.
- values: torch.Tensor#
None
- class nemo_rl.models.value.interfaces.ValueInterface#
Bases:
abc.ABCAbstract base class defining the interface for value functions.
- abstractmethod get_values(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- timer: Optional[nemo_rl.utils.timer.Timer] = None,
Get value predictions for observations.
- Parameters:
data – BatchedDataDict containing input sequences (tokens)
timer – Optional timer for profiling
- Returns:
values: Tensor of value predictions [batch_size, sequence_length]
- Return type:
BatchedDataDict containing
- abstractmethod train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- eval_mode: bool = False,
- *,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
- timer: Optional[nemo_rl.utils.timer.Timer] = None,
Train the value function on a global batch of data.
- Parameters:
data – BatchedDataDict containing training data
loss_fn – Loss function to use for training
eval_mode – Whether to run in evaluation mode (no gradient updates)
gbs – Global batch size override (if None, uses config default)
mbs – Micro batch size override (if None, uses config default)
timer – Optional timer for profiling
- Returns:
Dictionary containing training metrics (loss, grad_norm, etc.)
- abstractmethod prepare_for_training(*args: Any, **kwargs: Any) None#
Prepare the value model for training (e.g., load to GPU).
- abstractmethod prepare_for_inference(*args: Any, **kwargs: Any) None#
Prepare the value model for inference (e.g., offload gradients).
- abstractmethod finish_training(*args: Any, **kwargs: Any) None#
Clean up after training.
- abstractmethod save_checkpoint(*args: Any, **kwargs: Any) None#
Save model checkpoint.
- abstractmethod shutdown() bool#
Shutdown workers and clean up resources.