nemo_rl.models.value.interfaces#

Module Contents#

Classes#

ValueOutputSpec

values: Tensor of value predictions [batch_size, sequence_length].

ValueInterface

Abstract base class defining the interface for value functions.

API#

class nemo_rl.models.value.interfaces.ValueOutputSpec#

Bases: typing.TypedDict

values: Tensor of value predictions [batch_size, sequence_length].

Initialization

Initialize self. See help(type(self)) for accurate signature.

values: torch.Tensor#

None

class nemo_rl.models.value.interfaces.ValueInterface#

Bases: abc.ABC

Abstract base class defining the interface for value functions.

abstractmethod get_values(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
timer: Optional[nemo_rl.utils.timer.Timer] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.value.interfaces.ValueOutputSpec]#

Get value predictions for observations.

Parameters:
  • data – BatchedDataDict containing input sequences (tokens)

  • timer – Optional timer for profiling

Returns:

  • values: Tensor of value predictions [batch_size, sequence_length]

Return type:

BatchedDataDict containing

abstractmethod train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
eval_mode: bool = False,
*,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
timer: Optional[nemo_rl.utils.timer.Timer] = None,
) dict[str, Any]#

Train the value function on a global batch of data.

Parameters:
  • data – BatchedDataDict containing training data

  • loss_fn – Loss function to use for training

  • eval_mode – Whether to run in evaluation mode (no gradient updates)

  • gbs – Global batch size override (if None, uses config default)

  • mbs – Micro batch size override (if None, uses config default)

  • timer – Optional timer for profiling

Returns:

Dictionary containing training metrics (loss, grad_norm, etc.)

abstractmethod prepare_for_training(*args: Any, **kwargs: Any) None#

Prepare the value model for training (e.g., load to GPU).

abstractmethod prepare_for_inference(*args: Any, **kwargs: Any) None#

Prepare the value model for inference (e.g., offload gradients).

abstractmethod finish_training(*args: Any, **kwargs: Any) None#

Clean up after training.

abstractmethod save_checkpoint(*args: Any, **kwargs: Any) None#

Save model checkpoint.

abstractmethod shutdown() bool#

Shutdown workers and clean up resources.