# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/math_dapo.py
import re
from typing import Optional
[docs]
def last_boxed_only_string(string: str) -> Optional[str]:
"""Extract the last LaTeX boxed expression from a string.
Args:
string: Input string containing LaTeX code
Returns:
The last boxed expression or None if not found
"""
idx = string.rfind("\\boxed{")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None
[docs]
def remove_boxed(s: str) -> str:
r"""Remove the LaTeX boxed command from a string.
Args:
s: String with format "\\boxed{content}"
Returns:
The content inside the boxed command
"""
left = "\\boxed{"
assert s[: len(left)] == left, f"box error: {s}"
assert s[-1] == "}", f"box error: {s}"
return s[len(left) : -1]
# Constants for normalization
SUBSTITUTIONS = [
("an ", ""),
("a ", ""),
(".$", "$"),
("\\$", ""),
(r"\ ", ""),
(" ", ""),
("mbox", "text"),
(",\\text{and}", ","),
("\\text{and}", ","),
("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [
"square",
"ways",
"integers",
"dollars",
"mph",
"inches",
"hours",
"km",
"units",
"\\ldots",
"sue",
"points",
"feet",
"minutes",
"digits",
"cents",
"degrees",
"cm",
"gm",
"pounds",
"meters",
"meals",
"edges",
"students",
"childrentickets",
"multiples",
"\\text{s}",
"\\text{.}",
"\\text{\ns}",
"\\text{}^2",
"\\text{}^3",
"\\text{\n}",
"\\text{}",
r"\mathrm{th}",
r"^\circ",
r"^{\circ}",
r"\;",
r",\!",
"{,}",
'"',
"\\dots",
]
[docs]
def normalize_final_answer(final_answer: str) -> str:
"""Normalize a final answer to a quantitative reasoning question.
Args:
final_answer: The answer string to normalize
Returns:
Normalized answer string
"""
final_answer = final_answer.split("=")[-1]
# Apply substitutions and removals
for before, after in SUBSTITUTIONS:
final_answer = final_answer.replace(before, after)
for expr in REMOVED_EXPRESSIONS:
final_answer = final_answer.replace(expr, "")
# Extract and normalize LaTeX math
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
# Normalize shorthand TeX:
# \fracab -> \frac{a}{b}
# \frac{abc}{bef} -> \frac{abc}{bef}
# \fracabc -> \frac{a}{b}c
# \sqrta -> \sqrt{a}
# \sqrtab -> sqrt{a}b
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
final_answer = final_answer.replace("$", "")
# Normalize numbers
if final_answer.replace(",", "").isdigit():
final_answer = final_answer.replace(",", "")
return final_answer.strip()
[docs]
def is_correct_minerva(
solution_str: str,
gt: str,
gt_need_extract: bool = False,
answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)",
) -> tuple[bool, str]:
"""Check if the solution is correct according to Minerva criteria.
Args:
solution_str: The solution string to check
gt: The ground truth answer
gt_need_extract: Whether the ground truth needs extraction
answer_pattern: Regex pattern to extract the answer
Returns:
Tuple of (is_correct, normalized_prediction)
"""
# Extract answer from solution
match = re.findall(answer_pattern, solution_str)
extracted_answer = match[-1] if match else "[INVALID]"
pred = normalize_final_answer(extracted_answer)
# Process ground truth
if gt_need_extract:
gt = normalize_final_answer(remove_boxed(last_boxed_only_string(gt)))
else:
gt = normalize_final_answer(gt)
return (pred == gt), pred
[docs]
def is_correct_strict_box(
pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None
) -> tuple[int, Optional[str]]:
"""Check if the prediction is correct using strict boxed answer criteria.
Args:
pred: The prediction string
gt: The ground truth answer
pause_tokens_index: Indices of pause tokens
Returns:
Tuple of (score, extracted_prediction)
"""
# Extract the relevant part of the prediction
if pause_tokens_index is not None:
assert len(pause_tokens_index) == 4
pred = pred[pause_tokens_index[-1] - 100 :]
else:
pred = pred[-100:]
# Extract and check the boxed answer
boxed_pred = last_boxed_only_string(pred)
extracted_pred = remove_boxed(boxed_pred) if boxed_pred is not None else None
return 1 if (extracted_pred == gt) else -1, extracted_pred
[docs]
def verify(
solution_str: str,
answer: str,
strict_box_verify: bool = False,
pause_tokens_index: Optional[list[int]] = None,
) -> bool:
"""Verify if the solution is correct.
Args:
solution_str: The solution string to verify
answer: The ground truth answer
strict_box_verify: Whether to use strict box verification
pause_tokens_index: Indices of pause tokens
Returns:
True if the solution is correct, False otherwise
"""
if strict_box_verify:
correct, pred = is_correct_strict_box(solution_str, answer, pause_tokens_index)
return correct == 1, pred
correct, pred = is_correct_minerva(solution_str, answer)
return correct, pred
[docs]
def compute_score(
solution_str: str,
ground_truth: str,
strict_box_verify: bool = False,
pause_tokens_index: Optional[list[int]] = None,
) -> float:
"""Compute the reward score for a solution.
Args:
solution_str: The solution string
ground_truth: The ground truth answer
strict_box_verify: Whether to use strict box verification
pause_tokens_index: Indices of pause tokens
Returns:
Reward score (1.0 for correct, 0.0 for incorrect)
"""
# Limit solution length for efficiency
solution_str = solution_str[
-300:
] # The longest answer in MATH-500 has 159 characters
# Verify the solution
correct, pred = verify(
solution_str, ground_truth, strict_box_verify, pause_tokens_index
)
reward = 1.0 if correct else 0.0
acc = correct
return {
"score": reward,
"acc": acc,
"pred": pred,
}