Source code for nemo_evaluator.adapters.interceptors.caching_interceptor

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Caching interceptor with registry support."""

import hashlib
import json
import re
import threading
from typing import Any, final

import requests
import requests.structures
from pydantic import Field

from nemo_evaluator.adapters.caching.diskcaching import Cache
from nemo_evaluator.adapters.decorators import register_for_adapter
from nemo_evaluator.adapters.types import (
    AdapterGlobalContext,
    AdapterRequest,
    AdapterResponse,
    RequestToResponseInterceptor,
    ResponseInterceptor,
)
from nemo_evaluator.logging import BaseLoggingParams, get_logger


[docs] @register_for_adapter( name="caching", description="Caches requests and responses with disk storage", ) @final class CachingInterceptor(RequestToResponseInterceptor, ResponseInterceptor): """Caching interceptor is special in the sense that it intercepts both requests and responses."""
[docs] class Params(BaseLoggingParams): """Configuration parameters for caching.""" cache_dir: str = Field( default="/tmp", description="Directory to store cache files" ) reuse_cached_responses: bool = Field( default=False, description="Whether to reuse cached responses. If True, this overrides save_responses (sets it to True) and max_saved_responses (sets it to None)", ) save_requests: bool = Field( default=False, description="Whether to save requests to cache" ) save_responses: bool = Field( default=True, description="Whether to save responses to cache. Note: This is automatically set to True if reuse_cached_responses is True", ) max_saved_requests: int | None = Field( default=None, description="Maximum number of requests to save" ) max_saved_responses: int | None = Field( default=None, description="Maximum number of responses to cache. Note: This is automatically set to None if reuse_cached_responses is True", )
responses_cache: Cache requests_cache: Cache headers_cache: Cache def __init__(self, params: Params): """ Initialize the caching interceptor. Args: params: Configuration parameters """ # Initialize caches immediately self.responses_cache = Cache(directory=f"{params.cache_dir}/responses") self.requests_cache = Cache(directory=f"{params.cache_dir}/requests") self.headers_cache = Cache(directory=f"{params.cache_dir}/headers") self.reuse_cached_responses = params.reuse_cached_responses self.save_requests = params.save_requests # If reuse_cached_responses is True, override save_responses and max_saved_responses if params.reuse_cached_responses: self.save_responses = True self.max_saved_responses = None else: self.save_responses = params.save_responses self.max_saved_responses = params.max_saved_responses self.max_saved_requests = params.max_saved_requests # Counters for cache management self._cached_requests_count = 0 self._cached_responses_count = 0 # Thread safety self._count_lock = threading.Lock() # Get logger for this interceptor with interceptor context self.logger = get_logger(self.__class__.__name__) self.logger.info( "Caching interceptor initialized", cache_dir=params.cache_dir, reuse_cached_responses=self.reuse_cached_responses, save_requests=self.save_requests, save_responses=self.save_responses, max_saved_requests=self.max_saved_requests, max_saved_responses=self.max_saved_responses, ) _MEDIA_DATA_RE = re.compile(r"data:(image|audio|video)/([^;]+);base64,(.+)") @staticmethod def _sanitize_media_url(url: str) -> str: """Replace a base64 data URL with a brief human-readable description.""" match = CachingInterceptor._MEDIA_DATA_RE.fullmatch(url) if match: media_type, fmt, b64 = match.group(1), match.group(2), match.group(3) size_bytes = len(b64) * 3 // 4 # Approximate decoded size return f"<{media_type}: format={fmt}, size≈{size_bytes} bytes>" return url
[docs] @staticmethod def sanitize_request_data_for_logging(data: Any) -> Any: """ Sanitize request data for logging by replacing media content with brief descriptions. Args: data: Request data to sanitize (can be dict, list, or any other type) Returns: Sanitized version of the data with media replaced by brief descriptions """ if isinstance(data, dict): sanitized = {} for key, value in data.items(): if key == "url" and isinstance(value, str): sanitized[key] = CachingInterceptor._sanitize_media_url(value) else: sanitized[key] = ( CachingInterceptor.sanitize_request_data_for_logging(value) ) return sanitized elif isinstance(data, list): return [ CachingInterceptor.sanitize_request_data_for_logging(item) for item in data ] else: return data
@staticmethod def _generate_cache_key(data: Any) -> str: """ Generate a hash for the request data to be used as the cache key. Args: data: Data to be hashed Returns: str: Hash of the data """ data_str = json.dumps(data, sort_keys=True) return hashlib.sha256(data_str.encode("utf-8")).hexdigest() def _get_from_cache(self, cache_key: str) -> tuple[Any, Any] | None: """ Attempt to retrieve content and headers from cache. Args: cache_key (str): Cache key to lookup Returns: Optional[tuple[Any, Any]]: Tuple of (content, headers) if found, None if not """ try: cached_content = self.responses_cache[cache_key] cached_headers = self.headers_cache[cache_key] self.logger.debug("Cache hit", cache_key=cache_key[:8] + "...") return cached_content, cached_headers except KeyError: self.logger.debug("Cache miss", cache_key=cache_key[:8] + "...") return None def _save_to_cache(self, cache_key: str, content: Any, headers: Any) -> None: """ Save content and headers to cache. Args: cache_key (str): Cache key to store under content: Content to cache headers: Headers to cache """ # Check if we've reached the max responses limit if self.max_saved_responses is not None: with self._count_lock: if self._cached_responses_count >= self.max_saved_responses: self.logger.warning( "Maximum cached responses limit reached", max_saved_responses=self.max_saved_responses, ) return self._cached_responses_count += 1 # Save content to cache self.responses_cache[cache_key] = content # NOTE: headers are `CaseInsensitiveDict()` which is not serializable # by `Cache` class. If this is the case, transform to a normal dict. if isinstance(headers, requests.structures.CaseInsensitiveDict): cached_headers = dict(headers) else: cached_headers = headers self.headers_cache[cache_key] = cached_headers self.logger.debug( "Saved response to cache", cache_key=cache_key[:8] + "...", content_size=len(content) if hasattr(content, "__len__") else "unknown", )
[docs] @final def intercept_request( self, req: AdapterRequest, context: AdapterGlobalContext ) -> AdapterRequest | AdapterResponse: """Shall return request if no cache hit, and response if it is. Args: req (AdapterRequest): The adapter request to intercept context (AdapterGlobalContext): Global context containing server-level configuration """ request_data = req.r.get_json() # Check cache. Create cache key that will be used everywhere (also if no cache hit) req.rctx.cache_key = self._generate_cache_key(request_data) # Sanitize request data for logging (replace images with brief descriptions) sanitized_request_data = self.sanitize_request_data_for_logging(request_data) self.logger.debug("Intercepted request", request_data=sanitized_request_data) self.logger.debug( "Processing request for caching", cache_key=req.rctx.cache_key[:8] + "...", request_data_keys=( list(request_data.keys()) if isinstance(request_data, dict) else "unknown" ), ) # Cache request if needed and within limit if self.save_requests: with self._count_lock: if ( self.max_saved_requests is None or self._cached_requests_count < self.max_saved_requests ): self.requests_cache[req.rctx.cache_key] = request_data self._cached_requests_count += 1 self.logger.debug( "Saved request to cache", cache_key=req.rctx.cache_key[:8] + "...", ) else: self.logger.warning( "Maximum cached requests limit reached", max_saved_requests=self.max_saved_requests, ) # Only check cache if response reuse is enabled if self.reuse_cached_responses: cached_result = self._get_from_cache(req.rctx.cache_key) if cached_result: content, headers = cached_result requests_response = requests.Response() requests_response._content = content requests_response.status_code = 200 requests_response.reason = "OK" requests_response.headers = requests.utils.CaseInsensitiveDict(headers) requests_response.request = request_data # Make downstream know req.rctx.cache_hit = True self.logger.info( "Returning cached response", cache_key=req.rctx.cache_key[:8] + "...", status_code=200, ) return AdapterResponse(r=requests_response, rctx=req.rctx) self.logger.debug( "No cache hit, proceeding with request", cache_key=req.rctx.cache_key[:8] + "...", ) return req
[docs] @final def intercept_response( self, resp: AdapterResponse, context: AdapterGlobalContext ) -> AdapterResponse: """Cache the response if caching is enabled and response is successful.""" # first, if caching was used, we do nothing if resp.rctx.cache_hit: self.logger.debug( "Response was from cache, skipping caching", cache_key=( resp.rctx.cache_key[:8] + "..." if hasattr(resp.rctx, "cache_key") else "unknown" ), ) return resp if resp.r.status_code == 200 and self.save_responses: # Save both content and headers to cache try: assert resp.rctx.cache_key, "cache key is unset, this is a bug" self._save_to_cache( cache_key=resp.rctx.cache_key, content=resp.r.content, headers=resp.r.headers, ) self.logger.info( "Cached successful response", cache_key=resp.rctx.cache_key[:8] + "...", ) except Exception as e: self.logger.error( "Could not cache response", error=str(e), cache_key=( resp.rctx.cache_key[:8] + "..." if hasattr(resp.rctx, "cache_key") else "unknown" ), ) else: self.logger.debug( "Response not cached", status_code=resp.r.status_code, save_responses=self.save_responses, ) # And just propagate return resp