Source code for nemo_gym.profile

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from collections import defaultdict
from pathlib import Path
from typing import Optional

from pydantic import BaseModel, Field

from nemo_gym.config_types import BaseNeMoGymCLIConfig
from nemo_gym.global_config import TASK_INDEX_KEY_NAME, get_global_config_dict


[docs] class ProfileConfig(BaseNeMoGymCLIConfig): input_jsonl_fpath: str = Field(description="Original task dataset.") rollouts_jsonl_fpath: str = Field(description="Rollouts file from ng_collect_rollouts with num_repeats.") output_jsonl_fpath: str = Field(description="Output file for profiled dataset.") pass_threshold: Optional[float] = Field( default=None, description="Reward threshold for pass_rate. If None, pass_rate not computed." )
[docs] class RewardProfilingMetrics(BaseModel): avg_reward: float = Field(description="Average reward across all rollouts for this task.") std_reward: float = Field(description="Standard deviation of rewards.") min_reward: float = Field(description="Minimum reward observed.") max_reward: float = Field(description="Maximum reward observed.") total_samples: int = Field(description="Number of rollout samples for this task.") pass_rate: Optional[float] = Field(default=None, description="Fraction of rollouts meeting pass_threshold.") pass_rate_total: Optional[int] = Field(default=None, description="Total rollouts used for pass_rate calculation.") pass_rate_passed: Optional[int] = Field(default=None, description="Number of rollouts that passed.") pass_threshold: Optional[float] = Field(default=None, description="Threshold used for pass_rate calculation.")
[docs] def profile(): config = ProfileConfig.model_validate(get_global_config_dict()) grouped_rewards: dict[int, list[float]] = defaultdict(list) with open(config.rollouts_jsonl_fpath) as f: for line in f: rollout = json.loads(line) task_idx = rollout.get(TASK_INDEX_KEY_NAME) if task_idx is not None: grouped_rewards[task_idx].append(rollout.get("reward", 0.0)) Path(config.output_jsonl_fpath).parent.mkdir(exist_ok=True, parents=True) with open(config.input_jsonl_fpath) as f_in, open(config.output_jsonl_fpath, "w") as f_out: for task_idx, line in enumerate(f_in): if task_idx not in grouped_rewards: continue task = json.loads(line) rewards = grouped_rewards[task_idx] avg = sum(rewards) / len(rewards) metrics = RewardProfilingMetrics( avg_reward=avg, std_reward=(sum((r - avg) ** 2 for r in rewards) / len(rewards)) ** 0.5, min_reward=min(rewards), max_reward=max(rewards), total_samples=len(rewards), ) if config.pass_threshold is not None: passed = sum(1 for r in rewards if r >= config.pass_threshold) metrics.pass_rate = passed / len(rewards) metrics.pass_rate_total = len(rewards) metrics.pass_rate_passed = passed metrics.pass_threshold = config.pass_threshold profiled_task = {**task, **metrics.model_dump(exclude_none=True)} f_out.write(json.dumps(profiled_task) + "\n")