Source code for nv_ingest_api.interface

# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
import functools
import inspect
import pprint
from typing import Dict, Any, Optional, List

from pydantic import BaseModel

from nv_ingest_api.internal.schemas.extract.extract_pdf_schema import PDFiumConfigSchema, NemoRetrieverParseConfigSchema

logger = logging.getLogger(__name__)

## CONFIG_SCHEMAS is a global dictionary that maps extraction methods to Pydantic schemas.
CONFIG_SCHEMAS: Dict[str, Any] = {
    "adobe": PDFiumConfigSchema,
    "llama": PDFiumConfigSchema,
    "nemoretriever_parse": NemoRetrieverParseConfigSchema,
    "pdfium": PDFiumConfigSchema,
    "tika": PDFiumConfigSchema,
    "unstructured_io": PDFiumConfigSchema,
}


def _build_config_from_schema(schema_class: type[BaseModel], args: Dict[str, Any]) -> Dict[str, Any]:
    """
    Build and validate a configuration dictionary from the provided arguments using a Pydantic schema.

    This function filters the supplied arguments to include only those keys defined in the given
    Pydantic schema (using Pydantic v2's `model_fields`), instantiates the schema for validation,
    and returns the validated configuration as a dictionary.

    Parameters
    ----------
    schema_class : type[BaseModel]
        The Pydantic BaseModel subclass used for validating the configuration.
    args : dict
        A dictionary of arguments from which to extract and validate configuration data.

    Returns
    -------
    dict
        A dictionary containing the validated configuration data as defined by the schema.

    Raises
    ------
    pydantic.ValidationError
        If the provided arguments do not conform to the schema.
    """
    field_names = schema_class.model_fields.keys()
    config_data = {k: v for k, v in args.items() if k in field_names}
    # Instantiate the schema to perform validation, then return the model's dictionary representation.

    return schema_class(**config_data).dict()


[docs] def extraction_interface_relay_constructor(api_fn, task_keys: Optional[List[str]] = None): """ Decorator for constructing and validating configuration using Pydantic schemas. This decorator wraps a user-facing interface function. It extracts common task parameters (using the provided task_keys, or defaults if not specified) and method-specific configuration parameters based on a required 'extract_method' keyword argument. It then uses the corresponding Pydantic schema (from the global CONFIG_SCHEMAS registry) to validate and build a method-specific configuration. The resulting composite configuration, along with the extraction ledger and execution trace log, is passed to the backend API function. Parameters ---------- api_fn : callable The backend API function that will be called with the extraction ledger, the task configuration dictionary, the extractor configuration, and the execution trace log. This function must conform to the signature: extract_primitives_from_pdf_internal(df_extraction_ledger: pd.DataFrame, task_config: Dict[str, Any], extractor_config: Any, execution_trace_log: Optional[List[Any]] = None) task_keys : list of str, optional A list of keyword names that should be extracted from the user function as common task parameters. If not provided, defaults to ["extract_text", "extract_images", "extract_tables", "extract_charts"]. Returns ------- callable A wrapped function that builds and validates the configuration before invoking the backend API function. Raises ------ ValueError If the extraction method specified is not supported (i.e., no corresponding Pydantic schema exists in CONFIG_SCHEMAS), if api_fn does not conform to the expected signature, or if the required 'extract_method' parameter is not provided. """ # Verify that api_fn conforms to the expected signature. try: # Try binding four arguments: ledger, task_config, extractor_config, and execution_trace_log. inspect.signature(api_fn).bind("dummy_ledger", {"dummy": True}, {"dummy": True}, {}) except TypeError as e: raise ValueError( "api_fn must conform to the signature: " "extract_primitives_from_pdf(df_extraction_ledger, task_config, extractor_config, execution_trace_log)" ) from e if task_keys is None: task_keys = [] def decorator(user_fn): @functools.wraps(user_fn) def wrapper(*args, **kwargs): # Use bind_partial so that missing required arguments can be handled gracefully. sig = inspect.signature(user_fn) bound = sig.bind_partial(*args, **kwargs) bound.apply_defaults() # The first parameter is assumed to be the extraction ledger. param_names = list(sig.parameters.keys()) if param_names[0] not in bound.arguments: raise ValueError("Missing required ledger argument.") ledger = bound.arguments[param_names[0]] # Process reserved 'execution_trace_log'. execution_trace_log = bound.arguments.get("execution_trace_log", None) if execution_trace_log is None: execution_trace_log = {} # Replace None with an empty dict. if "execution_trace_log" in bound.arguments: del bound.arguments["execution_trace_log"] # Ensure that 'extract_method' is provided. if "extract_method" not in bound.arguments or bound.arguments["extract_method"] is None: raise ValueError("The 'extract_method' parameter is required.") extract_method = bound.arguments["extract_method"] del bound.arguments["extract_method"] # Extract common task parameters using the specified task_keys. task_params = {key: bound.arguments[key] for key in task_keys if key in bound.arguments} task_params["extract_method"] = extract_method task_config = {"params": task_params} # Look up the appropriate Pydantic schema. schema_class = CONFIG_SCHEMAS.get(extract_method) if schema_class is None: raise ValueError(f"Unsupported extraction method: {extract_method}") # Build the method-specific configuration using the schema class. extraction_config_dict = _build_config_from_schema(schema_class, bound.arguments) # Create a Pydantic object instead of a dictionary for the specific extractor config extractor_schema = None try: # Find the appropriate extractor schema class based on the extraction method extractor_schema_name = f"{extract_method.capitalize()}ExtractorSchema" extractor_schema_class = globals().get(extractor_schema_name) if extractor_schema_class is None: # Try another common naming pattern extractor_schema_name = f"{extract_method.upper()}ExtractorSchema" extractor_schema_class = globals().get(extractor_schema_name) if extractor_schema_class is None: # Final fallback attempt with camelCase extractor_schema_name = f"{extract_method[0].upper() + extract_method[1:]}ExtractorSchema" extractor_schema_class = globals().get(extractor_schema_name) if extractor_schema_class is not None: # Create the extractor schema with the method-specific config config_key = f"{extract_method}_config" extractor_schema = extractor_schema_class(**{config_key: extraction_config_dict}) else: logger.warning(f"Could not find extractor schema class for method: {extract_method}") except Exception as e: logger.warning(f"Error creating extractor schema: {str(e)}") # Fall back to dictionary approach if schema creation fails extractor_schema = {f"{extract_method}_config": extraction_config_dict} # If schema creation failed, fall back to dictionary if extractor_schema is None: extractor_schema = {f"{extract_method}_config": extraction_config_dict} # Log the task and extractor configurations for debugging logger.debug("\n" + "=" * 80) logger.debug(f"DEBUG - API Function: {api_fn.__name__}") logger.debug(f"DEBUG - Extract Method: {extract_method}") logger.debug("-" * 80) # Format the task config as a string and log it task_config_str = pprint.pformat(task_config, width=100, sort_dicts=False) logger.debug(f"DEBUG - Task Config:\n{task_config_str}") logger.debug("-" * 80) # Format the extractor config as a string and log it if hasattr(extractor_schema, "model_dump"): extractor_config_str = pprint.pformat(extractor_schema.model_dump(), width=100, sort_dicts=False) else: extractor_config_str = pprint.pformat(extractor_schema, width=100, sort_dicts=False) logger.debug(f"DEBUG - Extractor Config Type: {type(extractor_schema)}") logger.debug(f"DEBUG - Extractor Config:\n{extractor_config_str}") logger.debug("=" * 80 + "\n") # Call the backend API function. pprint.pprint(task_config) pprint.pprint(extractor_schema) result = api_fn(ledger, task_config, extractor_schema, execution_trace_log) # If the result is a tuple, return only the first element if isinstance(result, tuple): return result[0] return result return wrapper return decorator