nemo_rl.models.policy.interfaces#

Module Contents#

Classes#

LogprobOutputSpec

logprobs: Tensor of log probabilities.

ReferenceLogprobOutputSpec

logprobs: Tensor of log probabilities.

PolicyInterface

Abstract base class defining the interface for RL policies.

ColocatablePolicyInterface

API#

class nemo_rl.models.policy.interfaces.LogprobOutputSpec#

Bases: typing.TypedDict

logprobs: Tensor of log probabilities.

Initialization

Initialize self. See help(type(self)) for accurate signature.

logprobs: torch.Tensor#

None

class nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec#

Bases: typing.TypedDict

logprobs: Tensor of log probabilities.

Initialization

Initialize self. See help(type(self)) for accurate signature.

reference_logprobs: torch.Tensor#

None

class nemo_rl.models.policy.interfaces.PolicyInterface#

Bases: abc.ABC

Abstract base class defining the interface for RL policies.

abstractmethod get_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec]#

Get logprobs of actions from observations.

Parameters:

data – BatchedDataDict containing rollouts (tokens)

Returns:

  • logprobs: Tensor of logprobs of actions

Return type:

BatchedDataDict containing

abstractmethod get_reference_policy_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec]#

Get logprobs of actions from observations.

Parameters:

data – BatchedDataDict containing rollouts (tokens)

Returns:

  • logprobs: Tensor of logprobs of actions

Return type:

BatchedDataDict containing

abstractmethod train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
eval_mode: bool = False,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) dict[str, Any]#

Train the policy on a global batch of data.

Parameters:
  • data – BatchedDataDict containing rollouts (tokens)

  • loss_fn – Loss function to use for training

  • eval_mode – Whether to run in evaluation mode (no gradient updates)

  • gbs – Global batch size override (if None, uses config default)

  • mbs – Micro batch size override (if None, uses config default)

abstractmethod prepare_for_training(*args: Any, **kwargs: Any) None#
abstractmethod finish_training(*args: Any, **kwargs: Any) None#
abstractmethod save_checkpoint(*args: Any, **kwargs: Any) None#
abstractmethod shutdown() bool#
class nemo_rl.models.policy.interfaces.ColocatablePolicyInterface#

Bases: nemo_rl.models.policy.interfaces.PolicyInterface

abstractmethod init_collective(
ip: str,
port: int,
world_size: int,
) list[ray.ObjectRef]#
abstractmethod offload_before_refit() None#
abstractmethod offload_after_refit() None#
abstractmethod prepare_refit_info() Optional[dict[str, Any]]#
abstractmethod prepare_weights_for_ipc(
*args: Any,
**kwargs: Any,
) list[list[str]]#
abstractmethod get_weights_ipc_handles(keys: list[str]) dict[str, Any]#
abstractmethod broadcast_weights_for_collective() list[ray.ObjectRef]#