NimClient Usage Guide for NeMo Retriever Extraction
The NimClient class provides a unified interface for connecting to and interacting with NVIDIA NIM Microservices.
This documentation demonstrates how to create custom NIM integrations for use in NeMo Retriever extraction pipelines and User Defined Functions (UDFs).
Note
NeMo Retriever extraction is also known as NVIDIA Ingest and nv-ingest.
The NimClient architecture consists of two main components:
- NimClient: The client class that handles communication with NIM endpoints via gRPC or HTTP protocols
- ModelInterface: An abstract base class that defines how to format input data, parse output responses, and process inference results for specific models
For advanced usage patterns, see the existing model interfaces in api/src/nv_ingest_api/internal/primitives/nim/model_interface/.
Quick Start
Basic NimClient Creation
from nv_ingest_api.util.nim import create_inference_client
from nv_ingest_api.internal.primitives.nim import ModelInterface
# Create a custom model interface (see examples below)
model_interface = MyCustomModelInterface()
# Define endpoints (gRPC, HTTP)
endpoints = ("grpc://my-nim-service:8001", "http://my-nim-service:8000")
# Create the client
client = create_inference_client(
endpoints=endpoints,
model_interface=model_interface,
auth_token="your-ngc-api-key", # Optional
infer_protocol="grpc", # Optional: "grpc" or "http"
timeout=120.0, # Optional: request timeout
max_retries=5 # Optional: retry attempts
)
# Perform inference
data = {"input": "your input data"}
results = client.infer(data, model_name="your-model-name")
Using Environment Variables
import os
from nv_ingest_api.util.nim import create_inference_client
# Use environment variables for configuration
auth_token = os.getenv("NGC_API_KEY")
grpc_endpoint = os.getenv("NIM_GRPC_ENDPOINT", "grpc://localhost:8001")
http_endpoint = os.getenv("NIM_HTTP_ENDPOINT", "http://localhost:8000")
client = create_inference_client(
endpoints=(grpc_endpoint, http_endpoint),
model_interface=model_interface,
auth_token=auth_token
)
Creating Custom Model Interfaces
To integrate a new NIM, you need to create a custom ModelInterface subclass that implements the required methods.
Basic Model Interface Template
from typing import Dict, Any, List, Tuple, Optional
import numpy as np
from nv_ingest_api.internal.primitives.nim import ModelInterface
class MyCustomModelInterface(ModelInterface):
"""
Custom model interface for My Custom NIM.
"""
def __init__(self, model_name: str = "my-custom-model"):
"""Initialize the model interface."""
self.model_name = model_name
def name(self) -> str:
"""Return the name of this model interface."""
return "MyCustomModel"
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Prepare and validate input data before formatting.
Parameters
----------
data : dict
Raw input data
Returns
-------
dict
Validated and prepared data
"""
# Validate required fields
if "input_text" not in data:
raise KeyError("Input data must include 'input_text'")
# Ensure input is in the expected format
if not isinstance(data["input_text"], str):
raise ValueError("input_text must be a string")
return data
def format_input(
self,
data: Dict[str, Any],
protocol: str,
max_batch_size: int,
**kwargs
) -> Tuple[List[Any], List[Dict[str, Any]]]:
"""
Format input data for the specified protocol.
Parameters
----------
data : dict
Prepared input data
protocol : str
Communication protocol ("grpc" or "http")
max_batch_size : int
Maximum batch size for processing
**kwargs : dict
Additional parameters
Returns
-------
tuple
(formatted_batches, batch_data_list)
"""
if protocol == "http":
return self._format_http_input(data, max_batch_size, **kwargs)
elif protocol == "grpc":
return self._format_grpc_input(data, max_batch_size, **kwargs)
else:
raise ValueError("Invalid protocol. Must be 'grpc' or 'http'")
def _format_http_input(
self,
data: Dict[str, Any],
max_batch_size: int,
**kwargs
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Format input for HTTP protocol."""
input_text = data["input_text"]
# Create HTTP payload
payload = {
"model": kwargs.get("model_name", self.model_name),
"input": input_text,
"max_tokens": kwargs.get("max_tokens", 512),
"temperature": kwargs.get("temperature", 0.7),
}
# Return as single batch
return [payload], [{"original_input": input_text}]
def _format_grpc_input(
self,
data: Dict[str, Any],
max_batch_size: int,
**kwargs
) -> Tuple[List[np.ndarray], List[Dict[str, Any]]]:
"""Format input for gRPC protocol."""
input_text = data["input_text"]
# Convert to numpy array for gRPC
text_array = np.array([[input_text.encode("utf-8")]], dtype=np.object_)
return [text_array], [{"original_input": input_text}]
def parse_output(
self,
response: Any,
protocol: str,
data: Optional[Dict[str, Any]] = None,
**kwargs
) -> Any:
"""
Parse the raw model response.
Parameters
----------
response : Any
Raw response from the model
protocol : str
Communication protocol used
data : dict, optional
Original batch data
**kwargs : dict
Additional parameters
Returns
-------
Any
Parsed response data
"""
if protocol == "http":
return self._parse_http_response(response)
elif protocol == "grpc":
return self._parse_grpc_response(response)
else:
raise ValueError("Invalid protocol. Must be 'grpc' or 'http'")
def _parse_http_response(self, response: Dict[str, Any]) -> str:
"""Parse HTTP response."""
if isinstance(response, dict):
# Extract the generated text from response
if "choices" in response:
return response["choices"][0].get("text", "")
elif "output" in response:
return response["output"]
else:
raise RuntimeError("Unexpected response format")
return str(response)
def _parse_grpc_response(self, response: np.ndarray) -> str:
"""Parse gRPC response."""
if isinstance(response, np.ndarray):
# Decode bytes response
return response.flatten()[0].decode("utf-8")
return str(response)
def process_inference_results(
self,
output: Any,
protocol: str,
**kwargs
) -> Any:
"""
Post-process the parsed inference results.
Parameters
----------
output : Any
Parsed output from parse_output
protocol : str
Communication protocol used
**kwargs : dict
Additional parameters
Returns
-------
Any
Final processed results
"""
# Apply any final processing (e.g., filtering, formatting)
if isinstance(output, str):
return output.strip()
return output
Real-World Examples
Text Generation Model Interface
class TextGenerationModelInterface(ModelInterface):
"""Interface for text generation NIMs (e.g., LLaMA, GPT-style models)."""
def name(self) -> str:
return "TextGeneration"
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
if "prompt" not in data:
raise KeyError("Input data must include 'prompt'")
return data
def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs):
prompt = data["prompt"]
if protocol == "http":
payload = {
"model": kwargs.get("model_name", "llama-2-7b-chat"),
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", 512),
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"stream": False
}
return [payload], [{"prompt": prompt}]
else:
raise ValueError("Only HTTP protocol supported for this model")
def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs):
if protocol == "http" and isinstance(response, dict):
choices = response.get("choices", [])
if choices:
return choices[0].get("message", {}).get("content", "")
return str(response)
def process_inference_results(self, output: Any, protocol: str, **kwargs):
return output.strip() if isinstance(output, str) else output
Image Analysis Model Interface
import base64
from nv_ingest_api.util.image_processing.transforms import numpy_to_base64
class ImageAnalysisModelInterface(ModelInterface):
"""Interface for image analysis NIMs (e.g., vision models)."""
def name(self) -> str:
return "ImageAnalysis"
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
if "images" not in data:
raise KeyError("Input data must include 'images'")
# Ensure images is a list
if not isinstance(data["images"], list):
data["images"] = [data["images"]]
return data
def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs):
images = data["images"]
prompt = data.get("prompt", "Describe this image.")
# Convert images to base64 if needed
base64_images = []
for img in images:
if isinstance(img, np.ndarray):
base64_images.append(numpy_to_base64(img))
elif isinstance(img, str) and img.startswith("data:image"):
# Already base64 encoded
base64_images.append(img.split(",")[1])
else:
base64_images.append(str(img))
# Batch images
batches = [base64_images[i:i + max_batch_size]
for i in range(0, len(base64_images), max_batch_size)]
payloads = []
batch_data_list = []
for batch in batches:
if protocol == "http":
messages = []
for img_b64 in batch:
messages.append({
"role": "user",
"content": f'{prompt} <img src="data:image/png;base64,{img_b64}" />'
})
payload = {
"model": kwargs.get("model_name", "llava-1.5-7b-hf"),
"messages": messages,
"max_tokens": kwargs.get("max_tokens", 512),
"temperature": kwargs.get("temperature", 0.1)
}
payloads.append(payload)
batch_data_list.append({"images": batch, "prompt": prompt})
return payloads, batch_data_list
def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs):
if protocol == "http" and isinstance(response, dict):
choices = response.get("choices", [])
return [choice.get("message", {}).get("content", "") for choice in choices]
return [str(response)]
def process_inference_results(self, output: Any, protocol: str, **kwargs):
if isinstance(output, list):
return [result.strip() for result in output]
return output
Using NimClient in UDFs
Basic UDF with NimClient
from nv_ingest_api.internal.primitives.control_message import IngestControlMessage
from nv_ingest_api.util.nim import create_inference_client
import os
def analyze_document_with_nim(control_message: IngestControlMessage) -> IngestControlMessage:
"""UDF that uses a custom NIM to analyze document content."""
# Create NIM client
model_interface = TextGenerationModelInterface()
client = create_inference_client(
endpoints=(
os.getenv("ANALYSIS_NIM_GRPC", "grpc://analysis-nim:8001"),
os.getenv("ANALYSIS_NIM_HTTP", "http://analysis-nim:8000")
),
model_interface=model_interface,
auth_token=os.getenv("NGC_API_KEY"),
infer_protocol="http"
)
# Get the document DataFrame
df = control_message.get_payload()
# Process each document
for idx, row in df.iterrows():
if row.get("content"):
# Prepare analysis prompt
prompt = f"Analyze the following document content and provide a summary: {row['content'][:1000]}"
# Perform inference
try:
results = client.infer(
data={"prompt": prompt},
model_name="llama-2-7b-chat",
max_tokens=256,
temperature=0.3
)
# Add analysis to metadata
if results:
analysis = results[0] if isinstance(results, list) else results
df.at[idx, "custom_analysis"] = analysis
except Exception as e:
print(f"NIM inference failed: {e}")
df.at[idx, "custom_analysis"] = "Analysis failed"
# Update the control message with processed data
control_message.payload(df)
return control_message
Advanced UDF with Batching
def batch_image_analysis_udf(control_message: IngestControlMessage) -> IngestControlMessage:
"""UDF that performs batched image analysis using NIM."""
# Create image analysis client
model_interface = ImageAnalysisModelInterface()
client = create_inference_client(
endpoints=(
os.getenv("VISION_NIM_GRPC", "grpc://vision-nim:8001"),
os.getenv("VISION_NIM_HTTP", "http://vision-nim:8000")
),
model_interface=model_interface,
auth_token=os.getenv("NGC_API_KEY")
)
df = control_message.get_payload()
# Collect all images for batch processing
image_rows = []
images = []
for idx, row in df.iterrows():
if "image_data" in row and row["image_data"]:
image_rows.append(idx)
images.append(row["image_data"])
if images:
try:
# Batch process all images
results = client.infer(
data={
"images": images,
"prompt": "Describe the content and key elements in this image."
},
model_name="llava-1.5-7b-hf",
max_tokens=200
)
# Apply results back to DataFrame
for idx, result in zip(image_rows, results):
df.at[idx, "image_description"] = result
except Exception as e:
print(f"Batch image analysis failed: {e}")
for idx in image_rows:
df.at[idx, "image_description"] = "Analysis failed"
control_message.payload(df)
return control_message
Configuration and Best Practices
Environment Variables
Set these environment variables for your NIM endpoints:
# NIM endpoints
export MY_NIM_GRPC_ENDPOINT="grpc://my-nim-service:8001"
export MY_NIM_HTTP_ENDPOINT="http://my-nim-service:8000"
# Authentication
export NGC_API_KEY="your-ngc-api-key"
# Optional: timeouts and retries
export NIM_TIMEOUT=120
export NIM_MAX_RETRIES=5
Performance Optimization
- Use gRPC when possible: Generally faster than HTTP for high-throughput scenarios
- Batch processing: Process multiple items together to reduce overhead
- Connection reuse: Create NimClient instances once and reuse them
- Appropriate timeouts: Set reasonable timeouts based on your model's response time
- Error handling: Always handle inference failures gracefully
Error Handling
def robust_nim_udf(control_message: IngestControlMessage) -> IngestControlMessage:
"""UDF with comprehensive error handling."""
try:
client = create_inference_client(
endpoints=(grpc_endpoint, http_endpoint),
model_interface=model_interface,
auth_token=auth_token,
timeout=60.0,
max_retries=3
)
except Exception as e:
print(f"Failed to create NIM client: {e}")
return control_message
df = control_message.get_payload()
for idx, row in df.iterrows():
try:
results = client.infer(data=input_data, model_name="my-model")
df.at[idx, "nim_result"] = results
except TimeoutError:
print(f"NIM request timed out for row {idx}")
df.at[idx, "nim_result"] = "timeout"
except Exception as e:
print(f"NIM inference failed for row {idx}: {e}")
df.at[idx, "nim_result"] = "error"
control_message.payload(df)
return control_message
Troubleshooting
Common Issues
- Connection Errors: Verify NIM service is running and endpoints are correct
- Authentication Failures: Check NGC_API_KEY is valid and properly set
- Timeout Errors: Increase timeout values or check NIM service performance
- Format Errors: Ensure your ModelInterface formats data correctly for your NIM
- Memory Issues: Use appropriate batch sizes to avoid memory exhaustion
Debugging Tips
import logging
# Enable debug logging
logging.getLogger("nv_ingest_api.internal.primitives.nim").setLevel(logging.DEBUG)
# Test your model interface separately
model_interface = MyCustomModelInterface()
test_data = {"input": "test"}
# Test data preparation
prepared = model_interface.prepare_data_for_inference(test_data)
print(f"Prepared data: {prepared}")
# Test input formatting
formatted, batch_data = model_interface.format_input(prepared, "http", 1)
print(f"Formatted input: {formatted}")