nemo_rl.models.policy.interfaces#

Module Contents#

Classes#

LogprobOutputSpec

logprobs: Tensor of log probabilities.

ReferenceLogprobOutputSpec

logprobs: Tensor of log probabilities.

ScoreOutputSpec

scores: Tensor of scores.

TopkLogitsOutputSpec

Per-position top-k logits and corresponding global token indices.

PolicyInterface

Abstract base class defining the interface for RL policies.

ColocatablePolicyInterface

API#

class nemo_rl.models.policy.interfaces.LogprobOutputSpec#

Bases: typing.TypedDict

logprobs: Tensor of log probabilities.

Initialization

Initialize self. See help(type(self)) for accurate signature.

logprobs: torch.Tensor#

None

class nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec#

Bases: typing.TypedDict

logprobs: Tensor of log probabilities.

Initialization

Initialize self. See help(type(self)) for accurate signature.

reference_logprobs: torch.Tensor#

None

class nemo_rl.models.policy.interfaces.ScoreOutputSpec#

Bases: typing.TypedDict

scores: Tensor of scores.

Initialization

Initialize self. See help(type(self)) for accurate signature.

scores: torch.Tensor#

None

class nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec#

Bases: typing.TypedDict

Per-position top-k logits and corresponding global token indices.

Initialization

Initialize self. See help(type(self)) for accurate signature.

topk_logits: torch.Tensor#

None

topk_indices: torch.Tensor#

None

class nemo_rl.models.policy.interfaces.PolicyInterface#

Bases: abc.ABC

Abstract base class defining the interface for RL policies.

abstractmethod get_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec]#

Get logprobs of actions from observations.

Parameters:

data – BatchedDataDict containing rollouts (tokens)

Returns:

  • logprobs: Tensor of logprobs of actions

Return type:

BatchedDataDict containing

abstractmethod get_reference_policy_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec]#

Get logprobs of actions from observations.

Parameters:

data – BatchedDataDict containing rollouts (tokens)

Returns:

  • logprobs: Tensor of logprobs of actions

Return type:

BatchedDataDict containing

abstractmethod get_topk_logits(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
k: int,
micro_batch_size: Optional[int] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec]#

Get per-position top-k logits and global indices for a batch of inputs.

.. rubric:: Notes

  • Aligns to next-token positions → returns S-1 positions.

abstractmethod train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
eval_mode: bool = False,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) dict[str, Any]#

Train the policy on a global batch of data.

Parameters:
  • data – BatchedDataDict containing rollouts (tokens)

  • loss_fn – Loss function to use for training

  • eval_mode – Whether to run in evaluation mode (no gradient updates)

  • gbs – Global batch size override (if None, uses config default)

  • mbs – Micro batch size override (if None, uses config default)

abstractmethod score(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ScoreOutputSpec]#

Score a batch of data using the policy.

abstractmethod prepare_for_training(*args: Any, **kwargs: Any) None#
abstractmethod finish_training(*args: Any, **kwargs: Any) None#
abstractmethod save_checkpoint(*args: Any, **kwargs: Any) None#
abstractmethod shutdown() bool#
class nemo_rl.models.policy.interfaces.ColocatablePolicyInterface#

Bases: nemo_rl.models.policy.interfaces.PolicyInterface

abstractmethod init_collective(
ip: str,
port: int,
world_size: int,
*,
train_world_size: int,
) list[ray.ObjectRef]#
abstractmethod offload_before_refit() None#
abstractmethod offload_after_refit() None#
abstractmethod prepare_refit_info() Optional[dict[str, Any]]#
abstractmethod stream_weights_via_ipc_zmq(
*args: Any,
**kwargs: Any,
) list[ray.ObjectRef]#
abstractmethod broadcast_weights_for_collective() list[ray.ObjectRef]#