nemo_rl.experience.sync_rollout_actor#

Sync GRPO rollout actor — sibling of async_utils.

Houses :class:SyncRolloutActor, the Ray actor that owns the multi-turn rollout loop AND the post-rollout flatten / mask / prompt extraction / reward shaping / baseline-std for a sync GRPO step. The driver dispatches a per-step prompt batch + uids; the actor runs run_multi_turn_rollout (or async / nemo_gym variants), then writes the bulk schema to TQ via

func:

nemo_rl.data_plane.column_io.kv_first_write. Only a KVBatchMeta and a small per-sample driver_carry dict (rewards, masks, lengths, baseline/std, prompt_ids_for_adv) cross back to the driver via Ray.

Goal — rollout 1-hop put: bulk tensors (input_ids, output_ids, attention_mask, position_ids, multi_modal_inputs, generation_logprobs, token_mask) stay actor-side until put_samples, then live only in TQ. Driver never holds these bytes between rollout finish and train fan-out.

The actor is the sync counterpart to

class:

nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector. It intentionally does not buffer or stream — sync GRPO consumes the whole step batch in one call.

Module Contents#

Classes#

SyncRolloutActor

Per-step rollout dispatcher.

Functions#

_flatten_rollout_message_log_for_tq

Prepare rollout message logs for the TQ payload and driver carry.

Data#

API#

nemo_rl.experience.sync_rollout_actor.OPT_IN_CARRY_KEYS: tuple[str, ...]#

(‘turn_roles’, ‘turn_contents’)

nemo_rl.experience.sync_rollout_actor._flatten_rollout_message_log_for_tq(
message_logs: list[Any],
prompt_lengths: torch.Tensor,
*,
pad_token_id: int,
make_sequence_length_divisible_by: int,
) tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any], torch.Tensor, nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]]#

Prepare rollout message logs for the TQ payload and driver carry.

class nemo_rl.experience.sync_rollout_actor.SyncRolloutActor(
policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
tokenizer: Any,
task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
master_config: Any,
dp_cfg: dict[str, Any],
)#

Per-step rollout dispatcher.

Runs: rollout + flatten + mask + prompt extraction + baseline/std + TQ put. Returns (meta, driver_carry, rollout_metrics, gen_metrics).

Lifecycle: one instance per grpo_train_sync invocation. The driver instantiates with the same handles it would normally pass to run_multi_turn_rollout plus the data-plane config so the actor can attach as a TQ client (bootstrap=False — controller is bootstrapped on the driver via TQPolicy).

Initialization

rollout_to_tq(
input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
*,
partition_id: str,
group_size: int = 1,
first_iter: bool = True,
finish_generation: bool = True,
task_to_env_override: Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]] = None,
carry_keys: Optional[list[str]] = None,
) tuple[nemo_rl.data_plane.interfaces.KVBatchMeta, dict[str, Any], dict[str, Any], Optional[dict[str, Any]]]#

Run the full per-step generation cycle and write bulk data to TQ.

Bundles six steps into one Ray round-trip so the driver only sees a single RPC instead of separate calls for each:

  1. Reset metricspolicy_generation.clear_logger_metrics() clears per-step generation accumulators before the rollout.

  2. Rollout — runs run_multi_turn_rollout (or the async / nemo-gym variants) to produce final_batch.

  3. Flatten + mask + prompt extraction — converts message_log layout to flat tensors; builds token mask, sample mask, prompt-only ids, baseline/std.

  4. Write bulk to TQkv_first_write puts every tensor field in one flat put_samples; the driver never touches bulk bytes.

  5. Release GPUpolicy_generation.finish_generation() frees KV cache and inference state so the trainer can use the GPU immediately.

  6. Capture metricspolicy_generation.get_logger_metrics() collects generation stats (throughput, etc.) and returns them to the driver in the result tuple.

The driver receives (meta, driver_carry, rollout_metrics, generation_logger_metrics) and uses driver_carry for its own per-row compute (rewards, advantages, dynamic sampling).

Parameters:
  • input_batch – Per-step prompt batch (already repeat-interleaved).

  • partition_id – TQ partition target.

  • group_size – Rollouts per original prompt. One uid is minted per prompt; bulk keys are f"{uid}_g{i}" where i ranges over the per-prompt expansion (group × rollout turns). Train passes num_generations_per_prompt; val passes 1.

  • first_iter – True on the first DS iteration of a step; drives policy_generation.snapshot_step_metrics() so per-step metrics align with the legacy grpo.grpo_train path.

  • finish_generation – Call policy_generation.finish_generation() at the tail. Default True matches the training step (one rollout per step, release KV after). Validation sets False so inference state survives across val batches; the trainer owns the explicit finish_generation() call at the end of the val pass.

  • task_to_env_override – Per-call task → env map. None uses self.task_to_env (training envs supplied at construction). Validation passes val_task_to_env here so val rollouts run against the val env set without rebuilding the actor.

  • carry_keys – Names of per-row tensors to return in driver_carry. None returns every available key (training uses this). Validation passes a slim list (e.g. ["total_reward"]) to avoid wasting Ray transfer on fields it doesn’t consume.

Returns:

(meta, driver_carry, rollout_metrics, generation_logger_metrics) where driver_carry is a per-row dict of tensors the driver uses for compute (rewards, masks, lengths, prompt_ids_for_adv, …) — stays on the driver, never crosses an actor boundary.

shutdown() None#