Source code for nemo_rl.environments.math_environment

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import io
import logging
import re
from typing import Any, Optional, TypedDict

import ray
import torch
from math_verify.errors import TimeoutException
from math_verify.metric import math_metric
from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig

from nemo_rl.data.interfaces import LLMMessageLogType
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES
from nemo_rl.environments.interfaces import (
    EnvironmentInterface,
    EnvironmentReturn,
)
from nemo_rl.environments.metrics import (
    calculate_pass_rate_per_prompt,
)
from nemo_rl.environments.utils import chunk_list_to_workers
from nemo_rl.evals import answer_parsing


[docs] class MathEnvConfig(TypedDict): num_workers: int stop_strings: Optional[list[str]] # Default stop strings for this env verifier_type: Optional[str]
[docs] @contextlib.contextmanager def _mute_output(): devnull_out, devnull_err = io.StringIO(), io.StringIO() with ( contextlib.redirect_stdout(devnull_out), contextlib.redirect_stderr(devnull_err), ): yield
[docs] @ray.remote # pragma: no cover class HFVerifyWorker: def __init__(self) -> None: logging.getLogger("math_verify").setLevel(logging.CRITICAL) # Use Latex and plain math extraction from predictions # https://github.com/huggingface/Math-Verify?tab=readme-ov-file#extraction-targets self.verify_func = math_metric( gold_extraction_target=(LatexExtractionConfig(),), pred_extraction_target=( ExprExtractionConfig(), LatexExtractionConfig(), ), )
[docs] def verify( self, pred_responses: list[str], ground_truths: list[str] ) -> list[float]: """Verify the correctness of the predicted responses against the ground truth. Args: pred_responses: list[str]. The predicted responses from the LLM. ground_truths: list[str]. The ground truth responses. Returns: list[float]. The rewards for each predicted response. """ results = [] for response, ground_truth in zip(pred_responses, ground_truths): try: ground_truth_parsable = "\\boxed{" + ground_truth + "}" with _mute_output(): try: ret_score, _ = self.verify_func( [ground_truth_parsable], [response] ) # It's possible to emit a TimeoutException and that wouldn't be caught since # it actually subclasses from BaseException and math-verify itself does not # to catch it. except (Exception, TimeoutException): ret_score = 0.0 results.append(float(ret_score)) except Exception: results.append(0.0) return results
[docs] @ray.remote # pragma: no cover class MultilingualMultichoiceVerifyWorker:
[docs] def verify( self, pred_responses: list[str], ground_truths: list[str] ) -> list[float]: """Verify the correctness of the predicted responses against the ground truth. Args: pred_responses: list[str]. The predicted responses from the LLM. ground_truths: list[str]. The ground truth responses. Returns: list[float]. The rewards for each predicted response. """ results = [] for response, ground_truth in zip(pred_responses, ground_truths): response = answer_parsing.normalize_response(response) extracted_answer = None for answer_regex in answer_parsing.MULTILINGUAL_ANSWER_REGEXES: regex = answer_parsing.MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format( answer_regex ) match = re.search(regex, response) if match: extracted_answer = answer_parsing.normalize_extracted_answer( match.group(1) ) break score = 1.0 if extracted_answer == ground_truth else 0.0 results.append(score) return results
[docs] @ray.remote # pragma: no cover class EnglishMultichoiceVerifyWorker:
[docs] def verify( self, pred_responses: list[str], ground_truths: list[str] ) -> list[float]: """Verify the correctness of the predicted responses against the ground truth. Args: pred_responses: list[str]. The predicted responses from the LLM. ground_truths: list[str]. The ground truth responses. Returns: list[float]. The rewards for each predicted response. """ results = [] for response, ground_truth in zip(pred_responses, ground_truths): ground_truth = answer_parsing.normalize_response(ground_truth) response = answer_parsing.normalize_response(response) extracted_answer = None match = re.search(r"(?i)Answer\s*:[ \t]*([A-Z])", response) if match: extracted_answer = answer_parsing.normalize_extracted_answer( match.group(1) ) score = 1.0 if extracted_answer == ground_truth else 0.0 results.append(score) return results
[docs] class MathEnvironmentMetadata(TypedDict): ground_truth: str
[docs] @ray.remote(max_restarts=-1, max_task_retries=-1) # pragma: no cover class MathEnvironment(EnvironmentInterface[MathEnvironmentMetadata]): def __init__(self, cfg: MathEnvConfig): self.cfg = cfg self.num_workers = cfg["num_workers"] # TODO: split out this environment since it's doing more than just math verifier_type = cfg.get("verifier_type", "math") assert isinstance(verifier_type, str), ( f"{verifier_type=} must be a string but was {type(verifier_type)}" ) worker_cls = { "math": HFVerifyWorker, "english_multichoice": EnglishMultichoiceVerifyWorker, "multilingual_multichoice": MultilingualMultichoiceVerifyWorker, }[verifier_type] self.workers = [ worker_cls.options( # type: ignore # (decorated with @ray.remote) runtime_env={"py_executable": PY_EXECUTABLES.SYSTEM} ).remote() for _ in range(self.num_workers) ]
[docs] def shutdown(self) -> None: # shutdown all workers for worker in self.workers: ray.kill(worker)
[docs] def step( self, message_log_batch: list[LLMMessageLogType], metadata: list[MathEnvironmentMetadata], ) -> EnvironmentReturn[MathEnvironmentMetadata]: """Runs a step in the math environment. Args: message_log: list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the LLM. metadata: list[MathEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness. Returns: EnvironmentReturn: A tuple containing: - list[dict[str, str]]: Observations/responses batch - list[dict]: Updated metadata - list[str]: Next stop strings for the next turn - Tensor: Rewards tensor - Tensor: Done flags tensor """ # Extract the assistant's responses from the message history # Each message list should have at least one assistant response assistant_response_batch = [] for conversation in message_log_batch: assistant_responses = [ str(interaction["content"]) for interaction in conversation if interaction["role"] == "assistant" ] assistant_response_batch.append("".join(assistant_responses)) ground_truths = [g["ground_truth"] for g in metadata] chunked_assistant_response_batch = chunk_list_to_workers( assistant_response_batch, self.num_workers ) chunked_ground_truths = chunk_list_to_workers(ground_truths, self.num_workers) # # Process each chunk in parallel futures = [ self.workers[i].verify.remote(chunk, ground_truth_chunk) for i, (chunk, ground_truth_chunk) in enumerate( zip(chunked_assistant_response_batch, chunked_ground_truths) ) ] results = ray.get(futures) # flatten the results results = [item for sublist in results for item in sublist] observations = [ { "role": "environment", "content": "Environment: correct" if result else "Environment: incorrect", } for result in results ] # create a tensor of rewards and done flags rewards = torch.tensor(results).cpu() done = torch.ones_like(rewards).cpu() next_stop_strings = [None] * len(message_log_batch) return EnvironmentReturn( observations=observations, metadata=metadata, next_stop_strings=next_stop_strings, rewards=rewards, terminateds=done, )
[docs] def global_post_process_and_metrics( self, batch: BatchedDataDict[Any] ) -> tuple[BatchedDataDict[Any], dict[str, float | int]]: """Computes metrics for this environment given a global rollout batch. Every rank will run this function, so you're free to use distributed calculations if you'd prefer for heavy metrics. """ batch["rewards"] = ( batch["rewards"] * batch["is_end"] ) # set a reward of 0 for any incorrectly ended sequences if (batch["rewards"] == 1).float().sum() > 0: correct_solution_generation_lengths = ( (batch["generation_lengths"] - batch["prompt_lengths"])[ batch["rewards"] == 1 ] .float() .mean() .item() ) else: correct_solution_generation_lengths = 0 metrics = { # "table": table, TODO @sahilj WIP "accuracy": batch["rewards"].mean().item(), "pass@samples_per_prompt": calculate_pass_rate_per_prompt( batch["text"], batch["rewards"] ), "fraction_of_samples_properly_ended": batch["is_end"].float().mean().item(), "num_problems_in_batch": batch["is_end"].shape[0], "generation_lengths": batch["generation_lengths"].float().mean().item(), "prompt_lengths": batch["prompt_lengths"].float().mean().item(), "correct_solution_generation_lengths": correct_solution_generation_lengths, } return batch, metrics