nemo_rl.models.interfaces#

Module Contents#

Classes#

PolicyInterface

Abstract base class defining the interface for RL policies.

API#

class nemo_rl.models.interfaces.PolicyInterface[source]#

Bases: abc.ABC

Abstract base class defining the interface for RL policies.

abstractmethod get_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[source]#

Get logprobs of actions from observations.

Parameters:

data – BatchedDataDict containing rollouts (tokens)

Returns:

  • logprobs: Tensor of logprobs of actions

Return type:

BatchedDataDict containing

abstractmethod get_reference_policy_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[source]#

Get logprobs of actions from observations.

Parameters:

data – BatchedDataDict containing rollouts (tokens)

Returns:

  • logprobs: Tensor of logprobs of actions

Return type:

BatchedDataDict containing

abstractmethod train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
) Dict[str, Any][source]#

Train the policy on a global batch of data.

Parameters:

data – BatchedDataDict containing rollouts (tokens)

abstractmethod prepare_for_training(*args, **kwargs)[source]#
abstractmethod finish_training(*args, **kwargs)[source]#