Source code for nv_ingest.util.message_brokers.redis.redis_client

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import logging
import time
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import redis
from redis.exceptions import RedisError

from nv_ingest.util.message_brokers.client_base import MessageBrokerClientBase

# pylint: skip-file

logger = logging.getLogger(__name__)


[docs] class RedisClient(MessageBrokerClientBase): """ A client for interfacing with Redis, providing mechanisms for sending and receiving messages with retry logic and connection management. Parameters ---------- host : str The hostname of the Redis server. port : int The port number of the Redis server. db : int, optional The database number to connect to. Default is 0. max_retries : int, optional The maximum number of retry attempts for operations. Default is 0 (no retries). max_backoff : int, optional The maximum backoff delay between retries in seconds. Default is 32 seconds. connection_timeout : int, optional The timeout in seconds for connecting to the Redis server. Default is 300 seconds. max_pool_size : int, optional The maximum number of connections in the Redis connection pool. Default is 128. use_ssl : bool, optional Specifies if SSL should be used for the connection. Default is False. redis_allocator : Any, optional The Redis client allocator, allowing for custom Redis client instances. Default is redis.Redis. Attributes ---------- client : Any The Redis client instance used for operations. """ def __init__( self, host: str, port: int, db: int = 0, max_retries: int = 0, max_backoff: int = 32, connection_timeout: int = 300, max_pool_size: int = 128, use_ssl: bool = False, redis_allocator: Any = redis.Redis, # Type hint as 'Any' due to dynamic nature ): self._host = host self._port = port self._db = db self._max_retries = max_retries self._max_backoff = max_backoff self._connection_timeout = connection_timeout self._use_ssl = use_ssl self._pool = redis.ConnectionPool( host=self._host, port=self._port, db=self._db, socket_connect_timeout=self._connection_timeout, max_connections=max_pool_size, ) self._redis_allocator = redis_allocator self._client = self._redis_allocator(connection_pool=self._pool) self._retries = 0 def _connect(self) -> None: """ Attempts to reconnect to the Redis server if the current connection is not responsive. """ if not self.ping(): logger.debug("Reconnecting to Redis") self._client = self._redis_allocator(connection_pool=self._pool) @property def max_retries(self) -> int: return self._max_retries @max_retries.setter def max_retries(self, value: int) -> None: self._max_retries = value
[docs] def get_client(self) -> Any: """ Returns a Redis client instance, reconnecting if necessary. Returns ------- Any The Redis client instance. """ if self._client is None or not self.ping(): self._connect() return self._client
[docs] def ping(self) -> bool: """ Checks if the Redis server is responsive. Returns ------- bool True if the server responds to a ping, False otherwise. """ try: self._client.ping() return True except (RedisError, AttributeError): return False
def _check_response( self, channel_name: str, timeout: float ) -> Tuple[Optional[Dict[str, Any]], Optional[int], Optional[int]]: """ Checks for a response from the Redis queue and processes it into a message, fragment, and fragment count. Parameters ---------- channel_name : str The name of the Redis channel from which to receive the response. timeout : float The time in seconds to wait for a response from the Redis queue before timing out. Returns ------- Tuple[Optional[Dict[str, Any]], Optional[int], Optional[int]] A tuple containing: - message: A dictionary containing the decoded message if successful, or None if no message was retrieved. - fragment: An integer representing the fragment number of the message, or None if no fragment was found. - fragment_count: An integer representing the total number of message fragments, or None if no fragment count was found. Raises ------ ValueError If the message retrieved from Redis cannot be decoded from JSON. """ response = self.get_client().blpop([channel_name], timeout) if response is None: raise TimeoutError("No response was received in the specified timeout period") if len(response) > 1 and response[1]: try: message = json.loads(response[1]) fragment = message.get("fragment", 0) fragment_count = message.get("fragment_count", 1) return message, fragment, fragment_count except json.JSONDecodeError as e: logger.error(f"Failed to decode message: {e}") raise ValueError(f"Failed to decode message from Redis: {e}") return None, None, None
[docs] def fetch_message(self, channel_name: str, timeout: float = 10) -> Optional[Union[str, Dict]]: """ Fetches a message from the specified queue with retries on failure. If the message is fragmented, it will continue fetching fragments until all parts have been collected. Parameters ---------- channel_name: str Channel to fetch the message from. timeout : float The timeout in seconds for blocking until a message is available. If we receive a multi-part message, this value will be temporarily extended in order to collect all fragments. Returns ------- Optional[str] The full fetched message, or None if no message could be fetched after retries. Raises ------ ValueError If fetching the message fails after the specified number of retries or due to other critical errors. """ accumulated_time = 0 collected_fragments = [] fragment_count = None retries = 0 while True: try: # Attempt to fetch a message from the Redis queue message, fragment, fragment_count = self._check_response(channel_name, timeout) if message is not None: if fragment_count == 1: # No fragmentation, return the message as is return message collected_fragments.append(message) # If we have collected all fragments, combine and return if len(collected_fragments) == fragment_count: # Sort fragments by the 'fragment' field to ensure correct order collected_fragments.sort(key=lambda x: x["fragment"]) # Combine fragments (assuming they are part of a larger payload) reconstructed_message = self._combine_fragments(collected_fragments) return reconstructed_message else: # Return None if the response is empty return message except TimeoutError: # TODO(Devin) Once we start accumulating fragments, we can no longer fully recover from a timeout, so # we should consider this a failure. Look into caching partial results for retries in the future. if fragment_count and fragment_count > 1: accumulated_time += timeout if accumulated_time >= (timeout * fragment_count): err_msg = f"Failed to reconstruct message from {channel_name} after {accumulated_time} sec." logger.error(err_msg) raise ValueError(err_msg) else: raise # This is expected in many cases, so re-raise it except RedisError as err: retries += 1 logger.error(f"Redis error during fetch: {err}") backoff_delay = min(2**retries, self._max_backoff) if self.max_retries > 0 and retries <= self.max_retries: logger.error(f"Fetch attempt failed, retrying in {backoff_delay}s...") time.sleep(backoff_delay) else: logger.error(f"Failed to fetch message from {channel_name} after {retries} attempts.") raise ValueError(f"Failed to fetch message from Redis queue after {retries} attempts: {err}") # Invalidate client to force reconnection on the next try self._client = None except Exception as e: # Handle non-Redis specific exceptions logger.error(f"Unexpected error during fetch from {channel_name}: {e}") raise ValueError(f"Unexpected error during fetch: {e}")
@staticmethod def _combine_fragments(fragments: List[Dict[str, Any]]) -> Dict: """ Combines multiple message fragments into a single message by extending the 'data' elements, retaining the 'status' and 'description' of the first fragment, and removing 'fragment' and 'fragment_counts'. Parameters ---------- fragments : List[Dict[str, Any]] A list of fragments to be combined. Returns ------- str The combined message as a JSON string, containing 'status', 'description', and combined 'data'. """ if not fragments: raise ValueError("Fragments list is empty") # Use 'status' and 'description' from the first fragment combined_message = { "status": fragments[0]["status"], "description": fragments[0]["description"], "data": [], "trace": fragments[0].get("trace", {}), } # Combine the 'data' elements from all fragments for fragment in fragments: combined_message["data"].extend(fragment["data"]) return combined_message
[docs] def submit_message(self, channel_name: str, message: str) -> None: """ Submits a message to a specified Redis queue with retries on failure. Parameters ---------- channel_name : str The name of the queue to submit the message to. message : str The message to submit. Raises ------ RedisError If submitting the message fails after the specified number of retries. """ retries = 0 while True: try: self.get_client().rpush(channel_name, message) logger.debug(f"Message submitted to {channel_name}") break except RedisError as e: logger.error(f"Failed to submit message, retrying... Error: {e}") self._client = None # Invalidate client to force reconnection retries += 1 backoff_delay = min(2**retries, self._max_backoff) if self.max_retries == 0 or retries < self.max_retries: logger.error(f"Submit attempt failed, retrying in {backoff_delay}s...") time.sleep(backoff_delay) else: logger.error(f"Failed to submit message to {channel_name} after {retries} attempts.") raise