nemo_rl.algorithms.ppo#

Module Contents#

Classes#

AdvEstimatorConfig

Configuration for PPO advantage estimator (GAE or raw_reward).

PPOConfig

PPOSaveState

PPOLoggerConfig

MasterConfig

Functions#

_default_ppo_save_state

setup

Main entry point for running PPO algorithm.

dynamic_sampling

Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation.

_create_advantage_estimator

Create and return an advantage estimator based on configuration.

ppo_train

Run PPO training algorithm.

validate

Run validation on the validation dataset.

Data#

API#

nemo_rl.algorithms.ppo.TokenizerType#

‘TypeVar(…)’

class nemo_rl.algorithms.ppo.AdvEstimatorConfig#

Bases: typing.TypedDict

Configuration for PPO advantage estimator (GAE or raw_reward).

Initialization

Initialize self. See help(type(self)) for accurate signature.

name: str#

None

gae_lambda: NotRequired[float]#

None

gae_gamma: NotRequired[float]#

None

normalize_advantages: NotRequired[bool]#

None

gae_lambda_value: NotRequired[Optional[float]]#

None

gae_lambda_policy: NotRequired[Optional[float]]#

None

length_adaptive_alpha: NotRequired[float]#

None

class nemo_rl.algorithms.ppo.PPOConfig#

Bases: typing.TypedDict

num_prompts_per_step: int#

None

num_generations_per_prompt: int#

None

max_num_epochs: int#

None

max_num_steps: int#

None

max_rollout_turns: int#

None

val_period: int#

None

val_batch_size: int#

None

val_at_start: bool#

None

val_at_end: bool#

None

max_val_samples: int#

None

skip_reference_policy_logprobs_calculation: NotRequired[bool]#

None

seed: int#

None

overlong_filtering: bool#

None

use_dynamic_sampling: bool#

None

dynamic_sampling_max_gen_batches: NotRequired[int]#

None

batch_multiplier: NotRequired[float]#

None

ppo_epochs: int#

None

reward_shaping: nemo_rl.algorithms.reward_functions.RewardShapingConfig#

None

reward_scaling: nemo_rl.algorithms.grpo.RewardScalingConfig#

None

calculate_advantages_on_gpu: NotRequired[bool]#

None

adv_estimator: nemo_rl.algorithms.ppo.AdvEstimatorConfig#

None

policy_training_start_step: NotRequired[int]#

None

class nemo_rl.algorithms.ppo.PPOSaveState#

Bases: typing.TypedDict

consumed_samples: int#

None

current_step: int#

None

current_epoch: int#

None

total_steps: int#

None

total_valid_tokens: int#

None

val_reward: NotRequired[float]#

None

nemo_rl.algorithms.ppo._default_ppo_save_state() nemo_rl.algorithms.ppo.PPOSaveState#
class nemo_rl.algorithms.ppo.PPOLoggerConfig#

Bases: nemo_rl.utils.logger.LoggerConfig

num_val_samples_to_print: int#

None

class nemo_rl.algorithms.ppo.MasterConfig#

Bases: pydantic.BaseModel

policy: nemo_rl.models.policy.PolicyConfig#

None

value: nemo_rl.models.value.ValueConfig#

None

loss_fn: nemo_rl.algorithms.loss.ClippedPGLossConfig#

None

value_loss_fn: nemo_rl.algorithms.loss.loss_functions.MseValueLossConfig#

None

env: dict[str, Any]#

None

data: nemo_rl.data.DataConfig#

None

ppo: nemo_rl.algorithms.ppo.PPOConfig#

None

logger: nemo_rl.algorithms.ppo.PPOLoggerConfig#

None

cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#

None

checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#

None

nemo_rl.algorithms.ppo.setup(
master_config: nemo_rl.algorithms.ppo.MasterConfig,
tokenizer: nemo_rl.algorithms.ppo.TokenizerType,
dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
val_dataset: Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
processor: Optional[transformers.AutoProcessor] = None,
) tuple[nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, Optional[nemo_rl.models.generation.interfaces.GenerationInterface], nemo_rl.models.value.interfaces.ValueInterface, tuple[nemo_rl.distributed.virtual_cluster.RayVirtualCluster, nemo_rl.distributed.virtual_cluster.RayVirtualCluster], torchdata.stateful_dataloader.StatefulDataLoader, Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss.ClippedPGLossFn, nemo_rl.algorithms.loss.loss_functions.MseValueLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.ppo.PPOSaveState, nemo_rl.algorithms.ppo.MasterConfig]#

Main entry point for running PPO algorithm.

Returns:

tuple of (policy, policy_generation, value_model, clusters, dataloader, val_dataloader, loss_fn, value_loss_fn, logger, checkpointer, ppo_save_state, master_config).

nemo_rl.algorithms.ppo.dynamic_sampling(
repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
std: torch.Tensor,
baseline: torch.Tensor,
dynamic_sampling_num_gen_batches: int,
master_config: nemo_rl.algorithms.ppo.MasterConfig,
timer: nemo_rl.utils.timer.Timer,
batch_cache: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec]#

Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation.

This function filters the current batch to retain only those prompts that have a non-zero standard deviation. If the current batch has fewer number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, we store it in the batch_cache to be used in later iterations. If the current batch has more number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, the batch is sliced to ensure batch size is num_prompts_per_step * num_generations_per_prompt. is_batch_complete is set to False to indicate that the current batch is not enough to meet the required batch size. This is used as a signal in the training loop to continue sampling or proceed to training. This approach is based on the dynamic sampling algorithm from the DAPO paper: https://arxiv.org/pdf/2503.14476.

Parameters:
  • repeated_batch (BatchedDataDict[DatumSpec]) – The current batch of data containing prompts, responses, rewards, baselines, and std.

  • std (torch.Tensor) – Tensor representing the standard deviation for each prompt group.

  • baseline (torch.Tensor) – Baseline values for each prompt group.

  • dynamic_sampling_num_gen_batches (int) – Number of generation batches processed at the current step.

  • master_config (MasterConfig) – Configuration containing PPO and policy settings.

  • batch_cache (BatchedDataDict[DatumSpec], optional) – Cache storing previously selected prompts with non-zero std.

Returns:

A tuple containing: - repeated_batch (BatchedDataDict[DatumSpec]): Updated batch with selected prompts. - is_batch_complete (bool): Indicates if the batch has enough samples with non-zero std for training. - batch_cache (BatchedDataDict[DatumSpec]): Updated cache for future iterations.

Return type:

tuple

nemo_rl.algorithms.ppo._create_advantage_estimator(
master_config: nemo_rl.algorithms.ppo.MasterConfig,
)#

Create and return an advantage estimator based on configuration.

PPO’s training loop consumes a (advantages, returns) pair from a value-model-based estimator, so only gae and raw_reward are supported here. Group-relative estimators like GRPO / Reinforce++ are not compatible with PPO’s loop and live in grpo.py.

Parameters:

master_config – The master configuration dictionary.

Returns:

A GeneralizedAdvantageEstimator or RawRewardAdvantageEstimator instance.

Raises:

ValueError – If the advantage estimator name is not recognized.

nemo_rl.algorithms.ppo.ppo_train(
policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
policy_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
value_model: nemo_rl.models.value.interfaces.ValueInterface,
dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer: nemo_rl.algorithms.ppo.TokenizerType,
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
value_loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
logger: nemo_rl.utils.logger.Logger,
checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
ppo_save_state: nemo_rl.algorithms.ppo.PPOSaveState,
master_config: nemo_rl.algorithms.ppo.MasterConfig,
) None#

Run PPO training algorithm.

Based on the grpo_train loop with PPO-specific modifications:

  • Value model inference and training (actor-critic)

  • GAE advantage estimation with value bootstrap

  • Multiple training steps per rollout (ppo_epochs)

  • Configurable policy training start epoch

nemo_rl.algorithms.ppo.validate(
policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer,
val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
step: int,
master_config: nemo_rl.algorithms.ppo.MasterConfig,
logger: Optional[nemo_rl.utils.logger.Logger] = None,
) tuple[dict[str, Any], dict[str, Any]]#

Run validation on the validation dataset.