nemo_rl.models.policy.interfaces#

Module Contents#

Classes#

LogprobOutputSpec

logprobs: Tensor of log probabilities.

ReferenceLogprobOutputSpec

logprobs: Tensor of log probabilities.

PolicyInterface

Abstract base class defining the interface for RL policies.

ColocatablePolicyInterface

API#

class nemo_rl.models.policy.interfaces.LogprobOutputSpec[source]#

Bases: typing.TypedDict

logprobs: Tensor of log probabilities.

Initialization

Initialize self. See help(type(self)) for accurate signature.

logprobs: torch.Tensor#

None

class nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec[source]#

Bases: typing.TypedDict

logprobs: Tensor of log probabilities.

Initialization

Initialize self. See help(type(self)) for accurate signature.

reference_logprobs: torch.Tensor#

None

class nemo_rl.models.policy.interfaces.PolicyInterface[source]#

Bases: abc.ABC

Abstract base class defining the interface for RL policies.

abstractmethod get_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec][source]#

Get logprobs of actions from observations.

Parameters:

data – BatchedDataDict containing rollouts (tokens)

Returns:

  • logprobs: Tensor of logprobs of actions

Return type:

BatchedDataDict containing

abstractmethod get_reference_policy_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec][source]#

Get logprobs of actions from observations.

Parameters:

data – BatchedDataDict containing rollouts (tokens)

Returns:

  • logprobs: Tensor of logprobs of actions

Return type:

BatchedDataDict containing

abstractmethod train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
) dict[str, Any][source]#

Train the policy on a global batch of data.

Parameters:

data – BatchedDataDict containing rollouts (tokens)

abstractmethod prepare_for_training(*args: Any, **kwargs: Any) None[source]#
abstractmethod finish_training(*args: Any, **kwargs: Any) None[source]#
abstractmethod save_checkpoint(*args: Any, **kwargs: Any) None[source]#
abstractmethod shutdown() bool[source]#
class nemo_rl.models.policy.interfaces.ColocatablePolicyInterface[source]#

Bases: nemo_rl.models.policy.interfaces.PolicyInterface

abstractmethod init_collective(
ip: str,
port: int,
world_size: int,
) list[ray.ObjectRef][source]#
abstractmethod offload_before_refit() None[source]#
abstractmethod offload_after_refit() None[source]#
abstractmethod prepare_refit_info() Optional[dict[str, Any]][source]#
abstractmethod prepare_weights_for_ipc(
*args: Any,
**kwargs: Any,
) list[list[str]][source]#
abstractmethod get_weights_ipc_handles(keys: list[str]) dict[str, Any][source]#
abstractmethod broadcast_weights_for_collective() list[ray.ObjectRef][source]#