nemo_rl.algorithms.grpo_sync#
GRPO trainer — TransferQueue-mediated path (sync).
Sibling fork of nemo_rl.algorithms.grpo. Each file has zero
internal branching on whether TQ is engaged; the example script
chooses one or the other based on data_plane.enabled.
Setup and helpers are re-imported from grpo; the training loop body
is duplicated here so the per-step lifecycle hooks (register / seed-put
/ per-rank fetch / clear) can live in straight sequential code.
Validation is implemented locally as :func:validate_sync — a
TQ-mediated sibling of :func:nemo_rl.algorithms.grpo.validate that
routes val rollouts through SyncRolloutActor.rollout_to_tq into a
per-batch "val" partition.
Parity with the legacy path is verified by running the same config against both entrypoints and diffing the wandb runs.
Module Contents#
Functions#
Raise if message-level advantage penalties are set in the sync trainer. |
|
Process one dynamic-sampling iteration. |
|
TQ-mediated counterpart to :func: |
|
Run GRPO training algorithm — TransferQueue-mediated. |
API#
- nemo_rl.algorithms.grpo_sync._raise_if_message_level_advantage_penalties_enabled(
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
Raise if message-level advantage penalties are set in the sync trainer.
Message-level advantage penalties are not supported with
data_plane.enabled=true. Raises NotImplementedError listing the offending keys so the user can disable them or switch to the legacy GRPO trainer.
- nemo_rl.algorithms.grpo_sync._apply_dynamic_sampling(
- *,
- meta: nemo_rl.data_plane.interfaces.KVBatchMeta,
- driver_carry: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
- pending_meta: Optional[nemo_rl.data_plane.interfaces.KVBatchMeta],
- pending_carry: Optional[nemo_rl.distributed.batched_data_dict.BatchedDataDict],
- pending_unfiltered_rewards: list[torch.Tensor],
- train_prompts_size: int,
- num_gen_batches: int,
- max_gen_batches: int,
- policy: nemo_rl.models.policy.tq_policy.TQPolicy,
Process one dynamic-sampling iteration.
Drops zero-std (filtered) keys, merges survivors into the running pending cache, and reports whether the cache has reached
train_prompts_size. When complete, the returnedpending_*IS the training batch.- Parameters:
meta – This iteration’s
KVBatchMeta.driver_carry – Per-row driver-local tensors for this iteration (rewards, masks, prompt_ids_for_adv, baseline/std, …).
pending_meta – Survivors accumulated from prior iterations.
pending_carry –
driver_carryrows aligned topending_meta.pending_unfiltered_rewards – All iterations’ rewards pre-filter, for legacy reward metric parity.
train_prompts_size – Target batch size.
num_gen_batches – Iteration counter (1-based).
max_gen_batches – Upper bound on iterations before raising.
policy – TQPolicy whose
discard_samplesis used to drop filtered keys.
- Returns:
(pending_meta, pending_carry, pending_rewards, is_complete, ds_metrics, unfiltered_for_log).
- nemo_rl.algorithms.grpo_sync.validate_sync(
- *,
- rollout_actor: nemo_rl.experience.sync_rollout_actor.SyncRolloutActor,
- policy: nemo_rl.models.policy.tq_policy.TQPolicy,
- val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
- step: int,
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
- logger: Optional[nemo_rl.utils.logger.Logger] = None,
- partition_id: str = 'val',
TQ-mediated counterpart to :func:
nemo_rl.algorithms.grpo.validate.Per-batch: register the val partition →
rollout_to_tq→policy.read_from_dataplanefor message logs →policy.finish_step. Caller ownspolicy_generation.prepare_for_generation/finish_generationaround the call; the actor’s per-rolloutfinish_generationis suppressed so inference state stays warm across batches.
- nemo_rl.algorithms.grpo_sync._compute_seq_logprob_error_metrics(
- *,
- token_mask: torch.Tensor,
- sample_mask: torch.Tensor,
- prev_logprobs: torch.Tensor,
- generation_logprobs: torch.Tensor,
- rewards: torch.Tensor,
- seq_logprob_error_threshold: Optional[float],
- nemo_rl.algorithms.grpo_sync.grpo_train_sync(
- policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
- policy_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
- wrapped_dataloader,
- val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
- tokenizer,
- loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
- logger: nemo_rl.utils.logger.Logger,
- checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
- grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState,
- master_config: nemo_rl.algorithms.grpo.MasterConfig,
Run GRPO training algorithm — TransferQueue-mediated.
Body mirrors :func:
nemo_rl.algorithms.grpo.grpo_trainwith TQ-mediated Policy methods substituting the in-memory dispatch. The TQ lifecycle (controller bootstrap, worker attach, partition register, fan-out, drain, close) is fully encapsulated in- Class:
nemo_rl.models.policy.tq_policy.TQPolicy— this trainer just callspolicy.prepare_step,policy.get_logprobs,policy.get_reference_policy_logprobs, andpolicy.train.
Parity with the legacy path is verified by running the same config against both entrypoints and diffing the wandb runs.