Source code for nv_ingest_client.primitives.tasks.extract

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0


# pylint: disable=too-few-public-methods
# pylint: disable=too-many-arguments

import logging
import os
from typing import Any
from typing import Dict
from typing import Literal
from typing import Optional
from typing import get_args

from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import field_validator
from pydantic import model_validator

from .task_base import Task

logger = logging.getLogger(__name__)

UNSTRUCTURED_API_KEY = os.environ.get("UNSTRUCTURED_API_KEY", None)
UNSTRUCTURED_URL = os.environ.get("UNSTRUCTURED_URL", "https://api.unstructured.io/general/v0/general")
UNSTRUCTURED_STRATEGY = os.environ.get("UNSTRUCTURED_STRATEGY", "auto")
UNSTRUCTURED_CONCURRENCY_LEVEL = os.environ.get("UNSTRUCTURED_CONCURRENCY_LEVEL", 10)

ADOBE_CLIENT_ID = os.environ.get("ADOBE_CLIENT_ID", None)
ADOBE_CLIENT_SECRET = os.environ.get("ADOBE_CLIENT_SECRET", None)

_DEFAULT_EXTRACTOR_MAP = {
    "bmp": "image",
    "csv": "pandas",
    "docx": "python_docx",
    "excel": "openpyxl",
    "html": "markitdown",
    "jpeg": "image",
    "jpg": "image",
    "parquet": "pandas",
    "pdf": "pdfium",
    "png": "image",
    "pptx": "python_pptx",
    "text": "txt",
    "tiff": "image",
    "txt": "txt",
    "xml": "lxml",
    "mp3": "audio",
    "wav": "audio",
    "json": "txt",
    "md": "txt",
    "sh": "txt",
}

_Type_Extract_Method_PDF = Literal[
    "adobe",
    "nemoretriever_parse",
    "haystack",
    "llama_parse",
    "pdfium",
    "tika",
    "unstructured_io",
]

_Type_Extract_Method_DOCX = Literal["python_docx", "haystack", "unstructured_local", "unstructured_service"]

_Type_Extract_Method_PPTX = Literal["python_pptx", "haystack", "unstructured_local", "unstructured_service"]

_Type_Extract_Method_Image = Literal["image"]

_Type_Extract_Method_Audio = Literal["audio"]

_Type_Extract_Method_Text = Literal["txt"]

_Type_Extract_Method_Html = Literal["markitdown"]

_Type_Extract_Method_Map = {
    "bmp": get_args(_Type_Extract_Method_Image),
    "docx": get_args(_Type_Extract_Method_DOCX),
    "html": get_args(_Type_Extract_Method_Html),
    "jpeg": get_args(_Type_Extract_Method_Image),
    "jpg": get_args(_Type_Extract_Method_Image),
    "pdf": get_args(_Type_Extract_Method_PDF),
    "png": get_args(_Type_Extract_Method_Image),
    "pptx": get_args(_Type_Extract_Method_PPTX),
    "text": get_args(_Type_Extract_Method_Text),
    "tiff": get_args(_Type_Extract_Method_Image),
    "txt": get_args(_Type_Extract_Method_Text),
    "mp3": get_args(_Type_Extract_Method_Audio),
    "wav": get_args(_Type_Extract_Method_Audio),
}

_Type_Extract_Tables_Method_PDF = Literal["yolox", "pdfium", "nemoretriever_parse"]

_Type_Extract_Tables_Method_DOCX = Literal["python_docx",]

_Type_Extract_Tables_Method_PPTX = Literal["python_pptx",]

_Type_Extract_Tables_Method_Map = {
    "pdf": get_args(_Type_Extract_Tables_Method_PDF),
    "docx": get_args(_Type_Extract_Tables_Method_DOCX),
    "pptx": get_args(_Type_Extract_Tables_Method_PPTX),
}

_Type_Extract_Images_Method = Literal["simple", "group"]


[docs] class ExtractTaskSchema(BaseModel): document_type: str extract_method: str = None # Initially allow None to set a smart default extract_text: bool = True extract_images: bool = True extract_images_method: str = "group" extract_images_params: Optional[Dict[str, Any]] = None extract_tables: bool = True extract_tables_method: str = "yolox" extract_charts: Optional[bool] = None # Initially allow None to set a smart default extract_infographics: bool = False extract_audio_params: Optional[Dict[str, Any]] = None text_depth: str = "document" paddle_output_format: str = "pseudo_markdown"
[docs] @model_validator(mode="after") @classmethod def set_default_extract_method(cls, values): document_type = values.document_type.lower() # Ensure case-insensitive comparison extract_method = values.extract_method if document_type not in _DEFAULT_EXTRACTOR_MAP: raise ValueError( f"Unsupported document type: {document_type}." f" Supported types are: {list(_DEFAULT_EXTRACTOR_MAP.keys())}" ) if extract_method is None: values.extract_method = _DEFAULT_EXTRACTOR_MAP[document_type] return values
[docs] @field_validator("extract_charts") def set_default_extract_charts(cls, v, values): # `extract_charts` is initially set to None for backward compatibility. # {extract_tables: true, extract_charts: None} or {extract_tables: true, extract_charts: true} enables both # table and chart extraction. # {extract_tables: true, extract_charts: false} enables only the table extraction and disables chart extraction. extract_charts = v if extract_charts is None: extract_charts = values.data.get("extract_tables") return extract_charts
[docs] @field_validator("extract_method") def extract_method_must_be_valid(cls, v, values, **kwargs): document_type = values.data.get("document_type", "").lower() # Ensure case-insensitive comparison # Skip validation for text-like types, since they do not have 'extract' stages. if document_type in ["txt", "text", "json", "md", "sh"]: return valid_methods = set(_Type_Extract_Method_Map[document_type]) if v not in valid_methods: raise ValueError(f"extract_method must be one of {valid_methods}") return v
[docs] @field_validator("document_type") def document_type_must_be_supported(cls, v): if v.lower() not in _DEFAULT_EXTRACTOR_MAP: raise ValueError( f"Unsupported document type '{v}'. Supported types are: {', '.join(_DEFAULT_EXTRACTOR_MAP.keys())}" ) return v.lower()
[docs] @field_validator("extract_tables_method") def extract_tables_method_must_be_valid(cls, v, values, **kwargs): document_type = values.data.get("document_type", "").lower() # Ensure case-insensitive comparison valid_methods = set(_Type_Extract_Tables_Method_Map[document_type]) if v not in valid_methods: raise ValueError(f"extract_method must be one of {valid_methods}") return v
[docs] @field_validator("extract_images_method") def extract_images_method_must_be_valid(cls, v): if v.lower() not in get_args(_Type_Extract_Images_Method): raise ValueError( f"Unsupported document type '{v}'. Supported types are: {', '.join(_Type_Extract_Images_Method)}" ) return v.lower()
model_config = ConfigDict(extra="forbid")
[docs] class ExtractTask(Task): """ Object for document extraction task """ def __init__( self, document_type, extract_method: _Type_Extract_Method_PDF = "pdfium", extract_text: bool = False, extract_images: bool = False, extract_tables: bool = False, extract_charts: Optional[bool] = None, extract_audio_params: Optional[Dict[str, Any]] = None, extract_images_method: _Type_Extract_Images_Method = "group", extract_images_params: Optional[Dict[str, Any]] = None, extract_tables_method: _Type_Extract_Tables_Method_PDF = "yolox", extract_infographics: bool = False, text_depth: str = "document", paddle_output_format: str = "pseudo_markdown", ) -> None: """ Setup Extract Task Config """ super().__init__() self._document_type = document_type self._extract_audio_params = extract_audio_params self._extract_images = extract_images self._extract_method = extract_method self._extract_tables = extract_tables self._extract_images_method = extract_images_method self._extract_images_params = extract_images_params self._extract_tables_method = extract_tables_method # `extract_charts` is initially set to None for backward compatibility. # {extract_tables: true, extract_charts: None} or {extract_tables: true, extract-charts: true} enables both # table and chart extraction. # {extract_tables: true, extract_charts: false} enables only the table extraction and disables chart extraction. self._extract_charts = extract_charts if extract_charts is not None else extract_tables self._extract_infographics = extract_infographics self._extract_text = extract_text self._text_depth = text_depth self._paddle_output_format = paddle_output_format def __str__(self) -> str: """ Returns a string with the object's config and run time state """ info = "" info += "Extract Task:\n" info += f" document type: {self._document_type}\n" info += f" extract method: {self._extract_method}\n" info += f" extract text: {self._extract_text}\n" info += f" extract images: {self._extract_images}\n" info += f" extract tables: {self._extract_tables}\n" info += f" extract charts: {self._extract_charts}\n" info += f" extract infographics: {self._extract_infographics}\n" info += f" extract images method: {self._extract_images_method}\n" info += f" extract tables method: {self._extract_tables_method}\n" info += f" text depth: {self._text_depth}\n" info += f" paddle_output_format: {self._paddle_output_format}\n" if self._extract_images_params: info += f" extract images params: {self._extract_images_params}\n" if self._extract_audio_params: info += f" extract audio params: {self._extract_audio_params}\n" return info
[docs] def to_dict(self) -> Dict: """ Convert to a dict for submission to redis (fixme) """ extract_params = { "extract_text": self._extract_text, "extract_images": self._extract_images, "extract_tables": self._extract_tables, "extract_images_method": self._extract_images_method, "extract_tables_method": self._extract_tables_method, "extract_charts": self._extract_charts, "extract_infographics": self._extract_infographics, "text_depth": self._text_depth, "paddle_output_format": self._paddle_output_format, } if self._extract_images_params: extract_params.update( { "extract_images_params": self._extract_images_params, } ) if self._extract_audio_params: extract_params.update( { "extract_audio_params": self._extract_audio_params, } ) task_properties = { "method": self._extract_method, "document_type": self._document_type, "params": extract_params, } # TODO(Devin): I like the idea of Derived classes augmenting the to_dict method, but its not logically # consistent with how we define tasks, we don't have multiple extract tasks, we have extraction paths based on # the method and the document type. if self._extract_method == "unstructured_local": unstructured_properties = { "api_key": "", # TODO(Devin): Should be an environment variable or configurable parameter "unstructured_url": "", # TODO(Devin): Should be an environment variable } task_properties["params"].update(unstructured_properties) elif self._extract_method == "unstructured_io": unstructured_properties = { "unstructured_api_key": os.environ.get("UNSTRUCTURED_API_KEY", UNSTRUCTURED_API_KEY), "unstructured_url": os.environ.get("UNSTRUCTURED_URL", UNSTRUCTURED_URL), "unstructured_strategy": os.environ.get("UNSTRUCTURED_STRATEGY", UNSTRUCTURED_STRATEGY), "unstructured_concurrency_level": os.environ.get( "UNSTRUCTURED_CONCURRENCY_LEVEL", UNSTRUCTURED_CONCURRENCY_LEVEL ), } task_properties["params"].update(unstructured_properties) elif self._extract_method == "adobe": adobe_properties = { "adobe_client_id": os.environ.get("ADOBE_CLIENT_ID", ADOBE_CLIENT_ID), "adobe_client_secrect": os.environ.get("ADOBE_CLIENT_SECRET", ADOBE_CLIENT_SECRET), } task_properties["params"].update(adobe_properties) return {"type": "extract", "task_properties": task_properties}
@property def document_type(self): return self._document_type