nemo_rl.models.policy.interfaces#
Module Contents#
Classes#
logprobs: Tensor of log probabilities. |
|
logprobs: Tensor of log probabilities. |
|
scores: Tensor of scores. |
|
Per-position top-k logits and corresponding global token indices. |
|
Abstract base class defining the interface for RL policies. |
|
API#
- class nemo_rl.models.policy.interfaces.LogprobOutputSpec#
Bases:
typing.TypedDictlogprobs: Tensor of log probabilities.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- logprobs: torch.Tensor#
None
- class nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec#
Bases:
typing.TypedDictlogprobs: Tensor of log probabilities.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- reference_logprobs: torch.Tensor#
None
- class nemo_rl.models.policy.interfaces.ScoreOutputSpec#
Bases:
typing.TypedDictscores: Tensor of scores.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- scores: torch.Tensor#
None
- class nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec#
Bases:
typing.TypedDictPer-position top-k logits and corresponding global token indices.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- topk_logits: torch.Tensor#
None
- topk_indices: torch.Tensor#
None
- class nemo_rl.models.policy.interfaces.PolicyInterface#
Bases:
abc.ABCAbstract base class defining the interface for RL policies.
- abstractmethod get_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
Get logprobs of actions from observations.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
- Returns:
logprobs: Tensor of logprobs of actions
- Return type:
BatchedDataDict containing
- abstractmethod get_reference_policy_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
Get logprobs of actions from observations.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
- Returns:
logprobs: Tensor of logprobs of actions
- Return type:
BatchedDataDict containing
- abstractmethod get_topk_logits(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- k: int,
- micro_batch_size: Optional[int] = None,
Get per-position top-k logits and global indices for a batch of inputs.
.. rubric:: Notes
Aligns to next-token positions → returns S-1 positions.
- abstractmethod train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- eval_mode: bool = False,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
Train the policy on a global batch of data.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
loss_fn – Loss function to use for training
eval_mode – Whether to run in evaluation mode (no gradient updates)
gbs – Global batch size override (if None, uses config default)
mbs – Micro batch size override (if None, uses config default)
- abstractmethod score(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
Score a batch of data using the policy.
- abstractmethod prepare_for_training(*args: Any, **kwargs: Any) None#
- abstractmethod finish_training(*args: Any, **kwargs: Any) None#
- abstractmethod save_checkpoint(*args: Any, **kwargs: Any) None#
- abstractmethod shutdown() bool#
- class nemo_rl.models.policy.interfaces.ColocatablePolicyInterface#
Bases:
nemo_rl.models.policy.interfaces.PolicyInterface- abstractmethod init_collective(
- ip: str,
- port: int,
- world_size: int,
- *,
- train_world_size: int,
- abstractmethod offload_before_refit() None#
- abstractmethod offload_after_refit() None#
- abstractmethod prepare_refit_info() Optional[dict[str, Any]]#
- abstractmethod stream_weights_via_ipc_zmq(
- *args: Any,
- **kwargs: Any,
- abstractmethod broadcast_weights_for_collective() list[ray.ObjectRef]#