# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import pandas as pd
from nv_ingest_api.internal.enums.common import ContentTypeEnum
from nv_ingest_api.internal.primitives.nim import NimClient
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import PaddleOCRModelInterface
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import NemoRetrieverOCRModelInterface
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import get_ocr_model_name
from nv_ingest_api.internal.schemas.extract.extract_ocr_schema import OCRExtractorSchema
from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
from nv_ingest_api.util.nim import create_inference_client
logger = logging.getLogger(__name__)
PADDLE_MIN_WIDTH = 32
PADDLE_MIN_HEIGHT = 32
def _filter_text_images(
base64_images: List[str],
min_width: int = PADDLE_MIN_WIDTH,
min_height: int = PADDLE_MIN_HEIGHT,
) -> Tuple[List[str], List[int], List[Tuple[str, Optional[Any], Optional[Any]]]]:
"""
Filters base64-encoded images based on minimum size requirements.
Parameters
----------
base64_images : List[str]
List of base64-encoded image strings.
Returns
-------
Tuple[List[str], List[int], List[Tuple[str, Optional[Any], Optional[Any]]]]
- valid_images: List of images that meet the size requirements.
- valid_indices: Original indices of valid images.
"""
valid_images: List[str] = []
valid_indices: List[int] = []
for i, img in enumerate(base64_images):
array = base64_to_numpy(img)
height, width = array.shape[0], array.shape[1]
if width >= min_width and height >= min_height:
valid_images.append(img)
valid_indices.append(i)
return valid_images, valid_indices
def _update_text_metadata(
base64_images: List[str],
ocr_client: NimClient,
ocr_model_name: str,
worker_pool_size: int = 8, # Not currently used
trace_info: Optional[Dict] = None,
) -> List[Tuple[str, Optional[Any], Optional[Any]]]:
"""
Filters base64-encoded images and uses OCR to extract text data.
For each image that meets the minimum size, calls ocr_client.infer to obtain
(text_predictions, bounding_boxes). Invalid images are marked as skipped.
Parameters
----------
base64_images : List[str]
List of base64-encoded images.
ocr_client : NimClient
Client instance for OCR inference.
worker_pool_size : int, optional
Worker pool size (currently not used), by default 8.
trace_info : Optional[Dict], optional
Optional trace information for debugging.
Returns
-------
List[Tuple[str, Optional[Any], Optional[Any]]]
List of tuples in the same order as base64_images, where each tuple contains:
(base64_image, text_predictions, bounding_boxes).
"""
logger.debug(f"Running text extraction using protocol {ocr_client.protocol}")
if ocr_model_name == "paddle":
valid_images, valid_indices = _filter_text_images(base64_images)
else:
valid_images, valid_indices = _filter_text_images(base64_images, min_width=1, min_height=1)
data_ocr = {"base64_images": valid_images}
# worker_pool_size is not used in current implementation.
_ = worker_pool_size
infer_kwargs = dict(
stage_name="ocr_extraction",
trace_info=trace_info,
)
if ocr_model_name == "paddle":
infer_kwargs.update(
model_name="paddle",
max_batch_size=1 if ocr_client.protocol == "grpc" else 2,
)
elif ocr_model_name in {"scene_text_ensemble", "scene_text_wrapper", "scene_text_python", "pipeline"}:
infer_kwargs.update(
model_name=ocr_model_name,
input_names=["INPUT_IMAGE_URLS", "MERGE_LEVELS"],
output_names=["OUTPUT"],
dtypes=["BYTES", "BYTES"],
merge_level="paragraph",
)
else:
raise ValueError(f"Unknown OCR model name: {ocr_model_name}")
try:
ocr_results = ocr_client.infer(data_ocr, **infer_kwargs)
except Exception as e:
logger.error(f"Error calling ocr_client.infer: {e}", exc_info=True)
raise
if len(ocr_results) != len(valid_images):
raise ValueError(f"Expected {len(valid_images)} ocr results, got {len(ocr_results)}")
results = [(None, None, None)] * len(base64_images)
for idx, ocr_res in enumerate(ocr_results):
original_index = valid_indices[idx]
results[original_index] = ocr_res
return results
def _create_ocr_client(
ocr_endpoints: Tuple[str, str],
ocr_protocol: str,
ocr_model_name: str,
auth_token: str,
) -> NimClient:
ocr_model_interface = (
NemoRetrieverOCRModelInterface()
if ocr_model_name in {"scene_text_ensemble", "scene_text_wrapper", "scene_text_python", "pipeline"}
else PaddleOCRModelInterface()
)
ocr_client = create_inference_client(
endpoints=ocr_endpoints,
model_interface=ocr_model_interface,
auth_token=auth_token,
infer_protocol=ocr_protocol,
enable_dynamic_batching=(
True
if ocr_model_name in {"scene_text_ensemble", "scene_text_wrapper", "scene_text_python", "pipeline"}
else False
),
dynamic_batch_memory_budget_mb=32,
)
return ocr_client
def _meets_page_elements_text_criteria(row: pd.Series) -> bool:
"""
Determines if a DataFrame row meets the criteria for text extraction.
A row qualifies if:
- It contains a 'metadata' dictionary.
- The 'content_metadata' in metadata has type "text" and one of subtype:
"title", "paragraph", "header_footer".
- The 'content' is not None or an empty string.
Parameters
----------
row : pd.Series
A row from the DataFrame.
Returns
-------
bool
True if the row meets all criteria; False otherwise.
"""
page_element_subtypes = {"paragraph", "title", "header_footer"}
metadata = row.get("metadata", {})
if not metadata:
return False
content_md = metadata.get("content_metadata", {})
if (
content_md.get("type") == ContentTypeEnum.TEXT
and content_md.get("subtype") in page_element_subtypes
and metadata.get("content") not in {None, ""}
):
return True
return False
def _meets_page_image_criteria(row: pd.Series) -> bool:
"""
Determines if a DataFrame row meets the criteria for text extraction.
A row qualifies if:
- It contains a 'metadata' dictionary.
- The 'content_metadata' in metadata has type "image" and subtype "page_image".
- The 'content' is not None or an empty string.
Parameters
----------
row : pd.Series
A row from the DataFrame.
Returns
-------
bool
True if the row meets all criteria; False otherwise.
"""
page_image_subtypes = {ContentTypeEnum.PAGE_IMAGE}
metadata = row.get("metadata", {})
if not metadata:
return False
content_md = metadata.get("content_metadata", {})
if (
content_md.get("type") == ContentTypeEnum.IMAGE
and content_md.get("subtype") in page_image_subtypes
and metadata.get("content") not in {None, ""}
):
return True
return False
def _process_page_images(df_to_process: pd.DataFrame, ocr_results: List[Tuple]):
valid_indices = df_to_process.index.tolist()
for result_idx, df_idx in enumerate(valid_indices):
# Unpack result: (bounding_boxes, text_predictions, confidence_scores)
bboxes, texts, _ = ocr_results[result_idx]
if not bboxes or not texts:
df_to_process.loc[df_idx, "metadata"]["image_metadata"]["text"] = ""
continue
df_to_process.loc[df_idx, "metadata"]["image_metadata"]["text"] = " ".join([t for t in texts])
return df_to_process
def _process_page_elements(df_to_process: pd.DataFrame, ocr_results: List[Tuple]):
valid_indices = df_to_process.index.tolist()
if not valid_indices:
return df_to_process
for result_idx, df_idx in enumerate(valid_indices):
# Preserve the original base64 image before overwriting with OCR text.
# This enables text_image modality for multimodal embeddings.
original_image = df_to_process.loc[df_idx, "metadata"]["content"]
df_to_process.loc[df_idx, "metadata"]["text_metadata"]["source_image"] = original_image
# Unpack result: (bounding_boxes, text_predictions, confidence_scores)
bboxes, texts, _ = ocr_results[result_idx]
if not bboxes or not texts:
df_to_process.loc[df_idx, "_x0"] = None
df_to_process.loc[df_idx, "_y0"] = None
df_to_process.loc[df_idx, "metadata"]["content"] = ""
continue
combined_data = list(zip(bboxes, texts))
try:
# Sort by reading order (y_min, then x_min)
combined_data.sort(key=lambda item: (min(p[1] for p in item[0]), min(p[0] for p in item[0])))
except (ValueError, IndexError):
logger.warning("Could not sort OCR results due to malformed bounding box.")
df_to_process.loc[df_idx, "_x0"] = min(point[0] for item in combined_data for point in item[0])
df_to_process.loc[df_idx, "_y0"] = min(point[1] for item in combined_data for point in item[0])
df_to_process.loc[df_idx, "metadata"]["content"] = " ".join([item[1] for item in combined_data])
df_to_process = df_to_process.drop(["_x0", "_y0"], axis=1)
df_to_process.loc[:, "_page_number"] = df_to_process["metadata"].apply(
lambda meta: meta["content_metadata"]["page_number"]
)
# Group by page number to aggregate all text blocks on each page
grouped = df_to_process.groupby("_page_number")
new_text = {}
for page_num, group_df in grouped:
if group_df.empty:
continue
# Sort text blocks by their original position for correct reading order
group_df.loc[:, "_x0"] = group_df["metadata"].apply(lambda meta: meta["text_metadata"]["text_location"][0])
group_df.loc[:, "_y0"] = group_df["metadata"].apply(lambda meta: meta["text_metadata"]["text_location"][1])
loc_mask = group_df[["_y0", "_x0"]].notna().all(axis=1)
sorted_group = group_df.loc[loc_mask].sort_values(by=["_y0", "_x0"], ascending=[True, True])
page_text = " ".join(sorted_group["metadata"].apply(lambda meta: meta["content"]).tolist())
if page_text.strip():
new_text[page_num] = page_text
df_text = df_to_process[df_to_process["document_type"] == "text"].drop_duplicates(
subset=["_page_number"], keep="first"
)
for page_num, page_text in new_text.items():
page_num_mask = df_text["_page_number"] == page_num
df_text.loc[page_num_mask, "metadata"] = df_text.loc[page_num_mask, "metadata"].apply(
lambda meta: {**meta, "content": page_text}
)
df_non_text = df_to_process[df_to_process["document_type"] != "text"]
df_to_process = pd.concat([df_text, df_non_text])
for col in {"_y0", "_x0", "_page_number"}:
if col in df_to_process:
df_to_process = df_to_process.drop(col, axis=1)
return df_to_process