# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from nv_ingest.util.image_processing.transforms import numpy_to_base64
from nv_ingest.util.nim.helpers import ModelInterface
ACCEPTED_TEXT_CLASSES = set(
[
"Text",
"Title",
"Section-header",
"List-item",
"TOC",
"Bibliography",
"Formula",
"Page-header",
"Page-footer",
"Caption",
"Footnote",
"Floating-text",
]
)
ACCEPTED_TABLE_CLASSES = set(
[
"Table",
]
)
ACCEPTED_IMAGE_CLASSES = set(
[
"Picture",
]
)
ACCEPTED_CLASSES = ACCEPTED_TEXT_CLASSES | ACCEPTED_TABLE_CLASSES | ACCEPTED_IMAGE_CLASSES
logger = logging.getLogger(__name__)
[docs]
class NemoRetrieverParseModelInterface(ModelInterface):
"""
An interface for handling inference with a NemoRetrieverParse model.
"""
def __init__(self, model_name: str = "nvidia/nemoretriever-parse"):
"""
Initialize the instance with a specified model name.
Parameters
----------
model_name : str, optional
The name of the model to be used, by default "nvidia/nemoretriever-parse".
"""
self.model_name = model_name
[docs]
def name(self) -> str:
"""
Get the name of the model interface.
Returns
-------
str
The name of the model interface.
"""
return "nemoretriever_parse"
[docs]
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Prepare input data for inference by resizing images and storing their original shapes.
Parameters
----------
data : dict
The input data containing a list of images.
Returns
-------
dict
The updated data dictionary with resized images and original image shapes.
"""
return data
[docs]
def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
"""
Parse the output from the model's inference response.
Parameters
----------
response : Any
The response from the model inference.
protocol : str
The protocol used ("grpc" or "http").
data : dict, optional
Additional input data passed to the function.
Returns
-------
Any
The parsed output data.
Raises
------
ValueError
If an invalid protocol is specified.
"""
if protocol == "grpc":
raise ValueError("gRPC protocol is not supported for NemoRetrieverParse.")
elif protocol == "http":
logger.debug("Parsing output from HTTP NemoRetrieverParse model")
return self._extract_content_from_nemoretriever_parse_response(response)
else:
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
[docs]
def process_inference_results(self, output: Any, **kwargs) -> Any:
"""
Process inference results for the NemoRetrieverParse model.
Parameters
----------
output : Any
The raw output from the model.
Returns
-------
Any
The processed inference results.
"""
return output
def _prepare_nemoretriever_parse_payload(self, base64_list: List[str]) -> Dict[str, Any]:
messages = []
for b64_img in base64_list:
messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{b64_img}",
},
}
],
}
)
payload = {
"model": self.model_name,
"messages": messages,
}
return payload
def _extract_content_from_nemoretriever_parse_response(self, json_response: Dict[str, Any]) -> Any:
"""
Extract content from the JSON response of a Deplot HTTP API request.
Parameters
----------
json_response : dict
The JSON response from the Deplot API.
Returns
-------
Any
The extracted content from the response.
Raises
------
RuntimeError
If the response does not contain the expected "choices" key or if it is empty.
"""
if "choices" not in json_response or not json_response["choices"]:
raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.")
tool_call = json_response["choices"][0]["message"]["tool_calls"][0]
return json.loads(tool_call["function"]["arguments"])