nemo_rl.models.policy.interfaces
#
Module Contents#
Classes#
logprobs: Tensor of log probabilities. |
|
logprobs: Tensor of log probabilities. |
|
Abstract base class defining the interface for RL policies. |
|
API#
- class nemo_rl.models.policy.interfaces.LogprobOutputSpec[source]#
Bases:
typing.TypedDict
logprobs: Tensor of log probabilities.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- logprobs: torch.Tensor#
None
- class nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec[source]#
Bases:
typing.TypedDict
logprobs: Tensor of log probabilities.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- reference_logprobs: torch.Tensor#
None
- class nemo_rl.models.policy.interfaces.PolicyInterface[source]#
Bases:
abc.ABC
Abstract base class defining the interface for RL policies.
- abstractmethod get_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
Get logprobs of actions from observations.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
- Returns:
logprobs: Tensor of logprobs of actions
- Return type:
BatchedDataDict containing
- abstractmethod get_reference_policy_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
Get logprobs of actions from observations.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
- Returns:
logprobs: Tensor of logprobs of actions
- Return type:
BatchedDataDict containing
- abstractmethod train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
Train the policy on a global batch of data.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
- class nemo_rl.models.policy.interfaces.ColocatablePolicyInterface[source]#