gRPC Reference for NeMo Retriever Text Reranking NIM#
This documentation contains the gRPC reference for NeMo Retriever Text Reranking NIM.
The Text Reranking NIM supports the Open Inference Protocol (KServe V2). You can make gRPC inference requests by using the Triton Client Libraries.
Launch the Text Reranking NIM#
Launch Text Reranking NIM by following the Get Started guide.
In the code to Launch the NIM,
include the additional argument -p 8001:8001
as shown following.
# Start the NIM
docker run -it --rm --name=$CONTAINER_NAME \
--runtime=nvidia \
--gpus all \
--shm-size=16GB \
-e NGC_API_KEY \
-v "$LOCAL_NIM_CACHE:/opt/nim/.cache" \
-u $(id -u) \
-p 8000:8000 \
-p 8001:8001 \ # additional argument
$IMG_NAME
Make Inference Calls#
After you launch the Text Reranking NIM, you can make inference calls by using the following code.
First, install Python dependencies.
python3 -m pip install tritonclient[all]==2.53.0
Next, make inference calls.
from dataclasses import dataclass
from enum import StrEnum, auto
from typing import Optional
import numpy as np
import tritonclient.grpc.aio as grpcclient
from numpy.typing import NDArray
from pydantic import BaseModel, field_validator
class TritonInputName(StrEnum):
QUERY = auto()
PASSAGE = auto()
class TruncateValue(StrEnum):
END = auto()
NONE = auto()
class TritonInputParameters(BaseModel):
truncate: Optional[TruncateValue] = None
@field_validator("truncate", mode="before")
@classmethod
def value_to_lower(cls, v: str) -> str:
if isinstance(v, str):
return v.lower()
return v
class TritonOutputName(StrEnum):
INDEX = auto()
LOGIT = auto()
TOKEN_COUNT = auto()
@dataclass
class TritonRankingResponse:
index: NDArray[np.int32]
logit: NDArray[np.float32]
token_count: NDArray[np.int32]
class TritonComputeClient:
def __init__(self, endpoint: str = "localhost:8001", verbose: bool = False) -> None:
self._endpoint = endpoint
self._verbose = verbose
async def compute(
self,
model_name: str,
query: str,
passages: list[str],
truncate: TruncateValue = TruncateValue.NONE,
) -> TritonRankingResponse:
query_np = np.array([query.encode("utf-8")] * len(passages), dtype=object)
passage_np = np.array([passage.encode("utf-8") for passage in passages], dtype=object)
query_in = grpcclient.InferInput(
TritonInputName.QUERY.value, shape=query_np.shape, datatype="BYTES"
)
passage_in = grpcclient.InferInput(
TritonInputName.PASSAGE.value, shape=passage_np.shape, datatype="BYTES"
)
query_in.set_data_from_numpy(query_np)
passage_in.set_data_from_numpy(passage_np)
# InferenceServerClient is not thread safe
async with grpcclient.InferenceServerClient(
url=self._endpoint, verbose=self._verbose
) as client:
result = await client.infer(
model_name=model_name,
inputs=[query_in, passage_in],
parameters=TritonInputParameters(truncate=truncate).model_dump(exclude_none=True),
)
result_np = {ee.value: result.as_numpy(ee.value) for ee in TritonOutputName}
for kk, vv in result_np.items():
if vv is None:
raise ValueError(f"'{kk}' not found in the response")
return TritonRankingResponse(**result_np)
client = TritonComputeClient()
await client.compute(
model_name="nvidia_llama_3_2_nv_rerankqa_1b_v2",
query="hello",
passages=["world"],
truncate="none",
)
The result should look similar to the following.
TritonRankingResponse(index=array([0], dtype=int32), logit=array([-4.3242188], dtype=float32), token_count=array([8], dtype=int32))
API Reference#
gRPC Models#
The gRPC model names differ from the NIM model IDs that appear in the Support Matrix. The following table contains the mapping of the names.
Model ID |
gRPC Model Name |
---|---|
nvidia/llama-3-2-nv-rerankqa-1b-v2 |
nvidia_llama_3_2_nv_rerankqa_1b_v2 |
nvidia/llama-3-2-nemoretriever-500m-rerank-v2 |
nvidia_llama_3_2_nemoretriever_500m_rerank_v2 |
Request Inputs#
Input |
Shape |
Data Type |
Description |
---|---|---|---|
|
[batch_size, 1] |
BYTES |
The queries for reranking, encoded as UTF-8. |
|
[batch_size, 1] |
BYTES |
The passages for reranking, encoded as UTF-8. |
Request Parameters#
Parameter |
Data Type |
Description |
Valid Values |
Default |
Required |
---|---|---|---|---|---|
|
String |
How to handle text that exceeds the maximum token length. |
|
|
No |
Response#
Output |
Shape |
Data Type |
Description |
---|---|---|---|
|
[batch_size, 1] |
INT32 |
The index of the passages in descending order. |
|
[batch_size, 1] |
FP32 |
The logit of the passages. |
|
[batch_size] |
INT32 |
The number of tokens in each input text. |