nemo_rl.models.policy.interfaces#

Module Contents#

Classes#

LogprobOutputSpec

logprobs: Tensor of log probabilities.

ReferenceLogprobOutputSpec

logprobs: Tensor of log probabilities.

ScoreOutputSpec

scores: Tensor of scores.

TopkLogitsOutputSpec

Per-position top-k logits and corresponding global token indices.

PolicyInterface

Abstract base class defining the interface for RL policies.

ColocatablePolicyInterface

API#

class nemo_rl.models.policy.interfaces.LogprobOutputSpec#

Bases: typing.TypedDict

logprobs: Tensor of log probabilities.

Initialization

Initialize self. See help(type(self)) for accurate signature.

logprobs: torch.Tensor#

None

class nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec#

Bases: typing.TypedDict

logprobs: Tensor of log probabilities.

Initialization

Initialize self. See help(type(self)) for accurate signature.

reference_logprobs: torch.Tensor#

None

class nemo_rl.models.policy.interfaces.ScoreOutputSpec#

Bases: typing.TypedDict

scores: Tensor of scores.

Initialization

Initialize self. See help(type(self)) for accurate signature.

scores: torch.Tensor#

None

class nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec#

Bases: typing.TypedDict

Per-position top-k logits and corresponding global token indices.

Initialization

Initialize self. See help(type(self)) for accurate signature.

topk_logits: torch.Tensor#

None

topk_indices: torch.Tensor#

None

class nemo_rl.models.policy.interfaces.PolicyInterface#

Bases: abc.ABC

Abstract base class defining the interface for RL policies.

abstractmethod get_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec]#

Get logprobs of actions from observations.

Parameters:

data – BatchedDataDict containing rollouts (tokens)

Returns:

  • logprobs: Tensor of logprobs of actions

Return type:

BatchedDataDict containing

abstractmethod get_reference_policy_logprobs(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
micro_batch_size: Optional[int] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec]#

Get logprobs of actions from observations.

Parameters:

data – BatchedDataDict containing rollouts (tokens)

Returns:

  • logprobs: Tensor of logprobs of actions

Return type:

BatchedDataDict containing

abstractmethod get_topk_logits(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
k: int,
micro_batch_size: Optional[int] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec]#

Get per-position top-k logits and global indices for a batch of inputs.

.. rubric:: Notes

  • Aligns to next-token positions → returns S-1 positions.

abstractmethod train(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
eval_mode: bool = False,
*,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) dict[str, Any]#

Train the policy on a global batch of data.

Parameters:
  • data – BatchedDataDict containing rollouts (tokens)

  • loss_fn – Loss function to use for training

  • eval_mode – Whether to run in evaluation mode (no gradient updates)

  • gbs – Global batch size override (if None, uses config default)

  • mbs – Micro batch size override (if None, uses config default)

abstractmethod calibrate_qkv_fp8_scales(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
micro_batch_size: Optional[int] = None,
percentile: float = 99.9,
margin: float = 1.05,
include_q: bool = False,
) dict[str, Any]#

Calibrate FP8 scales for Q/K/V activations used by KV cache.

Parameters:
  • data – BatchedDataDict containing input_ids and input_lengths.

  • micro_batch_size – Optional override for micro batch size during calibration.

  • percentile – Percentile for per-tensor amax estimation.

  • margin – Safety margin multiplier applied to amax.

  • include_q – Whether to also compute scale for Q in addition to K/V.

Returns:

Dict with overall configuration and per-layer scales.

abstractmethod prepare_for_training(*args: Any, **kwargs: Any) None#
abstractmethod finish_training(*args: Any, **kwargs: Any) None#
abstractmethod save_checkpoint(*args: Any, **kwargs: Any) None#
abstractmethod shutdown() bool#
class nemo_rl.models.policy.interfaces.ColocatablePolicyInterface#

Bases: nemo_rl.models.policy.interfaces.PolicyInterface

abstractmethod init_collective(
ip: str,
port: int,
world_size: int,
*,
train_world_size: int,
) list[ray.ObjectRef]#
abstractmethod offload_before_refit() None#
abstractmethod offload_after_refit() None#
abstractmethod prepare_refit_info() Optional[dict[str, Any]]#
abstractmethod stream_weights_via_ipc_zmq(
*args: Any,
**kwargs: Any,
) list[ray.ObjectRef]#
abstractmethod broadcast_weights_for_collective(
kv_scales: Optional[dict[str, float]] = None,
) list[ray.ObjectRef]#
abstractmethod prepare_for_lp_inference() None#