Source code for nemo_evaluator.core.utils

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import os
import subprocess
import tempfile
import time
from typing import Any, TypeVar

import requests
import yaml

from nemo_evaluator.logging import get_logger

logger = get_logger(__name__)


[docs] class MisconfigurationError(Exception): pass
KeyType = TypeVar("KeyType")
[docs] def deep_update( mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any], skip_nones: bool = False, ) -> dict[KeyType, Any]: """Deep update a mapping with other mappings. If `skip_nones` is True, then the values that are None in the updating mappings are not updated. """ updated_mapping = mapping.copy() for updating_mapping in updating_mappings: for k, v in updating_mapping.items(): if ( k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict) ): updated_mapping[k] = deep_update( updated_mapping[k], v, skip_nones=skip_nones ) else: if skip_nones and v is None: continue updated_mapping[k] = v return updated_mapping
[docs] def dotlist_to_dict(dotlist: list[str]) -> dict: """Resolve dot-list style key-value pairs with YAML. Helper for overriding configuration values using command-line arguments in dot-list style. """ dotlist_dict = {} for override in dotlist: parts = override.strip().split("=", 1) if len(parts) == 2: key = parts[0].strip() raw_value = parts[1].strip() # If the value starts with a quote but doesn't end with the same quote, # it means we have a malformed string. In this case, we'll treat it as a raw string. if (raw_value.startswith('"') and not raw_value.endswith('"')) or ( raw_value.startswith("'") and not raw_value.endswith("'") ): value = raw_value else: try: value = yaml.safe_load(raw_value) except yaml.YAMLError: # If YAML parsing fails, treat it as a raw string value = raw_value keys = key.split(".") temp = dotlist_dict for k in keys[:-1]: temp = temp.setdefault(k, {}) temp[keys[-1]] = value return dotlist_dict
[docs] def run_command(command, cwd=None, verbose=False, propagate_errors=False): if verbose: logger.info(f"Running command: {command}") if cwd: print(f"Current working directory set to: {cwd}") with tempfile.TemporaryDirectory() as tmpdirname: if verbose: logger.info(f"Temporary directory created at: {tmpdirname}") file = os.path.join( tmpdirname, hashlib.sha1(command.encode("utf-8")).hexdigest() + ".sh" ) if verbose: logger.info(f"Script file created: {file}") with open(file, "w") as f: f.write(command) f.flush() if verbose: logger.info("Command written to script file.") master, slave = os.openpty() process = subprocess.Popen( f"bash {file}", stdout=slave, stderr=slave, stdin=subprocess.PIPE, cwd=cwd, shell=True, executable="/bin/bash", ) if verbose: logger.info("Subprocess started.") os.close(slave) if propagate_errors: stderr_output = [] while True: try: output = os.read(master, 1024) if not output: break decoded_output = output.decode(errors="ignore") print(decoded_output, end="", flush=True) if propagate_errors: stderr_output.append(decoded_output) except OSError as e: if e.errno == 5: # Input/output error is expected at the end of output break raise if verbose: logger.info("Output reading completed.") rc = process.wait() if verbose: logger.info(f"Subprocess finished with return code: {rc}") # New error propagation logic if rc != 0 and propagate_errors: error_content = ( "".join(stderr_output) if stderr_output else "No error details captured" ) raise RuntimeError( f"Evaluation failed! Please consult the logs below:\n{error_content}" ) return rc
[docs] def check_health( health_url: str, max_retries: int = 600, retry_interval: int = 2 ) -> bool: """ Check the health of the server. """ for _ in range(max_retries): try: response = requests.get(health_url) if response.status_code == 200: return True logger.info(f"Server replied with status code: {response.status_code}") time.sleep(retry_interval) except requests.exceptions.RequestException: logger.info("Server is not ready") time.sleep(retry_interval) return False
[docs] def check_endpoint( endpoint_url: str, endpoint_type: str, model_name: str, max_retries: int = 600, retry_interval: int = 2, ) -> bool: """ Check if the endpoint is responsive and ready to accept requests. """ payload = {"model": model_name, "max_tokens": 1} if endpoint_type == "completions": payload["prompt"] = "hello, my name is" elif endpoint_type == "chat": payload["messages"] = [{"role": "user", "content": "hello, what is your name?"}] else: raise ValueError(f"Invalid endpoint type: {endpoint_type}") for _ in range(max_retries): try: response = requests.post(endpoint_url, json=payload) if response.status_code == 200: return True logger.info(f"Server replied with status code: {response.status_code}") time.sleep(retry_interval) except requests.exceptions.RequestException: logger.info("Server is not ready") time.sleep(retry_interval) return False