Source code for nemo_retriever.utils.detection_summary

# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Shared detection summary logic.

Provides a single function that accumulates per-page detection counters from
an iterable of ``(page_key, metadata_dict, row_dict)`` tuples.  Both the
batch pipeline (reading from LanceDB) and inprocess pipeline (reading from
a DataFrame) can produce these tuples, allowing the summary computation to
be shared.
"""

from __future__ import annotations

from datetime import datetime
import json
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Tuple


def _safe_int(value: object, default: int = 0) -> int:
    try:
        if value is None:
            return default
        return int(value)
    except Exception:
        return default


[docs] def compute_detection_summary( rows: Iterable[Tuple[Any, Dict[str, Any], Dict[str, Any]]], ) -> Dict[str, Any]: """Compute deduped detection totals from an iterable of page data. Each element is ``(page_key, metadata_dict, row_dict)`` where: - *page_key* is a hashable value used to deduplicate exploded content rows (e.g. ``(source_id, page_number)``). - *metadata_dict* is the parsed JSON metadata (may contain counters from the LanceDB metadata column or from direct DataFrame columns). - *row_dict* is the raw row dict, used as fallback for counters stored as top-level DataFrame columns (e.g. ``table``, ``chart`` lists). """ per_page: dict[Any, dict] = {} for page_key, meta, raw_row in rows: entry = per_page.setdefault( page_key, { "pe": 0, "ocr_table": 0, "ocr_chart": 0, "ocr_infographic": 0, "pe_by_label": defaultdict(int), }, ) pe = _safe_int(meta.get("page_elements_v3_num_detections") or raw_row.get("page_elements_v3_num_detections")) entry["pe"] = max(entry["pe"], pe) for field, meta_key, col_key in [ ("ocr_table", "ocr_table_detections", "table"), ("ocr_chart", "ocr_chart_detections", "chart"), ("ocr_infographic", "ocr_infographic_detections", "infographic"), ]: val = _safe_int(meta.get(meta_key)) if val == 0: col_val = raw_row.get(col_key) if isinstance(col_val, list): val = len(col_val) entry[field] = max(entry[field], val) label_counts = meta.get("page_elements_v3_counts_by_label") or raw_row.get("page_elements_v3_counts_by_label") if isinstance(label_counts, dict): for label, count in label_counts.items(): entry["pe_by_label"][str(label)] = max( entry["pe_by_label"][str(label)], _safe_int(count), ) pe_by_label_totals: dict[str, int] = defaultdict(int) pe_total = ocr_table_total = ocr_chart_total = ocr_infographic_total = 0 for e in per_page.values(): pe_total += e["pe"] ocr_table_total += e["ocr_table"] ocr_chart_total += e["ocr_chart"] ocr_infographic_total += e["ocr_infographic"] for label, count in e["pe_by_label"].items(): pe_by_label_totals[label] += count return { "pages_seen": len(per_page), "page_elements_v3_total_detections": pe_total, "page_elements_v3_counts_by_label": dict(sorted(pe_by_label_totals.items())), "ocr_table_total_detections": ocr_table_total, "ocr_chart_total_detections": ocr_chart_total, "ocr_infographic_total_detections": ocr_infographic_total, }
[docs] def iter_dataframe_rows(df): """Yield ``(page_key, meta, row_dict)`` tuples from a pandas DataFrame.""" for _, row in df.iterrows(): row_dict = row.to_dict() path = str(row_dict.get("path") or row_dict.get("source_id") or "") page_number = _safe_int(row_dict.get("page_number", -1), default=-1) meta = row_dict.get("metadata") if isinstance(meta, str): try: meta = json.loads(meta) except Exception: meta = {} if not isinstance(meta, dict): meta = {} yield (path, page_number), meta, row_dict
[docs] def collect_detection_summary_from_lancedb(uri: str, table_name: str) -> Optional[Dict[str, Any]]: """Collect detection summary from a LanceDB table.""" try: from nemo_retriever.vdb.lancedb_read import iter_lancedb_rows return compute_detection_summary(iter_lancedb_rows(uri, table_name)) except Exception: return None
[docs] def collect_detection_summary_from_df(df) -> Dict[str, Any]: """Collect detection summary from a pandas DataFrame.""" return compute_detection_summary(iter_dataframe_rows(df))
[docs] def write_detection_summary(path: Path, summary: Optional[Dict[str, Any]]) -> None: """Write a detection summary dict to a JSON file.""" target = Path(path).expanduser().resolve() target.parent.mkdir(parents=True, exist_ok=True) payload = summary if summary is not None else {"error": "Detection summary unavailable."} target.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
def _fmt_time(seconds: float) -> str: """Format *seconds* as ``raw / H:MM:SS.mmm``.""" ms = int(round(seconds * 1000)) h, remainder = divmod(ms, 3_600_000) m, remainder = divmod(remainder, 60_000) s, millis = divmod(remainder, 1000) return f"{seconds:.2f}s / {h}:{m:02d}:{s:02d}.{millis:03d}" def _evaluation_metric_sort_key(item: tuple[str, float]) -> tuple[str, int, str]: """Sort metrics like ndcg@1, ndcg@3, ..., recall@1, recall@3, ... .""" key, _value = item metric_name, sep, suffix = str(key).partition("@") if sep: try: return metric_name, int(suffix), str(key) except ValueError: pass return metric_name, 0, str(key)