gRPC Reference for NeMo Retriever Text Embedding NIM#

This documentation contains the gRPC reference for NeMo Retriever Text Embedding NIM.

The Text Embedding NIM supports the Open Inference Protocol (KServe V2). You can make gRPC inference requests by using the Triton Client Libraries.

Launch the Text Embedding NIM#

Launch Text Embedding NIM by following the Get Started guide. In the code to Launch the NIM, include the additional argument -p 8001:8001 as shown following.

# Start the NIM
docker run -it --rm --name=$CONTAINER_NAME \
  --runtime=nvidia \
  --gpus all \
  --shm-size=16GB \
  -e NGC_API_KEY \
  -v "$LOCAL_NIM_CACHE:/opt/nim/.cache" \
  -u $(id -u) \
  -p 8000:8000 \
  -p 8001:8001 \ # additional argument
  $IMG_NAME

Make Inference Calls#

After you launch the Text Embedding NIM, you can make inference calls by using the following code.

Install Python dependencies.

python3 -m pip install tritonclient[all]==2.53.0

Make inference calls.

import os
from dataclasses import dataclass
from enum import StrEnum, auto
from typing import Any, List, Optional, Union

import numpy as np
import tritonclient.grpc.aio as grpcclient
from numpy.typing import NDArray
from pydantic import BaseModel, Field, field_validator, model_validator

SMALL_EMBEDDING_DIMENSION = 2


REQ_TIMEOUT = int(os.environ.get("MAXIMUM_REQUEST_TIMEOUT", "720"))


class InputTypeValue(StrEnum):
    QUERY = auto()
    PASSAGE = auto()


class TruncateValue(StrEnum):
    START = auto()
    END = auto()
    NONE = auto()


class EmbeddingTypeValue(StrEnum):
    FLOAT = auto()
    BINARY = auto()
    UBINARY = auto()
    INT8 = auto()
    UINT8 = auto()


class ModalityValue(StrEnum):
    TEXT = auto()
    IMAGE = auto()
    TEXT_IMAGE = auto()


class TritonInputName(StrEnum):
    """Triton inference input names"""

    TEXT = auto()
    MODALITY = auto()


class TritonInputParameters(BaseModel):
    input_type: Optional[InputTypeValue] = None
    truncate: TruncateValue = TruncateValue.NONE
    dimensions: Optional[int] = Field(
        default=None,
        ge=SMALL_EMBEDDING_DIMENSION,
        description="The number of dimensions the output embeddings should have for models supporting Matryoshka Representation Learning. "
        f"To ensure numerical stability, the minimum number of dimensions is {SMALL_EMBEDDING_DIMENSION}.",
    )
    embedding_type: EmbeddingTypeValue = Field(
        default=EmbeddingTypeValue.FLOAT, description="Controls output type of the embeddings."
    )
    nvcf_asset_dir: Optional[str] = Field(
        default=None,
        description="Directory path where NVCF (NVIDIA Cloud Functions) asset files are stored. "
        "Required when using 'asset_id' format in image data URLs (e.g., 'data:image/png;asset_id,my_image.png'). "
        "This path is provided by NVCF in the 'NVCF-ASSET-DIR' header and contains the absolute path to the directory "
        "where uploaded asset files can be accessed during function invocation. "
        "See https://docs.nvidia.com/cloud-functions/user-guide/latest/cloud-function/assets.html for more details.",
    )

    @field_validator("truncate", mode="before")
    @classmethod
    def value_to_lower(cls, v: str) -> str:
        return v.lower() if isinstance(v, str) else v

    @model_validator(mode="before")
    @classmethod
    def normalize_keys(cls, values: Any) -> Any:
        normalized = {}
        for k, v in values.items():
            if k.lower().replace("-", "_") == "nvcf_asset_dir":
                normalized["nvcf_asset_dir"] = v
            else:
                normalized[k] = v
        return normalized


class TritonOutputName(StrEnum):
    """Triton inference output names"""

    EMBEDDINGS = auto()
    EMBEDDINGS_INT8 = auto()
    EMBEDDINGS_UINT8 = auto()
    EMBEDDINGS_BINARY = auto()
    EMBEDDINGS_UBINARY = auto()
    TOKEN_COUNT = auto()


@dataclass
class TritonEmbeddingResponse:
    embeddings: NDArray[Union[np.float32, np.int8, np.uint8]]
    token_count: int


class TritonComputeClient:
    def __init__(self, endpoint: str = "localhost:8001", verbose: bool = False) -> None:
        self._endpoint = endpoint
        self._verbose = verbose
        self._client: Optional[grpcclient.InferenceServerClient] = None

    async def _ensure_client(self) -> None:
        """Ensure client is created and connected"""
        if self._client is None:
            self._client = grpcclient.InferenceServerClient(
                url=self._endpoint, verbose=self._verbose
            )

    async def compute(
        self,
        model_name: str,
        text_batch: List[str],
        *,
        input_type: InputTypeValue = InputTypeValue.QUERY,
        modality_batch: Optional[List[ModalityValue]] = None,
        truncate: TruncateValue = TruncateValue.NONE,
        dimensions: Optional[int] = None,
        embedding_type: EmbeddingTypeValue = EmbeddingTypeValue.FLOAT,
        nvcf_asset_dir: Optional[str] = None,
    ) -> TritonEmbeddingResponse:
        await self._ensure_client()

        if self._client is None:
            raise RuntimeError("Failed to create Triton client")

        # Prepare text input
        text_np = np.array([[text.encode("utf-8")] for text in text_batch], dtype=np.object_)
        text_in = grpcclient.InferInput(
            TritonInputName.TEXT.value, shape=text_np.shape, datatype="BYTES"
        )
        text_in.set_data_from_numpy(text_np)

        infer_inputs = [text_in]

        # Prepare modality input if provided
        if modality_batch is not None:
            modality_np = np.array(
                [[modality.encode("utf-8")] for modality in modality_batch], dtype=np.object_
            )
            modality_in = grpcclient.InferInput(
                TritonInputName.MODALITY.value, modality_np.shape, "BYTES"
            )
            modality_in.set_data_from_numpy(modality_np)
            infer_inputs.append(modality_in)

        # Make inference request
        result = await self._client.infer(
            model_name=model_name,
            inputs=infer_inputs,
            parameters=TritonInputParameters(
                input_type=input_type,
                truncate=truncate,
                dimensions=dimensions,
                embedding_type=embedding_type,
                nvcf_asset_dir=nvcf_asset_dir,
            ).model_dump(exclude_none=True),
            client_timeout=max(REQ_TIMEOUT - 5, 5),
        )

        # Extract embeddings and token count
        _output_name_map = {
            EmbeddingTypeValue.FLOAT: TritonOutputName.EMBEDDINGS,
            EmbeddingTypeValue.INT8: TritonOutputName.EMBEDDINGS_INT8,
            EmbeddingTypeValue.UINT8: TritonOutputName.EMBEDDINGS_UINT8,
            EmbeddingTypeValue.BINARY: TritonOutputName.EMBEDDINGS_BINARY,
            EmbeddingTypeValue.UBINARY: TritonOutputName.EMBEDDINGS_UBINARY,
        }

        embeddings_np = result.as_numpy(_output_name_map[embedding_type])
        token_count_np = result.as_numpy(TritonOutputName.TOKEN_COUNT.value)
        total_token_count = int(token_count_np.prod()) if token_count_np is not None else 0

        return TritonEmbeddingResponse(embeddings=embeddings_np, token_count=total_token_count)


client = TritonComputeClient()
await client.compute(
    model_name="nvidia_llama_3_2_nv_embedqa_1b_v2", 
    text_batch=["hi", "hello world"],
    dimensions=5,
)

The result should look similar to the following.

TritonEmbeddingResponse(embeddings=array([[-0.5575916 ,  0.31149718,  0.683525  ,  0.11063856, -0.3355797 ],
       [ 0.32987896,  0.1812951 ,  0.8195797 ,  0.38031602,  0.20484307]],
      dtype=float32), token_count=10)

API Reference#

gRPC Models#

The gRPC model names differ from the NIM model IDs shown in the Support Matrix. The following table contains the mapping of the names.

Model ID

gRPC Model Name

nvidia/llama-3.2-nv-embedqa-1b-v2

nvidia_llama_3_2_nv_embedqa_1b_v2

nvidia/nv-embedqa-e5-v5

nvidia_nv_embedqa_e5_v5

Request Inputs#

Input

Shape

Data Type

Description

Required

text

[batch_size, 1]

BYTES

A list of UTF-8 encoded strings to embed. For details on how to encode multimodal data as string, refer to Specify Modality.

Yes

modality

[batch_size, 1]

BYTES

A list of UTF-8 modality strings for each of the text input elements. If you don’t specify modality, the modality is inferred. For supported modalities, refer to Specify Modality.

No

Request Parameters#

Parameter

Data Type

Description

Valid Values

Default

Required

input_type

String

The context of the embedding.

"query", "passage"

"query"

Yes

truncate

String

How to handle text that exceeds the maximum token length.

"END", "START", "NONE"

"NONE"

Yes

dimensions

Integer

The desired dimensionality of the output embeddings. Must be supported by the model.

The model’s default dimension.

No

embedding_type

String

The output type of the embeddings. See OpenAI API Reference for how the output type is handled.

"float", "binary", "ubinary", "int8", "uint8"

"float"

No

nvcf_asset_dir

String

Directory path where NVCF (NVIDIA Cloud Functions) asset files are stored. See OpenAI API Reference for more details.

-

No

Response#

Output

Shape

Data Type

Description

token_count

[batch_size]

INT32

The number of tokens in each input text.

embeddings

[batch_size, embedding_dimension]

Configurable using the embedding_type request parameter. The default is FLOAT32.

The resulting embedding vectors.