Source code for nemo_evaluator.adapters.interceptors.caching_interceptor

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Caching interceptor with registry support."""

import hashlib
import json
import threading
from typing import Any, final

import requests
import requests.structures
from pydantic import Field

from nemo_evaluator.adapters.caching.diskcaching import Cache
from nemo_evaluator.adapters.decorators import register_for_adapter
from nemo_evaluator.adapters.types import (
    AdapterGlobalContext,
    AdapterRequest,
    AdapterResponse,
    RequestToResponseInterceptor,
    ResponseInterceptor,
)
from nemo_evaluator.logging import BaseLoggingParams, get_logger


[docs] @register_for_adapter( name="caching", description="Caches requests and responses with disk storage", ) @final class CachingInterceptor(RequestToResponseInterceptor, ResponseInterceptor): """Caching interceptor is special in the sense that it intercepts both requests and responses."""
[docs] class Params(BaseLoggingParams): """Configuration parameters for caching.""" cache_dir: str = Field( default="/tmp", description="Directory to store cache files" ) reuse_cached_responses: bool = Field( default=False, description="Whether to reuse cached responses. If True, this overrides save_responses (sets it to True) and max_saved_responses (sets it to None)", ) save_requests: bool = Field( default=False, description="Whether to save requests to cache" ) save_responses: bool = Field( default=True, description="Whether to save responses to cache. Note: This is automatically set to True if reuse_cached_responses is True", ) max_saved_requests: int | None = Field( default=None, description="Maximum number of requests to save" ) max_saved_responses: int | None = Field( default=None, description="Maximum number of responses to cache. Note: This is automatically set to None if reuse_cached_responses is True", )
responses_cache: Cache requests_cache: Cache headers_cache: Cache def __init__(self, params: Params): """ Initialize the caching interceptor. Args: params: Configuration parameters """ # Initialize caches immediately self.responses_cache = Cache(directory=f"{params.cache_dir}/responses") self.requests_cache = Cache(directory=f"{params.cache_dir}/requests") self.headers_cache = Cache(directory=f"{params.cache_dir}/headers") self.reuse_cached_responses = params.reuse_cached_responses self.save_requests = params.save_requests # If reuse_cached_responses is True, override save_responses and max_saved_responses if params.reuse_cached_responses: self.save_responses = True self.max_saved_responses = None else: self.save_responses = params.save_responses self.max_saved_responses = params.max_saved_responses self.max_saved_requests = params.max_saved_requests # Counters for cache management self._cached_requests_count = 0 self._cached_responses_count = 0 # Thread safety self._count_lock = threading.Lock() # Get logger for this interceptor with interceptor context self.logger = get_logger(self.__class__.__name__) self.logger.info( "Caching interceptor initialized", cache_dir=params.cache_dir, reuse_cached_responses=self.reuse_cached_responses, save_requests=self.save_requests, save_responses=self.save_responses, max_saved_requests=self.max_saved_requests, max_saved_responses=self.max_saved_responses, ) @staticmethod def _generate_cache_key(data: Any) -> str: """ Generate a hash for the request data to be used as the cache key. Args: data: Data to be hashed Returns: str: Hash of the data """ data_str = json.dumps(data, sort_keys=True) return hashlib.sha256(data_str.encode("utf-8")).hexdigest() def _get_from_cache(self, cache_key: str) -> tuple[Any, Any] | None: """ Attempt to retrieve content and headers from cache. Args: cache_key (str): Cache key to lookup Returns: Optional[tuple[Any, Any]]: Tuple of (content, headers) if found, None if not """ try: cached_content = self.responses_cache[cache_key] cached_headers = self.headers_cache[cache_key] self.logger.debug("Cache hit", cache_key=cache_key[:8] + "...") return cached_content, cached_headers except KeyError: self.logger.debug("Cache miss", cache_key=cache_key[:8] + "...") return None def _save_to_cache(self, cache_key: str, content: Any, headers: Any) -> None: """ Save content and headers to cache. Args: cache_key (str): Cache key to store under content: Content to cache headers: Headers to cache """ # Check if we've reached the max responses limit if self.max_saved_responses is not None: with self._count_lock: if self._cached_responses_count >= self.max_saved_responses: self.logger.warning( "Maximum cached responses limit reached", max_saved_responses=self.max_saved_responses, ) return self._cached_responses_count += 1 # Save content to cache self.responses_cache[cache_key] = content # NOTE: headers are `CaseInsensitiveDict()` which is not serializable # by `Cache` class. If this is the case, transform to a normal dict. if isinstance(headers, requests.structures.CaseInsensitiveDict): cached_headers = dict(headers) else: cached_headers = headers self.headers_cache[cache_key] = cached_headers self.logger.debug( "Saved response to cache", cache_key=cache_key[:8] + "...", content_size=len(content) if hasattr(content, "__len__") else "unknown", )
[docs] @final def intercept_request( self, req: AdapterRequest, context: AdapterGlobalContext ) -> AdapterRequest | AdapterResponse: """Shall return request if no cache hit, and response if it is. Args: req (AdapterRequest): The adapter request to intercept context (AdapterGlobalContext): Global context containing server-level configuration """ request_data = req.r.get_json() # Check cache. Create cache key that will be used everywhere (also if no cache hit) req.rctx.cache_key = self._generate_cache_key(request_data) self.logger.debug("Request", request_data=request_data) self.logger.debug( "Processing request for caching", cache_key=req.rctx.cache_key[:8] + "...", request_data_keys=( list(request_data.keys()) if isinstance(request_data, dict) else "unknown" ), ) # Cache request if needed and within limit if self.save_requests: with self._count_lock: if ( self.max_saved_requests is None or self._cached_requests_count < self.max_saved_requests ): self.requests_cache[req.rctx.cache_key] = request_data self._cached_requests_count += 1 self.logger.debug( "Saved request to cache", cache_key=req.rctx.cache_key[:8] + "...", ) else: self.logger.warning( "Maximum cached requests limit reached", max_saved_requests=self.max_saved_requests, ) # Only check cache if response reuse is enabled if self.reuse_cached_responses: cached_result = self._get_from_cache(req.rctx.cache_key) if cached_result: content, headers = cached_result requests_response = requests.Response() requests_response._content = content requests_response.status_code = 200 requests_response.reason = "OK" requests_response.headers = requests.utils.CaseInsensitiveDict(headers) requests_response.request = request_data # Make downstream know req.rctx.cache_hit = True self.logger.info( "Returning cached response", cache_key=req.rctx.cache_key[:8] + "...", status_code=200, ) return AdapterResponse(r=requests_response, rctx=req.rctx) self.logger.debug( "No cache hit, proceeding with request", cache_key=req.rctx.cache_key[:8] + "...", ) return req
[docs] @final def intercept_response( self, resp: AdapterResponse, context: AdapterGlobalContext ) -> AdapterResponse: """Cache the response if caching is enabled and response is successful.""" # first, if caching was used, we do nothing if resp.rctx.cache_hit: self.logger.debug( "Response was from cache, skipping caching", cache_key=( resp.rctx.cache_key[:8] + "..." if hasattr(resp.rctx, "cache_key") else "unknown" ), ) return resp if resp.r.status_code == 200 and self.save_responses: # Save both content and headers to cache try: assert resp.rctx.cache_key, "cache key is unset, this is a bug" self._save_to_cache( cache_key=resp.rctx.cache_key, content=resp.r.content, headers=resp.r.headers, ) self.logger.info( "Cached successful response", cache_key=resp.rctx.cache_key[:8] + "...", ) except Exception as e: self.logger.error( "Could not cache response", error=str(e), cache_key=( resp.rctx.cache_key[:8] + "..." if hasattr(resp.rctx, "cache_key") else "unknown" ), ) else: self.logger.debug( "Response not cached", status_code=resp.r.status_code, save_responses=self.save_responses, ) # And just propagate return resp