# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import os
import subprocess
import tempfile
import time
from typing import Any, Literal, TypeVar
import requests
import yaml
from jinja2 import Environment, StrictUndefined, nodes
from nemo_evaluator.logging import get_logger
__all__ = []
logger = get_logger(__name__)
class MisconfigurationError(Exception):
pass
KeyType = TypeVar("KeyType")
def get_jinja2_environment() -> Environment:
"""Get a configured Jinja2 environment for template operations.
This ensures consistency between template parsing and rendering.
Uses StrictUndefined to match the behavior in api_dataclasses.py.
Returns:
Environment: Configured Jinja2 environment
"""
return Environment(undefined=StrictUndefined)
def deep_update(
mapping: dict[KeyType, Any],
*updating_mappings: dict[KeyType, Any],
skip_nones: bool = False,
) -> dict[KeyType, Any]:
"""Deep update a mapping with other mappings.
If `skip_nones` is True, then the values that are None in the updating mappings are
not updated.
"""
updated_mapping = mapping.copy()
for updating_mapping in updating_mappings:
for k, v in updating_mapping.items():
if (
k in updated_mapping
and isinstance(updated_mapping[k], dict)
and isinstance(v, dict)
):
updated_mapping[k] = deep_update(
updated_mapping[k], v, skip_nones=skip_nones
)
else:
if skip_nones and v is None:
continue
updated_mapping[k] = v
return updated_mapping
def extract_params_from_command(command: str) -> tuple[set[str], set[str]]:
"""Extract all config.params.* parameter names used in a command template.
Args:
command: Jinja2 command template string
Returns:
Tuple of (standard_params, extra_params) where:
- standard_params: Set of param names like {'temperature', 'max_new_tokens'}
- extra_params: Set of params.extra names like {'dummy_score', 'another_param'}
"""
# Use Jinja2's AST parser to extract variable attribute access patterns
# Uses the same environment configuration as template rendering
env = get_jinja2_environment()
ast = env.parse(command)
standard_params = set()
extra_params = set()
def extract_getattr_path(node):
"""Recursively extract the full dotted path from a Getattr node."""
if isinstance(node, nodes.Name):
return node.name
elif isinstance(node, nodes.Getattr):
base = extract_getattr_path(node.node)
if base:
return f"{base}.{node.attr}"
return node.attr
return None
def visit_node(node):
"""Visit all nodes in the AST to find variable references."""
if isinstance(node, nodes.Getattr):
full_path = extract_getattr_path(node)
if full_path:
# Check for config.params.extra.PARAM_NAME pattern
if full_path.startswith("config.params.extra."):
param_name = full_path.replace("config.params.extra.", "")
if param_name: # Only add if there's something after "extra."
# Extract only the first-level key after "extra."
param_name = param_name.split(".")[0]
extra_params.add(param_name)
# Check for config.params.PARAM_NAME pattern (but not just config.params.extra)
elif (
full_path.startswith("config.params.")
and full_path != "config.params.extra"
):
param_name = full_path.replace("config.params.", "")
if param_name != "extra": # Don't add "extra" itself
standard_params.add(param_name)
# Recursively visit child nodes
for child in node.iter_child_nodes():
visit_node(child)
visit_node(ast)
return standard_params, extra_params
def validate_params_in_command(
command: str,
merged_config: dict[KeyType, Any],
) -> None:
"""Validate that all params keys in merged config are used in the command.
Args:
command: The command template from framework.yml
merged_config: The final merged configuration
Raises:
MisconfigurationError: If merged_config contains params keys not used in command
"""
# Extract params keys used in command
command_standard_params, command_extra_params = extract_params_from_command(command)
# Get params from merged config
config_params = merged_config.get("config", {}).get("params", {})
if not config_params:
return # No params to validate
# Check standard params
unused_standard = []
for key, value in config_params.items():
if key == "extra":
continue # Handle extra separately
# Only validate non-None values (None means not set/using default)
if value is not None and key not in command_standard_params:
unused_standard.append(f"config.params.{key}")
# Check params.extra
config_extra = config_params.get("extra", {})
unused_extra = []
for key in config_extra.keys():
if key not in command_extra_params:
unused_extra.append(f"config.params.extra.{key}")
# Raise error if any unused params found
all_unused = unused_standard + unused_extra
if all_unused:
valid_standard = [f"config.params.{p}" for p in sorted(command_standard_params)]
valid_extra = [f"config.params.extra.{p}" for p in sorted(command_extra_params)]
logger.warn(
f"Configuration contains parameter(s) that are not used in the command template: "
f"{', '.join(all_unused)}. "
f"Valid params from command: {valid_standard + valid_extra}. "
f"Remove the unused parameters or update the command template to use them."
)
def dotlist_to_dict(dotlist: list[str]) -> dict:
"""Resolve dot-list style key-value pairs with YAML.
Helper for overriding configuration values using command-line arguments in dot-list style.
"""
dotlist_dict = {}
for override in dotlist:
parts = override.strip().split("=", 1)
if len(parts) == 2:
key = parts[0].strip()
raw_value = parts[1].strip()
# If the value starts with a quote but doesn't end with the same quote,
# it means we have a malformed string. In this case, we'll treat it as a raw string.
if (raw_value.startswith('"') and not raw_value.endswith('"')) or (
raw_value.startswith("'") and not raw_value.endswith("'")
):
value = raw_value
else:
try:
value = yaml.safe_load(raw_value)
except yaml.YAMLError:
# If YAML parsing fails, treat it as a raw string
value = raw_value
keys = key.split(".")
temp = dotlist_dict
for k in keys[:-1]:
temp = temp.setdefault(k, {})
temp[keys[-1]] = value
return dotlist_dict
def run_command(command, cwd=None, verbose=False, propagate_errors=False):
if verbose:
logger.info(f"Running command: {command}")
if cwd:
print(f"Current working directory set to: {cwd}")
with tempfile.TemporaryDirectory() as tmpdirname:
if verbose:
logger.info(f"Temporary directory created at: {tmpdirname}")
file = os.path.join(
tmpdirname, hashlib.sha1(command.encode("utf-8")).hexdigest() + ".sh"
)
if verbose:
logger.info(f"Script file created: {file}")
with open(file, "w") as f:
f.write(command)
f.flush()
if verbose:
logger.info("Command written to script file.")
master, slave = os.openpty()
process = subprocess.Popen(
f"bash {file}",
stdout=slave,
stderr=slave,
stdin=subprocess.PIPE,
cwd=cwd,
shell=True,
executable="/bin/bash",
)
if verbose:
logger.info("Subprocess started.")
os.close(slave)
if propagate_errors:
stderr_output = []
while True:
try:
output = os.read(master, 1024)
if not output:
break
decoded_output = output.decode(errors="ignore")
print(decoded_output, end="", flush=True)
if propagate_errors:
stderr_output.append(decoded_output)
except OSError as e:
if e.errno == 5: # Input/output error is expected at the end of output
break
raise
if verbose:
logger.info("Output reading completed.")
rc = process.wait()
if verbose:
logger.info(f"Subprocess finished with return code: {rc}")
# New error propagation logic
if rc != 0 and propagate_errors:
error_content = (
"".join(stderr_output) if stderr_output else "No error details captured"
)
raise RuntimeError(
f"Evaluation failed! Please consult the logs below:\n{error_content}"
)
return rc
def check_health(
health_url: str, max_retries: int = 600, retry_interval: int = 2
) -> bool:
"""
Check the health of the server.
"""
for _ in range(max_retries):
try:
response = requests.get(health_url)
if response.status_code == 200:
return True
logger.info(f"Server replied with status code: {response.status_code}")
time.sleep(retry_interval)
except requests.exceptions.RequestException:
logger.info("Server is not ready")
time.sleep(retry_interval)
return False
[docs]
def check_endpoint(
endpoint_url: str,
endpoint_type: Literal["completions", "chat"],
model_name: str,
max_retries: int = 600,
retry_interval: int = 2,
) -> bool:
"""Checks if the OpenAI-compatible endpoint is alive by sending a simple prompt.
Args:
endpoint_url (str): Full endpoint URL. For most servers that means either ``/v1/chat/completions`` or ``/completions`` must be provided
endpoint_type (Literal[completions, chat]): indicates if the model is instruction-tuned (chat) or a base model (completions). Used to constuct a proper payload structure.
model_name (str): model name that is linked to payload. Might be required by some endpoint.
max_retries (int, optional): How many attempt before returning false. Defaults to 600.
retry_interval (int, optional): How many seconds to wait between attempts. Defaults to 2.
Raises:
ValueError: if endpoint_type was not one of "completions", "chat"
Returns:
bool: whether the endpoint is alive
"""
payload = {"model": model_name, "max_tokens": 1}
if endpoint_type == "completions":
payload["prompt"] = "hello, my name is"
elif endpoint_type == "chat":
payload["messages"] = [{"role": "user", "content": "hello, what is your name?"}]
else:
raise ValueError(f"Invalid endpoint type: {endpoint_type}")
for _ in range(max_retries):
try:
response = requests.post(endpoint_url, json=payload)
if response.status_code == 200:
return True
logger.info(f"Server replied with status code: {response.status_code}")
time.sleep(retry_interval)
except requests.exceptions.RequestException:
logger.info("Server is not ready")
time.sleep(retry_interval)
return False