nemo_rl.algorithms.grpo#

Module Contents#

Classes#

Functions#

_default_grpo_save_state

setup

Main entry point for running GRPO algorithm.

dynamic_sampling

Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation.

scale_rewards

Linearly scales rewards from a source range to a target range.

_should_use_async_rollouts

Determine if async rollouts should be used based on the configuration.

refit_policy_generation

Refit the policy generation interface with the latest policy weights.

grpo_train

Run GRPO training algorithm.

validate

Run validation on the validation dataset.

async_grpo_train

Run asynchronous GRPO training with replay buffer.

Data#

API#

nemo_rl.algorithms.grpo.TokenizerType#

β€˜TypeVar(…)’

class nemo_rl.algorithms.grpo.RewardScalingConfig#

Bases: typing.TypedDict

Configure linear reward scaling with clamping.

When enabled is True, each reward is clamped to the source interval [source_min, source_max] and linearly mapped to the target interval [target_min, target_max]. Refer to the scale_rewards function for the implementation.

Defaults: source_min=0.0, source_max=1.0, target_min=0.0, target_max=1.0

Initialization

Initialize self. See help(type(self)) for accurate signature.

enabled: bool#

None

source_min: NotRequired[float]#

None

source_max: NotRequired[float]#

None

target_min: NotRequired[float]#

None

target_max: NotRequired[float]#

None

class nemo_rl.algorithms.grpo.AsyncGRPOConfig#

Bases: typing.TypedDict

enabled: bool#

None

max_trajectory_age_steps: int#

None

class nemo_rl.algorithms.grpo.GRPOConfig#

Bases: typing.TypedDict

num_prompts_per_step: int#

None

num_generations_per_prompt: int#

None

max_num_epochs: int#

None

max_num_steps: int#

None

max_rollout_turns: int#

None

normalize_rewards: bool#

None

use_leave_one_out_baseline: bool#

None

val_period: int#

None

val_batch_size: int#

None

val_at_start: bool#

None

max_val_samples: int#

None

seed: int#

None

async_grpo: NotRequired[nemo_rl.algorithms.grpo.AsyncGRPOConfig]#

None

overlong_filtering: NotRequired[bool]#

None

use_dynamic_sampling: bool#

None

dynamic_sampling_max_gen_batches: NotRequired[int]#

None

batch_multiplier: NotRequired[float]#

None

reward_shaping: nemo_rl.algorithms.reward_functions.RewardShapingConfig#

None

reward_scaling: nemo_rl.algorithms.grpo.RewardScalingConfig#

None

class nemo_rl.algorithms.grpo.GRPOSaveState#

Bases: typing.TypedDict

consumed_samples: int#

None

current_step: int#

None

current_epoch: int#

None

total_steps: int#

None

total_valid_tokens: int#

None

val_reward: NotRequired[float]#

None

nemo_rl.algorithms.grpo._default_grpo_save_state() nemo_rl.algorithms.grpo.GRPOSaveState#
class nemo_rl.algorithms.grpo.GRPOLoggerConfig#

Bases: nemo_rl.utils.logger.LoggerConfig

num_val_samples_to_print: int#

None

class nemo_rl.algorithms.grpo.MasterConfig#

Bases: typing.TypedDict

policy: nemo_rl.models.policy.PolicyConfig#

None

loss_fn: nemo_rl.algorithms.loss_functions.ClippedPGLossConfig#

None

env: dict[str, Any]#

None

data: nemo_rl.data.DataConfig#

None

grpo: nemo_rl.algorithms.grpo.GRPOConfig#

None

logger: nemo_rl.algorithms.grpo.GRPOLoggerConfig#

None

cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#

None

checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#

None

nemo_rl.algorithms.grpo.setup(
master_config: nemo_rl.algorithms.grpo.MasterConfig,
tokenizer: nemo_rl.algorithms.grpo.TokenizerType,
dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
val_dataset: Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
processor: Optional[transformers.AutoProcessor] = None,
) tuple[nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, Optional[nemo_rl.models.generation.interfaces.GenerationInterface], tuple[nemo_rl.distributed.virtual_cluster.RayVirtualCluster, nemo_rl.distributed.virtual_cluster.RayVirtualCluster], torchdata.stateful_dataloader.StatefulDataLoader, Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.ClippedPGLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.grpo.GRPOSaveState, nemo_rl.algorithms.grpo.MasterConfig]#

Main entry point for running GRPO algorithm.

Returns:

tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader

nemo_rl.algorithms.grpo.dynamic_sampling(
repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
std: torch.Tensor,
baseline: torch.Tensor,
dynamic_sampling_num_gen_batches: int,
master_config: nemo_rl.algorithms.grpo.MasterConfig,
timer: nemo_rl.utils.timer.Timer,
batch_cache: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] = None,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec]#

Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation.

This function filters the current batch to retain only those prompts that have a non-zero standard deviation. If the current batch has fewer number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, we store it in the batch_cache to be used in later iterations. If the current batch has more number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, the batch is sliced to ensure batch size is num_prompts_per_step * num_generations_per_prompt. is_batch_complete is set to False to indicate that the current batch is not enough to meet the required batch size. This is used as a signal in the GRPO training loop to continue sampling or proceed to training. This approach is based on the dynamic sampling algorithm from the DAPO paper: https://arxiv.org/pdf/2503.14476.

Parameters:
  • repeated_batch (BatchedDataDict[DatumSpec]) – The current batch of data containing prompts, responses, rewards, baselines, and std.

  • std (torch.Tensor) – Tensor representing the standard deviation for each prompt group.

  • baseline (torch.Tensor) – Baseline values for each prompt group.

  • dynamic_sampling_num_gen_batches (int) – Number of generation batches processed at the current step.

  • master_config (MasterConfig) – Configuration containing GRPO and policy settings.

  • batch_cache (BatchedDataDict[DatumSpec], optional) – Cache storing previously selected prompts with non-zero std.

Returns:

A tuple containing: - repeated_batch (BatchedDataDict[DatumSpec]): Updated batch with selected prompts. - is_batch_complete (bool): Indicates if the batch has enough samples with non-zero std for training. - batch_cache (BatchedDataDict[DatumSpec]): Updated cache for future iterations.

Return type:

tuple

nemo_rl.algorithms.grpo.scale_rewards(
repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
reward_scaling_cfg: nemo_rl.algorithms.grpo.RewardScalingConfig,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec]#

Linearly scales rewards from a source range to a target range.

If reward_scaling.enabled is True, each reward in repeated_batch["total_reward"] is clamped to the configured source interval [source_min, source_max] and then rescaled to the target interval [target_min, target_max].

Default configuration: source_min = 0.0 source_max = 1.0 target_min = 0.0 target_max = 1.0

nemo_rl.algorithms.grpo._should_use_async_rollouts(
master_config: nemo_rl.algorithms.grpo.MasterConfig,
) bool#

Determine if async rollouts should be used based on the configuration.

Returns True if vLLM backend is used with async_engine enabled.

nemo_rl.algorithms.grpo.refit_policy_generation(
policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
colocated_inference: bool,
_refit_buffer_size_gb: Optional[int] = None,
timer: Optional[nemo_rl.utils.timer.Timer] = None,
) None#

Refit the policy generation interface with the latest policy weights.

Parameters:
  • policy – The policy to provide weights to the inference engine.

  • policy_generation – The inference engine to refit.

  • _refit_buffer_size_gb – The size of the buffer to use for refitting. If it is None, the buffer size will be computed by the remaining memory. This parameter is primarily used for testing.

  • timer – Optional Timer used to time the prepare/transfer/update phase

nemo_rl.algorithms.grpo.grpo_train(
policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
policy_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer: nemo_rl.algorithms.grpo.TokenizerType,
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
logger: nemo_rl.utils.logger.Logger,
checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState,
master_config: nemo_rl.algorithms.grpo.MasterConfig,
processor: Optional[transformers.AutoProcessor] = None,
) None#

Run GRPO training algorithm.

nemo_rl.algorithms.grpo.validate(
policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer,
val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
step: int,
master_config: nemo_rl.algorithms.grpo.MasterConfig,
) tuple[dict[str, Any], dict[str, Any]]#

Run validation on the validation dataset.

nemo_rl.algorithms.grpo.async_grpo_train(
policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
policy_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer: nemo_rl.algorithms.grpo.TokenizerType,
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
logger: nemo_rl.utils.logger.Logger,
checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState,
master_config: nemo_rl.algorithms.grpo.MasterConfig,
max_trajectory_age_steps: int = 1,
) None#

Run asynchronous GRPO training with replay buffer.

Parameters:
  • policy – Training policy

  • policy_generation – Generation interface

  • dataloader – Training data loader

  • val_dataloader – Validation data loader

  • tokenizer – Tokenizer

  • loss_fn – Loss function

  • task_to_env – Training environments

  • val_task_to_env – Validation environments

  • logger – Logger

  • checkpointer – Checkpoint manager

  • grpo_save_state – Training state

  • master_config – Master configuration

  • max_trajectory_age_steps – Maximum age (in training steps) for trajectories to be used in training