nemo_rl.models.policy.interfaces#
Module Contents#
Classes#
logprobs: Tensor of log probabilities. |
|
logprobs: Tensor of log probabilities. |
|
scores: Tensor of scores. |
|
Per-position top-k logits and corresponding global token indices. |
|
Abstract base class defining the interface for RL policies. |
|
API#
- class nemo_rl.models.policy.interfaces.LogprobOutputSpec#
Bases:
typing.TypedDictlogprobs: Tensor of log probabilities.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- logprobs: torch.Tensor#
None
- class nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec#
Bases:
typing.TypedDictlogprobs: Tensor of log probabilities.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- reference_logprobs: torch.Tensor#
None
- class nemo_rl.models.policy.interfaces.ScoreOutputSpec#
Bases:
typing.TypedDictscores: Tensor of scores.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- scores: torch.Tensor#
None
- class nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec#
Bases:
typing.TypedDictPer-position top-k logits and corresponding global token indices.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- topk_logits: torch.Tensor#
None
- topk_indices: torch.Tensor#
None
- class nemo_rl.models.policy.interfaces.PolicyInterface#
Bases:
abc.ABCAbstract base class defining the interface for RL policies.
- abstractmethod get_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
Get logprobs of actions from observations.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
- Returns:
logprobs: Tensor of logprobs of actions
- Return type:
BatchedDataDict containing
- abstractmethod get_reference_policy_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- micro_batch_size: Optional[int] = None,
Get logprobs of actions from observations.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
- Returns:
logprobs: Tensor of logprobs of actions
- Return type:
BatchedDataDict containing
- abstractmethod get_topk_logits(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- k: int,
- micro_batch_size: Optional[int] = None,
Get per-position top-k logits and global indices for a batch of inputs.
.. rubric:: Notes
Aligns to next-token positions → returns S-1 positions.
- abstractmethod train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- eval_mode: bool = False,
- *,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
Train the policy on a global batch of data.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
loss_fn – Loss function to use for training
eval_mode – Whether to run in evaluation mode (no gradient updates)
gbs – Global batch size override (if None, uses config default)
mbs – Micro batch size override (if None, uses config default)
- abstractmethod calibrate_qkv_fp8_scales(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- micro_batch_size: Optional[int] = None,
- percentile: float = 99.9,
- margin: float = 1.05,
- include_q: bool = False,
Calibrate FP8 scales for Q/K/V activations used by KV cache.
- Parameters:
data – BatchedDataDict containing input_ids and input_lengths.
micro_batch_size – Optional override for micro batch size during calibration.
percentile – Percentile for per-tensor amax estimation.
margin – Safety margin multiplier applied to amax.
include_q – Whether to also compute scale for Q in addition to K/V.
- Returns:
Dict with overall configuration and per-layer scales.
- abstractmethod prepare_for_training(*args: Any, **kwargs: Any) None#
- abstractmethod finish_training(*args: Any, **kwargs: Any) None#
- abstractmethod save_checkpoint(*args: Any, **kwargs: Any) None#
- abstractmethod shutdown() bool#
- class nemo_rl.models.policy.interfaces.ColocatablePolicyInterface#
Bases:
nemo_rl.models.policy.interfaces.PolicyInterface- abstractmethod init_collective(
- ip: str,
- port: int,
- world_size: int,
- *,
- train_world_size: int,
- abstractmethod offload_before_refit() None#
- abstractmethod offload_after_refit() None#
- abstractmethod prepare_refit_info() Optional[dict[str, Any]]#
- abstractmethod stream_weights_via_ipc_zmq(
- *args: Any,
- **kwargs: Any,
- abstractmethod broadcast_weights_for_collective(
- kv_scales: Optional[dict[str, float]] = None,
- abstractmethod prepare_for_lp_inference() None#