nemo_rl.algorithms.grpo
#
Module Contents#
Classes#
Functions#
Main entry point for running GRPO algorithm. |
|
Refit the policy generation interface with the latest policy weights. |
|
Run GRPO training algorithm. |
|
Run validation on the validation dataset. |
API#
- class nemo_rl.algorithms.grpo.GRPOConfig[source]#
Bases:
typing.TypedDict
- num_prompts_per_step: int#
None
- num_generations_per_prompt: int#
None
- max_num_steps: int#
None
- normalize_rewards: bool#
None
- use_leave_one_out_baseline: bool#
None
- val_period: int#
None
- val_batch_size: int#
None
- val_at_start: bool#
None
- checkpoint_dir: str#
None
- class nemo_rl.algorithms.grpo.GRPOSaveState[source]#
Bases:
typing.TypedDict
- step: int#
None
- val_reward: float#
None
- consumed_samples: int#
None
- nemo_rl.algorithms.grpo._default_grpo_save_state() nemo_rl.algorithms.grpo.GRPOSaveState [source]#
- class nemo_rl.algorithms.grpo.MasterConfig[source]#
Bases:
typing.TypedDict
- policy: nemo_rl.models.policy.PolicyConfig#
None
- loss_fn: nemo_rl.algorithms.loss_functions.ClippedPGLossConfig#
None
- env_configs: Dict[str, Any]#
None
- data: nemo_rl.data.DataConfig#
None
- grpo: nemo_rl.algorithms.grpo.GRPOConfig#
None
- logger: nemo_rl.utils.logger.LoggerConfig#
None
- cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#
None
- checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#
None
- nemo_rl.algorithms.grpo.setup(
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
- tokenizer: transformers.AutoTokenizer,
- dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
- val_dataset: Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
Main entry point for running GRPO algorithm.
- Returns:
Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader
- nemo_rl.algorithms.grpo.refit_policy_generation(
- policy: nemo_rl.models.interfaces.PolicyInterface,
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- refit_buffer_size_gb: int,
Refit the policy generation interface with the latest policy weights.
- nemo_rl.algorithms.grpo.grpo_train(
- policy: nemo_rl.models.interfaces.PolicyInterface,
- policy_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
- dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
- val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- tokenizer,
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- task_to_env: Dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- val_task_to_env: Optional[Dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
- logger: nemo_rl.utils.logger.Logger,
- checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
- grpo_save_state: Optional[nemo_rl.algorithms.grpo.GRPOSaveState],
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
Run GRPO training algorithm.
- nemo_rl.algorithms.grpo.validate(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
- tokenizer,
- val_task_to_env: Dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- step: int,
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
Run validation on the validation dataset.