Source code for nv_ingest_api.internal.primitives.nim.nim_client

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Optional
from typing import Tuple

import numpy as np
import requests
import tritonclient.grpc as grpcclient

from nv_ingest_api.internal.primitives.tracing.tagging import traceable_func
from nv_ingest_api.util.string_processing import generate_url

logger = logging.getLogger(__name__)


[docs] class NimClient: """ A client for interfacing with a model inference server using gRPC or HTTP protocols. """ def __init__( self, model_interface, protocol: str, endpoints: Tuple[str, str], auth_token: Optional[str] = None, timeout: float = 120.0, max_retries: int = 5, ): """ Initialize the NimClient with the specified model interface, protocol, and server endpoints. Parameters ---------- model_interface : ModelInterface The model interface implementation to use. protocol : str The protocol to use ("grpc" or "http"). endpoints : tuple A tuple containing the gRPC and HTTP endpoints. auth_token : str, optional Authorization token for HTTP requests (default: None). timeout : float, optional Timeout for HTTP requests in seconds (default: 30.0). Raises ------ ValueError If an invalid protocol is specified or if required endpoints are missing. """ self.client = None self.model_interface = model_interface self.protocol = protocol.lower() self.auth_token = auth_token self.timeout = timeout # Timeout for HTTP requests self.max_retries = max_retries self._grpc_endpoint, self._http_endpoint = endpoints self._max_batch_sizes = {} self._lock = threading.Lock() if self.protocol == "grpc": if not self._grpc_endpoint: raise ValueError("gRPC endpoint must be provided for gRPC protocol") logger.debug(f"Creating gRPC client with {self._grpc_endpoint}") self.client = grpcclient.InferenceServerClient(url=self._grpc_endpoint) elif self.protocol == "http": if not self._http_endpoint: raise ValueError("HTTP endpoint must be provided for HTTP protocol") logger.debug(f"Creating HTTP client with {self._http_endpoint}") self.endpoint_url = generate_url(self._http_endpoint) self.headers = {"accept": "application/json", "content-type": "application/json"} if self.auth_token: self.headers["Authorization"] = f"Bearer {self.auth_token}" else: raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int: """Fetch the maximum batch size from the Triton model configuration in a thread-safe manner.""" if model_name in self._max_batch_sizes: return self._max_batch_sizes[model_name] with self._lock: # Double check, just in case another thread set the value while we were waiting if model_name in self._max_batch_sizes: return self._max_batch_sizes[model_name] if not self._grpc_endpoint: self._max_batch_sizes[model_name] = 1 return 1 try: client = self.client if self.client else grpcclient.InferenceServerClient(url=self._grpc_endpoint) model_config = client.get_model_config(model_name=model_name, model_version=model_version) self._max_batch_sizes[model_name] = model_config.config.max_batch_size logger.debug(f"Max batch size for model '{model_name}': {self._max_batch_sizes[model_name]}") except Exception as e: self._max_batch_sizes[model_name] = 1 logger.warning(f"Failed to retrieve max batch size: {e}, defaulting to 1") return self._max_batch_sizes[model_name] def _process_batch(self, batch_input, *, batch_data, model_name, **kwargs): """ Process a single batch input for inference using its corresponding batch_data. Parameters ---------- batch_input : Any The input data for this batch. batch_data : Any The corresponding scratch-pad data for this batch as returned by format_input. model_name : str The model name for inference. kwargs : dict Additional parameters. Returns ------- tuple A tuple (parsed_output, batch_data) for subsequent post-processing. """ if self.protocol == "grpc": logger.debug("Performing gRPC inference for a batch...") response = self._grpc_infer(batch_input, model_name, **kwargs) logger.debug("gRPC inference received response for a batch") elif self.protocol == "http": logger.debug("Performing HTTP inference for a batch...") response = self._http_infer(batch_input) logger.debug("HTTP inference received response for a batch") else: raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") parsed_output = self.model_interface.parse_output(response, protocol=self.protocol, data=batch_data, **kwargs) return parsed_output, batch_data
[docs] def try_set_max_batch_size(self, model_name, model_version: str = ""): """Attempt to set the max batch size for the model if it is not already set, ensuring thread safety.""" self._fetch_max_batch_size(model_name, model_version)
[docs] @traceable_func(trace_name="{stage_name}::{model_name}") def infer(self, data: dict, model_name: str, **kwargs) -> Any: """ Perform inference using the specified model and input data. Parameters ---------- data : dict The input data for inference. model_name : str The model name. kwargs : dict Additional parameters for inference. Returns ------- Any The processed inference results, coalesced in the same order as the input images. """ try: # 1. Retrieve or default to the model's maximum batch size. batch_size = self._fetch_max_batch_size(model_name) max_requested_batch_size = kwargs.get("max_batch_size", batch_size) force_requested_batch_size = kwargs.get("force_max_batch_size", False) max_batch_size = ( min(batch_size, max_requested_batch_size) if not force_requested_batch_size else max_requested_batch_size ) # 2. Prepare data for inference. data = self.model_interface.prepare_data_for_inference(data) # 3. Format the input based on protocol. formatted_batches, formatted_batch_data = self.model_interface.format_input( data, protocol=self.protocol, max_batch_size=max_batch_size, model_name=model_name ) # Check for a custom maximum pool worker count, and remove it from kwargs. max_pool_workers = kwargs.pop("max_pool_workers", 16) # 4. Process each batch concurrently using a thread pool. # We enumerate the batches so that we can later reassemble results in order. results = [None] * len(formatted_batches) with ThreadPoolExecutor(max_workers=max_pool_workers) as executor: futures = [] for idx, (batch, batch_data) in enumerate(zip(formatted_batches, formatted_batch_data)): future = executor.submit( self._process_batch, batch, batch_data=batch_data, model_name=model_name, **kwargs ) futures.append((idx, future)) for idx, future in futures: results[idx] = future.result() # 5. Process the parsed outputs for each batch using its corresponding batch_data. # As the batches are in order, we coalesce their outputs accordingly. all_results = [] for parsed_output, batch_data in results: batch_results = self.model_interface.process_inference_results( parsed_output, original_image_shapes=batch_data.get("original_image_shapes"), protocol=self.protocol, **kwargs, ) if isinstance(batch_results, list): all_results.extend(batch_results) else: all_results.append(batch_results) except Exception as err: error_str = f"Error during NimClient inference [{self.model_interface.name()}, {self.protocol}]: {err}" logger.error(error_str) raise RuntimeError(error_str) return all_results
def _grpc_infer(self, formatted_input: np.ndarray, model_name: str, **kwargs) -> np.ndarray: """ Perform inference using the gRPC protocol. Parameters ---------- formatted_input : np.ndarray The input data formatted as a numpy array. model_name : str The name of the model to use for inference. Returns ------- np.ndarray The output of the model as a numpy array. """ parameters = kwargs.get("parameters", {}) output_names = kwargs.get("outputs", ["output"]) dtype = kwargs.get("dtype", "FP32") input_name = kwargs.get("input_name", "input") input_tensors = grpcclient.InferInput(input_name, formatted_input.shape, datatype=dtype) input_tensors.set_data_from_numpy(formatted_input) outputs = [grpcclient.InferRequestedOutput(output_name) for output_name in output_names] response = self.client.infer( model_name=model_name, parameters=parameters, inputs=[input_tensors], outputs=outputs ) logger.debug(f"gRPC inference response: {response}") if len(outputs) == 1: return response.as_numpy(outputs[0].name()) else: return [response.as_numpy(output.name()) for output in outputs] def _http_infer(self, formatted_input: dict) -> dict: """ Perform inference using the HTTP protocol, retrying for timeouts or 5xx errors up to 5 times. Parameters ---------- formatted_input : dict The input data formatted as a dictionary. Returns ------- dict The output of the model as a dictionary. Raises ------ TimeoutError If the HTTP request times out repeatedly, up to the max retries. requests.RequestException For other HTTP-related errors that persist after max retries. """ base_delay = 2.0 attempt = 0 while attempt < self.max_retries: try: response = requests.post( self.endpoint_url, json=formatted_input, headers=self.headers, timeout=self.timeout ) status_code = response.status_code # Check for server-side or rate-limit type errors # e.g. 5xx => server error, 429 => too many requests if status_code == 429 or status_code == 503 or (500 <= status_code < 600): logger.warning( f"Received HTTP {status_code} ({response.reason}) from " f"{self.model_interface.name()}. Attempt {attempt + 1} of {self.max_retries}." ) if attempt == self.max_retries - 1: # No more retries left logger.error(f"Max retries exceeded after receiving HTTP {status_code}.") response.raise_for_status() # raise the appropriate HTTPError else: # Exponential backoff backoff_time = base_delay * (2**attempt) time.sleep(backoff_time) attempt += 1 continue else: # Not in our "retry" category => just raise_for_status or return response.raise_for_status() logger.debug(f"HTTP inference response: {response.json()}") return response.json() except requests.Timeout: # Treat timeouts similarly to 5xx => attempt a retry logger.warning( f"HTTP request timed out after {self.timeout} seconds during {self.model_interface.name()} " f"inference. Attempt {attempt + 1} of {self.max_retries}." ) if attempt == self.max_retries - 1: logger.error("Max retries exceeded after repeated timeouts.") raise TimeoutError( f"Repeated timeouts for {self.model_interface.name()} after {attempt + 1} attempts." ) # Exponential backoff backoff_time = base_delay * (2**attempt) time.sleep(backoff_time) attempt += 1 except requests.HTTPError as http_err: # If we ended up here, it's a non-retryable 4xx or final 5xx after final attempt logger.error(f"HTTP request failed with status code {response.status_code}: {http_err}") raise except requests.RequestException as e: # ConnectionError or other non-HTTPError logger.error(f"HTTP request encountered a network issue: {e}") if attempt == self.max_retries - 1: raise # Else retry on next loop iteration backoff_time = base_delay * (2**attempt) time.sleep(backoff_time) attempt += 1 # If we exit the loop without returning, we've exhausted all attempts logger.error(f"Failed to get a successful response after {self.max_retries} retries.") raise Exception(f"Failed to get a successful response after {self.max_retries} retries.")
[docs] def close(self): if self.protocol == "grpc" and hasattr(self.client, "close"): self.client.close()