Use the API (gRPC) for NVIDIA NeMo Retriever Embedding NIM#

Use the examples in this documentation to help you get started using the API for NVIDIA NeMo Retriever Embedding NIM.

For the full API reference, refer to API Reference (gRPC).

gRPC Support#

The NeMo Retriever Embedding NIM supports the Open Inference Protocol (KServe V2 Protocol). You can make gRPC inference requests by using the KServe V2 protocol buffers with a gRPC client.

Launch the NeMo Retriever Embedding NIM#

Launch NeMo Retriever Embedding NIM by following the Get Started guide. In the code to launch the NIM, include the additional argument -p 8001:8001 as shown following.

# Start the NIM
docker run -it --rm --name=$CONTAINER_NAME \
  --runtime=nvidia \
  --gpus all \
  --shm-size=16GB \
  -e NGC_API_KEY \
  -v "$LOCAL_NIM_CACHE/cache:/opt/cache" \
  -v "$LOCAL_NIM_CACHE/weights:/model" \
  -u $(id -u) \
  -p 8000:8000 \
  -p 8001:8001 \ # additional argument
  $IMG_NAME

Make Inference Calls#

After you launch the NeMo Retriever Embedding NIM, you can make inference calls by using the following code.

Install Python dependencies and generate the KServe V2 gRPC Python stubs.

python3 -m pip install grpcio grpcio-tools numpy
curl -LO https://raw.githubusercontent.com/kserve/open-inference-protocol/main/specification/protocol/open_inference_grpc.proto
python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. open_inference_grpc.proto

Make inference calls.

import os
from dataclasses import dataclass
from enum import StrEnum, auto
from typing import Optional

import grpc
import numpy as np
from numpy.typing import NDArray

import open_inference_grpc_pb2 as oip
import open_inference_grpc_pb2_grpc as oip_grpc

SMALL_EMBEDDING_DIMENSION = 2


REQ_TIMEOUT = int(os.environ.get("MAXIMUM_REQUEST_TIMEOUT", "720"))


class InputTypeValue(StrEnum):
    QUERY = auto()
    PASSAGE = auto()


class TruncateValue(StrEnum):
    START = auto()
    END = auto()
    NONE = auto()


class EmbeddingTypeValue(StrEnum):
    FLOAT = auto()
    BINARY = auto()
    UBINARY = auto()
    INT8 = auto()
    UINT8 = auto()


class ModalityValue(StrEnum):
    TEXT = auto()
    IMAGE = auto()
    TEXT_IMAGE = auto()


class InputName(StrEnum):
    TEXT = auto()
    MODALITY = auto()


class OutputName(StrEnum):
    EMBEDDINGS = auto()
    EMBEDDINGS_INT8 = auto()
    EMBEDDINGS_UINT8 = auto()
    EMBEDDINGS_BINARY = auto()
    EMBEDDINGS_UBINARY = auto()
    TOKEN_COUNT = auto()


@dataclass
class EmbeddingResponse:
    embeddings: NDArray
    token_count: int


def _string_param(value: StrEnum | str) -> oip.InferParameter:
    return oip.InferParameter(string_param=value.value if isinstance(value, StrEnum) else value)


def _bytes_tensor(name: str, values: list[str]) -> oip.ModelInferRequest.InferInputTensor:
    encoded_values = [value.encode("utf-8") for value in values]
    return oip.ModelInferRequest.InferInputTensor(
        name=name,
        datatype="BYTES",
        shape=[len(encoded_values), 1],
        contents=oip.InferTensorContents(bytes_contents=encoded_values),
    )


def _numpy_dtype(datatype: str) -> np.dtype:
    datatype_map = {
        "FP32": np.float32,
        "FP64": np.float64,
        "INT8": np.int8,
        "INT16": np.int16,
        "INT32": np.int32,
        "INT64": np.int64,
        "UINT8": np.uint8,
        "UINT16": np.uint16,
        "UINT32": np.uint32,
        "UINT64": np.uint64,
    }
    return np.dtype(datatype_map[datatype.upper()])


def _output_to_numpy(response: oip.ModelInferResponse, name: str) -> NDArray:
    for index, output in enumerate(response.outputs):
        if output.name != name:
            continue

        shape = tuple(output.shape)
        raw_output = (
            response.raw_output_contents[index]
            if index < len(response.raw_output_contents)
            else None
        )
        if raw_output:
            return np.frombuffer(raw_output, dtype=_numpy_dtype(output.datatype)).reshape(shape)

        datatype = output.datatype.upper()
        contents = output.contents
        if datatype == "FP32":
            values = contents.fp32_contents
            dtype = np.float32
        elif datatype == "FP64":
            values = contents.fp64_contents
            dtype = np.float64
        elif datatype in {"INT8", "INT16", "INT32"}:
            values = contents.int_contents
            dtype = _numpy_dtype(datatype)
        elif datatype == "INT64":
            values = contents.int64_contents
            dtype = np.int64
        elif datatype in {"UINT8", "UINT16", "UINT32"}:
            values = contents.uint_contents
            dtype = _numpy_dtype(datatype)
        elif datatype == "UINT64":
            values = contents.uint64_contents
            dtype = np.uint64
        elif datatype == "BYTES":
            values = contents.bytes_contents
            dtype = object
        else:
            raise ValueError(f"Unsupported output datatype: {output.datatype}")
        return np.array(values, dtype=dtype).reshape(shape)

    raise ValueError(f"'{name}' not found in the response")


class EmbeddingClient:
    def __init__(self, endpoint: str = "localhost:8001") -> None:
        self._endpoint = endpoint

    async def compute(
        self,
        model_name: str,
        text_batch: list[str],
        *,
        input_type: InputTypeValue = InputTypeValue.QUERY,
        modality_batch: Optional[list[ModalityValue | str]] = None,
        truncate: TruncateValue = TruncateValue.NONE,
        dimensions: Optional[int] = None,
        embedding_type: EmbeddingTypeValue | str = EmbeddingTypeValue.FLOAT,
        nvcf_asset_dir: Optional[str] = None,
    ) -> EmbeddingResponse:
        if dimensions is not None and dimensions < SMALL_EMBEDDING_DIMENSION:
            raise ValueError(f"dimensions must be at least {SMALL_EMBEDDING_DIMENSION}")

        inputs = [_bytes_tensor(InputName.TEXT.value, text_batch)]
        if modality_batch is not None:
            modality_values = [
                modality.value if isinstance(modality, StrEnum) else modality
                for modality in modality_batch
            ]
            inputs.append(_bytes_tensor(InputName.MODALITY.value, modality_values))

        parameters = {
            "input_type": _string_param(input_type),
            "truncate": _string_param(truncate),
            "embedding_type": _string_param(embedding_type),
        }
        if dimensions is not None:
            parameters["dimensions"] = oip.InferParameter(int64_param=dimensions)
        if nvcf_asset_dir is not None:
            parameters["nvcf_asset_dir"] = oip.InferParameter(string_param=nvcf_asset_dir)

        request = oip.ModelInferRequest(
            model_name=model_name,
            inputs=inputs,
            parameters=parameters,
        )

        async with grpc.aio.insecure_channel(self._endpoint) as channel:
            client = oip_grpc.GRPCInferenceServiceStub(channel)
            response = await client.ModelInfer(request, timeout=max(REQ_TIMEOUT - 5, 5))

        embedding_type_value = (
            embedding_type.value if isinstance(embedding_type, StrEnum) else embedding_type
        )
        output_name = {
            EmbeddingTypeValue.FLOAT.value: OutputName.EMBEDDINGS.value,
            EmbeddingTypeValue.INT8.value: OutputName.EMBEDDINGS_INT8.value,
            EmbeddingTypeValue.UINT8.value: OutputName.EMBEDDINGS_UINT8.value,
            EmbeddingTypeValue.BINARY.value: OutputName.EMBEDDINGS_BINARY.value,
            EmbeddingTypeValue.UBINARY.value: OutputName.EMBEDDINGS_UBINARY.value,
        }[embedding_type_value]

        embeddings = _output_to_numpy(response, output_name)
        token_count = _output_to_numpy(response, OutputName.TOKEN_COUNT.value)
        total_token_count = int(token_count.sum()) if token_count.size else 0

        return EmbeddingResponse(embeddings=embeddings, token_count=total_token_count)


client = EmbeddingClient()
await client.compute(
    model_name="nvidia_llama_nemotron_embed_1b_v2",
    text_batch=["hi", "hello world"],
    dimensions=5,
)

The result should look similar to the following.

EmbeddingResponse(embeddings=array([[-0.5575916 ,  0.31149718,  0.683525  ,  0.11063856, -0.3355797 ],
       [ 0.32987896,  0.1812951 ,  0.8195797 ,  0.38031602,  0.20484307]],
      dtype=float32), token_count=10)