Source code for nv_ingest_client.util.transport
import re
import logging
from nv_ingest_api.internal.primitives.nim import NimClient
from nv_ingest_api.internal.primitives.nim.model_interface.text_embedding import EmbeddingModelInterface
logger = logging.getLogger(__name__)
[docs]
def infer_microservice(
data,
model_name: str = None,
embedding_endpoint: str = None,
nvidia_api_key: str = None,
input_type: str = "passage",
truncate: str = "END",
batch_size: int = 8191,
grpc: bool = False,
):
"""
This function takes the input data and creates a list of embeddings
using the NVIDIA embedding microservice.
"""
data = {"prompts": [res["metadata"]["content"] for res in data]}
if grpc:
model_name = re.sub(r"[^a-zA-Z0-9]", "_", model_name)
client = NimClient(
model_interface=EmbeddingModelInterface(),
protocol="grpc",
endpoints=(embedding_endpoint, None),
auth_token=nvidia_api_key,
)
return client.infer(
data,
model_name,
parameters={"input_type": input_type, "truncate": truncate},
outputs=["embeddings"],
dtype=["BYTES"],
input_name=["text"],
batch_size=batch_size,
)
else:
embedding_endpoint = f"{embedding_endpoint}/embeddings"
client = NimClient(
model_interface=EmbeddingModelInterface(),
protocol="http",
endpoints=(None, embedding_endpoint),
auth_token=nvidia_api_key,
)
return client.infer(data, model_name, input_type=input_type, truncate=truncate, batch_size=batch_size)