nemo_rl.algorithms.grpo#

Module Contents#

Classes#

Functions#

_default_grpo_save_state

setup

Main entry point for running GRPO algorithm.

_should_use_async_rollouts

Determine if async rollouts should be used based on the configuration.

refit_policy_generation

Refit the policy generation interface with the latest policy weights.

grpo_train

Run GRPO training algorithm.

validate

Run validation on the validation dataset.

Data#

API#

nemo_rl.algorithms.grpo.TokenizerType#

β€˜TypeVar(…)’

class nemo_rl.algorithms.grpo.GRPOConfig[source]#

Bases: typing.TypedDict

num_prompts_per_step: int#

None

num_generations_per_prompt: int#

None

max_num_steps: int#

None

max_rollout_turns: int#

None

normalize_rewards: bool#

None

use_leave_one_out_baseline: bool#

None

val_period: int#

None

val_batch_size: int#

None

val_at_start: bool#

None

max_val_samples: int#

None

class nemo_rl.algorithms.grpo.GRPOSaveState[source]#

Bases: typing.TypedDict

step: int#

None

val_reward: NotRequired[float]#

None

consumed_samples: int#

None

nemo_rl.algorithms.grpo._default_grpo_save_state() nemo_rl.algorithms.grpo.GRPOSaveState[source]#
class nemo_rl.algorithms.grpo.GRPOLoggerConfig[source]#

Bases: nemo_rl.utils.logger.LoggerConfig

num_val_samples_to_print: int#

None

class nemo_rl.algorithms.grpo.MasterConfig[source]#

Bases: typing.TypedDict

policy: nemo_rl.models.policy.PolicyConfig#

None

loss_fn: nemo_rl.algorithms.loss_functions.ClippedPGLossConfig#

None

env: dict[str, Any]#

None

data: nemo_rl.data.DataConfig#

None

grpo: nemo_rl.algorithms.grpo.GRPOConfig#

None

logger: nemo_rl.algorithms.grpo.GRPOLoggerConfig#

None

cluster: nemo_rl.distributed.virtual_cluster.ClusterConfig#

None

checkpointing: nemo_rl.utils.checkpoint.CheckpointingConfig#

None

nemo_rl.algorithms.grpo.setup(
master_config: nemo_rl.algorithms.grpo.MasterConfig,
tokenizer: nemo_rl.algorithms.grpo.TokenizerType,
dataset: nemo_rl.data.datasets.AllTaskProcessedDataset,
val_dataset: Optional[nemo_rl.data.datasets.AllTaskProcessedDataset],
) tuple[nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, Optional[nemo_rl.models.generation.interfaces.GenerationInterface], tuple[nemo_rl.distributed.virtual_cluster.RayVirtualCluster, nemo_rl.distributed.virtual_cluster.RayVirtualCluster], torchdata.stateful_dataloader.StatefulDataLoader, Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.ClippedPGLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.grpo.GRPOSaveState, nemo_rl.algorithms.grpo.MasterConfig][source]#

Main entry point for running GRPO algorithm.

Returns:

tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader

nemo_rl.algorithms.grpo._should_use_async_rollouts(
master_config: nemo_rl.algorithms.grpo.MasterConfig,
) bool[source]#

Determine if async rollouts should be used based on the configuration.

Returns True if vLLM backend is used with async_engine enabled.

nemo_rl.algorithms.grpo.refit_policy_generation(
policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
colocated_inference: bool,
_refit_buffer_size_gb: Optional[int] = None,
) None[source]#

Refit the policy generation interface with the latest policy weights.

Parameters:
  • policy – The policy to provide weights to the inference engine.

  • policy_generation – The inference engine to refit.

  • _refit_buffer_size_gb – The size of the buffer to use for refitting. If it is None, the buffer size will be computed by the remaining memory. This parameter is primarily used for testing.

nemo_rl.algorithms.grpo.grpo_train(
policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
policy_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
dataloader: torchdata.stateful_dataloader.StatefulDataLoader,
val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer: nemo_rl.algorithms.grpo.TokenizerType,
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
logger: nemo_rl.utils.logger.Logger,
checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState,
master_config: nemo_rl.algorithms.grpo.MasterConfig,
) None[source]#

Run GRPO training algorithm.

nemo_rl.algorithms.grpo.validate(
policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer,
val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
step: int,
master_config: nemo_rl.algorithms.grpo.MasterConfig,
) tuple[dict[str, Any], dict[str, Any]][source]#

Run validation on the validation dataset.