nemo_rl.algorithms.ppo#
Module Contents#
Classes#
Configuration for PPO advantage estimator (GAE or raw_reward). |
|
Functions#
Main entry point for running PPO algorithm. |
|
Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation. |
|
Create and return an advantage estimator based on configuration. |
|
Run PPO training algorithm. |
|
Run validation on the validation dataset. |
Data#
API#
- nemo_rl.algorithms.ppo.TokenizerType#
‘TypeVar(…)’
- class nemo_rl.algorithms.ppo.AdvEstimatorConfig#
Bases:
typing.TypedDictConfiguration for PPO advantage estimator (GAE or raw_reward).
Initialization
Initialize self. See help(type(self)) for accurate signature.
- name: str#
None
- gae_lambda: NotRequired[float]#
None
- gae_gamma: NotRequired[float]#
None
- normalize_advantages: NotRequired[bool]#
None
- gae_lambda_value: NotRequired[Optional[float]]#
None
- gae_lambda_policy: NotRequired[Optional[float]]#
None
- length_adaptive_alpha: NotRequired[float]#
None
- class nemo_rl.algorithms.ppo.PPOConfig#
Bases:
typing.TypedDict- num_prompts_per_step: int#
None
- num_generations_per_prompt: int#
None
- max_num_epochs: int#
None
- max_num_steps: int#
None
- max_rollout_turns: int#
None
- val_period: int#
None
- val_batch_size: int#
None
- val_at_start: bool#
None
- val_at_end: bool#
None
- max_val_samples: int#
None
- skip_reference_policy_logprobs_calculation: NotRequired[bool]#
None
- seed: int#
None
- overlong_filtering: bool#
None
- use_dynamic_sampling: bool#
None
- dynamic_sampling_max_gen_batches: NotRequired[int]#
None
- batch_multiplier: NotRequired[float]#
None
- ppo_epochs: int#
None
- reward_shaping: nemo_rl.algorithms.reward_functions.RewardShapingConfig#
None
- reward_scaling: nemo_rl.algorithms.grpo.RewardScalingConfig#
None
- calculate_advantages_on_gpu: NotRequired[bool]#
None
- adv_estimator: nemo_rl.algorithms.ppo.AdvEstimatorConfig#
None
- policy_training_start_step: NotRequired[int]#
None
- class nemo_rl.algorithms.ppo.PPOSaveState#
Bases:
typing.TypedDict- consumed_samples: int#
None
- current_step: int#
None
- current_epoch: int#
None
- total_steps: int#
None
- total_valid_tokens: int#
None
- val_reward: NotRequired[float]#
None
- nemo_rl.algorithms.ppo._default_ppo_save_state() nemo_rl.algorithms.ppo.PPOSaveState#
- class nemo_rl.algorithms.ppo.PPOLoggerConfig#
Bases:
nemo_rl.utils.logger.LoggerConfig- num_val_samples_to_print: int#
None
- class nemo_rl.algorithms.ppo.MasterConfig#
Bases:
pydantic.BaseModel- policy: nemo_rl.models.policy.PolicyConfig#
None
- value: nemo_rl.models.value.ValueConfig#
None
- loss_fn: nemo_rl.algorithms.loss.ClippedPGLossConfig#
None
- value_loss_fn: nemo_rl.algorithms.loss.loss_functions.MseValueLossConfig#
None
- env: dict[str, Any]#
None
- data: nemo_rl.data.DataConfig#
None
- ppo: nemo_rl.algorithms.ppo.PPOConfig#
None
- logger: nemo_rl.algorithms.ppo.PPOLoggerConfig#
None
- cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#
None
- checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#
None
- nemo_rl.algorithms.ppo.setup(
- master_config: nemo_rl.algorithms.ppo.MasterConfig,
- tokenizer: nemo_rl.algorithms.ppo.TokenizerType,
- dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
- val_dataset: Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
- processor: Optional[transformers.AutoProcessor] = None,
Main entry point for running PPO algorithm.
- Returns:
tuple of (policy, policy_generation, value_model, clusters, dataloader, val_dataloader, loss_fn, value_loss_fn, logger, checkpointer, ppo_save_state, master_config).
- nemo_rl.algorithms.ppo.dynamic_sampling(
- repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
- std: torch.Tensor,
- baseline: torch.Tensor,
- dynamic_sampling_num_gen_batches: int,
- master_config: nemo_rl.algorithms.ppo.MasterConfig,
- timer: nemo_rl.utils.timer.Timer,
- batch_cache: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] = None,
Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation.
This function filters the current batch to retain only those prompts that have a non-zero standard deviation. If the current batch has fewer number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, we store it in the batch_cache to be used in later iterations. If the current batch has more number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, the batch is sliced to ensure batch size is num_prompts_per_step * num_generations_per_prompt. is_batch_complete is set to False to indicate that the current batch is not enough to meet the required batch size. This is used as a signal in the training loop to continue sampling or proceed to training. This approach is based on the dynamic sampling algorithm from the DAPO paper: https://arxiv.org/pdf/2503.14476.
- Parameters:
repeated_batch (BatchedDataDict[DatumSpec]) – The current batch of data containing prompts, responses, rewards, baselines, and std.
std (torch.Tensor) – Tensor representing the standard deviation for each prompt group.
baseline (torch.Tensor) – Baseline values for each prompt group.
dynamic_sampling_num_gen_batches (int) – Number of generation batches processed at the current step.
master_config (MasterConfig) – Configuration containing PPO and policy settings.
batch_cache (BatchedDataDict[DatumSpec], optional) – Cache storing previously selected prompts with non-zero std.
- Returns:
A tuple containing: - repeated_batch (BatchedDataDict[DatumSpec]): Updated batch with selected prompts. - is_batch_complete (bool): Indicates if the batch has enough samples with non-zero std for training. - batch_cache (BatchedDataDict[DatumSpec]): Updated cache for future iterations.
- Return type:
tuple
- nemo_rl.algorithms.ppo._create_advantage_estimator(
- master_config: nemo_rl.algorithms.ppo.MasterConfig,
Create and return an advantage estimator based on configuration.
PPO’s training loop consumes a
(advantages, returns)pair from a value-model-based estimator, so onlygaeandraw_rewardare supported here. Group-relative estimators like GRPO / Reinforce++ are not compatible with PPO’s loop and live ingrpo.py.- Parameters:
master_config – The master configuration dictionary.
- Returns:
A
GeneralizedAdvantageEstimatororRawRewardAdvantageEstimatorinstance.- Raises:
ValueError – If the advantage estimator name is not recognized.
- nemo_rl.algorithms.ppo.ppo_train(
- policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
- policy_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
- value_model: nemo_rl.models.value.interfaces.ValueInterface,
- dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
- val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- tokenizer: nemo_rl.algorithms.ppo.TokenizerType,
- loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- value_loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
- logger: nemo_rl.utils.logger.Logger,
- checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
- ppo_save_state: nemo_rl.algorithms.ppo.PPOSaveState,
- master_config: nemo_rl.algorithms.ppo.MasterConfig,
Run PPO training algorithm.
Based on the grpo_train loop with PPO-specific modifications:
Value model inference and training (actor-critic)
GAE advantage estimation with value bootstrap
Multiple training steps per rollout (ppo_epochs)
Configurable policy training start epoch
- nemo_rl.algorithms.ppo.validate(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- tokenizer,
- val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
- step: int,
- master_config: nemo_rl.algorithms.ppo.MasterConfig,
- logger: Optional[nemo_rl.utils.logger.Logger] = None,
Run validation on the validation dataset.