nemo_rl.models.interfaces
#
Module Contents#
Classes#
Abstract base class defining the interface for RL policies. |
API#
- class nemo_rl.models.interfaces.PolicyInterface[source]#
Bases:
abc.ABC
Abstract base class defining the interface for RL policies.
- abstractmethod get_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
Get logprobs of actions from observations.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
- Returns:
logprobs: Tensor of logprobs of actions
- Return type:
BatchedDataDict containing
- abstractmethod get_reference_policy_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
Get logprobs of actions from observations.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
- Returns:
logprobs: Tensor of logprobs of actions
- Return type:
BatchedDataDict containing
- abstractmethod train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
Train the policy on a global batch of data.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)