nemo_rl.models.policy.interfaces
#
Module Contents#
Classes#
logprobs: Tensor of log probabilities. |
|
logprobs: Tensor of log probabilities. |
|
Abstract base class defining the interface for RL policies. |
|
API#
- class nemo_rl.models.policy.interfaces.LogprobOutputSpec#
Bases:
typing.TypedDict
logprobs: Tensor of log probabilities.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- logprobs: torch.Tensor#
None
- class nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec#
Bases:
typing.TypedDict
logprobs: Tensor of log probabilities.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- reference_logprobs: torch.Tensor#
None
- class nemo_rl.models.policy.interfaces.PolicyInterface#
Bases:
abc.ABC
Abstract base class defining the interface for RL policies.
- abstractmethod get_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
Get logprobs of actions from observations.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
- Returns:
logprobs: Tensor of logprobs of actions
- Return type:
BatchedDataDict containing
- abstractmethod get_reference_policy_logprobs(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
Get logprobs of actions from observations.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
- Returns:
logprobs: Tensor of logprobs of actions
- Return type:
BatchedDataDict containing
- abstractmethod train(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- eval_mode: bool = False,
- gbs: Optional[int] = None,
- mbs: Optional[int] = None,
Train the policy on a global batch of data.
- Parameters:
data – BatchedDataDict containing rollouts (tokens)
loss_fn – Loss function to use for training
eval_mode – Whether to run in evaluation mode (no gradient updates)
gbs – Global batch size override (if None, uses config default)
mbs – Micro batch size override (if None, uses config default)
- abstractmethod prepare_for_training(*args: Any, **kwargs: Any) None #
- abstractmethod finish_training(*args: Any, **kwargs: Any) None #
- abstractmethod save_checkpoint(*args: Any, **kwargs: Any) None #
- abstractmethod shutdown() bool #
- class nemo_rl.models.policy.interfaces.ColocatablePolicyInterface#
Bases:
nemo_rl.models.policy.interfaces.PolicyInterface
- abstractmethod init_collective(
- ip: str,
- port: int,
- world_size: int,
- abstractmethod offload_before_refit() None #
- abstractmethod offload_after_refit() None #
- abstractmethod prepare_refit_info() Optional[dict[str, Any]] #
- abstractmethod prepare_weights_for_ipc(
- *args: Any,
- **kwargs: Any,
- abstractmethod get_weights_ipc_handles(keys: list[str]) dict[str, Any] #
- abstractmethod broadcast_weights_for_collective() list[ray.ObjectRef] #