Use the API (gRPC) for NVIDIA NeMo Retriever Reranking NIM#

Use the examples in this documentation to help you get started using the API for NVIDIA NeMo Retriever Reranking NIM.

For the full API reference, refer to API Reference (gRPC).

gRPC Support#

The NeMo Retriever Reranking NIM supports the Open Inference Protocol (KServe V2). You can make gRPC inference requests by using the Triton Client Libraries.

Launch the NeMo Retriever Reranking NIM#

Launch NeMo Retriever Reranking NIM by following the Get Started guide. In the code to Launch the NIM, include the additional argument -p 8001:8001 as shown following.

# Start the NIM
docker run -it --rm --name=$CONTAINER_NAME \
  --runtime=nvidia \
  --gpus all \
  --shm-size=16GB \
  -e NGC_API_KEY \
  -v "$LOCAL_NIM_CACHE:/opt/nim/.cache" \
  -u $(id -u) \
  -p 8000:8000 \
  -p 8001:8001 \ # additional argument
  $IMG_NAME

Make Inference Calls#

After you launch the NeMo Retriever Reranking NIM, you can make inference calls by using the following code.

First, install Python dependencies.

python3 -m pip install tritonclient[all]==2.53.0

Next, make inference calls.

from dataclasses import dataclass
from enum import StrEnum, auto
from typing import Optional

import numpy as np
import tritonclient.grpc.aio as grpcclient
from numpy.typing import NDArray
from pydantic import BaseModel, field_validator


class TritonInputName(StrEnum):
    QUERY = auto()
    PASSAGE = auto()


class TruncateValue(StrEnum):
    END = auto()
    NONE = auto()


class TritonInputParameters(BaseModel):
    truncate: Optional[TruncateValue] = None

    @field_validator("truncate", mode="before")
    @classmethod
    def value_to_lower(cls, v: str) -> str:
        if isinstance(v, str):
            return v.lower()
        return v


class TritonOutputName(StrEnum):
    INDEX = auto()
    LOGIT = auto()
    TOKEN_COUNT = auto()


@dataclass
class TritonRankingResponse:
    index: NDArray[np.int32]
    logit: NDArray[np.float32]
    token_count: NDArray[np.int32]


class TritonComputeClient:
    def __init__(self, endpoint: str = "localhost:8001", verbose: bool = False) -> None:
        self._endpoint = endpoint
        self._verbose = verbose

    async def compute(
        self,
        model_name: str,
        query: str,
        passages: list[str],
        truncate: TruncateValue = TruncateValue.NONE,
    ) -> TritonRankingResponse:
        query_np = np.array([query.encode("utf-8")] * len(passages), dtype=object)
        passage_np = np.array([passage.encode("utf-8") for passage in passages], dtype=object)
        query_in = grpcclient.InferInput(
            TritonInputName.QUERY.value, shape=query_np.shape, datatype="BYTES"
        )
        passage_in = grpcclient.InferInput(
            TritonInputName.PASSAGE.value, shape=passage_np.shape, datatype="BYTES"
        )
        query_in.set_data_from_numpy(query_np)
        passage_in.set_data_from_numpy(passage_np)

        # InferenceServerClient is not thread safe
        async with grpcclient.InferenceServerClient(
            url=self._endpoint, verbose=self._verbose
        ) as client:
            result = await client.infer(
                model_name=model_name,
                inputs=[query_in, passage_in],
                parameters=TritonInputParameters(truncate=truncate).model_dump(exclude_none=True),
            )

            result_np = {ee.value: result.as_numpy(ee.value) for ee in TritonOutputName}
            for kk, vv in result_np.items():
                if vv is None:
                    raise ValueError(f"'{kk}' not found in the response")
            return TritonRankingResponse(**result_np)


client = TritonComputeClient()
await client.compute(
    model_name="nvidia_llama_nemotron_rerank_1b_v2", 
    query="hello",
    passages=["world"], 
    truncate="none",
)

The result should look similar to the following.

TritonRankingResponse(index=array([0], dtype=int32), logit=array([-4.3242188], dtype=float32), token_count=array([8], dtype=int32))