# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Table/chart/infographic content reconstruction utilities.
Ports bbox-matching and content-reconstruction algorithms from
``nemo_retriever.api.util.image_processing.table_and_chart`` and adds adapter
functions that convert the retriever's detection/OCR formats into the
pixel-coordinate representations expected by the core joining routines.
"""
from __future__ import annotations
import logging
import re
from typing import Any, Dict, List, Optional, Sequence, Tuple # noqa: F401
import numpy as np
import pandas as pd
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Core algorithms ported from `nemo_retriever.api`
# ---------------------------------------------------------------------------
[docs]
def match_bboxes(
yolox_box: np.ndarray,
ocr_boxes: np.ndarray,
already_matched: Optional[list] = None,
delta: float = 2.0,
) -> np.ndarray:
"""Union-based IoU matching for chart graphic elements."""
x0_1, y0_1, x1_1, y1_1 = yolox_box
x0_2, y0_2, x1_2, y1_2 = (
ocr_boxes[:, 0],
ocr_boxes[:, 1],
ocr_boxes[:, 2],
ocr_boxes[:, 3],
)
inter_y0 = np.maximum(y0_1, y0_2)
inter_y1 = np.minimum(y1_1, y1_2)
inter_x0 = np.maximum(x0_1, x0_2)
inter_x1 = np.minimum(x1_1, x1_2)
inter_area = np.maximum(0, inter_y1 - inter_y0) * np.maximum(0, inter_x1 - inter_x0)
area_1 = (y1_1 - y0_1) * (x1_1 - x0_1)
area_2 = (y1_2 - y0_2) * (x1_2 - x0_2)
union_area = area_1 + area_2 - inter_area
ious = inter_area / union_area
max_iou = np.max(ious)
if max_iou <= 0.01:
return []
matches = np.where(ious > (max_iou / delta))[0]
if already_matched is not None:
matches = np.array([m for m in matches if m not in already_matched])
return matches
[docs]
def assign_boxes(
ocr_box: np.ndarray,
boxes: np.ndarray,
delta: float = 2.0,
min_overlap: float = 0.25,
) -> np.ndarray:
"""Area-normalized overlap matching for table structure."""
if not len(boxes):
return []
boxes = np.array(boxes)
x0_1, y0_1, x1_1, y1_1 = ocr_box
x0_2, y0_2, x1_2, y1_2 = (
boxes[:, 0],
boxes[:, 1],
boxes[:, 2],
boxes[:, 3],
)
inter_y0 = np.maximum(y0_1, y0_2)
inter_y1 = np.minimum(y1_1, y1_2)
inter_x0 = np.maximum(x0_1, x0_2)
inter_x1 = np.minimum(x1_1, x1_2)
inter_area = np.maximum(0, inter_y1 - inter_y0) * np.maximum(0, inter_x1 - inter_x0)
area_1 = (y1_1 - y0_1) * (x1_1 - x0_1)
ious = inter_area / (area_1 + 1e-6)
max_iou = np.max(ious)
if max_iou <= min_overlap:
return []
n = len(np.where(ious >= (max_iou / delta))[0])
matches = np.argsort(-ious)[:n]
return matches
def _join_yolox_graphic_elements_and_ocr_output(
yolox_output: Dict[str, list],
ocr_boxes: np.ndarray,
ocr_txts: list,
) -> Dict[str, str]:
"""Match graphic-element detections to OCR text via IoU."""
KEPT_CLASSES = [
"chart_title",
"x_title",
"y_title",
"xlabel",
"ylabel",
"other",
"legend_label",
"legend_title",
"mark_label",
"value_label",
]
ocr_txts = np.array(ocr_txts)
ocr_boxes = np.array(ocr_boxes)
if ocr_txts.size == 0 or ocr_boxes.size == 0:
return {}
# Convert quadrilateral (N,4,2) → xyxy (N,4).
ocr_boxes = np.array(
[
ocr_boxes[:, :, 0].min(-1),
ocr_boxes[:, :, 1].min(-1),
ocr_boxes[:, :, 0].max(-1),
ocr_boxes[:, :, 1].max(-1),
]
).T
already_matched: list = []
results: Dict[str, str] = {}
for k in KEPT_CLASSES:
if not len(yolox_output.get(k, [])):
continue
texts = []
for yolox_box in yolox_output[k]:
yolox_box = yolox_box[:4]
ocr_ids = match_bboxes(yolox_box, ocr_boxes, already_matched=already_matched, delta=4)
if len(ocr_ids) > 0:
text = " ".join(ocr_txts[ocr_ids].tolist())
texts.append(text)
processed_texts = []
for t in texts:
t = re.sub(r"\s+", " ", t)
t = re.sub(r"\.+", ".", t)
processed_texts.append(t)
if "title" in k:
processed_texts = " ".join(processed_texts)
else:
processed_texts = " - ".join(processed_texts)
results[k] = processed_texts
return results
[docs]
def process_yolox_graphic_elements(yolox_text_dict: Dict[str, str]) -> str:
"""Concatenate chart text by semantic region."""
chart_content = ""
chart_content += yolox_text_dict.get("chart_title", "")
chart_content += " " + yolox_text_dict.get("caption", "")
chart_content += " " + yolox_text_dict.get("x_title", "")
chart_content += " " + yolox_text_dict.get("xlabel", "")
chart_content += " " + yolox_text_dict.get("y_title", "")
chart_content += " " + yolox_text_dict.get("ylabel", "")
chart_content += " " + yolox_text_dict.get("legend_label", "")
chart_content += " " + yolox_text_dict.get("legend_title", "")
chart_content += " " + yolox_text_dict.get("mark_label", "")
chart_content += " " + yolox_text_dict.get("value_label", "")
chart_content += " " + yolox_text_dict.get("other", "")
return chart_content.strip()
def _join_yolox_table_structure_and_ocr_output(
yolox_cell_preds: Dict[str, np.ndarray],
ocr_boxes: list,
ocr_txts: list,
) -> str:
"""Combine table-structure cell/row/column predictions with OCR text."""
if not ocr_boxes or not ocr_txts:
return ""
# Sort rows top-to-bottom and columns left-to-right so that
# assign_boxes indices correspond to spatial positions.
if yolox_cell_preds["row"].shape[0] > 0:
yolox_cell_preds["row"] = yolox_cell_preds["row"][yolox_cell_preds["row"][:, 1].argsort()]
if yolox_cell_preds["column"].shape[0] > 0:
yolox_cell_preds["column"] = yolox_cell_preds["column"][yolox_cell_preds["column"][:, 0].argsort()]
ocr_boxes = np.array(ocr_boxes)
ocr_boxes_ = np.array(
[
ocr_boxes[:, :, 0].min(-1),
ocr_boxes[:, :, 1].min(-1),
ocr_boxes[:, :, 0].max(-1),
ocr_boxes[:, :, 1].max(-1),
]
).T
assignments = []
for i, (b, t) in enumerate(zip(ocr_boxes_, ocr_txts)):
matches_cell = assign_boxes(b, yolox_cell_preds["cell"], delta=1)
cell = yolox_cell_preds["cell"][matches_cell[0]] if len(matches_cell) else b
matches_row = assign_boxes(cell, yolox_cell_preds["row"], delta=1)
row_ids = matches_row if len(matches_row) else -1
if isinstance(row_ids, np.ndarray):
delta = 2 if row_ids.min() == 0 else 1
else:
delta = 1
matches_col = assign_boxes(cell, yolox_cell_preds["column"], delta=delta)
col_ids = matches_col if len(matches_col) else -1
assignments.append(
{
"index": i,
"ocr_box": b,
"is_table": isinstance(col_ids, np.ndarray) and isinstance(row_ids, np.ndarray),
"cell_id": matches_cell[0] if len(matches_cell) else -1,
"cell": cell,
"col_ids": col_ids,
"row_ids": row_ids,
"text": t,
}
)
df_assign = pd.DataFrame(assignments)
dfs = []
for cell_id, df_cell in df_assign.groupby("cell_id"):
if len(df_cell) > 1 and cell_id > -1:
df_cell = merge_text_in_cell(df_cell)
dfs.append(df_cell)
df_assign = pd.concat(dfs)
df_text = df_assign[~df_assign["is_table"]].reset_index(drop=True)
df_table = df_assign[df_assign["is_table"]].reset_index(drop=True)
if len(df_table):
mat = build_markdown(df_table)
markdown_table = display_markdown(mat, use_header=True)
all_boxes = np.stack(df_table.ocr_box.values)
table_box = np.concatenate([all_boxes[:, [0, 1]].min(0), all_boxes[:, [2, 3]].max(0)])
df_table_to_text = pd.DataFrame(
[
{
"ocr_box": table_box,
"text": markdown_table,
"is_table": True,
}
]
)
df_text = pd.concat([df_text, df_table_to_text], ignore_index=True)
df_text = df_text.rename(columns={"ocr_box": "box"})
df_text["x"] = df_text["box"].apply(lambda x: (x[0] + x[2]) / 2)
df_text["y"] = df_text["box"].apply(lambda x: (x[1] + x[3]) / 2)
df_text["x"] = (df_text["x"] - df_text["x"].min()) // 10
df_text["y"] = (df_text["y"] - df_text["y"].min()) // 20
df_text = df_text.sort_values(["y", "x"], ignore_index=True)
rows_list = []
for r, df_row in df_text.groupby("y"):
if df_row["is_table"].values.any():
table = df_row[df_row["is_table"]]
df_row = df_row[~df_row["is_table"]]
else:
table = None
if len(df_row) > 1:
df_row = df_row.reset_index(drop=True)
df_row["text"] = "\n".join(df_row["text"].values.tolist())
rows_list.append(df_row.head(1))
if table is not None:
rows_list.append(table)
df_display = pd.concat(rows_list, ignore_index=True)
result = "\n".join(df_display.text.values.tolist())
return result
[docs]
def build_markdown(df: pd.DataFrame) -> list:
"""Convert a dataframe with row_ids/col_ids/text into a markdown matrix."""
df = df.reset_index(drop=True)
n_cols = max([np.max(c) for c in df["col_ids"].values])
n_rows = max([np.max(c) for c in df["row_ids"].values])
mat = np.empty((n_rows + 1, n_cols + 1), dtype=str).tolist()
for i in range(len(df)):
if isinstance(df["row_ids"][i], int) or isinstance(df["col_ids"][i], int):
continue
for r in df["row_ids"][i]:
for c in df["col_ids"][i]:
mat[r][c] = (mat[r][c] + " " + df["text"][i]).strip()
mat = remove_empty_row(mat)
mat = np.array(remove_empty_row(np.array(mat).T.tolist())).T.tolist()
return mat
[docs]
def display_markdown(data: list, use_header: bool = False) -> str:
"""Convert a list-of-lists into a markdown table string."""
if not len(data):
return "EMPTY TABLE"
max_cols = max(len(row) for row in data)
data = [row + [""] * (max_cols - len(row)) for row in data]
if use_header:
header = "| " + " | ".join(data[0]) + " |"
separator = "| " + " | ".join(["---"] * max_cols) + " |"
body = "\n".join("| " + " | ".join(row) + " |" for row in data[1:])
markdown_table = f"{header}\n{separator}\n{body}" if body else f"{header}\n{separator}"
else:
markdown_table = "\n".join("| " + " | ".join(row) + " |" for row in data)
return markdown_table
[docs]
def merge_text_in_cell(df_cell: pd.DataFrame) -> pd.DataFrame:
"""Merge text from multiple OCR items inside one table cell."""
ocr_boxes = np.stack(df_cell["ocr_box"].values)
df_cell = df_cell.copy()
df_cell["x"] = (ocr_boxes[:, 0] - ocr_boxes[:, 0].min()) // 10
df_cell["y"] = (ocr_boxes[:, 1] - ocr_boxes[:, 1].min()) // 10
df_cell = df_cell.sort_values(["y", "x"])
text = " ".join(df_cell["text"].values.tolist())
df_cell["text"] = text
df_cell = df_cell.head(1)
df_cell["ocr_box"] = df_cell["cell"]
df_cell = df_cell.drop(["x", "y"], axis=1)
return df_cell
[docs]
def remove_empty_row(mat: list) -> list:
"""Remove empty rows from a matrix."""
mat_filter = []
for row in mat:
if max([len(c) for c in row]):
mat_filter.append(row)
return mat_filter
[docs]
def reorder_boxes(
boxes: np.ndarray,
texts: list,
confs: list,
mode: str = "top_left",
dbscan_eps: float = 10,
) -> Tuple[list, list, list]:
"""Reorder OCR boxes in reading order using DBSCAN clustering."""
df = pd.DataFrame(
[[b, t, c] for b, t, c in zip(boxes, texts, confs)],
columns=["bbox", "text", "conf"],
)
if mode == "center":
df["x"] = df["bbox"].apply(lambda box: (box[0][0] + box[2][0]) / 2)
df["y"] = df["bbox"].apply(lambda box: (box[0][1] + box[2][1]) / 2)
elif mode == "top_left":
df["x"] = df["bbox"].apply(lambda box: box[0][0])
df["y"] = df["bbox"].apply(lambda box: box[0][1])
if dbscan_eps:
do_naive_sorting = False
try:
from sklearn.cluster import DBSCAN
dbscan = DBSCAN(eps=dbscan_eps, min_samples=1)
dbscan.fit(df["y"].values[:, None])
df["cluster"] = dbscan.labels_
df["cluster_centers"] = df.groupby("cluster")["y"].transform("mean").astype(int)
df = df.sort_values(["cluster_centers", "x"], ascending=[True, True], ignore_index=True)
except (ImportError, ValueError):
do_naive_sorting = True
else:
do_naive_sorting = True
if do_naive_sorting:
df["y"] = np.round((df["y"] - df["y"].min()) // 5, 0)
df = df.sort_values(["y", "x"], ascending=[True, True], ignore_index=True)
bboxes = df["bbox"].values.tolist()
texts = df["text"].values.tolist()
confs = df["conf"].values.tolist()
return bboxes, texts, confs
# ---------------------------------------------------------------------------
# Adapter functions (retriever formats → `nemo_retriever.api` formats)
# ---------------------------------------------------------------------------
def _normalize_ocr_items(preds: Any) -> List[Dict[str, Any]]:
"""Normalize any OCR output format to ``[{"left", "right", "upper", "lower", "text"}, ...]``.
Handles both list-of-dict (Nemotron OCR normalized-coord form) and
dict with ``boxes``/``texts`` keys (packed form).
"""
items: List[Dict[str, Any]] = []
if isinstance(preds, list):
for item in preds:
if not isinstance(item, dict):
continue
if not all(k in item for k in ("left", "right", "upper", "lower")):
continue
txt = str(item.get("text") or "").strip()
if not txt or txt == "nan":
continue
items.append(
{
"left": float(item["left"]),
"right": float(item["right"]),
"upper": float(item["upper"]),
"lower": float(item["lower"]),
"text": txt,
}
)
elif isinstance(preds, dict):
pb = preds.get("boxes") or preds.get("bboxes") or []
pt = preds.get("texts") or preds.get("text") or []
if isinstance(pb, list) and isinstance(pt, list):
for b, txt in zip(pb, pt):
if not isinstance(txt, str) or not txt.strip():
continue
if isinstance(b, (list, tuple)) and len(b) == 4:
if all(isinstance(p, (list, tuple)) and len(p) == 2 for p in b):
xs = [float(p[0]) for p in b]
ys = [float(p[1]) for p in b]
items.append(
{
"left": min(xs),
"right": max(xs),
"upper": min(ys),
"lower": max(ys),
"text": txt.strip(),
}
)
elif all(isinstance(v, (int, float)) for v in b):
items.append(
{
"left": float(b[0]),
"upper": float(b[1]),
"right": float(b[2]),
"lower": float(b[3]),
"text": txt.strip(),
}
)
return items
def _ocr_items_to_pixel_quad_boxes(
ocr_items: List[Dict[str, Any]],
crop_hw: Tuple[int, int],
) -> Tuple[list, list]:
"""Convert normalized OCR items to pixel-coordinate quadrilateral boxes.
Returns ``(quad_boxes, texts)`` where *quad_boxes* is a list of
``[[x0,y0],[x1,y1],[x2,y2],[x3,y3]]`` arrays (pixel coords) and
*texts* is the corresponding list of strings.
"""
H, W = crop_hw
quad_boxes: list = []
texts: list = []
for item in ocr_items:
left = float(item["left"]) * W
right = float(item["right"]) * W
upper = float(item["upper"]) * H
lower = float(item["lower"]) * H
quad_boxes.append([[left, upper], [right, upper], [right, lower], [left, lower]])
texts.append(item["text"])
return quad_boxes, texts
def _structure_dets_to_class_boxes(
dets: List[Dict[str, Any]],
crop_hw: Tuple[int, int],
) -> Dict[str, np.ndarray]:
"""Group structure-model detections by label_name and scale to pixel coords.
Parameters
----------
dets : list[dict]
Output of ``_prediction_to_detections()`` — each dict has
``bbox_xyxy_norm`` (normalized [0,1]) and ``label_name``.
crop_hw : (int, int)
``(H, W)`` of the crop image.
Returns
-------
dict[str, ndarray]
``{label_name: array_of_shape_(N, 4)}`` in pixel coordinates.
"""
H, W = crop_hw
grouped: Dict[str, list] = {}
for d in dets:
name = d.get("label_name", "")
bbox = d.get("bbox_xyxy_norm")
if not bbox or len(bbox) != 4:
continue
x1, y1, x2, y2 = float(bbox[0]) * W, float(bbox[1]) * H, float(bbox[2]) * W, float(bbox[3]) * H
grouped.setdefault(name, []).append([x1, y1, x2, y2])
return {k: np.array(v) for k, v in grouped.items()}
[docs]
def join_table_structure_and_ocr_output(
structure_dets: List[Dict[str, Any]],
ocr_preds: Any,
crop_hw: Tuple[int, int],
) -> str:
"""Adapter: convert retriever table-structure detections + OCR items,
then call the core joining function.
Parameters
----------
structure_dets : list[dict]
From ``_prediction_to_detections()`` with label_names cell/row/column
and ``bbox_xyxy_norm`` in [0, 1].
ocr_preds : list | dict
Raw OCR output from ``NemotronOCRV1.invoke()``.
crop_hw : (int, int)
``(H, W)`` of the crop image.
"""
ocr_items = _normalize_ocr_items(ocr_preds)
if not ocr_items:
return ""
class_boxes = _structure_dets_to_class_boxes(structure_dets, crop_hw)
# Ensure all three required keys exist.
cell_preds: Dict[str, np.ndarray] = {
"cell": class_boxes.get("cell", np.empty((0, 4))),
"row": class_boxes.get("row", np.empty((0, 4))),
"column": class_boxes.get("column", np.empty((0, 4))),
}
if cell_preds["cell"].shape[0] == 0:
return ""
quad_boxes, texts = _ocr_items_to_pixel_quad_boxes(ocr_items, crop_hw)
return _join_yolox_table_structure_and_ocr_output(cell_preds, quad_boxes, texts)
[docs]
def join_graphic_elements_and_ocr_output(
ge_dets: List[Dict[str, Any]],
ocr_preds: Any,
crop_hw: Tuple[int, int],
) -> str:
"""Adapter: convert retriever graphic-elements detections + OCR items,
then call the core joining + concatenation functions.
Parameters
----------
ge_dets : list[dict]
From ``_prediction_to_detections()`` with chart-element label_names
and ``bbox_xyxy_norm`` in [0, 1].
ocr_preds : list | dict
Raw OCR output from ``NemotronOCRV1.invoke()``.
crop_hw : (int, int)
``(H, W)`` of the crop image.
"""
ocr_items = _normalize_ocr_items(ocr_preds)
if not ocr_items:
return ""
class_boxes = _structure_dets_to_class_boxes(ge_dets, crop_hw)
if not class_boxes:
return ""
# Convert class_boxes values from (N,4) arrays to list-of-arrays (one per detection).
yolox_output: Dict[str, list] = {}
for k, arr in class_boxes.items():
yolox_output[k] = [arr[i] for i in range(arr.shape[0])]
quad_boxes, texts = _ocr_items_to_pixel_quad_boxes(ocr_items, crop_hw)
matched = _join_yolox_graphic_elements_and_ocr_output(yolox_output, quad_boxes, texts)
if not matched:
return ""
return process_yolox_graphic_elements(matched)
[docs]
def reorder_ocr_for_infographic(
ocr_preds: Any,
crop_hw: Tuple[int, int],
) -> str:
"""Adapter: convert OCR items to pixel-coord quad boxes, reorder in
reading order, and return joined text.
"""
ocr_items = _normalize_ocr_items(ocr_preds)
if not ocr_items:
return ""
quad_boxes, texts = _ocr_items_to_pixel_quad_boxes(ocr_items, crop_hw)
if not quad_boxes:
return ""
confs = [1.0] * len(texts)
_, reordered_texts, _ = reorder_boxes(
np.array(quad_boxes),
texts,
confs,
mode="top_left",
dbscan_eps=10,
)
return "\n".join(t for t in reordered_texts if t)