# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from typing import TypedDict
import ray
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from nemo_rl.algorithms.utils import set_seed
from nemo_rl.data import MathDataConfig
from nemo_rl.data.datasets import AllTaskProcessedDataset, eval_collate_fn
from nemo_rl.data.llm_message_utils import get_keys_from_message_log
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
from nemo_rl.environments.math_environment import MathEnvConfig
from nemo_rl.models.generation.interfaces import GenerationConfig
from nemo_rl.models.generation.vllm import VllmGeneration
# ===============================================================================
# Configuration
# ===============================================================================
[docs]
class EvalConfig(TypedDict):
metric: str
num_tests_per_prompt: int
seed: int
pass_k_value: int
[docs]
class MasterConfig(TypedDict):
eval: EvalConfig
generate: GenerationConfig
data: MathDataConfig
env: MathEnvConfig
cluster: ClusterConfig
# ===============================================================================
# Setup & Initialization
# ===============================================================================
[docs]
def setup(
master_config: MasterConfig,
tokenizer: AutoTokenizer,
dataset: AllTaskProcessedDataset,
) -> tuple[
VllmGeneration,
DataLoader,
MasterConfig,
]:
"""Set up components for model evaluation.
Initializes the VLLM model and data loader.
Args:
master_config: Configuration settings.
dataset: Dataset to evaluate on.
Returns:
VLLM model, data loader, and config.
"""
# Extract individual configs for easier access
eval_config = master_config["eval"]
generation_config = master_config["generation"]
cluster_config = master_config["cluster"]
# Set seed for reproducibility
set_seed(eval_config["seed"])
# Check settings
metric = eval_config["metric"]
pass_k_value = eval_config["pass_k_value"]
num_tests_per_prompt = eval_config["num_tests_per_prompt"]
temperature = generation_config["temperature"]
top_k = generation_config["top_k"]
# TODO @yukih: support cons@k
# Validate metrics
assert metric in ["pass@k"], f"Invalid metric: {metric}"
if num_tests_per_prompt > 1:
assert temperature > 0 and top_k != 1, (
"temperature > 0 and top_k != 1 are required for multiple samples"
)
assert pass_k_value >= 1, (
"pass_k_value must be greater than or equal to 1 for pass@k metric"
)
assert num_tests_per_prompt >= pass_k_value, (
"num_tests_per_prompt must be greater than or equal to pass_k_value for pass@k metric"
)
# ==========================
# Data
# ==========================
if generation_config["num_prompts_per_step"] == -1:
generation_config["num_prompts_per_step"] = len(dataset)
dataloader = DataLoader(
dataset,
batch_size=generation_config["num_prompts_per_step"],
shuffle=False,
collate_fn=eval_collate_fn,
)
print(f" ✓ Evaluation dataset loaded with {len(dataset)} samples")
# ==========================
# Cluster
# ==========================
print("\n▶ Setting up compute cluster...")
cluster = RayVirtualCluster(
name="eval_cluster",
bundle_ct_per_node_list=[cluster_config["gpus_per_node"]]
* cluster_config["num_nodes"],
use_gpus=True,
num_gpus_per_node=cluster_config["gpus_per_node"],
max_colocated_worker_groups=1,
)
print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes")
# ==========================
# Model
# ==========================
print("\n▶ Setting up model...")
# check backend
backend = generation_config["backend"]
assert backend == "vllm", "Only vLLM backend is supported for evaluation"
# initialize vllm generation
vllm_generation = VllmGeneration(cluster=cluster, config=generation_config)
print(
f" ✓ Using vLLM backend for generation with {generation_config['model_name']}"
)
print("\n" + "=" * 60)
print(" " * 18 + "SETUP COMPLETE")
print("=" * 60 + "\n")
return (
vllm_generation,
dataloader,
master_config,
)
# ===============================================================================
# Evaluation
# ===============================================================================
[docs]
def eval_pass_k(rewards: torch.Tensor, num_tests_per_prompt: int, k: int) -> float:
"""Evaluate pass@k score using an unbiased estimator.
Reference: https://github.com/huggingface/evaluate/blob/32546aafec25cdc2a5d7dd9f941fc5be56ba122f/metrics/code_eval/code_eval.py#L198-L213
Args:
rewards: Tensor of shape (batch_size * num_tests_per_prompt)
k: int (pass@k value)
Returns:
pass_k_score: float
"""
def eval_single_chunk(n: int, c: int, k: int) -> float:
"""Calculates 1 - comb(n - c, k) / comb(n, k)."""
if n - c < k:
return 1.0
return float(1.0 - torch.prod(1.0 - k / torch.arange(n - c + 1, n + 1)).item())
# rewards is a 1d tensor of size (batch_size * num_tests_per_prompt)
group_rewards = rewards.split(num_tests_per_prompt)
pass_k_score = 0.0
for group_reward in group_rewards:
num_correct = group_reward.sum().item()
pass_k_score += eval_single_chunk(num_tests_per_prompt, num_correct, k)
return pass_k_score
[docs]
def run_env_eval(vllm_generation, dataloader, env, master_config):
"""Main entry point for running evaluation using environment.
Generates model responses and evaluates them by env.
Args:
vllm_generation: Model for generating responses.
dataloader: Data loader with evaluation samples.
env: Environment that scores responses.
master_config: Configuration settings.
"""
# Check if async engine is enabled and run appropriate version
if master_config["generation"]["vllm_cfg"]["async_engine"]:
asyncio.run(
_run_env_eval_impl(
vllm_generation, dataloader, env, master_config, use_async=True
)
)
else:
asyncio.run(
_run_env_eval_impl(
vllm_generation, dataloader, env, master_config, use_async=False
)
)
[docs]
async def _run_env_eval_impl(
vllm_generation, dataloader, env, master_config, use_async=False
):
"""Unified implementation for both sync and async evaluation."""
# Extract for easier access
generation_config = master_config["generation"]
eval_config = master_config["eval"]
metric = eval_config["metric"]
num_tests_per_prompt = eval_config["num_tests_per_prompt"]
pass_k_value = eval_config["pass_k_value"]
# Run evaluation loop
score = 0.0
for batch in dataloader:
# measure multiple samples
if num_tests_per_prompt > 1:
batch = batch.repeat_interleave(num_tests_per_prompt)
# get input prompt from message_log
prompts = []
for message_log in batch["message_log"]:
content = [message["content"] for message in message_log]
content = "\n".join(content)
prompts.append(content)
# generate by vllm
inputs = BatchedDataDict({"prompts": prompts})
outputs = await _generate_texts(vllm_generation, inputs, use_async)
# append to message_log
for idx, output in enumerate(outputs):
batch["message_log"][idx].append(
{
"role": "assistant",
"content": output,
}
)
# evaluate generations with the environment
to_env = [
get_keys_from_message_log(batch["message_log"][i], ["role", "content"])
for i in range(len(batch["message_log"]))
]
env_return = ray.get(env.step.remote(to_env, batch["extra_env_info"]))
rewards = env_return.rewards
# update stats
if metric == "pass@k":
score += eval_pass_k(rewards, num_tests_per_prompt, pass_k_value)
else:
raise ValueError(f"Invalid metric: {metric}")
# Cleanup before printing results
ray.get(env.shutdown.remote())
vllm_generation.shutdown()
# Print results
_print_results(
master_config,
generation_config,
score,
len(dataloader.dataset),
metric,
pass_k_value,
num_tests_per_prompt,
)
[docs]
async def _generate_texts(vllm_generation, inputs, use_async):
"""Generate texts using either sync or async method."""
if use_async:
# Use async generation - collect all results
results = []
async for idx, result in vllm_generation.generate_text_async(inputs):
results.append((idx, result["texts"][0]))
# Sort by index to maintain order
results.sort(key=lambda x: x[0])
return [text for _, text in results]
else:
# Use sync generation
return vllm_generation.generate_text(inputs)["texts"]
[docs]
def _print_results(
master_config,
generation_config,
score,
dataset_size,
metric,
pass_k_value,
num_tests_per_prompt,
):
"""Print evaluation results."""
dataset_name = os.path.basename(master_config["data"]["dataset_name"])
model_name = os.path.basename(generation_config["model_name"])
max_new_tokens = generation_config["vllm_cfg"]["max_model_len"]
temperature = generation_config["temperature"]
top_p = generation_config["top_p"]
top_k = generation_config["top_k"]
average_score = score / dataset_size
print("\n" + "=" * 60)
print(f"{model_name=} {dataset_name=}")
print(f"{max_new_tokens=} {temperature=} {top_p=} {top_k=}\n")
print(f"{metric=} {pass_k_value=} {num_tests_per_prompt=}\n")
print(f"score={average_score:.4f} ({score}/{dataset_size})")
print("=" * 60 + "\n")