nemo_rl.algorithms.grpo#

Module Contents#

Classes#

Functions#

_default_grpo_save_state

setup

Main entry point for running GRPO algorithm.

refit_policy_generation

Refit the policy generation interface with the latest policy weights.

grpo_train

Run GRPO training algorithm.

validate

Run validation on the validation dataset.

API#

class nemo_rl.algorithms.grpo.GRPOConfig[source]#

Bases: typing.TypedDict

num_prompts_per_step: int#

None

num_generations_per_prompt: int#

None

max_num_steps: int#

None

normalize_rewards: bool#

None

use_leave_one_out_baseline: bool#

None

val_period: int#

None

val_batch_size: int#

None

val_at_start: bool#

None

checkpoint_dir: str#

None

class nemo_rl.algorithms.grpo.GRPOSaveState[source]#

Bases: typing.TypedDict

step: int#

None

val_reward: float#

None

consumed_samples: int#

None

nemo_rl.algorithms.grpo._default_grpo_save_state() nemo_rl.algorithms.grpo.GRPOSaveState[source]#
class nemo_rl.algorithms.grpo.MasterConfig[source]#

Bases: typing.TypedDict

policy: nemo_rl.models.policy.PolicyConfig#

None

loss_fn: nemo_rl.algorithms.loss_functions.ClippedPGLossConfig#

None

env_configs: Dict[str, Any]#

None

data: nemo_rl.data.DataConfig#

None

grpo: nemo_rl.algorithms.grpo.GRPOConfig#

None

logger: nemo_rl.utils.logger.LoggerConfig#

None

cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#

None

checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#

None

nemo_rl.algorithms.grpo.setup(
master_config: nemo_rl.algorithms.grpo.MasterConfig,
tokenizer: transformers.AutoTokenizer,
dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
val_dataset: Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
) Tuple[nemo_rl.models.interfaces.PolicyInterface, nemo_rl.models.generation.interfaces.GenerationInterface, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.ClippedPGLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.grpo.GRPOSaveState, nemo_rl.algorithms.grpo.MasterConfig][source]#

Main entry point for running GRPO algorithm.

Returns:

Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader

nemo_rl.algorithms.grpo.refit_policy_generation(
policy: nemo_rl.models.interfaces.PolicyInterface,
policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
refit_buffer_size_gb: int,
)[source]#

Refit the policy generation interface with the latest policy weights.

nemo_rl.algorithms.grpo.grpo_train(
policy: nemo_rl.models.interfaces.PolicyInterface,
policy_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer,
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
task_to_env: Dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
val_task_to_env: Optional[Dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
logger: nemo_rl.utils.logger.Logger,
checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
grpo_save_state: Optional[nemo_rl.algorithms.grpo.GRPOSaveState],
master_config: nemo_rl.algorithms.grpo.MasterConfig,
)[source]#

Run GRPO training algorithm.

nemo_rl.algorithms.grpo.validate(
policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
tokenizer,
val_task_to_env: Dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
step: int,
master_config: nemo_rl.algorithms.grpo.MasterConfig,
) Tuple[Dict[str, Any], Dict[str, Any]][source]#

Run validation on the validation dataset.