Source code for nv_ingest.framework.util.service.impl.ingest.redis_ingest_service

# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import asyncio
import json
import logging
import os
from json import JSONDecodeError
from typing import Optional, Dict, Any

from typing import List

import redis

from nv_ingest.framework.schemas.framework_message_wrapper_schema import MessageWrapper
from nv_ingest.framework.schemas.framework_processing_job_schema import ProcessingJob
from nv_ingest.framework.util.service.meta.ingest.ingest_service_meta import IngestServiceMeta
from nv_ingest_api.internal.schemas.meta.ingest_job_schema import validate_ingest_job
from nv_ingest_api.util.service_clients.client_base import FetchMode
from nv_ingest_api.util.service_clients.redis.redis_client import RedisClient

logger = logging.getLogger("uvicorn")


[docs] def get_fetch_mode_from_env() -> "FetchMode": """ Retrieves the fetch mode from the environment variable FETCH_MODE. Returns ------- FetchMode The fetch mode as specified by the environment variable, or NON_DESTRUCTIVE by default. """ mode_str: str = os.getenv("FETCH_MODE", "NON_DESTRUCTIVE").upper() try: return FetchMode[mode_str] except KeyError: logger.warning(f"Invalid FETCH_MODE '{mode_str}' in environment. Defaulting to DESTRUCTIVE.") return FetchMode.DESTRUCTIVE
[docs] class RedisIngestService(IngestServiceMeta): """ Submits jobs and fetches results via Redis, supporting multiple fetch modes and state management with TTLs. Operates asynchronously using asyncio.to_thread for synchronous Redis client operations. """ _concurrency_level: int = int(os.getenv("CONCURRENCY_LEVEL", "10")) __shared_instance: Optional["RedisIngestService"] = None
[docs] @staticmethod def get_instance() -> "RedisIngestService": """ Static access method implementing the Singleton pattern. Returns ------- RedisIngestService The singleton instance of the RedisIngestService. """ if RedisIngestService.__shared_instance is None: redis_host: str = os.getenv("MESSAGE_CLIENT_HOST", "localhost") redis_port: int = int(os.getenv("MESSAGE_CLIENT_PORT", "6379")) redis_task_queue: str = os.getenv("REDIS_INGEST_TASK_QUEUE", "ingest_task_queue") fetch_mode: "FetchMode" = get_fetch_mode_from_env() result_data_ttl: int = int(os.getenv("RESULT_DATA_TTL_SECONDS", "3600")) state_ttl: int = int(os.getenv("STATE_TTL_SECONDS", "7200")) cache_config: Dict[str, Any] = { "directory": os.getenv("FETCH_CACHE_DIR", "./.fetch_cache"), "ttl": int(os.getenv("FETCH_CACHE_TTL_SECONDS", "3600")), } use_ssl: bool = os.getenv("REDIS_USE_SSL", "false").lower() == "true" RedisIngestService.__shared_instance = RedisIngestService( redis_hostname=redis_host, redis_port=redis_port, redis_task_queue=redis_task_queue, fetch_mode=fetch_mode, result_data_ttl_seconds=result_data_ttl if result_data_ttl > 0 else None, state_ttl_seconds=state_ttl if state_ttl > 0 else None, cache_config=cache_config, use_ssl=use_ssl, ) logger.debug(f"RedisIngestService configured with FetchMode: {fetch_mode.name}") else: logger.debug("Returning existing RedisIngestService Singleton instance.") return RedisIngestService.__shared_instance
def __init__( self, redis_hostname: str, redis_port: int, redis_task_queue: str, fetch_mode: "FetchMode", result_data_ttl_seconds: Optional[int], state_ttl_seconds: Optional[int], cache_config: Optional[Dict[str, Any]], use_ssl: bool, ) -> None: """ Initializes the service and the underlying RedisClient. Parameters ---------- redis_hostname : str Redis server hostname. redis_port : int Redis server port. redis_task_queue : str The Redis queue name for tasks. fetch_mode : FetchMode The fetch mode configuration. result_data_ttl_seconds : int or None TTL for result data in seconds, or None to disable. state_ttl_seconds : int or None TTL for the job state record, or None to disable. cache_config : dict or None Configuration for caching. use_ssl : bool Whether to use SSL for the Redis connection. """ self._redis_hostname: str = redis_hostname self._redis_port: int = redis_port self._redis_task_queue: str = redis_task_queue self._fetch_mode: "FetchMode" = fetch_mode self._result_data_ttl_seconds: Optional[int] = result_data_ttl_seconds self._state_ttl_seconds: Optional[int] = state_ttl_seconds self._bulk_vdb_cache_prefix: str = "vdb_bulk_upload_cache:" self._cache_prefix: str = "processing_cache:" self._state_prefix: str = "job_state:" self._ingest_client = RedisClient( host=self._redis_hostname, port=self._redis_port, max_pool_size=self._concurrency_level, fetch_mode=self._fetch_mode, cache_config=cache_config, message_ttl_seconds=self._result_data_ttl_seconds, use_ssl=use_ssl, max_retries=int(os.getenv("REDIS_MAX_RETRIES", "3")), max_backoff=int(os.getenv("REDIS_MAX_BACKOFF", "32")), connection_timeout=int(os.getenv("REDIS_CONNECTION_TIMEOUT", "300")), ) logger.debug( f"RedisClient initialized for service. Host: {redis_hostname}:{redis_port}, " f"FetchMode: {fetch_mode.name}, ResultTTL: {result_data_ttl_seconds}, StateTTL: {state_ttl_seconds}" )
[docs] async def submit_job(self, job_spec_wrapper: "MessageWrapper", trace_id: str) -> str: """ Validates, prepares, and submits a job specification to the Redis task queue. Sets result data TTL if configured for NON_DESTRUCTIVE mode. Parameters ---------- job_spec_wrapper : MessageWrapper A wrapper containing the job specification payload. trace_id : str A unique identifier for the job. Returns ------- str The job trace_id. Raises ------ ValueError If the payload is missing or invalid. JSONDecodeError, TypeError For payload parsing errors. RedisError, ConnectionError For Redis-related errors. """ try: json_data = job_spec_wrapper.model_dump(mode="json").get("payload") if not json_data: raise ValueError("MessageWrapper payload is missing or empty.") if isinstance(json_data, str): job_spec = json.loads(json_data) elif isinstance(json_data, dict): job_spec = json_data else: raise TypeError(f"Unexpected payload type: {type(json_data)}") validate_ingest_job(job_spec) job_spec["job_id"] = trace_id tasks = job_spec.get("tasks", []) updated_tasks = [] for task in tasks: task_prop = task.get("task_properties", {}) if hasattr(task_prop, "model_dump") and callable(task_prop.model_dump): task["task_properties"] = task_prop.model_dump(mode="json") elif not isinstance(task_prop, dict): try: task["task_properties"] = dict(task_prop) except (TypeError, ValueError): logger.error(f"Cannot convert task_properties to dict: {task_prop}. Skipping properties.") task["task_properties"] = {} updated_tasks.append(task) job_spec["tasks"] = updated_tasks job_spec_json = json.dumps(job_spec) ttl_for_result: Optional[int] = ( self._result_data_ttl_seconds if self._fetch_mode == FetchMode.NON_DESTRUCTIVE else None ) logger.debug( f"Submitting job {trace_id} to queue '{self._redis_task_queue}' with result TTL: {ttl_for_result}" ) await asyncio.to_thread( self._ingest_client.submit_message, channel_name=self._redis_task_queue, message=job_spec_json, ttl_seconds=ttl_for_result, ) logger.debug(f"Successfully submitted job {trace_id}") return trace_id except (JSONDecodeError, TypeError, ValueError) as err: logger.exception(f"Data validation or parsing error for job {trace_id}: {err}") raise ValueError(f"Invalid job specification: {err}") from err except (redis.RedisError, ConnectionError) as err: logger.exception(f"Redis error submitting job {trace_id}: {err}") raise err except Exception as err: logger.exception(f"Unexpected error submitting job {trace_id}: {err}") raise
[docs] async def fetch_job(self, job_id: str) -> Optional[Dict]: """ Fetches the job result using the configured RedisClient fetch mode and timeout. Executes the synchronous client call asynchronously. Parameters ---------- job_id : str The unique identifier of the job. Returns ------- dict or None The job result message. Raises ------ TimeoutError, RedisError, ConnectionError, ValueError, RuntimeError If the fetch operation fails. """ try: result_channel: str = f"{job_id}" logger.debug(f"Attempting to fetch job result for {job_id} using mode {self._fetch_mode.name}") message = await asyncio.to_thread( self._ingest_client.fetch_message, channel_name=result_channel, timeout=10, ) if message is not None: logger.debug(f"Successfully fetched result for job {job_id}.") return message else: logger.warning(f"fetch_message for {job_id} returned None unexpectedly.") raise TimeoutError("No data found (unexpected None response).") except (TimeoutError, redis.RedisError, ConnectionError, ValueError, RuntimeError) as e: logger.info(f"Fetch operation for job {job_id} did not complete: ({type(e).__name__}) {e}") raise e except Exception as e: logger.exception(f"Unexpected error during async fetch_job for {job_id}: {e}") raise RuntimeError(f"Unexpected error fetching job {job_id}") from e
[docs] async def set_job_state(self, job_id: str, state: str) -> None: """ Sets the explicit state of a job and refreshes its TTL. Parameters ---------- job_id : str The unique identifier of the job. state : str The state to be assigned to the job. Returns ------- None """ state_key: str = f"{self._state_prefix}{job_id}" ttl_to_set: Optional[int] = self._state_ttl_seconds try: logger.debug(f"Setting state for {job_id} to {state} with TTL {ttl_to_set}") await asyncio.to_thread( self._ingest_client.get_client().set, state_key, state, ex=ttl_to_set, ) logger.debug(f"Successfully set state for {job_id}.") except (redis.RedisError, ConnectionError) as err: logger.error(f"Failed to set state for {state_key}: {err}") except Exception as err: logger.exception(f"Unexpected error setting state for {state_key}: {err}")
[docs] async def get_job_state(self, job_id: str) -> Optional[str]: """ Retrieves the explicit state of a job. Parameters ---------- job_id : str The unique identifier of the job. Returns ------- str or None The state of the job, or None if not found or upon error. """ state_key: str = f"{self._state_prefix}{job_id}" try: data_bytes: Optional[bytes] = await asyncio.to_thread(self._ingest_client.get_client().get, state_key) if data_bytes: state: str = data_bytes.decode("utf-8") logger.debug(f"Retrieved state for {job_id}: {state}") return state else: logger.debug(f"No state found for {job_id} (key: {state_key})") return None except (redis.RedisError, ConnectionError) as err: logger.error(f"Redis error getting state for {state_key}: {err}") return None except Exception as err: logger.exception(f"Unexpected error getting state for {state_key}: {err}") return None
[docs] async def set_processing_cache(self, job_id: str, jobs_data: List["ProcessingJob"]) -> None: """ Stores processing jobs data in a simple key-value cache. Parameters ---------- job_id : str The unique identifier of the job. jobs_data : list of ProcessingJob The processing job data to be cached. Returns ------- None """ cache_key: str = f"{self._cache_prefix}{job_id}" try: data_to_store: str = json.dumps([job.model_dump(mode="json") for job in jobs_data]) await asyncio.to_thread( self._ingest_client.get_client().set, cache_key, data_to_store, ex=3600, ) except Exception as err: logger.exception(f"Error setting cache for {cache_key}: {err}")
[docs] async def get_processing_cache(self, job_id: str) -> List["ProcessingJob"]: """ Retrieves processing jobs data from the simple key-value cache. Parameters ---------- job_id : str The unique identifier of the job. Returns ------- list of ProcessingJob A list of processing jobs, or an empty list if not found or upon error. """ cache_key: str = f"{self._cache_prefix}{job_id}" try: data_bytes: Optional[bytes] = await asyncio.to_thread(self._ingest_client.get_client().get, cache_key) if data_bytes is None: return [] return [ProcessingJob(**job) for job in json.loads(data_bytes)] except Exception as err: logger.exception(f"Error getting cache for {cache_key}: {err}") return []
[docs] async def get_fetch_mode(self) -> "FetchMode": """ Returns the configured fetch mode for the service. Returns ------- FetchMode The current fetch mode. """ return self._fetch_mode