Source code for nemo_gym.prompt

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompt configuration: YAML-based prompt templates applied at rollout time.

Prompt templates are mutually exclusive with pre-populated
``responses_create_params.input`` values. This separation enables prompt
sweeps without re-preparing data.
"""

import json
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Optional

import yaml
from pydantic import BaseModel, Field

from nemo_gym import PARENT_DIR
from nemo_gym.config_types import BaseNeMoGymCLIConfig
from nemo_gym.global_config import GlobalConfigDictParserConfig, get_global_config_dict


[docs] class PromptConfig(BaseModel): """Schema for a prompt YAML file. ``user`` is required, ``system`` is optional.""" user: str system: Optional[str] = None
[docs] def _resolve_path(path: str) -> Path: """Resolve a path relative to the Gym root (PARENT_DIR), consistent with config_paths resolution.""" p = Path(path) if not p.is_absolute(): p = PARENT_DIR / p return p
[docs] @lru_cache(maxsize=64) def load_prompt_config(path: str) -> PromptConfig: """Load and validate a YAML prompt config file. Relative paths are resolved against the Gym root directory (``PARENT_DIR``), consistent with how ``config_paths`` and other Gym paths are resolved. Returns a ``PromptConfig`` with required ``user`` and optional ``system`` fields. Each value is a string template with ``{placeholder}`` syntax. Results are cached so the same file is only parsed once. """ resolved = _resolve_path(path) with open(resolved) as f: data = yaml.safe_load(f) return PromptConfig.model_validate(data)
[docs] def fill_prompt(prompt_config: PromptConfig, row: dict) -> List[Dict[str, str]]: """Apply a prompt template to a data row, producing message dicts. Placeholders (``{field_name}``) are filled from the row's top-level fields. Literal braces must be doubled (``{{`` / ``}}``). """ try: messages = [] if prompt_config.system is not None: messages.append({"role": "system", "content": prompt_config.system.format_map(row)}) messages.append({"role": "user", "content": prompt_config.user.format_map(row)}) return messages except KeyError as e: raise KeyError( f"Prompt template references field {e} but the data row only has fields: {list(row.keys())}" ) from None
[docs] def validate_prompt_compatibility(rows: List[dict], prompt_config: PromptConfig) -> None: """Validate that no rows have pre-populated responses_create_params.input when a prompt_config is provided. Collects all violating row indices and reports them in a single error. """ conflicting_indices = [i for i, row in enumerate(rows) if row.get("responses_create_params", {}).get("input")] if conflicting_indices: raise ValueError( "Some rows have responses_create_params.input but prompt_config is also specified. " f"These are mutually exclusive. Use one or the other. Violating rows: {conflicting_indices}" )
[docs] def apply_prompt_to_row(row: dict, prompt_config: PromptConfig) -> dict: """Apply prompt_config to a row, building responses_create_params.input. Other fields in responses_create_params (tools, metadata, temperature, max_output_tokens) are preserved. Returns a new dict (does not mutate the original). """ messages = fill_prompt(prompt_config, row) row = row.copy() rcp = row.get("responses_create_params", {}) if isinstance(rcp, dict): rcp = rcp.copy() else: rcp = {} rcp["input"] = messages row["responses_create_params"] = rcp return row
[docs] def materialize_prompts(input_jsonl: str, prompt_config: str, output_jsonl: str) -> None: """Apply a prompt template to raw JSONL data, producing materialized JSONL. Reads each row from ``input_jsonl``, validates that no row has pre-populated ``responses_create_params.input``, applies the prompt template, and writes the result to ``output_jsonl``. Args: input_jsonl: Path to raw JSONL (no responses_create_params.input). prompt_config: Path to prompt YAML file. output_jsonl: Path to write materialized JSONL (with responses_create_params.input). """ prompt_cfg = load_prompt_config(prompt_config) resolved_prompt_path = str(_resolve_path(prompt_config)) output_path = Path(output_jsonl) output_path.parent.mkdir(parents=True, exist_ok=True) with open(input_jsonl) as f_in: rows = [json.loads(line) for line in f_in] validate_prompt_compatibility(rows, prompt_cfg) with open(output_path, "w") as f_out: for row in rows: materialized = apply_prompt_to_row(row, prompt_cfg) materialized["prompt_config_used"] = resolved_prompt_path f_out.write(json.dumps(materialized) + "\n") print(f"Materialized {len(rows)} rows to {output_path}")
[docs] class MaterializePromptsConfig(BaseNeMoGymCLIConfig): """ Apply a prompt template to raw JSONL data, producing materialized JSONL with populated ``responses_create_params.input`` for RL training. Examples: ```bash ng_materialize_prompts \\ +input_jsonl_fpath=data/my_dataset.jsonl \\ +prompt_config=/path/to/my_prompt.yaml \\ +output_jsonl_fpath=my_dataset_materialized.jsonl ``` """ input_jsonl_fpath: str = Field(description="Raw JSONL data (no responses_create_params.input).") prompt_config: str = Field(description="Path to prompt YAML file to apply.") output_jsonl_fpath: str = Field(description="Output path for materialized JSONL with populated prompts.")
[docs] def materialize_prompts_cli() -> None: # pragma: no cover """CLI entry point for ng_materialize_prompts.""" global_config_dict = get_global_config_dict( global_config_dict_parser_config=GlobalConfigDictParserConfig( initial_global_config_dict=GlobalConfigDictParserConfig.NO_MODEL_GLOBAL_CONFIG_DICT, ) ) config = MaterializePromptsConfig.model_validate(global_config_dict) materialize_prompts(config.input_jsonl_fpath, config.prompt_config, config.output_jsonl_fpath)