nemo_rl.algorithms.grpo_sync#

GRPO trainer — TransferQueue-mediated path (sync).

Sibling fork of nemo_rl.algorithms.grpo. Each file has zero internal branching on whether TQ is engaged; the example script chooses one or the other based on data_plane.enabled.

Setup and helpers are re-imported from grpo; the training loop body is duplicated here so the per-step lifecycle hooks (register / seed-put / per-rank fetch / clear) can live in straight sequential code. Validation is implemented locally as :func:validate_sync — a TQ-mediated sibling of :func:nemo_rl.algorithms.grpo.validate that routes val rollouts through SyncRolloutActor.rollout_to_tq into a per-batch "val" partition.

Parity with the legacy path is verified by running the same config against both entrypoints and diffing the wandb runs.

Module Contents#

Functions#

_raise_if_message_level_advantage_penalties_enabled

Raise if message-level advantage penalties are set in the sync trainer.

_apply_dynamic_sampling

Process one dynamic-sampling iteration.

validate_sync

TQ-mediated counterpart to :func:nemo_rl.algorithms.grpo.validate.

_compute_seq_logprob_error_metrics

grpo_train_sync

Run GRPO training algorithm — TransferQueue-mediated.

API#

nemo_rl.algorithms.grpo_sync._raise_if_message_level_advantage_penalties_enabled(
master_config: nemo_rl.algorithms.grpo.MasterConfig,
) None#

Raise if message-level advantage penalties are set in the sync trainer.

Message-level advantage penalties are not supported with data_plane.enabled=true. Raises NotImplementedError listing the offending keys so the user can disable them or switch to the legacy GRPO trainer.

nemo_rl.algorithms.grpo_sync._apply_dynamic_sampling(
*,
meta: nemo_rl.data_plane.interfaces.KVBatchMeta,
driver_carry: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
pending_meta: Optional[nemo_rl.data_plane.interfaces.KVBatchMeta],
pending_carry: Optional[nemo_rl.distributed.batched_data_dict.BatchedDataDict],
pending_unfiltered_rewards: list[torch.Tensor],
train_prompts_size: int,
num_gen_batches: int,
max_gen_batches: int,
policy: nemo_rl.models.policy.tq_policy.TQPolicy,
) tuple[Optional[nemo_rl.data_plane.interfaces.KVBatchMeta], Optional[nemo_rl.distributed.batched_data_dict.BatchedDataDict], list[torch.Tensor], bool, dict[str, Any], Optional[torch.Tensor]]#

Process one dynamic-sampling iteration.

Drops zero-std (filtered) keys, merges survivors into the running pending cache, and reports whether the cache has reached train_prompts_size. When complete, the returned pending_* IS the training batch.

Parameters:
  • meta – This iteration’s KVBatchMeta.

  • driver_carry – Per-row driver-local tensors for this iteration (rewards, masks, prompt_ids_for_adv, baseline/std, …).

  • pending_meta – Survivors accumulated from prior iterations.

  • pending_carrydriver_carry rows aligned to pending_meta.

  • pending_unfiltered_rewards – All iterations’ rewards pre-filter, for legacy reward metric parity.

  • train_prompts_size – Target batch size.

  • num_gen_batches – Iteration counter (1-based).

  • max_gen_batches – Upper bound on iterations before raising.

  • policy – TQPolicy whose discard_samples is used to drop filtered keys.

Returns:

(pending_meta, pending_carry, pending_rewards, is_complete, ds_metrics, unfiltered_for_log).

nemo_rl.algorithms.grpo_sync.validate_sync(
*,
rollout_actor: nemo_rl.experience.sync_rollout_actor.SyncRolloutActor,
policy: nemo_rl.models.policy.tq_policy.TQPolicy,
val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
step: int,
master_config: nemo_rl.algorithms.grpo.MasterConfig,
logger: Optional[nemo_rl.utils.logger.Logger] = None,
partition_id: str = 'val',
) tuple[dict[str, Any], dict[str, Any]]#

TQ-mediated counterpart to :func:nemo_rl.algorithms.grpo.validate.

Per-batch: register the val partition → rollout_to_tqpolicy.read_from_dataplane for message logs → policy.finish_step. Caller owns policy_generation.prepare_for_generation / finish_generation around the call; the actor’s per-rollout finish_generation is suppressed so inference state stays warm across batches.

nemo_rl.algorithms.grpo_sync._compute_seq_logprob_error_metrics(
*,
token_mask: torch.Tensor,
sample_mask: torch.Tensor,
prev_logprobs: torch.Tensor,
generation_logprobs: torch.Tensor,
rewards: torch.Tensor,
seq_logprob_error_threshold: Optional[float],
) tuple[torch.Tensor, dict[str, Any]]#
nemo_rl.algorithms.grpo_sync.grpo_train_sync(
policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface,
policy_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface],
wrapped_dataloader,
val_dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader],
tokenizer,
loss_fn: nemo_rl.algorithms.loss.interfaces.LossFunction,
task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
val_task_to_env: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]],
logger: nemo_rl.utils.logger.Logger,
checkpointer: nemo_rl.utils.checkpoint.CheckpointManager,
grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState,
master_config: nemo_rl.algorithms.grpo.MasterConfig,
) None#

Run GRPO training algorithm — TransferQueue-mediated.

Body mirrors :func:nemo_rl.algorithms.grpo.grpo_train with TQ-mediated Policy methods substituting the in-memory dispatch. The TQ lifecycle (controller bootstrap, worker attach, partition register, fan-out, drain, close) is fully encapsulated in

Class:

nemo_rl.models.policy.tq_policy.TQPolicy — this trainer just calls policy.prepare_step, policy.get_logprobs, policy.get_reference_policy_logprobs, and policy.train.

Parity with the legacy path is verified by running the same config against both entrypoints and diffing the wandb runs.