Source code for nv_ingest_api.internal.primitives.nim.model_interface.text_embedding

# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Optional, Tuple

from nv_ingest_api.internal.primitives.nim import ModelInterface
import numpy as np


[docs] class EmbeddingModelInterface(ModelInterface): """ An interface for handling inference with an embedding model endpoint. This implementation supports HTTP inference for generating embeddings from text prompts. """
[docs] def name(self) -> str: """ Return the name of this model interface. """ return "Embedding"
[docs] def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Prepare input data for embedding inference. Returns a list of strings representing the text to be embedded. """ if "prompts" not in data: raise KeyError("Input data must include 'prompts'.") if not isinstance(data["prompts"], list): data["prompts"] = [data["prompts"]] return {"prompts": data["prompts"]}
[docs] def format_input( self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs ) -> Tuple[List[Any], List[Dict[str, Any]]]: """ Format the input payload for the embedding endpoint. This method constructs one payload per batch, where each payload includes a list of text prompts. Additionally, it returns batch data that preserves the original order of prompts. Parameters ---------- data : dict The input data containing "prompts" (a list of text prompts). protocol : str Only "http" is supported. max_batch_size : int Maximum number of prompts per payload. kwargs : dict Additional parameters including model_name, encoding_format, input_type, and truncate. Returns ------- tuple A tuple (payloads, batch_data_list) where: - payloads is a list of JSON-serializable payload dictionaries. - batch_data_list is a list of dictionaries containing the key "prompts" corresponding to each batch. """ def chunk_list(lst, chunk_size): lst = lst["prompts"] return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] batches = chunk_list(data, max_batch_size) if protocol == "http": payloads = [] batch_data_list = [] for batch in batches: payload = { "model": kwargs.get("model_name"), "input": batch, "encoding_format": kwargs.get("encoding_format", "float"), "input_type": kwargs.get("input_type", "passage"), "truncate": kwargs.get("truncate", "NONE"), } payloads.append(payload) batch_data_list.append({"prompts": batch}) elif protocol == "grpc": payloads = [] batch_data_list = [] for batch in batches: text_np = np.array([[text.encode("utf-8")] for text in batch], dtype=np.object_) payloads.append(text_np) batch_data_list.append({"prompts": batch}) return payloads, batch_data_list
[docs] def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any: """ Parse the HTTP response from the embedding endpoint. Expects a response structure with a "data" key. Parameters ---------- response : Any The raw HTTP response (assumed to be already decoded as JSON). protocol : str Only "http" is supported. data : dict, optional The original input data. kwargs : dict Additional keyword arguments. Returns ------- list A list of generated embeddings extracted from the response. """ if protocol == "http": if isinstance(response, dict): embeddings = response.get("data") if not embeddings: raise RuntimeError("Unexpected response format: 'data' key is missing or empty.") # Each item in embeddings is expected to have an 'embedding' field. return [item.get("embedding", None) for item in embeddings] else: return [str(response)] elif protocol == "grpc": return [res.flatten() for res in response]
[docs] def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any: """ Process inference results for the embedding model. For this implementation, the output is expected to be a list of embeddings. Returns ------- list The processed list of embeddings. """ return output