Source code for nv_ingest_client.client.ingest_job_handler

# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import json
import logging
import time
import os
import io
import base64
from collections import defaultdict
from typing import Any, Dict, List, Tuple

from tqdm import tqdm

# Reuse existing CLI utilities to avoid duplicating behavior
from concurrent.futures import as_completed
from nv_ingest_client.util.util import check_ingest_result
from PIL import Image

logger = logging.getLogger(__name__)


[docs] class IngestJobHandler: """ A modular job handler that mirrors the CLI's create_and_process_jobs flow, so the same proven scheduling/retry behavior can be reused by other entry points. Usage: handler = IngestJobHandler(client, files, tasks, output_dir, batch_size) total_files, trace_times, total_pages, trace_ids = handler.run() """ def __init__( self, client: Any, files: List[str], tasks: Dict[str, Any], output_directory: str, batch_size: int, fail_on_error: bool = False, save_images_separately: bool = False, show_progress: bool = True, show_telemetry: bool = False, job_queue_id: str = "ingest_task_queue", pdf_split_page_count: int = None, ) -> None: self.client = client self.files = files self.tasks = tasks self.output_directory = output_directory self.batch_size = batch_size self.fail_on_error = fail_on_error self.save_images_separately = save_images_separately self.show_progress = show_progress self.show_telemetry = show_telemetry self.job_queue_id = job_queue_id self.pdf_split_page_count = pdf_split_page_count self._pbar = None # Internal state used across iterations self._retry_job_ids: List[str] = [] self._processed: int = 0 self._job_ids_batch: List[str] = [] self._job_id_map: Dict[str, str] = {} self._trace_times: Dict[str, List[float]] = defaultdict(list) # Constants self._IMAGE_TYPES: set = {"png", "bmp", "jpeg", "jpg", "tiff"} # --------------------------- # Progress bar helpers # --------------------------- def _init_progress_bar(self, total: int) -> None: if self.show_progress: self._pbar = tqdm(total=total, desc="Processing files", unit="file") else: self._pbar = None def _update_progress(self, n: int = 1, pages_per_sec: float | None = None) -> None: if not self._pbar: return if pages_per_sec is not None: self._pbar.set_postfix(pages_per_sec=f"{pages_per_sec:.2f}") self._pbar.update(n) def _close_progress_bar(self) -> None: if self._pbar: self._pbar.close() self._pbar = None def _generate_job_batch_for_iteration(self) -> None: """ Build the next batch of jobs for processing and submit newly created jobs. This method mirrors the CLI batching semantics: it prioritizes retry jobs, then creates new jobs up to the given ``batch_size``, submits those new jobs asynchronously to the configured queue, and returns the combined list of job indices for this iteration. It also updates the internal progress bar when configured and advances the processed-file counter. Side Effects ------------ - Populates/overwrites ``self._job_ids_batch`` with the ordered job indices to process this iteration (``retry`` first, then newly created jobs). - Updates ``self._job_id_map`` with any new mappings from job index to source file path for jobs created in this iteration. Raises ------ RuntimeError If one or more job specs cannot be created (e.g., unreadable files) and ``self.fail_on_error`` is True. Notes ----- - Side effects: - Creates JobSpecs via ``self.client.create_jobs_for_batch(...)``. - Submits newly created jobs via ``self.client.submit_job_async(..., self.job_queue_id)``. - Updates the class-owned progress bar (``self._pbar``) to account for missing jobs when some files fail to produce specs and ``self.fail_on_error`` is False. - This method does not perform fetching; it only prepares and submits jobs for the current iteration. - The ``processed`` counter advances by the number of files attempted in this iteration, even if some job specs are missing (unless ``self.fail_on_error`` is True). Examples -------- >>> handler = IngestJobHandler(client, files, tasks, "/tmp/out", batch_size=32) >>> retry_ids = [] >>> handler._generate_job_batch_for_iteration() >>> len(handler._job_ids_batch) <= 32 True """ job_indices: List[str] = [] job_index_map_updates: Dict[str, str] = {} cur_job_count: int = 0 if self._retry_job_ids: job_indices.extend(self._retry_job_ids) cur_job_count = len(job_indices) if (cur_job_count < self.batch_size) and (self._processed < len(self.files)): new_job_count: int = min(self.batch_size - cur_job_count, len(self.files) - self._processed) batch_files: List[str] = self.files[self._processed : self._processed + new_job_count] new_job_indices: List[str] = self.client.create_jobs_for_batch( batch_files, self.tasks, pdf_split_page_count=self.pdf_split_page_count ) if len(new_job_indices) != new_job_count: missing_jobs: int = new_job_count - len(new_job_indices) error_msg: str = ( f"Missing {missing_jobs} job specs -- this is likely due to bad reads or file corruption" ) logger.warning(error_msg) if self.fail_on_error: raise RuntimeError(error_msg) if self._pbar: self._pbar.update(missing_jobs) job_index_map_updates = {job_index: file for job_index, file in zip(new_job_indices, batch_files)} self._processed += new_job_count # Submit newly created jobs asynchronously to the configured queue _ = self.client.submit_job_async(new_job_indices, self.job_queue_id) job_indices.extend(new_job_indices) # Save into class state self._job_ids_batch = job_indices # Merge new mappings (do not drop existing entries for retry jobs) self._job_id_map.update(job_index_map_updates) def _handle_future_result(self, future, timeout: int = 10): """ Handle the result of a completed future job and process annotations. Parameters ---------- future : concurrent.futures.Future Future representing an asynchronous job. timeout : int, optional Maximum seconds to wait for the future result. Returns ------- Tuple[Dict[str, Any], str] The decoded result dictionary and the trace_id for the job. Raises ------ RuntimeError If the job result indicates failure per check_ingest_result. """ result, _, trace_id = future.result(timeout=timeout)[0] if ("annotations" in result) and result["annotations"]: annotations = result["annotations"] for key, value in annotations.items(): logger.debug(f"Annotation: {key} -> {json.dumps(value, indent=2)}") failed, description = check_ingest_result(result) if failed: raise RuntimeError(f"Ingest job failed: {description}") return result, trace_id def _process_response(self, response: Dict[str, Any]) -> None: """ Extract trace timing entries from a response and accumulate per-stage elapsed times into ``self._trace_times``. Parameters ---------- response : Dict[str, Any] Full response payload containing an optional ``trace`` dictionary with entry/exit timestamps. """ trace_data: Dict[str, Any] = response.get("trace", {}) for key, entry_time in trace_data.items(): if "entry" in key: exit_key: str = key.replace("entry", "exit") exit_time: Any = trace_data.get(exit_key) if exit_time: stage_parts = key.split("::") if len(stage_parts) >= 3: stage_name: str = stage_parts[2] elapsed_time: int = exit_time - entry_time self._trace_times[stage_name].append(elapsed_time) def _save_response_data( self, response: Dict[str, Any], output_directory: str, images_to_disk: bool = False ) -> None: """ Save the response data into categorized metadata JSON files and optionally save images to disk. Parameters ---------- response : Dict[str, Any] Full response payload with a "data" list of documents. output_directory : str Output directory where per-type metadata JSON files (and any media) are written. images_to_disk : bool, optional If True, decode and write image contents to disk and replace content with a file URL. """ if ("data" not in response) or (not response["data"]): logger.debug("Data is not in the response or response.data is empty") return response_data = response["data"] if not isinstance(response_data, list) or len(response_data) == 0: logger.debug("Response data is not a list or the list is empty.") return doc_meta_base = response_data[0]["metadata"] source_meta = doc_meta_base["source_metadata"] doc_name = source_meta["source_id"] clean_doc_name = os.path.basename(doc_name) output_name = f"{clean_doc_name}.metadata.json" # Organize by document type doc_map: Dict[str, List[Dict[str, Any]]] = {} for document in response_data: meta: Dict[str, Any] = document.get("metadata", {}) content_meta: Dict[str, Any] = meta.get("content_metadata", {}) doc_type: str = content_meta.get("type", "unknown") doc_map.setdefault(doc_type, []).append(document) for doc_type, documents in doc_map.items(): doc_type_path = os.path.join(output_directory, doc_type) os.makedirs(doc_type_path, exist_ok=True) if doc_type in ("image", "structured") and images_to_disk: for i, doc in enumerate(documents): meta: Dict[str, Any] = doc.get("metadata", {}) image_content = meta.get("content") image_type = ( meta.get("image_metadata", {}).get("image_type", "png").lower() if doc_type == "image" else "png" ) if image_content and image_type in self._IMAGE_TYPES: try: image_data = base64.b64decode(image_content) image = Image.open(io.BytesIO(image_data)) image_ext = "jpg" if image_type == "jpeg" else image_type image_filename = f"{clean_doc_name}_{i}.{image_ext}" image_output_path = os.path.join(doc_type_path, "media", image_filename) os.makedirs(os.path.dirname(image_output_path), exist_ok=True) image.save(image_output_path, format=image_ext.upper()) meta["content"] = "" meta["content_url"] = os.path.realpath(image_output_path) logger.debug(f"Saved image to {image_output_path}") except Exception as e: logger.error(f"Failed to save image {i} for {clean_doc_name}: {e}") # Write the metadata JSON file for this type with open(os.path.join(doc_type_path, output_name), "w") as f: f.write(json.dumps(documents, indent=2))
[docs] def run(self) -> Tuple[int, Dict[str, List[float]], int, Dict[str, str]]: total_files: int = len(self.files) total_pages_processed: int = 0 trace_ids: Dict[str, str] = defaultdict(list) # type: ignore failed_jobs: List[str] = [] retry_counts: Dict[str, int] = defaultdict(int) pages_per_sec: float = None start_time_ns: int = time.time_ns() self._init_progress_bar(total_files) pages_per_sec: float = None try: self._processed = 0 while (self._processed < len(self.files)) or self._retry_job_ids: # Create a batch (retries first, then new jobs up to batch_size) self._generate_job_batch_for_iteration() job_id_map = self._job_id_map self._retry_job_ids = [] futures_dict: Dict[Any, str] = self.client.fetch_job_result_async(self._job_ids_batch, data_only=False) for future in as_completed(futures_dict.keys()): pages_per_sec = None try: # Block as each future completes; this mirrors CLI behavior future_response, trace_id = self._handle_future_result(future) job_id: str = futures_dict[future] trace_ids[job_id_map[job_id]] = trace_id # Extract page count: prefer V2 metadata location, fall back to V1 page_count = None source_name = None # Try V2 metadata location first (top-level metadata.total_pages) if "metadata" in future_response and future_response["metadata"]: response_metadata = future_response["metadata"] page_count = response_metadata.get("total_pages") source_name = response_metadata.get("original_source_name") # Fall back to V1 location (first data element's hierarchy.page_count) if page_count is None and future_response.get("data"): try: first_page_metadata = future_response["data"][0]["metadata"] page_count = first_page_metadata["content_metadata"]["hierarchy"]["page_count"] source_name = first_page_metadata["source_metadata"]["source_name"] except (KeyError, IndexError, TypeError): # If we can't extract from V1 location, use defaults pass # Use extracted values or defaults if page_count is None: page_count = 0 # Default if not found if source_name is None: source_name = "unknown_source" file_page_counts: Dict[str, int] = {source_name: page_count} if self.output_directory: self._save_response_data( future_response, self.output_directory, images_to_disk=self.save_images_separately, ) total_pages_processed += file_page_counts[list(file_page_counts.keys())[0]] elapsed_time: float = (time.time_ns() - start_time_ns) / 1e9 if elapsed_time > 0: pages_per_sec: float = total_pages_processed / elapsed_time else: pages_per_sec = None self._process_response(future_response) except TimeoutError: job_id = futures_dict[future] src_name = job_id_map[job_id] retry_counts[src_name] += 1 self._retry_job_ids.append(job_id) except json.JSONDecodeError as e: job_id = futures_dict[future] src_name = job_id_map[job_id] logger.error(f"Decoding error while processing {job_id}({src_name}): {e}") failed_jobs.append(f"{job_id}::{src_name}") except RuntimeError as e: job_id = futures_dict[future] src_name = job_id_map[job_id] logger.error(f"Error while processing '{job_id}' - ({src_name}):\n{e}") failed_jobs.append(f"{job_id}::{src_name}") except Exception as e: job_id = futures_dict[future] src_name = job_id_map[job_id] logger.exception(f"Unhandled error while processing {job_id}({src_name}): {e}") failed_jobs.append(f"{job_id}::{src_name}") finally: # Do not update the pbar if this job is going to be retried if futures_dict[future] not in self._retry_job_ids: self._update_progress(1, pages_per_sec) finally: self._close_progress_bar() # Optionally print telemetry summary if self.show_telemetry and hasattr(self.client, "summarize_telemetry"): try: summary = self.client.summarize_telemetry() logger.info("NvIngestClient Telemetry Summary: %s", json.dumps(summary, indent=2)) except Exception: pass return total_files, self._trace_times, total_pages_processed, trace_ids