Source code for nemo_evaluator.core.input

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pkgutil
from typing import Optional

import yaml

from nemo_evaluator.adapters.adapter_config import AdapterConfig
from nemo_evaluator.api.api_dataclasses import (
    Evaluation,
    EvaluationConfig,
    EvaluationTarget,
)
from nemo_evaluator.core.utils import (
    MisconfigurationError,
    deep_update,
    dotlist_to_dict,
)
from nemo_evaluator.logging import get_logger

logger = get_logger(__name__)


[docs] def load_run_config(yaml_file: str) -> dict: """Load the run configuration from the YAML file. NOTE: The YAML config allows to override all the run configuration parameters. """ with open(yaml_file, "r") as file: config = yaml.safe_load(file) return config
[docs] def parse_cli_args(args) -> dict: """Parse CLI arguments into the run configuration format. NOTE: The CLI args allow to override a subset of the run configuration parameters. """ config = { "config": {}, "target": { "api_endpoint": {}, }, } if args.eval_type: config["config"]["type"] = args.eval_type if args.output_dir: config["config"]["output_dir"] = args.output_dir if args.api_key_name: config["target"]["api_endpoint"]["api_key"] = args.api_key_name if args.model_id: config["target"]["api_endpoint"]["model_id"] = args.model_id if args.model_type: config["target"]["api_endpoint"]["type"] = args.model_type if args.model_url: config["target"]["api_endpoint"]["url"] = args.model_url overrides = parse_override_params(args.overrides) # "--overrides takes precedence over other CLI args (e.g. --model_id)" config = deep_update(config, overrides, skip_nones=True) return config
[docs] def parse_override_params(override_params_str: Optional[str] = None) -> dict: if not override_params_str: return {} # Split the string into key-value pairs, handling commas inside quotes pairs = [] current_pair = "" in_quotes = False quote_char = None for char in override_params_str: if char in ('"', "'") and not in_quotes: in_quotes = True quote_char = char current_pair += char elif char == quote_char and in_quotes: in_quotes = False quote_char = None current_pair += char elif char == "," and not in_quotes: pairs.append(current_pair.strip()) current_pair = "" else: current_pair += char if current_pair: pairs.append(current_pair.strip()) return dotlist_to_dict(pairs)
[docs] def get_framework_evaluations(filepath: str) -> tuple[str, dict, dict[str, Evaluation]]: framework = {} with open(filepath, "r") as f: framework = yaml.safe_load(f) framework_name = framework["framework"]["name"] pkg_name = framework["framework"]["pkg_name"] run_config_framework_defaults = framework["defaults"] run_config_framework_defaults["framework_name"] = framework_name run_config_framework_defaults["pkg_name"] = pkg_name evaluations = dict() for evaluation_dict in framework["evaluations"]: # Apply run config evaluation defaults onto the framework defaults run_config_task_defaults = deep_update( run_config_framework_defaults, evaluation_dict["defaults"], skip_nones=True ) evaluation = Evaluation( **run_config_task_defaults, ) evaluations[evaluation_dict["defaults"]["config"]["type"]] = evaluation return framework_name, run_config_framework_defaults, evaluations
# improve typing def _get_framework_evaluations( def_file: str, ) -> tuple[dict[str, dict[str, Evaluation]], dict[str, dict], dict[str, Evaluation]]: # we should decide if this should raise at this point. # Probably not because this function is used with task invocation that might # be from different harness if not os.path.exists(def_file): raise ValueError(f"Framework Definition File does not exists at {def_file}") framework_eval_mapping = {} # framework name -> set of tasks | used in 'framework.task' invocation eval_name_mapping = {} # task name -> set of tasks | used in 'task' invocation logger.debug("Loading task definitions", filepath=def_file) ( framework_name, framework_defaults, framework_evaluations, ) = get_framework_evaluations(def_file) framework_eval_mapping[framework_name] = framework_evaluations eval_name_mapping.update(framework_evaluations) framework_defaults = {framework_name: framework_defaults} return framework_eval_mapping, framework_defaults, eval_name_mapping
[docs] def merge_dicts(dict1, dict2): merged = {} all_keys = set(dict1.keys()) | set(dict2.keys()) for key in all_keys: v1 = dict1.get(key) v2 = dict2.get(key) if key in dict1 and key in dict2: result = [] # Handle case where value is a list or not if isinstance(v1, list): result.extend(v1) elif v1 is not None: result.append(v1) if isinstance(v2, list): result.extend(v2) elif v2 is not None: result.append(v2) merged[key] = result elif key in dict1: merged[key] = v1 else: merged[key] = v2 return merged
[docs] def get_available_evaluations() -> tuple[ dict[str, dict[str, Evaluation]], dict[str, Evaluation], dict ]: all_framework_eval_mappings = {} all_framework_defaults = {} all_eval_name_mapping = {} try: import core_evals core_evals_pkg = list(pkgutil.iter_modules(core_evals.__path__)) except ImportError: core_evals_pkg = [] for pkg in core_evals_pkg: ( framework_eval_mapping, framework_defaults, eval_name_mapping, ) = _get_framework_evaluations( os.path.join(pkg.module_finder.path, pkg.name, "framework.yml") ) all_framework_eval_mappings.update(framework_eval_mapping) all_framework_defaults.update(framework_defaults) all_eval_name_mapping = merge_dicts(all_eval_name_mapping, eval_name_mapping) return ( all_framework_eval_mappings, all_framework_defaults, all_eval_name_mapping, )
[docs] def check_task_invocation(run_config: dict): """ Checks if task invocation is formatted correctly and a harness or task is available: Args: run_config (dict): _description_ Raises: MisconfigurationError: if eval type does not follow specified format MisconfigurationError: if provided framework is not available MisconfigurationError: if provided task is not available """ # evaluation type can be either 'framework.task' or 'task' eval_type_components = run_config["config"]["type"].split(".") if len(eval_type_components) == 2: # framework.task invocation framework_name, evaluation_name = eval_type_components elif len(eval_type_components) == 1: # task invocation framework_name, evaluation_name = None, eval_type_components[0] else: raise MisconfigurationError( "eval_type must follow 'framework_name.evaluation_name'. No additional dots are allowed." ) framework_evals_mapping, _, all_evals_mapping = get_available_evaluations() # framework.task invocation if framework_name: try: framework_evals_mapping[framework_name] except KeyError: raise MisconfigurationError( f"Unknown framework {framework_name}. Frameworks available: {', '.join(framework_evals_mapping.keys())}" ) else: try: all_evals_mapping[evaluation_name] except KeyError: raise MisconfigurationError( f"Unknown evaluation {evaluation_name}. Evaluations available: {', '.join(all_evals_mapping.keys())}" )
[docs] def check_required_default_missing(run_config: dict): if run_config["config"].get("type") is None: raise MisconfigurationError( "Missing required argument: config.type (cli: --eval_type)" ) if run_config["config"].get("output_dir") is None: raise MisconfigurationError( "Missing required argument: config.output_dir (cli: --output_dir)" )
[docs] def check_adapter_config(run_config): adapter_config: AdapterConfig | None = AdapterConfig.get_validated_config( run_config ) if adapter_config: if run_config["target"].get("api_endpoint") is None: raise MisconfigurationError( "You need to define target.api_endpoint in order to use an adapter (cli: --model_id, --model_url, --model_type)" ) if run_config["target"]["api_endpoint"].get("url") is None: raise MisconfigurationError( "You need to define target.api_endpoint.url in order to use an adapter (cli: --model_url)" )
[docs] def get_evaluation( evaluation_config: EvaluationConfig, target_config: EvaluationTarget ) -> Evaluation: # type: ignore """Infers harness information from evaluation config and wraps it into Evaluation Args: evaluation_config (EvaluationConfig): _description_ Returns: Evaluation: EvalConfig """ eval_type_components = evaluation_config.type.split(".") if len(eval_type_components) == 2: # framework.task invocation framework_name, evaluation_name = eval_type_components elif len(eval_type_components) == 1: # task invocation framework_name, evaluation_name = None, eval_type_components[0] else: raise all_framework_eval_mappings, all_framework_defaults, all_eval_name_mapping = ( get_available_evaluations() ) # First, get default Evaluation # "framework.task" invocation if framework_name: try: default_evaluation = all_framework_eval_mappings[framework_name][ evaluation_name ] except KeyError: default_evaluation = Evaluation(**all_framework_defaults[framework_name]) evaluation_config.type = evaluation_name default_evaluation.config.params.task = evaluation_name else: if isinstance(all_eval_name_mapping[evaluation_name], list): framework_handlers = [ evaluation.framework_name for evaluation in all_eval_name_mapping[evaluation_name] ] raise MisconfigurationError( f"{evaluation_name} is available in multiple frameworks: {','.join(framework_handlers)}. \ Please indicate which implementation you would like to choose by using 'framework.task' invocation. \ For example: {framework_handlers[0]}.{evaluation_name}. " ) default_evaluation = all_eval_name_mapping[evaluation_name] default_configuration = default_evaluation.model_dump(exclude_none=True) user_configuration = { "config": evaluation_config.model_dump(), "target": target_config.model_dump(), } merged_configuration = deep_update( default_configuration, user_configuration, skip_nones=True ) return Evaluation(**merged_configuration)
[docs] def check_type_compatibility(evaluation: Evaluation): if ( evaluation.config.supported_endpoint_types is not None and evaluation.target.api_endpoint.type not in evaluation.config.supported_endpoint_types ): if evaluation.target.api_endpoint.type is None: raise MisconfigurationError( "target.api_endpoint.type should be defined and match one of the endpoint " f"types supported by the benchmark: '{evaluation.config.supported_endpoint_types}'", ) if ( evaluation.target.api_endpoint.type not in evaluation.config.supported_endpoint_types ): raise MisconfigurationError( f"The benchmark '{evaluation.config.type}' does not support the model type '{evaluation.target.api_endpoint.type}'. " f"The benchmark supports '{evaluation.config.supported_endpoint_types}'." ) if evaluation.target.api_endpoint.type: # Check this only if the model is really required (to accomodate for non-model evals) if evaluation.target.api_endpoint.url is None: raise MisconfigurationError( "target.api_endpoint.url (CLI: --model_url) should be defined to run model evaluation!" ) if evaluation.target.api_endpoint.model_id is None: raise MisconfigurationError( "target.api_endpoint.model_id (CLI: --model_id) should be defined to run model evaluation!" )
[docs] def prepare_output_directory(evaluation: Evaluation): try: os.makedirs(evaluation.config.output_dir, exist_ok=True) except OSError as error: print(f"An error occurred while creating output directory: {error}") with open(os.path.join(evaluation.config.output_dir, "run_config.yml"), "w") as f: yaml.dump(evaluation.model_dump(), f)
[docs] def validate_configuration(run_config: dict) -> Evaluation: """Validates requested task through a dataclass. Additionally, handles creation of task folowing the logic: - evaluation type can be either 'framework.task' or 'task' - FDF stands for Framework Definition File Args: run_config_cli_overrides (dict): run configuration merged from config file and CLI Raises: """ check_required_default_missing(run_config) check_task_invocation(run_config) check_adapter_config(run_config) evaluation = get_evaluation( EvaluationConfig(**run_config["config"]), EvaluationTarget(**run_config["target"]), ) check_type_compatibility(evaluation) logger.info(f"User-invoked config: \n{yaml.dump(evaluation.model_dump())}") return evaluation