nemo_rl.experience.sync_rollout_actor#
Sync GRPO rollout actor — sibling of async_utils.
Houses :class:SyncRolloutActor, the Ray actor that owns the multi-turn
rollout loop AND the post-rollout flatten / mask / prompt extraction /
reward shaping / baseline-std for a sync GRPO step. The driver dispatches
a per-step prompt batch + uids; the actor runs run_multi_turn_rollout
(or async / nemo_gym variants), then writes the bulk schema to TQ via
- func:
nemo_rl.data_plane.column_io.kv_first_write. Only aKVBatchMetaand a small per-sampledriver_carrydict (rewards, masks, lengths, baseline/std, prompt_ids_for_adv) cross back to the driver via Ray.
Goal — rollout 1-hop put: bulk tensors (input_ids, output_ids,
attention_mask, position_ids, multi_modal_inputs, generation_logprobs,
token_mask) stay actor-side until put_samples, then live only in
TQ. Driver never holds these bytes between rollout finish and train
fan-out.
The actor is the sync counterpart to
- class:
nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector. It intentionally does not buffer or stream — sync GRPO consumes the whole step batch in one call.
Module Contents#
Classes#
Per-step rollout dispatcher. |
Functions#
Prepare rollout message logs for the TQ payload and driver carry. |
Data#
API#
- nemo_rl.experience.sync_rollout_actor.OPT_IN_CARRY_KEYS: tuple[str, ...]#
(‘turn_roles’, ‘turn_contents’)
- nemo_rl.experience.sync_rollout_actor._flatten_rollout_message_log_for_tq(
- message_logs: list[Any],
- prompt_lengths: torch.Tensor,
- *,
- pad_token_id: int,
- make_sequence_length_divisible_by: int,
Prepare rollout message logs for the TQ payload and driver carry.
- class nemo_rl.experience.sync_rollout_actor.SyncRolloutActor(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- tokenizer: Any,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- master_config: Any,
- dp_cfg: dict[str, Any],
Per-step rollout dispatcher.
Runs: rollout + flatten + mask + prompt extraction + baseline/std + TQ put. Returns
(meta, driver_carry, rollout_metrics, gen_metrics).Lifecycle: one instance per
grpo_train_syncinvocation. The driver instantiates with the same handles it would normally pass torun_multi_turn_rolloutplus the data-plane config so the actor can attach as a TQ client (bootstrap=False— controller is bootstrapped on the driver viaTQPolicy).Initialization
- rollout_to_tq(
- input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- *,
- partition_id: str,
- group_size: int = 1,
- first_iter: bool = True,
- finish_generation: bool = True,
- task_to_env_override: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]] = None,
- carry_keys: Optional[list[str]] = None,
Run the full per-step generation cycle and write bulk data to TQ.
Bundles six steps into one Ray round-trip so the driver only sees a single RPC instead of separate calls for each:
Reset metrics —
policy_generation.clear_logger_metrics()clears per-step generation accumulators before the rollout.Rollout — runs
run_multi_turn_rollout(or the async / nemo-gym variants) to producefinal_batch.Flatten + mask + prompt extraction — converts
message_loglayout to flat tensors; builds token mask, sample mask, prompt-only ids, baseline/std.Write bulk to TQ —
kv_first_writeputs every tensor field in one flatput_samples; the driver never touches bulk bytes.Release GPU —
policy_generation.finish_generation()frees KV cache and inference state so the trainer can use the GPU immediately.Capture metrics —
policy_generation.get_logger_metrics()collects generation stats (throughput, etc.) and returns them to the driver in the result tuple.
The driver receives
(meta, driver_carry, rollout_metrics, generation_logger_metrics)and usesdriver_carryfor its own per-row compute (rewards, advantages, dynamic sampling).- Parameters:
input_batch – Per-step prompt batch (already repeat-interleaved).
partition_id – TQ partition target.
group_size – Rollouts per original prompt. One uid is minted per prompt; bulk keys are
f"{uid}_g{i}"whereiranges over the per-prompt expansion (group × rollout turns). Train passesnum_generations_per_prompt; val passes1.first_iter – True on the first DS iteration of a step; drives
policy_generation.snapshot_step_metrics()so per-step metrics align with the legacygrpo.grpo_trainpath.finish_generation – Call
policy_generation.finish_generation()at the tail. DefaultTruematches the training step (one rollout per step, release KV after). Validation setsFalseso inference state survives across val batches; the trainer owns the explicitfinish_generation()call at the end of the val pass.task_to_env_override – Per-call task → env map.
Noneusesself.task_to_env(training envs supplied at construction). Validation passesval_task_to_envhere so val rollouts run against the val env set without rebuilding the actor.carry_keys – Names of per-row tensors to return in
driver_carry.Nonereturns every available key (training uses this). Validation passes a slim list (e.g.["total_reward"]) to avoid wasting Ray transfer on fields it doesn’t consume.
- Returns:
(meta, driver_carry, rollout_metrics, generation_logger_metrics)wheredriver_carryis a per-row dict of tensors the driver uses for compute (rewards, masks, lengths, prompt_ids_for_adv, …) — stays on the driver, never crosses an actor boundary.
- shutdown() None#