Source code for nv_ingest.util.nim.deplot
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, Any, Optional, List
import numpy as np
import logging
from nv_ingest.util.image_processing.transforms import base64_to_numpy
from nv_ingest.util.nim.helpers import ModelInterface
logger = logging.getLogger(__name__)
[docs]
class DeplotModelInterface(ModelInterface):
"""
An interface for handling inference with a Deplot model, supporting both gRPC and HTTP protocols,
now updated to handle multiple base64 images ('base64_images').
"""
[docs]
def name(self) -> str:
"""
Get the name of the model interface.
Returns
-------
str
The name of the model interface ("Deplot").
"""
return "Deplot"
[docs]
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Prepare input data by decoding one or more base64-encoded images into NumPy arrays.
Parameters
----------
data : dict
The input data containing either 'base64_image' (single image)
or 'base64_images' (multiple images).
Returns
-------
dict
The updated data dictionary with 'image_arrays': a list of decoded NumPy arrays.
"""
# Handle a single base64_image or multiple base64_images
if "base64_images" in data:
base64_list = data["base64_images"]
if not isinstance(base64_list, list):
raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
image_arrays = [base64_to_numpy(b64) for b64 in base64_list]
elif "base64_image" in data:
# Fallback for single image
image_arrays = [base64_to_numpy(data["base64_image"])]
else:
raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
data["image_arrays"] = image_arrays
return data
[docs]
def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
"""
Format input data for the specified protocol (gRPC or HTTP) for Deplot.
For HTTP, we now construct multiple messages—one per image batch—along with
corresponding batch data carrying the original image arrays and their dimensions.
Parameters
----------
data : dict of str -> Any
The input data dictionary, expected to contain "image_arrays" (a list of np.ndarray).
protocol : str
The protocol to use, "grpc" or "http".
max_batch_size : int
The maximum number of images per batch.
kwargs : dict
Additional parameters to pass to the payload preparation (for HTTP).
Returns
-------
tuple
(formatted_batches, formatted_batch_data) where:
- For gRPC: formatted_batches is a list of NumPy arrays, each of shape (B, H, W, C)
with B <= max_batch_size.
- For HTTP: formatted_batches is a list of JSON-serializable payload dicts.
- In both cases, formatted_batch_data is a list of dicts containing:
"image_arrays": the list of original np.ndarray images for that batch, and
"image_dims": a list of (height, width) tuples for each image in the batch.
Raises
------
KeyError
If "image_arrays" is missing in the data dictionary.
ValueError
If the protocol is invalid, or if no valid images are found.
"""
if "image_arrays" not in data:
raise KeyError("Expected 'image_arrays' in data. Call prepare_data_for_inference first.")
image_arrays = data["image_arrays"]
# Compute image dimensions from each image array.
image_dims = [(img.shape[0], img.shape[1]) for img in image_arrays]
# Helper function: chunk a list into sublists of length <= chunk_size.
def chunk_list(lst: list, chunk_size: int) -> List[list]:
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
if protocol == "grpc":
logger.debug("Formatting input for gRPC Deplot model (potentially batched).")
processed = []
for arr in image_arrays:
# Ensure each image has shape (1, H, W, C)
if arr.ndim == 3:
arr = np.expand_dims(arr, axis=0)
arr = arr.astype(np.float32)
arr /= 255.0 # Normalize to [0,1]
processed.append(arr)
if not processed:
raise ValueError("No valid images found for gRPC formatting.")
formatted_batches = []
formatted_batch_data = []
proc_chunks = chunk_list(processed, max_batch_size)
orig_chunks = chunk_list(image_arrays, max_batch_size)
dims_chunks = chunk_list(image_dims, max_batch_size)
for proc_chunk, orig_chunk, dims_chunk in zip(proc_chunks, orig_chunks, dims_chunks):
# Concatenate along the batch dimension to form a single input.
batched_input = np.concatenate(proc_chunk, axis=0)
formatted_batches.append(batched_input)
formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
return formatted_batches, formatted_batch_data
elif protocol == "http":
logger.debug("Formatting input for HTTP Deplot model (multiple messages).")
if "base64_images" in data:
base64_list = data["base64_images"]
else:
base64_list = [data["base64_image"]]
formatted_batches = []
formatted_batch_data = []
b64_chunks = chunk_list(base64_list, max_batch_size)
orig_chunks = chunk_list(image_arrays, max_batch_size)
dims_chunks = chunk_list(image_dims, max_batch_size)
for b64_chunk, orig_chunk, dims_chunk in zip(b64_chunks, orig_chunks, dims_chunks):
payload = self._prepare_deplot_payload(
base64_list=b64_chunk,
max_tokens=kwargs.get("max_tokens", 500),
temperature=kwargs.get("temperature", 0.5),
top_p=kwargs.get("top_p", 0.9),
)
formatted_batches.append(payload)
formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
return formatted_batches, formatted_batch_data
else:
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
[docs]
def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
"""
Parse the model's inference response.
"""
if protocol == "grpc":
logger.debug("Parsing output from gRPC Deplot model (batched).")
# Each batch element might be returned as a list of bytes. Combine or keep separate as needed.
results = []
for item in response:
# If item is [b'...'], decode and join
if isinstance(item, list):
joined_str = " ".join(o.decode("utf-8") for o in item)
results.append(joined_str)
else:
# single bytes or str
val = item.decode("utf-8") if isinstance(item, bytes) else str(item)
results.append(val)
return results # Return a list of strings, one per image.
elif protocol == "http":
logger.debug("Parsing output from HTTP Deplot model.")
return self._extract_content_from_deplot_response(response)
else:
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
[docs]
def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
"""
Process inference results for the Deplot model.
Parameters
----------
output : Any
The raw output from the model.
protocol : str
The protocol used for inference (gRPC or HTTP).
Returns
-------
Any
The processed inference results.
"""
# For Deplot, the output is the chart content as a string
return output
@staticmethod
def _prepare_deplot_payload(
base64_list: list,
max_tokens: int = 500,
temperature: float = 0.5,
top_p: float = 0.9,
) -> Dict[str, Any]:
"""
Prepare an HTTP payload for Deplot that includes one message per image,
matching the original single-image style:
messages = [
{
"role": "user",
"content": "Generate ... <img src=\"data:image/png;base64,...\" />"
},
{
"role": "user",
"content": "Generate ... <img src=\"data:image/png;base64,...\" />"
},
...
]
If your backend expects multiple messages in a single request, this keeps
the same structure as the single-image code repeated N times.
"""
messages = []
# Note: deplot NIM currently only supports a single message per request
for b64_img in base64_list:
messages.append(
{
"role": "user",
"content": (
"Generate the underlying data table of the figure below: "
f'<img src="data:image/png;base64,{b64_img}" />'
),
}
)
payload = {
"model": "google/deplot",
"messages": messages, # multiple user messages now
"max_tokens": max_tokens,
"stream": False,
"temperature": temperature,
"top_p": top_p,
}
return payload
@staticmethod
def _extract_content_from_deplot_response(json_response: Dict[str, Any]) -> Any:
"""
Extract content from the JSON response of a Deplot HTTP API request.
The original code expected a single choice with a single textual content.
"""
if "choices" not in json_response or not json_response["choices"]:
raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.")
# If the service only returns one textual result, we return that one.
return json_response["choices"][0]["message"]["content"]