Source code for nv_ingest_api.util.nim

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Tuple, Optional
import re
import logging

from nv_ingest_api.internal.primitives.nim import NimClient
from nv_ingest_api.internal.primitives.nim.model_interface.text_embedding import EmbeddingModelInterface
from nv_ingest_api.internal.primitives.nim.nim_client import get_nim_client_manager
from nv_ingest_api.internal.primitives.nim.nim_model_interface import ModelInterface

logger = logging.getLogger(__name__)


__all__ = ["create_inference_client", "infer_microservice"]

_VALID_INFER_PROTOCOLS = ("grpc", "http", "local")


@dataclass(frozen=True)
class LocalInferenceClient:
    """
    Small adapter that mimics NimClient's `infer()` surface, but runs locally.

    It delegates to `model_interface.infer(data, **kwargs)`.
    """

    model_interface: ModelInterface

    @property
    def protocol(self) -> str:
        return "local"

    def infer(self, data: dict, model_name: Optional[str] = None, **kwargs) -> Any:
        # `model_name` is accepted for API compatibility but may be unused by local implementations.
        infer_fn = getattr(self.model_interface, "infer", None)
        if not callable(infer_fn):
            raise RuntimeError(
                f"Local inference requested but model_interface '{self.model_interface.name()}' "
                "does not implement infer()."
            )
        return infer_fn(data, **kwargs)

    def close(self) -> None:
        return None


[docs] def create_inference_client( endpoints: Tuple[str, str], model_interface: ModelInterface, auth_token: Optional[str] = None, infer_protocol: Optional[str] = None, timeout: float = 120.0, max_retries: int = 10, **kwargs, ) -> Any: """ Create a NimClientManager for interfacing with a model inference server. Parameters ---------- endpoints : tuple A tuple containing the gRPC and HTTP endpoints. model_interface : ModelInterface The model interface implementation to use. auth_token : str, optional Authorization token for HTTP requests (default: None). infer_protocol : str, optional The protocol to use ("grpc" or "http"). If not specified, it is inferred from the endpoints. timeout : float, optional The timeout for the request in seconds (default: 120.0). max_retries : int, optional The maximum number of retries for the request (default: 10). **kwargs : dict, optional Additional keyword arguments to pass to the NimClientManager. Returns ------- NimClientManager The initialized NimClientManager. Raises ------ ValueError If an invalid infer_protocol is specified. """ grpc_endpoint, http_endpoint = endpoints infer_protocol = (infer_protocol or "").strip().lower() or None # Auto-infer protocol from endpoints if not specified. if infer_protocol is None: if grpc_endpoint and str(grpc_endpoint).strip(): infer_protocol = "grpc" elif http_endpoint and str(http_endpoint).strip(): infer_protocol = "http" else: infer_protocol = "local" if infer_protocol not in _VALID_INFER_PROTOCOLS: raise ValueError("Invalid infer_protocol specified. Must be 'grpc', 'http', or 'local'.") if infer_protocol == "local": # If the interface has a backend attribute, force it to local. try: if hasattr(model_interface, "backend"): setattr(model_interface, "backend", "local") except Exception: pass return LocalInferenceClient(model_interface=model_interface) manager = get_nim_client_manager() client = manager.get_client( model_interface=model_interface, protocol=infer_protocol, endpoints=endpoints, auth_token=auth_token, timeout=timeout, max_retries=max_retries, **kwargs, ) return client
[docs] def infer_microservice( data, model_name: str = None, embedding_endpoint: str = None, nvidia_api_key: str = None, input_type: str = "passage", truncate: str = "END", batch_size: int = 8191, grpc: bool = False, input_names: list = ["text"], output_names: list = ["embeddings"], dtypes: list = ["BYTES"], ): """ This function takes the input data and creates a list of embeddings using the NVIDIA embedding microservice. Parameters ---------- data : list The input data to be embedded. model_name : str The name of the model to use. embedding_endpoint : str The endpoint of the embedding microservice. nvidia_api_key : str The API key for the NVIDIA embedding microservice. input_type : str The type of input to be embedded. truncate : str The truncation of the input data. batch_size : int The batch size of the input data. grpc : bool Whether to use gRPC or HTTP. input_names : list The names of the input data. output_names : list The names of the output data. dtypes : list The data types of the input data. Returns ------- list The list of embeddings. """ if isinstance(data[0], str): data = {"prompts": data} else: data = {"prompts": [res["metadata"]["content"] for res in data]} if grpc: model_name = re.sub(r"[^a-zA-Z0-9]", "_", model_name) client = NimClient( model_interface=EmbeddingModelInterface(), protocol="grpc", endpoints=(embedding_endpoint, None), auth_token=nvidia_api_key, ) return client.infer( data, model_name, parameters={"input_type": input_type, "truncate": truncate}, dtypes=dtypes, input_names=input_names, batch_size=batch_size, output_names=output_names, ) else: embedding_endpoint = f"{embedding_endpoint}/embeddings" client = NimClient( model_interface=EmbeddingModelInterface(), protocol="http", endpoints=(None, embedding_endpoint), auth_token=nvidia_api_key, ) return client.infer(data, model_name, input_type=input_type, truncate=truncate, batch_size=batch_size)