nemo_rl.algorithms.grpo#
Module Contents#
Classes#
Configure linear reward scaling with clamping. |
|
Functions#
Main entry point for running GRPO algorithm. |
|
Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation. |
|
Linearly scales rewards from a source range to a target range. |
|
Determine if async rollouts should be used based on the configuration. |
|
Refit the policy generation interface with the latest policy weights. |
|
Run GRPO training algorithm. |
|
Run validation on the validation dataset. |
|
Run asynchronous GRPO training with replay buffer. |
Data#
API#
- nemo_rl.algorithms.grpo.TokenizerType#
βTypeVar(β¦)β
- class nemo_rl.algorithms.grpo.RewardScalingConfig#
Bases:
typing.TypedDictConfigure linear reward scaling with clamping.
When
enabledis True, each reward is clamped to the source interval [source_min, source_max] and linearly mapped to the target interval [target_min, target_max]. Refer to the scale_rewards function for the implementation.Defaults: source_min=0.0, source_max=1.0, target_min=0.0, target_max=1.0
Initialization
Initialize self. See help(type(self)) for accurate signature.
- enabled: bool#
None
- source_min: NotRequired[float]#
None
- source_max: NotRequired[float]#
None
- target_min: NotRequired[float]#
None
- target_max: NotRequired[float]#
None
- class nemo_rl.algorithms.grpo.AsyncGRPOConfig#
Bases:
typing.TypedDict- enabled: bool#
None
- max_trajectory_age_steps: int#
None
- class nemo_rl.algorithms.grpo.GRPOConfig#
Bases:
typing.TypedDict- num_prompts_per_step: int#
None
- num_generations_per_prompt: int#
None
- max_num_epochs: int#
None
- max_num_steps: int#
None
- max_rollout_turns: int#
None
- normalize_rewards: bool#
None
- use_leave_one_out_baseline: bool#
None
- val_period: int#
None
- val_batch_size: int#
None
- val_at_start: bool#
None
- max_val_samples: int#
None
- seed: int#
None
- async_grpo: NotRequired[nemo_rl.algorithms.grpo.AsyncGRPOConfig]#
None
- overlong_filtering: NotRequired[bool]#
None
- use_dynamic_sampling: bool#
None
- dynamic_sampling_max_gen_batches: NotRequired[int]#
None
- batch_multiplier: NotRequired[float]#
None
- reward_shaping: nemo_rl.algorithms.reward_functions.RewardShapingConfig#
None
- reward_scaling: nemo_rl.algorithms.grpo.RewardScalingConfig#
None
- class nemo_rl.algorithms.grpo.GRPOSaveState#
Bases:
typing.TypedDict- consumed_samples: int#
None
- current_step: int#
None
- current_epoch: int#
None
- total_steps: int#
None
- total_valid_tokens: int#
None
- val_reward: NotRequired[float]#
None
- nemo_rl.algorithms.grpo._default_grpo_save_state() nemo_rl.algorithms.grpo.GRPOSaveState#
- class nemo_rl.algorithms.grpo.GRPOLoggerConfig#
Bases:
nemo_rl.utils.logger.LoggerConfig- num_val_samples_to_print: int#
None
- class nemo_rl.algorithms.grpo.MasterConfig#
Bases:
typing.TypedDict- policy: nemo_rl.models.policy.PolicyConfig#
None
- loss_fn: nemo_rl.algorithms.loss_functions.ClippedPGLossConfig#
None
- env: dict[str, Any]#
None
- data: nemo_rl.data.DataConfig#
None
- grpo: nemo_rl.algorithms.grpo.GRPOConfig#
None
- logger: nemo_rl.algorithms.grpo.GRPOLoggerConfig#
None
- cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#
None
- checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#
None
- nemo_rl.algorithms.grpo.setup(
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
- tokenizer: nemo_rl.algorithms.grpo.TokenizerType,
- dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
- val_dataset: Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
- processor: Optional[transformers.AutoProcessor] = None,
Main entry point for running GRPO algorithm.
- Returns:
tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader
- nemo_rl.algorithms.grpo.dynamic_sampling(
- repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
- std: torch.Tensor,
- baseline: torch.Tensor,
- dynamic_sampling_num_gen_batches: int,
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
- timer: nemo_rl.utils.timer.Timer,
- batch_cache: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] = None,
Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation.
This function filters the current batch to retain only those prompts that have a non-zero standard deviation. If the current batch has fewer number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, we store it in the batch_cache to be used in later iterations. If the current batch has more number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, the batch is sliced to ensure batch size is num_prompts_per_step * num_generations_per_prompt. is_batch_complete is set to False to indicate that the current batch is not enough to meet the required batch size. This is used as a signal in the GRPO training loop to continue sampling or proceed to training. This approach is based on the dynamic sampling algorithm from the DAPO paper: https://arxiv.org/pdf/2503.14476.
- Parameters:
repeated_batch (BatchedDataDict[DatumSpec]) β The current batch of data containing prompts, responses, rewards, baselines, and std.
std (torch.Tensor) β Tensor representing the standard deviation for each prompt group.
baseline (torch.Tensor) β Baseline values for each prompt group.
dynamic_sampling_num_gen_batches (int) β Number of generation batches processed at the current step.
master_config (MasterConfig) β Configuration containing GRPO and policy settings.
batch_cache (BatchedDataDict[DatumSpec], optional) β Cache storing previously selected prompts with non-zero std.
- Returns:
A tuple containing: - repeated_batch (BatchedDataDict[DatumSpec]): Updated batch with selected prompts. - is_batch_complete (bool): Indicates if the batch has enough samples with non-zero std for training. - batch_cache (BatchedDataDict[DatumSpec]): Updated cache for future iterations.
- Return type:
tuple
- nemo_rl.algorithms.grpo.scale_rewards(
- repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
- reward_scaling_cfg: nemo_rl.algorithms.grpo.RewardScalingConfig,
Linearly scales rewards from a source range to a target range.
If
reward_scaling.enabledis True, each reward inrepeated_batch["total_reward"]is clamped to the configured source interval [source_min, source_max] and then rescaled to the target interval [target_min, target_max].Default configuration: source_min = 0.0 source_max = 1.0 target_min = 0.0 target_max = 1.0
- nemo_rl.algorithms.grpo._should_use_async_rollouts(
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
Determine if async rollouts should be used based on the configuration.
Returns True if vLLM backend is used with async_engine enabled.
- nemo_rl.algorithms.grpo.refit_policy_generation(
- policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- colocated_inference: bool,
- _refit_buffer_size_gb: Optional[int] = None,
- timer: Optional[nemo_rl.utils.timer.Timer] = None,
Refit the policy generation interface with the latest policy weights.
- Parameters:
policy β The policy to provide weights to the inference engine.
policy_generation β The inference engine to refit.
_refit_buffer_size_gb β The size of the buffer to use for refitting. If it is None, the buffer size will be computed by the remaining memory. This parameter is primarily used for testing.
timer β Optional Timer used to time the prepare/transfer/update phase
- nemo_rl.algorithms.grpo.grpo_train(
- policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
- policy_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
- dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
- val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- tokenizer: nemo_rl.algorithms.grpo.TokenizerType,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
- logger: nemo_rl.utils.logger.Logger,
- checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
- grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState,
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
- processor: Optional[transformers.AutoProcessor] = None,
Run GRPO training algorithm.
- nemo_rl.algorithms.grpo.validate(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- tokenizer,
- val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
- step: int,
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
Run validation on the validation dataset.
- nemo_rl.algorithms.grpo.async_grpo_train(
- policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
- policy_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
- dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
- val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- tokenizer: nemo_rl.algorithms.grpo.TokenizerType,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
- logger: nemo_rl.utils.logger.Logger,
- checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
- grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState,
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
- max_trajectory_age_steps: int = 1,
Run asynchronous GRPO training with replay buffer.
- Parameters:
policy β Training policy
policy_generation β Generation interface
dataloader β Training data loader
val_dataloader β Validation data loader
tokenizer β Tokenizer
loss_fn β Loss function
task_to_env β Training environments
val_task_to_env β Validation environments
logger β Logger
checkpointer β Checkpoint manager
grpo_save_state β Training state
master_config β Master configuration
max_trajectory_age_steps β Maximum age (in training steps) for trajectories to be used in training