Use the API (gRPC) for NVIDIA NeMo Retriever Embedding NIM#
Use the examples in this documentation to help you get started using the API for NVIDIA NeMo Retriever Embedding NIM.
For the full API reference, refer to API Reference (gRPC).
gRPC Support#
The NeMo Retriever Embedding NIM supports the Open Inference Protocol (KServe V2 Protocol). You can make gRPC inference requests by using the Triton Client Libraries.
Launch the NeMo Retriever Embedding NIM#
Launch NeMo Retriever Embedding NIM by following the Get Started guide.
In the code to Launch the NIM,
include the additional argument -p 8001:8001 as shown following.
# Start the NIM
docker run -it --rm --name=$CONTAINER_NAME \
--runtime=nvidia \
--gpus all \
--shm-size=16GB \
-e NGC_API_KEY \
-v "$LOCAL_NIM_CACHE:/opt/nim/.cache" \
-u $(id -u) \
-p 8000:8000 \
-p 8001:8001 \ # additional argument
$IMG_NAME
Make Inference Calls#
After you launch the NeMo Retriever Embedding NIM, you can make inference calls by using the following code.
Install Python dependencies.
python3 -m pip install tritonclient[all]==2.53.0
Make inference calls.
import os
from dataclasses import dataclass
from enum import StrEnum, auto
from typing import Any, List, Optional, Union
import numpy as np
import tritonclient.grpc.aio as grpcclient
from numpy.typing import NDArray
from pydantic import BaseModel, Field, field_validator, model_validator
SMALL_EMBEDDING_DIMENSION = 2
REQ_TIMEOUT = int(os.environ.get("MAXIMUM_REQUEST_TIMEOUT", "720"))
class InputTypeValue(StrEnum):
QUERY = auto()
PASSAGE = auto()
class TruncateValue(StrEnum):
START = auto()
END = auto()
NONE = auto()
class EmbeddingTypeValue(StrEnum):
FLOAT = auto()
BINARY = auto()
UBINARY = auto()
INT8 = auto()
UINT8 = auto()
class ModalityValue(StrEnum):
TEXT = auto()
IMAGE = auto()
TEXT_IMAGE = auto()
class TritonInputName(StrEnum):
"""Triton inference input names"""
TEXT = auto()
MODALITY = auto()
class TritonInputParameters(BaseModel):
input_type: Optional[InputTypeValue] = None
truncate: TruncateValue = TruncateValue.NONE
dimensions: Optional[int] = Field(
default=None,
ge=SMALL_EMBEDDING_DIMENSION,
description="The number of dimensions the output embeddings should have for models supporting Matryoshka Representation Learning. "
f"To ensure numerical stability, the minimum number of dimensions is {SMALL_EMBEDDING_DIMENSION}.",
)
embedding_type: EmbeddingTypeValue = Field(
default=EmbeddingTypeValue.FLOAT, description="Controls output type of the embeddings."
)
nvcf_asset_dir: Optional[str] = Field(
default=None,
description="Directory path where NVCF (NVIDIA Cloud Functions) asset files are stored. "
"Required when using 'asset_id' format in image data URLs (e.g., 'data:image/png;asset_id,my_image.png'). "
"This path is provided by NVCF in the 'NVCF-ASSET-DIR' header and contains the absolute path to the directory "
"where uploaded asset files can be accessed during function invocation. "
"See https://docs.nvidia.com/cloud-functions/user-guide/latest/cloud-function/assets.html for more details.",
)
@field_validator("truncate", mode="before")
@classmethod
def value_to_lower(cls, v: str) -> str:
return v.lower() if isinstance(v, str) else v
@model_validator(mode="before")
@classmethod
def normalize_keys(cls, values: Any) -> Any:
normalized = {}
for k, v in values.items():
if k.lower().replace("-", "_") == "nvcf_asset_dir":
normalized["nvcf_asset_dir"] = v
else:
normalized[k] = v
return normalized
class TritonOutputName(StrEnum):
"""Triton inference output names"""
EMBEDDINGS = auto()
EMBEDDINGS_INT8 = auto()
EMBEDDINGS_UINT8 = auto()
EMBEDDINGS_BINARY = auto()
EMBEDDINGS_UBINARY = auto()
TOKEN_COUNT = auto()
@dataclass
class TritonEmbeddingResponse:
embeddings: NDArray[Union[np.float32, np.int8, np.uint8]]
token_count: int
class TritonComputeClient:
def __init__(self, endpoint: str = "localhost:8001", verbose: bool = False) -> None:
self._endpoint = endpoint
self._verbose = verbose
self._client: Optional[grpcclient.InferenceServerClient] = None
async def _ensure_client(self) -> None:
"""Ensure client is created and connected"""
if self._client is None:
self._client = grpcclient.InferenceServerClient(
url=self._endpoint, verbose=self._verbose
)
async def compute(
self,
model_name: str,
text_batch: List[str],
*,
input_type: InputTypeValue = InputTypeValue.QUERY,
modality_batch: Optional[List[ModalityValue]] = None,
truncate: TruncateValue = TruncateValue.NONE,
dimensions: Optional[int] = None,
embedding_type: EmbeddingTypeValue = EmbeddingTypeValue.FLOAT,
nvcf_asset_dir: Optional[str] = None,
) -> TritonEmbeddingResponse:
await self._ensure_client()
if self._client is None:
raise RuntimeError("Failed to create Triton client")
# Prepare text input
text_np = np.array([[text.encode("utf-8")] for text in text_batch], dtype=np.object_)
text_in = grpcclient.InferInput(
TritonInputName.TEXT.value, shape=text_np.shape, datatype="BYTES"
)
text_in.set_data_from_numpy(text_np)
infer_inputs = [text_in]
# Prepare modality input if provided
if modality_batch is not None:
modality_np = np.array(
[[modality.encode("utf-8")] for modality in modality_batch], dtype=np.object_
)
modality_in = grpcclient.InferInput(
TritonInputName.MODALITY.value, modality_np.shape, "BYTES"
)
modality_in.set_data_from_numpy(modality_np)
infer_inputs.append(modality_in)
# Make inference request
result = await self._client.infer(
model_name=model_name,
inputs=infer_inputs,
parameters=TritonInputParameters(
input_type=input_type,
truncate=truncate,
dimensions=dimensions,
embedding_type=embedding_type,
nvcf_asset_dir=nvcf_asset_dir,
).model_dump(exclude_none=True),
client_timeout=max(REQ_TIMEOUT - 5, 5),
)
# Extract embeddings and token count
_output_name_map = {
EmbeddingTypeValue.FLOAT: TritonOutputName.EMBEDDINGS,
EmbeddingTypeValue.INT8: TritonOutputName.EMBEDDINGS_INT8,
EmbeddingTypeValue.UINT8: TritonOutputName.EMBEDDINGS_UINT8,
EmbeddingTypeValue.BINARY: TritonOutputName.EMBEDDINGS_BINARY,
EmbeddingTypeValue.UBINARY: TritonOutputName.EMBEDDINGS_UBINARY,
}
embeddings_np = result.as_numpy(_output_name_map[embedding_type])
token_count_np = result.as_numpy(TritonOutputName.TOKEN_COUNT.value)
total_token_count = int(token_count_np.prod()) if token_count_np is not None else 0
return TritonEmbeddingResponse(embeddings=embeddings_np, token_count=total_token_count)
client = TritonComputeClient()
await client.compute(
model_name="nvidia_llama_nemotron_embed_1b_v2",
text_batch=["hi", "hello world"],
dimensions=5,
)
The result should look similar to the following.
TritonEmbeddingResponse(embeddings=array([[-0.5575916 , 0.31149718, 0.683525 , 0.11063856, -0.3355797 ],
[ 0.32987896, 0.1812951 , 0.8195797 , 0.38031602, 0.20484307]],
dtype=float32), token_count=10)