Source code for nemo_rl.evals.eval
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Tuple, TypedDict
import ray
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from nemo_rl.algorithms.utils import set_seed
from nemo_rl.data import MathDataConfig
from nemo_rl.data.datasets import AllTaskProcessedDataset, eval_collate_fn
from nemo_rl.data.llm_message_utils import get_keys_from_message_log
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
from nemo_rl.environments.math_environment import MathEnvConfig
from nemo_rl.models.generation.interfaces import GenerationConfig
from nemo_rl.models.generation.vllm import VllmGeneration
# ===============================================================================
# Configuration
# ===============================================================================
[docs]
class EvalConfig(TypedDict):
metric: str
num_tests_per_prompt: int
seed: int
[docs]
class MasterConfig(TypedDict):
eval: EvalConfig
generate: GenerationConfig
data: MathDataConfig
env: MathEnvConfig
cluster: ClusterConfig
# ===============================================================================
# Setup & Initialization
# ===============================================================================
[docs]
def setup(
master_config: MasterConfig,
tokenizer: AutoTokenizer,
dataset: AllTaskProcessedDataset,
) -> Tuple[
VllmGeneration,
DataLoader,
MasterConfig,
]:
"""Set up components for model evaluation.
Initializes the VLLM model and data loader.
Args:
master_config: Configuration settings.
dataset: Dataset to evaluate on.
Returns:
VLLM model, data loader, and config.
"""
# Extract individual configs for easier access
eval_config = master_config["eval"]
generation_config = master_config["generation"]
cluster_config = master_config["cluster"]
# Set seed for reproducibility
set_seed(eval_config["seed"])
# Check settings
metric = eval_config["metric"]
num_tests_per_prompt = eval_config["num_tests_per_prompt"]
temperature = generation_config["temperature"]
top_k = generation_config["top_k"]
# TODO @yukih: support pass@k and cons@k
assert metric in ["pass@1"], f"Invalid metric: {metric}"
if num_tests_per_prompt > 1:
assert temperature > 0 and top_k != 1, (
"temperature > 0 and top_k != 1 are required for multiple samples"
)
# ==========================
# Data
# ==========================
if generation_config["num_prompts_per_step"] == -1:
generation_config["num_prompts_per_step"] = len(dataset)
dataloader = DataLoader(
dataset,
batch_size=generation_config["num_prompts_per_step"],
shuffle=False,
collate_fn=eval_collate_fn,
)
print(f" ✓ Evaluation dataset loaded with {len(dataset)} samples")
# ==========================
# Cluster
# ==========================
print("\n▶ Setting up compute cluster...")
cluster = RayVirtualCluster(
name="eval_cluster",
bundle_ct_per_node_list=[cluster_config["gpus_per_node"]]
* cluster_config["num_nodes"],
use_gpus=True,
num_gpus_per_node=cluster_config["gpus_per_node"],
max_colocated_worker_groups=1,
)
print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes")
# ==========================
# Model
# ==========================
print("\n▶ Setting up model...")
# check backend
backend = generation_config["backend"]
assert backend == "vllm", "Only vLLM backend is supported for evaluation"
# initialize vllm generation
vllm_generation = VllmGeneration(cluster=cluster, config=generation_config)
print(
f" ✓ Using vLLM backend for generation with {generation_config['model_name']}"
)
print("\n" + "=" * 60)
print(" " * 18 + "SETUP COMPLETE")
print("=" * 60 + "\n")
return (
vllm_generation,
dataloader,
master_config,
)
# ===============================================================================
# Evaluation
# ===============================================================================
[docs]
def run_env_eval(vllm_generation, dataloader, env, master_config):
"""Main entry point for running evaluation using environment.
Generates model responses and evaluates them by env.
Args:
vllm_generation: Model for generating responses.
dataloader: Data loader with evaluation samples.
env: Environment that scores responses.
master_config: Configuration settings.
"""
# Extract for easier access
generation_config = master_config["generation"]
eval_config = master_config["eval"]
metric = eval_config["metric"]
num_tests_per_prompt = eval_config["num_tests_per_prompt"]
# Run evaluation loop
score, count = 0.0, 0
for batch in dataloader:
# update stats
count += batch.size * num_tests_per_prompt
# measure multiple samples
if num_tests_per_prompt > 1:
batch = batch.repeat_interleave(num_tests_per_prompt)
# get input prompt from message_log
prompts = []
for message_log in batch["message_log"]:
content = [message["content"] for message in message_log]
content = "\n".join(content)
prompts.append(content)
# generate by vllm
inputs = BatchedDataDict({"prompts": prompts})
outputs = vllm_generation.generate_text(inputs)["texts"]
# append to message_log
for idx, output in enumerate(outputs):
batch["message_log"][idx].append(
{
"role": "assistant",
"content": output,
}
)
# evaluate generations with the environment
to_env = [
get_keys_from_message_log(batch["message_log"][i], ["role", "content"])
for i in range(len(batch["message_log"]))
]
env_return = ray.get(env.step.remote(to_env, batch["extra_env_info"]))
# update stats
if metric == "pass@1":
score += env_return.rewards.sum().item()
else:
raise ValueError(f"Invalid metric: {metric}")
# Cleanup before printing results
ray.get(env.shutdown.remote())
vllm_generation.shutdown()
# Print results
dataset_name = os.path.basename(master_config["data"]["dataset_name"])
model_name = os.path.basename(generation_config["model_name"])
max_new_tokens = generation_config["vllm_cfg"]["max_model_len"]
temperature = generation_config["temperature"]
top_p = generation_config["top_p"]
top_k = generation_config["top_k"]
average_score = score / count
print("\n" + "=" * 60)
print(f"{model_name=} {dataset_name=}")
print(f"{max_new_tokens=} {temperature=} {top_p=} {top_k=}\n")
print(f"{metric=} {num_tests_per_prompt=}\n")
print(f"score={average_score:.4f} ({score}/{count})")
print("=" * 60 + "\n")