Source code for nemo_evaluator.adapters.interceptors.response_stats_interceptor

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Response stats interceptor that collects aggregated statistics from API responses."""

import copy
import datetime
import json
import threading
import time
from pathlib import Path
from typing import Optional, final

from pydantic import Field

from nemo_evaluator.adapters.caching.diskcaching import Cache
from nemo_evaluator.adapters.decorators import register_for_adapter
from nemo_evaluator.adapters.types import (
    AdapterGlobalContext,
    AdapterResponse,
    PostEvalHook,
    ResponseInterceptor,
)
from nemo_evaluator.logging import BaseLoggingParams, get_logger


[docs] @register_for_adapter( name="response_stats", description="Collects aggregated statistics from API responses for metrics collection", ) @final class ResponseStatsInterceptor(ResponseInterceptor, PostEvalHook): """Collects aggregated statistics from API responses for metrics collection. Tracks the following statistics: - Token usage (prompt, completion, total) with averages and maximums - Response status codes and counts - Finish reasons and stop reasons - Tool calls and function calls counts - Response latency (average and maximum) - Total response count - Number of runs, inference times (approximated by processing time from the first to the last response) """
[docs] class Params(BaseLoggingParams): """Configuration parameters for response stats collection.""" collect_token_stats: bool = Field( default=True, description="Whether to collect token statistics" ) collect_finish_reasons: bool = Field( default=True, description="Whether to collect finish reasons" ) collect_tool_calls: bool = Field( default=True, description="Whether to collect tool call statistics" ) stats_file_saving_interval: Optional[int] = Field( default=100, description="How often (every how many responses) to save stats to a file. If None, stats are only saved via post_eval_hook.", ) save_individuals: bool = Field( default=True, description="Whether to save individual request statistics. If True, saves all individuals; if False, saves only aggregated stats.", ) cache_dir: str = Field( default="/tmp/response_stats_interceptor", description="Custom cache directory for response stats interceptor.", ) logging_aggregated_stats_interval: int = Field( default=100, description="How often (every how many responses) to log aggregated response statistics. Default is 100.", )
def __init__(self, params: Params): """ Initialize the response stats interceptor. Args: params: Configuration parameters """ self.collect_token_stats = params.collect_token_stats self.collect_finish_reasons = params.collect_finish_reasons self.collect_tool_calls = params.collect_tool_calls self.stats_file_saving_interval = params.stats_file_saving_interval self.save_individuals = params.save_individuals self.cache_dir = params.cache_dir self.logging_aggregated_stats_interval = ( params.logging_aggregated_stats_interval ) # Get logger for this interceptor with interceptor context self.logger = get_logger(self.__class__.__name__) # Initialize lock and stats first self._lock = threading.Lock() self._adapter_start_time = time.time() # Record adapter initialization time self._stats = { # Average statistics "avg_prompt_tokens": None, "avg_total_tokens": None, "avg_completion_tokens": None, "avg_latency_ms": None, # Maximum statistics "max_prompt_tokens": None, "max_total_tokens": None, "max_completion_tokens": None, "max_latency_ms": None, # Counters and totals "count": 0, "successful_count": 0, "tool_calls_count": 0, "function_calls_count": 0, "finish_reason": {}, "stop_reason": {}, "status_codes": {}, # Time tracking "inference_time": 0.0, "run_id": 0, "last_request_time": None, "inference_run_times": {}, # {run_id: {"start": time, "end": time, "inference_time": time}} } # Always initialize cache database cache_path = Path(self.cache_dir) cache_path.mkdir(parents=True, exist_ok=True) self._request_stats_cache = Cache(cache_path) # Load existing aggregated stats if available try: self._load_aggregated_cached_stats() except Exception as e: self.logger.warning(f"Failed to load cached stats: {e}") # Save run info immediately on initialization self._save_run_ids_info() self.logger.info( "Response stats interceptor initialized", collect_token_stats=self.collect_token_stats, collect_finish_reasons=self.collect_finish_reasons, collect_tool_calls=self.collect_tool_calls, stats_file_saving_interval=self.stats_file_saving_interval, save_individuals=self.save_individuals, cache_dir=self.cache_dir, logging_aggregated_stats_interval=self.logging_aggregated_stats_interval, ) def _load_aggregated_cached_stats(self) -> None: """Load interceptor state from cache.""" interceptor_state = self._load_interceptor_state() if "aggregated_stats" in interceptor_state: aggregated_stats = interceptor_state["aggregated_stats"] # Convert ISO timestamps back to floats for inference_run_times (if they are strings) if "inference_run_times" in aggregated_stats: # Convert string keys back to integers for run_ids converted_run_times = {} for run_id, run_data in aggregated_stats["inference_run_times"].items(): # Convert run_id from string to int if needed int_run_id = int(run_id) if isinstance(run_id, str) else run_id if run_data.get("run_start") and isinstance( run_data["run_start"], str ): run_data["run_start"] = datetime.datetime.fromisoformat( run_data["run_start"] ).timestamp() if run_data.get("first_request_time") and isinstance( run_data["first_request_time"], str ): run_data["first_request_time"] = ( datetime.datetime.fromisoformat( run_data["first_request_time"] ).timestamp() ) if run_data.get("last_request_time") and isinstance( run_data["last_request_time"], str ): run_data["last_request_time"] = datetime.datetime.fromisoformat( run_data["last_request_time"] ).timestamp() converted_run_times[int_run_id] = run_data # Replace with converted run_times aggregated_stats["inference_run_times"] = converted_run_times # Convert string keys back to integers for status codes if "status_codes" in aggregated_stats and isinstance( aggregated_stats["status_codes"], dict ): status_codes = {} for key, value in aggregated_stats["status_codes"].items(): try: int_key = int(key) status_codes[int_key] = value except ValueError: status_codes[key] = value aggregated_stats["status_codes"] = status_codes # Set current stats to cached data (cached stats already contain accumulated data) self._stats = aggregated_stats # Note: run_id increment is handled in _save_run_ids_info() self.logger.info( f"Loaded interceptor state with run_id {aggregated_stats.get('run_id', 0)}, count={aggregated_stats.get('count', 0)}" ) else: self.logger.info("No cached interceptor state found") def _update_basic_stats(self, resp: AdapterResponse, current_time: float) -> None: """Update basic statistics with thread safety.""" with self._lock: # Update last_request_time self._stats["last_request_time"] = current_time # Update inference_run_times for current run run_id = self._stats["run_id"] if run_id not in self._stats["inference_run_times"]: # First request in this run - estimate when inference actually started using latency estimated_first_request_start = current_time if hasattr(resp, "latency_ms") and resp.latency_ms is not None: # Estimate when this request was sent (current_time - latency) estimated_first_request_start = current_time - ( resp.latency_ms / 1000.0 ) self._stats["inference_run_times"][run_id] = { "run_start": self._adapter_start_time, "first_request_time": estimated_first_request_start, "last_request_time": current_time, "inference_time": 0.0, } else: # Update last_request_time and calculate inference_time run_data = self._stats["inference_run_times"][run_id] old_inference_time = run_data["inference_time"] run_data["last_request_time"] = current_time run_data["inference_time"] = ( current_time - run_data["first_request_time"] ) # Add delta to global inference_time delta = run_data["inference_time"] - old_inference_time self._stats["inference_time"] += delta def _update_running_stats(self, stat_name: str, value: float) -> None: """Update running average and max for a given statistic.""" # Skip if value is not a valid number if not isinstance(value, (int, float)): self.logger.warning( f"Invalid value for {stat_name}: {value} (expected number)" ) return # Calculate running average using current successful count avg_key = f"avg_{stat_name}" if self._stats[avg_key] is None: self._stats[avg_key] = value else: self._stats[avg_key] = round( (self._stats[avg_key] * self._stats["successful_count"] + value) / (self._stats["successful_count"] + 1), 2, ) # Update max valuename max_key = f"max_{stat_name}" if self._stats[max_key] is None or value > self._stats[max_key]: self._stats[max_key] = value def _update_time_tracking(self, current_time: float) -> None: """Update time tracking statistics (thread-safe).""" with self._lock: # Update last request time self._stats["last_request_time"] = current_time def _update_response_stats(self, individual_stats: dict[str, any]) -> None: """Update response statistics with new data (thread-safe).""" with self._lock: # Update token statistics with running means BEFORE incrementing successful_count for token_type in ["prompt_tokens", "total_tokens", "completion_tokens"]: value = individual_stats.get(token_type, 0) self._update_running_stats(token_type, value) # Increment successful count after updating running averages self._stats["successful_count"] += 1 # Update finish reasons finish_reason = individual_stats.get("finish_reason") if isinstance(finish_reason, str): self._stats["finish_reason"][finish_reason] = ( self._stats["finish_reason"].get(finish_reason, 0) + 1 ) # Update tool calls and function calls tool_calls_count = individual_stats.get("tool_calls_count", 0) if tool_calls_count > 0: self._stats["tool_calls_count"] += tool_calls_count function_calls_count = individual_stats.get("function_calls_count", 0) if function_calls_count > 0: self._stats["function_calls_count"] += function_calls_count # Log aggregated stats at specified interval if ( self._stats["successful_count"] % self.logging_aggregated_stats_interval == 0 ): self.logger.info(**self._stats) def _add_basic_response_stats( self, adapter_response, context: AdapterGlobalContext ) -> None: """Add basic statistics for any response (JSON or non-JSON).""" with self._lock: self._stats["count"] += 1 # Track the specific status code status_code = adapter_response.r.status_code self._stats["status_codes"][status_code] = ( self._stats["status_codes"].get(status_code, 0) + 1 ) # Track latency statistics if ( hasattr(adapter_response, "latency_ms") and adapter_response.latency_ms is not None ): self._update_running_stats("latency_ms", adapter_response.latency_ms) def _extract_detailed_response_stats(self, response_data: dict) -> dict: """Extract detailed response statistics from response data.""" detailed_stats = {} try: # Extract usage information usage = response_data.get("usage", {}) if isinstance(usage, dict): detailed_stats["prompt_tokens"] = usage.get("prompt_tokens", 0) detailed_stats["total_tokens"] = usage.get("total_tokens", 0) detailed_stats["completion_tokens"] = usage.get("completion_tokens", 0) # Extract choices information choices = response_data.get("choices", []) if isinstance(choices, list): for choice in choices: if isinstance(choice, dict): # Track finish reasons finish_reason = choice.get("finish_reason") if isinstance(finish_reason, str): detailed_stats["finish_reason"] = finish_reason # Track tool calls and function calls message = choice.get("message", {}) if isinstance(message, dict): tool_calls = message.get("tool_calls", []) if isinstance(tool_calls, list): detailed_stats["tool_calls_count"] = len(tool_calls) function_call = message.get("function_call") detailed_stats["function_calls_count"] = ( 1 if function_call else 0 ) break # Only process first choice for individual stats except Exception as e: self.logger.warning( "Failed to extract detailed response stats", error=str(e), ) return detailed_stats def _cache_request_stats(self, request_id: str, stats: dict[str, any]) -> None: """Cache individual request stats by request ID.""" # Only save individual requests if save_individuals is True if self.save_individuals and self._request_stats_cache is not None: # Add request_id to the stats before caching stats_with_id = stats.copy() stats_with_id["request_id"] = request_id stats_json = json.dumps(stats_with_id, ensure_ascii=False) self._request_stats_cache.set(request_id, stats_json)
[docs] def intercept_response( self, resp: AdapterResponse, context: AdapterGlobalContext ) -> AdapterResponse: """Collect aggregated statistics from the response.""" # Get status code once and reuse it if resp.rctx.cache_hit: self.logger.debug( "Response was from cache, skipping response stats collection" ) return resp status_code = resp.r.status_code # Update time tracking with current timestamp current_time = time.time() self._update_time_tracking(current_time) # Update basic stats with thread safety self._update_basic_stats(resp, current_time) # Always add basic response stats (count, status_code) self._add_basic_response_stats(resp, context) # Extract detailed stats once and reuse them detailed_stats = None try: # Try to parse response as JSON response_data = resp.r.json() if status_code == 200: detailed_stats = self._extract_detailed_response_stats(response_data) # Add detailed stats for aggregation self._update_response_stats(detailed_stats) self.logger.debug( "Collected detailed response stats", request_id=resp.rctx.request_id, response_count=self._stats["count"], status_code=status_code, ) except (json.JSONDecodeError, Exception) as e: # Handle both JSON parsing errors and other exceptions # In case of any error, only basic stats are collected self.logger.warning(f"Error parsing response body for token counting: {e}") # Save stats to file if interval reached if ( self.stats_file_saving_interval is not None and self._stats["count"] % self.stats_file_saving_interval == 0 ): self._save_stats_to_file(context) # Cache individual request stats if enabled if self.save_individuals: request_id = resp.rctx.request_id # Create individual request stats with basic info individual_stats = { "timestamp": current_time, "status_code": status_code, "count": 1, # This is just one response "run_id": self._stats["run_id"], } # Add detailed stats if available (reuse the extracted stats) if detailed_stats: individual_stats.update(detailed_stats) self._cache_request_stats(request_id, individual_stats) # Save aggregated stats to cache self._save_aggregated_stats_to_cache() return resp
def _save_stats_to_file(self, context: AdapterGlobalContext) -> None: """Save current stats to the same file as post-eval hook.""" # Get stats in a thread-safe manner with self._lock: stats = copy.deepcopy(self._stats) if stats["count"] == 0: self.logger.debug("No response statistics collected, skipping file write") return # Convert timestamps to readable dates in inference_run_times and add time_to_first_request if "inference_run_times" in stats: for run_id, run_data in stats["inference_run_times"].items(): if run_data.get("run_start"): run_data["run_start"] = datetime.datetime.fromtimestamp( run_data["run_start"] ).isoformat() if run_data.get("first_request_time"): run_data["first_request_time"] = datetime.datetime.fromtimestamp( run_data["first_request_time"] ).isoformat() if run_data.get("last_request_time"): run_data["last_request_time"] = datetime.datetime.fromtimestamp( run_data["last_request_time"] ).isoformat() # Calculate time_to_first_request for this run if run_data.get("first_request_time") and run_data.get("run_start"): # Convert ISO strings back to timestamps for calculation first_request_timestamp = datetime.datetime.fromisoformat( run_data["first_request_time"] ).timestamp() run_start_timestamp = datetime.datetime.fromisoformat( run_data["run_start"] ).timestamp() time_to_first_request = ( first_request_timestamp - run_start_timestamp ) run_data["time_to_first_request_seconds"] = round( time_to_first_request, 3 ) # Prepare metrics data under adapter name metrics_data = { "response_stats": { "description": "Response statistics saved during processing", **stats, } } with self._lock: context.metrics_path.parent.mkdir(parents=True, exist_ok=True) # Read existing metrics if file exists existing_metrics = {} if context.metrics_path.exists(): try: with open(context.metrics_path, "r") as f: existing_metrics = json.load(f) except (json.JSONDecodeError, IOError): pass # Start fresh if file is corrupted # Merge with existing metrics merged_metrics = {**existing_metrics, **metrics_data} # Write merged metrics to file with open(context.metrics_path, "w") as f: json.dump(merged_metrics, f, indent=2, ensure_ascii=False) self.logger.debug( "Saved response stats to file", path=str(context.metrics_path), response_count=stats["count"], ) def _save_run_ids_info(self) -> None: """Save run IDs info during initialization as dictionary in cache.""" # Load existing interceptor state interceptor_state = self._load_interceptor_state() # Determine the next run_id based on existing run_ids if "run_ids" in interceptor_state and interceptor_state["run_ids"]: # Get the highest run_id and increment it max_run_id = max(run["run_id"] for run in interceptor_state["run_ids"]) self._stats["run_id"] = max_run_id + 1 else: # First run, start with 0 self._stats["run_id"] = 0 run_info = { "run_id": self._stats["run_id"], "start": datetime.datetime.fromtimestamp( self._adapter_start_time ).isoformat(), } # Add current run info if "run_ids" not in interceptor_state: interceptor_state["run_ids"] = [] interceptor_state["run_ids"].append(run_info) # Save updated interceptor state self._save_interceptor_state(interceptor_state) self.logger.debug( "Saved run info to interceptor state", run_id=self._stats["run_id"], start_time=run_info["start"], ) def _save_aggregated_stats_to_cache(self) -> None: """Save aggregated stats to interceptor state.""" # Load existing interceptor state interceptor_state = self._load_interceptor_state() # Create a copy of stats for caching stats_to_cache = self._stats.copy() # Keep timestamps as floats for caching (don't convert to ISO) # Update aggregated stats in interceptor state interceptor_state["aggregated_stats"] = stats_to_cache # Save updated interceptor state self._save_interceptor_state(interceptor_state) self.logger.debug( "Saved aggregated stats to interceptor state", run_id=self._stats["run_id"], ) def _load_interceptor_state(self) -> dict: """Load interceptor state from cache.""" if self._request_stats_cache is not None: state = self._request_stats_cache.get("interceptor_state") if state: if isinstance(state, str): return json.loads(state) return state return {} def _save_interceptor_state(self, state: dict) -> None: """Save interceptor state to cache.""" if self._request_stats_cache is not None: self._request_stats_cache.set("interceptor_state", state)
[docs] def post_eval_hook(self, context: AdapterGlobalContext) -> None: """Write collected response statistics to eval_factory_metrics.json.""" # Get aggregated stats self.logger.info( "Writing response statistics to metrics", total_responses=self._stats["count"], successful_responses=self._stats["successful_count"], output_dir=context.output_dir, ) self._save_stats_to_file(context)