nemo_rl.algorithms.grpo
#
Module Contents#
Classes#
Functions#
Main entry point for running GRPO algorithm. |
|
Determine if async rollouts should be used based on the configuration. |
|
Refit the policy generation interface with the latest policy weights. |
|
Run GRPO training algorithm. |
|
Run validation on the validation dataset. |
Data#
API#
- nemo_rl.algorithms.grpo.TokenizerType#
βTypeVar(β¦)β
- class nemo_rl.algorithms.grpo.GRPOConfig[source]#
Bases:
typing.TypedDict
- num_prompts_per_step: int#
None
- num_generations_per_prompt: int#
None
- max_num_steps: int#
None
- max_rollout_turns: int#
None
- normalize_rewards: bool#
None
- use_leave_one_out_baseline: bool#
None
- val_period: int#
None
- val_batch_size: int#
None
- val_at_start: bool#
None
- max_val_samples: int#
None
- class nemo_rl.algorithms.grpo.GRPOSaveState[source]#
Bases:
typing.TypedDict
- step: int#
None
- val_reward: NotRequired[float]#
None
- consumed_samples: int#
None
- nemo_rl.algorithms.grpo._default_grpo_save_state() nemo_rl.algorithms.grpo.GRPOSaveState [source]#
- class nemo_rl.algorithms.grpo.GRPOLoggerConfig[source]#
Bases:
nemo_rl.utils.logger.LoggerConfig
- num_val_samples_to_print: int#
None
- class nemo_rl.algorithms.grpo.MasterConfig[source]#
Bases:
typing.TypedDict
- policy: nemo_rl.models.policy.PolicyConfig#
None
- loss_fn: nemo_rl.algorithms.loss_functions.ClippedPGLossConfig#
None
- env: dict[str, Any]#
None
- data: nemo_rl.data.DataConfig#
None
- grpo: nemo_rl.algorithms.grpo.GRPOConfig#
None
- logger: nemo_rl.algorithms.grpo.GRPOLoggerConfig#
None
- cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#
None
- checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#
None
- nemo_rl.algorithms.grpo.setup(
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
- tokenizer: nemo_rl.algorithms.grpo.TokenizerType,
- dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
- val_dataset: Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
Main entry point for running GRPO algorithm.
- Returns:
tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader
- nemo_rl.algorithms.grpo._should_use_async_rollouts(
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
Determine if async rollouts should be used based on the configuration.
Returns True if vLLM backend is used with async_engine enabled.
- nemo_rl.algorithms.grpo.refit_policy_generation(
- policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- colocated_inference: bool,
- _refit_buffer_size_gb: Optional[int] = None,
Refit the policy generation interface with the latest policy weights.
- Parameters:
policy β The policy to provide weights to the inference engine.
policy_generation β The inference engine to refit.
_refit_buffer_size_gb β The size of the buffer to use for refitting. If it is None, the buffer size will be computed by the remaining memory. This parameter is primarily used for testing.
- nemo_rl.algorithms.grpo.grpo_train(
- policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
- policy_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
- dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
- val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- tokenizer: nemo_rl.algorithms.grpo.TokenizerType,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
- logger: nemo_rl.utils.logger.Logger,
- checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
- grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState,
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
Run GRPO training algorithm.
- nemo_rl.algorithms.grpo.validate(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- tokenizer,
- val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
- step: int,
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
Run validation on the validation dataset.