nemo_rl.experience.rollout_manager#
Module Contents#
Classes#
Manages per-prompt multi-turn rollouts, producing a PromptGroupRecord per call. |
|
Manages per-prompt NeMo-Gym rollouts, producing a PromptGroupRecord per call. |
|
Factory that routes to AsyncRolloutImpl (native async) or AsyncNemoGymRolloutImpl (NeMo-Gym). |
Data#
API#
- nemo_rl.experience.rollout_manager.TokenizerType#
None
- class nemo_rl.experience.rollout_manager.AsyncRolloutImpl(
- tokenizer: nemo_rl.experience.rollout_manager.TokenizerType,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- num_generations_per_prompt: int,
- max_seq_len: int,
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- max_rollout_turns: int = 999999,
- **kwargs: Any,
Manages per-prompt multi-turn rollouts, producing a PromptGroupRecord per call.
Each run_rollout takes one prompt and returns num_generations_per_prompt completions generated concurrently via asyncio.gather.
Initialization
- async run_rollout(
- input_sample: nemo_rl.data.interfaces.DatumSpec,
Run num_generations_per_prompt rollouts for one prompt.
- Parameters:
input_sample – A single prompt (one DatumSpec entry).
- Returns:
PromptGroupRecord with num_generations_per_prompt completions.
- async _run_single_rollout(
- input_sample: nemo_rl.data.interfaces.DatumSpec,
- traj_idx: int,
Run one multi-turn rollout for a single generation index.
- async _generate_response(
- message_log: list[dict],
- stop_strings: list[str] | None,
Generate a single-turn response for one sample.
- Returns:
Tuple of (assistant_message, input_lengths, gen_metrics)
- _aggregate_rollout_metrics(
- completions: list[nemo_rl.experience.interfaces.Completion],
- all_sample_metrics: list[dict],
Aggregate per-sample metrics across all completions.
- class nemo_rl.experience.rollout_manager.AsyncNemoGymRolloutImpl(
- tokenizer: nemo_rl.experience.rollout_manager.TokenizerType,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- num_generations_per_prompt: int,
- max_seq_len: int,
- generation_config: nemo_rl.models.generation.interfaces.GenerationConfig,
- max_rollout_turns: Optional[int] = None,
- **kwargs: Any,
Manages per-prompt NeMo-Gym rollouts, producing a PromptGroupRecord per call.
Each run_rollout takes one prompt and returns num_generations_per_prompt completions batched through a single NeMo-Gym run_rollouts call.
Initialization
- async run_rollout(
- input_sample: nemo_rl.data.interfaces.DatumSpec,
Run num_generations_per_prompt rollouts for one prompt.
- Parameters:
input_sample – A single prompt (one DatumSpec entry).
- Returns:
PromptGroupRecord with num_generations_per_prompt completions.
- _validate_init_params() None#
Validate initialization parameters.
- _build_inputs(
- input_sample: nemo_rl.data.interfaces.DatumSpec,
Build N row dicts from input_sample, applying generation config params.
- async _run_rollouts(
- inputs: list[dict],
- timer: nemo_rl.utils.timer.Timer,
- timer_prefix: str,
Dispatch rows to NeMo-Gym and return completions + metrics.
- _result_to_completion(
- result: dict,
Convert one run_rollouts result dict into a Completion.
- _compute_rollout_metrics(
- completions: list[nemo_rl.experience.interfaces.Completion],
- agent_name: str,
Aggregate per-sample and per-agent metrics.
- class nemo_rl.experience.rollout_manager.RolloutManager(
- tokenizer: nemo_rl.experience.rollout_manager.TokenizerType,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- num_generations_per_prompt: int,
- max_seq_len: int,
- max_rollout_turns: Optional[int] = None,
- policy_generation: Optional[nemo_rl.models.generation.interfaces.GenerationInterface] = None,
- generation_config: Optional[nemo_rl.models.generation.interfaces.GenerationConfig] = None,
- use_nemo_gym: bool = False,
Factory that routes to AsyncRolloutImpl (native async) or AsyncNemoGymRolloutImpl (NeMo-Gym).
Initialization
- async run_rollout(
- input_sample: nemo_rl.data.interfaces.DatumSpec,