gRPC Reference for NeMo Retriever Text Reranking NIM#

This documentation contains the gRPC reference for NeMo Retriever Text Reranking NIM.

The Text Reranking NIM supports the Open Inference Protocol (KServe V2). You can make gRPC inference requests by using the Triton Client Libraries.

Launch the Text Reranking NIM#

Launch Text Reranking NIM by following the Get Started guide. In the code to Launch the NIM, include the additional argument -p 8001:8001 as shown following.

# Start the NIM
docker run -it --rm --name=$CONTAINER_NAME \
  --runtime=nvidia \
  --gpus all \
  --shm-size=16GB \
  -e NGC_API_KEY \
  -v "$LOCAL_NIM_CACHE:/opt/nim/.cache" \
  -u $(id -u) \
  -p 8000:8000 \
  -p 8001:8001 \ # additional argument
  $IMG_NAME

Make Inference Calls#

After you launch the Text Reranking NIM, you can make inference calls by using the following code.

First, install Python dependencies.

python3 -m pip install tritonclient[all]==2.53.0

Next, make inference calls.

from dataclasses import dataclass
from enum import StrEnum, auto
from typing import Optional

import numpy as np
import tritonclient.grpc.aio as grpcclient
from numpy.typing import NDArray
from pydantic import BaseModel, field_validator


class TritonInputName(StrEnum):
    QUERY = auto()
    PASSAGE = auto()


class TruncateValue(StrEnum):
    END = auto()
    NONE = auto()


class TritonInputParameters(BaseModel):
    truncate: Optional[TruncateValue] = None

    @field_validator("truncate", mode="before")
    @classmethod
    def value_to_lower(cls, v: str) -> str:
        if isinstance(v, str):
            return v.lower()
        return v


class TritonOutputName(StrEnum):
    INDEX = auto()
    LOGIT = auto()
    TOKEN_COUNT = auto()


@dataclass
class TritonRankingResponse:
    index: NDArray[np.int32]
    logit: NDArray[np.float32]
    token_count: NDArray[np.int32]


class TritonComputeClient:
    def __init__(self, endpoint: str = "localhost:8001", verbose: bool = False) -> None:
        self._endpoint = endpoint
        self._verbose = verbose

    async def compute(
        self,
        model_name: str,
        query: str,
        passages: list[str],
        truncate: TruncateValue = TruncateValue.NONE,
    ) -> TritonRankingResponse:
        query_np = np.array([query.encode("utf-8")] * len(passages), dtype=object)
        passage_np = np.array([passage.encode("utf-8") for passage in passages], dtype=object)
        query_in = grpcclient.InferInput(
            TritonInputName.QUERY.value, shape=query_np.shape, datatype="BYTES"
        )
        passage_in = grpcclient.InferInput(
            TritonInputName.PASSAGE.value, shape=passage_np.shape, datatype="BYTES"
        )
        query_in.set_data_from_numpy(query_np)
        passage_in.set_data_from_numpy(passage_np)

        # InferenceServerClient is not thread safe
        async with grpcclient.InferenceServerClient(
            url=self._endpoint, verbose=self._verbose
        ) as client:
            result = await client.infer(
                model_name=model_name,
                inputs=[query_in, passage_in],
                parameters=TritonInputParameters(truncate=truncate).model_dump(exclude_none=True),
            )

            result_np = {ee.value: result.as_numpy(ee.value) for ee in TritonOutputName}
            for kk, vv in result_np.items():
                if vv is None:
                    raise ValueError(f"'{kk}' not found in the response")
            return TritonRankingResponse(**result_np)


client = TritonComputeClient()
await client.compute(
    model_name="nvidia_llama_3_2_nv_rerankqa_1b_v2", 
    query="hello",
    passages=["world"], 
    truncate="none",
)

The result should look similar to the following.

TritonRankingResponse(index=array([0], dtype=int32), logit=array([-4.3242188], dtype=float32), token_count=array([8], dtype=int32))

API Reference#

gRPC Models#

The gRPC model names differ from the NIM model IDs that appear in the Support Matrix. The following table contains the mapping of the names.

Model ID

gRPC Model Name

nvidia/llama-3-2-nv-rerankqa-1b-v2

nvidia_llama_3_2_nv_rerankqa_1b_v2

nvidia/llama-3-2-nemoretriever-500m-rerank-v2

nvidia_llama_3_2_nemoretriever_500m_rerank_v2

Request Inputs#

Input

Shape

Data Type

Description

query

[batch_size, 1]

BYTES

The queries for reranking, encoded as UTF-8.

passage

[batch_size, 1]

BYTES

The passages for reranking, encoded as UTF-8.

Request Parameters#

Parameter

Data Type

Description

Valid Values

Default

Required

truncate

String

How to handle text that exceeds the maximum token length.

"END", "NONE"

"NONE"

No

Response#

Output

Shape

Data Type

Description

index

[batch_size, 1]

INT32

The index of the passages in descending order.

logit

[batch_size, 1]

FP32

The logit of the passages.

token_count

[batch_size]

INT32

The number of tokens in each input text.