Source code for nemo_gym.rollout_collection

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
from asyncio import Future, Semaphore
from collections import Counter
from contextlib import nullcontext
from copy import deepcopy
from itertools import repeat
from pathlib import Path
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union

import orjson
from omegaconf import OmegaConf
from pydantic import BaseModel, Field
from tqdm.asyncio import tqdm
from wandb import Table

from nemo_gym import PARENT_DIR
from nemo_gym.base_resources_server import AggregateMetrics, AggregateMetricsRequest
from nemo_gym.config_types import BaseNeMoGymCLIConfig, BaseServerConfig
from nemo_gym.global_config import (
    AGENT_REF_KEY_NAME,
    RESPONSES_CREATE_PARAMS_KEY_NAME,
    ROLLOUT_INDEX_KEY_NAME,
    TASK_INDEX_KEY_NAME,
    get_wandb_run,
)
from nemo_gym.prompt import apply_prompt_to_row, load_prompt_config, validate_prompt_compatibility
from nemo_gym.server_utils import (
    GlobalAIOHTTPAsyncClientConfig,
    ServerClient,
    get_global_config_dict,
    get_response_json,
    is_global_aiohttp_client_setup,
    raise_for_status,
    set_global_aiohttp_client,
)


[docs] class SharedRolloutCollectionConfig(BaseNeMoGymCLIConfig): output_jsonl_fpath: str = Field(description="The output data jsonl file path.") num_samples_in_parallel: Optional[int] = Field( default=None, description="Limit the number of concurrent samples running at once." ) responses_create_params: Dict[str, Any] = Field( default_factory=dict, description="Overrides for the responses_create_params e.g. temperature, max_output_tokens, etc.", ) upload_rollouts_to_wandb: bool = Field( default=True, description="Upload the rollouts to W&B. Sometimes this should be off because the rollouts are massive. Default: True", )
[docs] class E2ERolloutCollectionConfig(SharedRolloutCollectionConfig): """ Spin up all necessary servers and perform a batch of rollout collection using each dataset inside the provided configs. Examples: ```bash ng_collect_rollouts \ +output_jsonl_fpath=weather_rollouts.jsonl \ +num_samples_in_parallel=10 ``` """ split: Union[Literal["train"], Literal["validation"], Literal["benchmark"]] reuse_existing_data_preparation: bool = False
[docs] class RolloutCollectionConfig(SharedRolloutCollectionConfig): """ Perform a batch of rollout collection. Examples: ```bash ng_collect_rollouts \ +agent_name=example_single_tool_call_simple_agent \ +input_jsonl_fpath=weather_query.jsonl \ +output_jsonl_fpath=weather_rollouts.jsonl \ +limit=100 \ +num_repeats=4 \ +num_samples_in_parallel=10 ``` """ agent_name: Optional[str] = Field( default=None, description="The agent to collect rollouts from. If not specified, uses agent_ref from each data row.", ) input_jsonl_fpath: str = Field( description="The input data source to use to collect rollouts, in the form of a file path to a jsonl file." ) limit: Optional[int] = Field( default=None, description="Maximum number of examples to load and take from the input dataset." ) num_repeats: Optional[int] = Field( default=None, description="The number of times to repeat each example to run. Useful if you want to calculate mean@k e.g. mean@4 or mean@16.", ) num_repeats_add_seed: bool = Field( default=False, description='When num_repeats > 1, add a "seed" parameter on the Responses create params.', ) resume_from_cache: bool = Field( default=False, description="If the same command is run multiple times, check the materialized inputs and current outputs and remove the inputs that have already been run", ) prompt_config: Optional[str] = Field( default=None, description="Path to a prompt YAML file. Builds responses_create_params.input from the template at rollout time. Mutually exclusive with pre-populated responses_create_params.input in the JSONL data.", ) @property def materialized_jsonl_fpath(self) -> Path: output_fpath = Path(self.output_jsonl_fpath) return output_fpath.with_stem(output_fpath.stem + "_materialized_inputs").with_suffix(".jsonl")
[docs] class RolloutCollectionHelper(BaseModel):
[docs] def _preprocess_rows_from_config(self, config: RolloutCollectionConfig) -> List[Dict]: range_iterator = repeat(0) if config.limit: range_iterator = range(config.limit) print(f"Limiting the number of rows to {config.limit}") if config.num_repeats_add_seed: print("Adding unique `seed` values to each input") if config.agent_name: print(f"Using `{config.agent_name}` for rows that do not already have an agent ref") if config.responses_create_params: print(f"Overriding responses_create_params fields with {config.responses_create_params}") responses_create_params_overrides = OmegaConf.to_container( OmegaConf.create(config.responses_create_params), resolve=True ) else: responses_create_params_overrides = dict() num_repeats = config.num_repeats or 1 if num_repeats: print(f"Repeating rows {num_repeats} times (in a pattern of abc to aabbcc)!") # Load prompt config if specified prompt_cfg = None if config.prompt_config: prompt_cfg = load_prompt_config(config.prompt_config) print(f"Using prompt config: {config.prompt_config}") _input_path = Path(config.input_jsonl_fpath) if not _input_path.is_absolute(): _cwd_path = Path.cwd() / _input_path _input_path = _cwd_path if _cwd_path.exists() else PARENT_DIR / _input_path with open(_input_path) as input_file: rows_iterator: Iterator[str] = tqdm(input_file, desc="Reading rows") rows_iterator: Iterator[tuple[int, str]] = zip(range_iterator, rows_iterator) raw_rows = [(row_idx, row_str, orjson.loads(row_str)) for row_idx, row_str in rows_iterator] # Validate and apply prompt config before per-row processing if prompt_cfg is not None: validate_prompt_compatibility([row for _, _, row in raw_rows], prompt_cfg) raw_rows = [(idx, s, apply_prompt_to_row(row, prompt_cfg)) for idx, s, row in raw_rows] # For ng_reward_profile to match rollouts to tasks row_to_task_idx: Dict[str, int] = dict() task_idx_to_rollout_idx: Dict[int, int] = Counter() row_idxs_missing_agent_ref: List[int] = [] rows: List[Dict] = [] for row_idx, row_str, row in raw_rows: # Resolve agent name if config.agent_name: row.setdefault(AGENT_REF_KEY_NAME, {"name": config.agent_name}) elif not row.get(AGENT_REF_KEY_NAME, dict()).get("name"): row_idxs_missing_agent_ref.append(row_idx) # Responses create params row[RESPONSES_CREATE_PARAMS_KEY_NAME] = ( row[RESPONSES_CREATE_PARAMS_KEY_NAME] | responses_create_params_overrides ) # Resolve task index row[TASK_INDEX_KEY_NAME] = row_to_task_idx.setdefault(row_str, len(row_to_task_idx)) for _ in range(num_repeats): row = deepcopy(row) # Resolve rollout index row[ROLLOUT_INDEX_KEY_NAME] = task_idx_to_rollout_idx[row[TASK_INDEX_KEY_NAME]] task_idx_to_rollout_idx[row[TASK_INDEX_KEY_NAME]] += 1 if config.num_repeats_add_seed: row[RESPONSES_CREATE_PARAMS_KEY_NAME]["seed"] = row[ROLLOUT_INDEX_KEY_NAME] rows.append(row) if row_idxs_missing_agent_ref: raise ValueError( f"No agent specified for rows {row_idxs_missing_agent_ref}. Either provide +agent_name config or include agent_ref in data." ) return rows
[docs] def _load_from_cache( self, config: RolloutCollectionConfig ) -> Tuple[List[Dict], List[Dict], List[Dict], List[List[str]]]: with config.materialized_jsonl_fpath.open() as f: original_input_rows = list(map(orjson.loads, f)) with Path(config.output_jsonl_fpath).open("rb") as f: result_strs = [[line.strip()] for line in f] results = [orjson.loads(p[0]) for p in result_strs] get_key = lambda r: (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME]) seen_rows = set(map(get_key, results)) input_rows = [row for row in original_input_rows if get_key(row) not in seen_rows] key_to_row = dict(zip(map(get_key, original_input_rows), original_input_rows)) rows = [key_to_row[get_key(result)] for result in results] print( f"""Resumed from cache. Found: - {len(original_input_rows)} original input rows - {len(rows)} rows that have already been run - {len(input_rows)} rows that still need to be run""" ) return input_rows, rows, results, result_strs
[docs] async def run_from_config(self, config: RolloutCollectionConfig) -> Tuple[List[Dict]]: output_fpath = Path(config.output_jsonl_fpath) if config.resume_from_cache and config.materialized_jsonl_fpath.exists() and output_fpath.exists(): ( input_rows, rows, results, result_strs, ) = self._load_from_cache(config) else: if config.resume_from_cache: if not output_fpath.exists(): print(f"Skipping resume_from_cache because output_fpath {output_fpath} doesn't exist!") if not config.materialized_jsonl_fpath.exists(): print( f"Skipping resume_from_cache because materialized_jsonl_fpath {config.materialized_jsonl_fpath} doesn't exist!" ) else: print("Clearing output fpath since `resume_from_cache=False`!") rows: List[Dict] = [] results: List[Dict] = [] result_strs: List[List[str]] = [] input_rows = self._preprocess_rows_from_config(config) # Returned rows are sorted by (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME]) with config.materialized_jsonl_fpath.open("wb") as f: for row in input_rows: f.write(orjson.dumps(row) + b"\n") output_fpath.unlink(missing_ok=True) semaphore = nullcontext() if config.num_samples_in_parallel: print(f"Querying with {config.num_samples_in_parallel} concurrent requests") semaphore = Semaphore(config.num_samples_in_parallel) output_fpath.parent.mkdir(exist_ok=True, parents=True) pcts_to_print = [20, 40, 60, 80, 90, 95, 98, 99, 100] counts_left = Counter(r[AGENT_REF_KEY_NAME]["name"] for r in input_rows) results_file = output_fpath.open("ab") for future in self.run_examples(input_rows, semaphore=semaphore): row, result = await future result[TASK_INDEX_KEY_NAME] = row[TASK_INDEX_KEY_NAME] result[ROLLOUT_INDEX_KEY_NAME] = row[ROLLOUT_INDEX_KEY_NAME] result[AGENT_REF_KEY_NAME] = row[AGENT_REF_KEY_NAME] rows.append(row) results.append(result) result_strs.append([orjson.dumps(result)]) results_file.write(result_strs[-1][0] + b"\n") results_file.flush() counts_left[row[AGENT_REF_KEY_NAME]["name"]] -= 1 if counts_left[row[AGENT_REF_KEY_NAME]["name"]] <= 0: counts_left.pop(row[AGENT_REF_KEY_NAME]["name"]) current_pct = 100 * len(results) / len(input_rows) if pcts_to_print and current_pct >= pcts_to_print[0]: while pcts_to_print and current_pct >= pcts_to_print[0]: pcts_to_print.pop(0) top_left = counts_left.most_common(5) # Fix to top 3 for now. if top_left: top_left_str = "\n".join(f"{i + 1}. {k}: {v}" for i, (k, v) in enumerate(top_left)) # Use tqdm.write here so we can print properly with tqdm being used. tqdm.write(f"Examples left:\n{top_left_str}") results_file.close() if config.upload_rollouts_to_wandb and get_wandb_run(): # pragma: no cover print("Uploading rollouts to W&B. This may take a few minutes if your data is large.") get_wandb_run().log({"Rollouts": Table(data=result_strs, columns=["Rollout"])}) del result_strs print("Sorting results to ensure consistent ordering") rows.sort(key=lambda r: (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME])) results.sort(key=lambda r: (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME])) # Compute and write aggregate metrics via /aggregate_metrics on each agent server print("Computing aggregate metrics") aggregate_metrics_fpath = await self._call_aggregate_metrics(results, rows, output_fpath) print(f"""Finished rollout collection! View results at: Fully materialized inputs: {config.materialized_jsonl_fpath} Rollouts: {output_fpath} Aggregate metrics: {aggregate_metrics_fpath}""") return results
[docs] async def _call_aggregate_metrics( self, results: List[Dict], rows: List[Dict], output_fpath: Path, ) -> Optional[Path]: """Call /aggregate_metrics on each agent server after rollouts complete. Writes a single _aggregate_metrics.json with one entry per agent (same shape as the old _agent_metrics.json). Returns the file path. """ if not results: return None # Group results by agent name agent_results: Dict[str, List[Dict]] = {} for row, result in zip(rows, results): agent_name = (row.get(AGENT_REF_KEY_NAME) or {}).get("name") if not agent_name: continue agent_results.setdefault(agent_name, []).append(result) server_client = self.setup_server_client() async def _fetch_agent_metrics(agent_name: str, agent_result_list: List[Dict]) -> Dict: # Strip heavyweight fields before sending, but preserve response.usage stripped = [] for r in agent_result_list: entry = {k: v for k, v in r.items() if k not in ("response", "responses_create_params")} usage = (r.get("response") or {}).get("usage") if usage: entry["response"] = {"usage": usage} stripped.append(entry) agg_request = AggregateMetricsRequest(verify_responses=stripped) agg_response = await server_client.post( server_name=agent_name, url_path="/aggregate_metrics", json=agg_request, ) await raise_for_status(agg_response) agg_result = AggregateMetrics.model_validate(await get_response_json(agg_response)) agent_entry = { AGENT_REF_KEY_NAME: {"name": agent_name}, "agent_metrics": agg_result.agent_metrics, "key_metrics": agg_result.key_metrics, "group_level_metrics": agg_result.group_level_metrics, } return agent_entry all_agent_metrics: List[Dict] = [] tasks = [_fetch_agent_metrics(name, results_list) for name, results_list in agent_results.items()] for coro in asyncio.as_completed(tasks): agent_entry = await coro all_agent_metrics.append(agent_entry) agent_name = agent_entry[AGENT_REF_KEY_NAME]["name"] key_metrics = agent_entry.get("key_metrics", {}) print(f"\nKey metrics for {agent_name}:\n" + json.dumps(key_metrics, indent=4)) primitive_types = (bool, int, float, str, type(None)) metrics_to_log = dict() for agent_entry in all_agent_metrics: agent_name = agent_entry[AGENT_REF_KEY_NAME]["name"] metrics_to_log.update( { f"{agent_name}/{k}": v for k, v in agent_entry["agent_metrics"].items() if isinstance(v, primitive_types) } ) metrics_to_log.update( { f"key_metrics/{agent_name}/{k}": v for k, v in agent_entry["key_metrics"].items() if isinstance(v, primitive_types) } ) if get_wandb_run(): # pragma: no cover get_wandb_run().log(metrics_to_log) # Write single file with all agents metrics_fpath = output_fpath.with_stem(output_fpath.stem + "_aggregate_metrics").with_suffix(".json") metrics_fpath.write_bytes(orjson.dumps(all_agent_metrics, option=orjson.OPT_INDENT_2)) return metrics_fpath
[docs] def run_examples( self, examples: List[Dict], head_server_config: Optional[BaseServerConfig] = None, semaphore: Optional[Semaphore] = None, ) -> Iterator[Future]: # pragma: no cover """ We provide this function as a lower level interface for running rollout collection. """ server_client = self.setup_server_client(head_server_config) semaphore = semaphore or nullcontext() async def _post_subroutine(row: Dict) -> Tuple[Dict, Dict]: async with semaphore: res = await server_client.post(server_name=row["agent_ref"]["name"], url_path="/run", json=row) await raise_for_status(res) return row, await get_response_json(res) return tqdm.as_completed( map(_post_subroutine, examples), desc="Collecting rollouts", miniters=10, total=len(examples), maxinterval=60, )
[docs] def setup_server_client( self, head_server_config: Optional[BaseServerConfig] = None ) -> ServerClient: # pragma: no cover server_client = ServerClient.load_from_global_config(head_server_config) # We set this rollout global aiohttp client to use the same max connections as the underlying head server global config. if not is_global_aiohttp_client_setup(): set_global_aiohttp_client( cfg=GlobalAIOHTTPAsyncClientConfig.model_validate(server_client.global_config_dict) ) return server_client
[docs] def collect_rollouts(): # pragma: no cover config = RolloutCollectionConfig.model_validate(get_global_config_dict()) rch = RolloutCollectionHelper() asyncio.run(rch.run_from_config(config))