Source code for nv_ingest_api.internal.transform.caption_image

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import Any, Dict, List, Optional, Tuple, Union

import pandas as pd
from pydantic import BaseModel

from nv_ingest_api.internal.primitives.nim.model_interface.vlm import VLMModelInterface
from nv_ingest_api.internal.enums.common import ContentTypeEnum
from nv_ingest_api.util.exception_handlers.decorators import unified_exception_handler
from nv_ingest_api.util.image_processing import scale_image_to_encoding_size
from nv_ingest_api.util.nim import create_inference_client

logger = logging.getLogger(__name__)

_MAX_CONTEXT_TEXT_CHARS = 4096


def _gather_context_text_for_image(
    image_meta: Dict[str, Any],
    page_text_map: Dict[int, List[str]],
    max_chars: int,
) -> str:
    """
    Gather surrounding OCR text for an image to provide as VLM prompt context.

    Parameters
    ----------
    image_meta : dict
        The full metadata dict for the image row.
    page_text_map : dict
        Mapping of page number -> list of text strings, precomputed from the
        DataFrame's text rows.
    max_chars : int
        Maximum number of characters to return. Will be clamped to
        ``_MAX_CONTEXT_TEXT_CHARS``.

    Returns
    -------
    str
        Surrounding text (possibly truncated), or empty string if none found.
    """
    effective_max = min(max_chars, _MAX_CONTEXT_TEXT_CHARS)
    content_meta = image_meta.get("content_metadata", {})
    page_num = content_meta.get("page_number", -1)
    page_texts = page_text_map.get(page_num, [])
    if page_texts:
        combined = " ".join(page_texts)
        return combined[:effective_max]

    return ""


def _build_prompt_with_context(base_prompt: str, context_text: str) -> str:
    """
    Prepend surrounding-text context to the base VLM prompt.

    If *context_text* is empty the *base_prompt* is returned unchanged.
    """
    if not context_text:
        return base_prompt
    return f"Text near this image:\n---\n{context_text}\n---\n\n{base_prompt}"


def _build_page_text_map(df: pd.DataFrame) -> Dict[int, List[str]]:
    """
    Build a mapping of page number -> list of text content strings from text
    rows in the DataFrame.  Computed once per call to avoid O(images * rows).
    """
    page_text_map: Dict[int, List[str]] = {}
    for _, row in df.iterrows():
        meta = row.get("metadata")
        if meta is None:
            continue
        cm = meta.get("content_metadata", {})
        if cm.get("type") != "text":
            continue
        content = meta.get("content", "")
        if not content:
            continue
        page_num = cm.get("page_number", -1)
        page_text_map.setdefault(page_num, []).append(content)
    return page_text_map


def _prepare_dataframes_mod(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series]:
    """
    Prepares and returns three DataFrame-related objects from the input DataFrame.

    The function performs the following:
      1. Checks if the DataFrame is empty or if the "document_type" column is missing.
         In such a case, returns the original DataFrame, an empty DataFrame, and an empty boolean Series.
      2. Otherwise, it creates a boolean Series identifying rows where "document_type" equals IMAGE.
      3. Extracts a DataFrame containing only those rows.

    Parameters
    ----------
    df : pd.DataFrame
        The input DataFrame that should contain a "document_type" column.

    Returns
    -------
    Tuple[pd.DataFrame, pd.DataFrame, pd.Series]
        A tuple containing:
          - The original DataFrame.
          - A DataFrame filtered to rows where "document_type" is IMAGE.
          - A boolean Series indicating which rows in the original DataFrame are IMAGE rows.
    """
    try:
        if df.empty or "document_type" not in df.columns:
            return df, pd.DataFrame(), pd.Series(dtype=bool)

        bool_index: pd.Series = df["document_type"] == ContentTypeEnum.IMAGE
        df_matched: pd.DataFrame = df.loc[bool_index]

        return df, df_matched, bool_index

    except Exception as e:
        err_msg = f"_prepare_dataframes_mod: Error preparing dataframes. Original error: {e}"
        logger.error(err_msg, exc_info=True)
        raise type(e)(err_msg) from e


def _generate_captions(
    base64_images: List[str],
    prompt: str,
    system_prompt: Optional[str],
    api_key: str,
    endpoint_url: str,
    model_name: str,
    temperature: float = 1.0,
) -> List[str]:
    """
    Generates captions for a list of base64-encoded PNG images using the VLM model API.

    This function performs the following steps:
      1. Scales each image to meet encoding size requirements using `scale_image_to_encoding_size`.
      2. Constructs the input payload containing the scaled images and the provided prompt.
      3. Creates an inference client using the VLMModelInterface.
      4. Calls the client's infer method to obtain a list of captions corresponding to the images.

    Parameters
    ----------
    base64_images : List[str]
        List of base64-encoded PNG image strings.
    prompt : str
        Text prompt to guide caption generation.
    api_key : str
        API key for authenticating with the VLM endpoint.
    endpoint_url : str
        URL of the VLM model HTTP endpoint.
    model_name : str
        The name of the model to use for inference.

    Returns
    -------
    List[str]
        A list of generated captions, each corresponding to an input image.

    Raises
    ------
    Exception
        Propagates any exception encountered during caption generation, with added context.
    """
    try:
        # Scale each image to ensure it meets encoding size requirements.
        scaled_images: List[str] = [scale_image_to_encoding_size(b64)[0] for b64 in base64_images]

        # Build the input payload for the VLM model.
        data: Dict[str, Any] = {
            "base64_images": scaled_images,
            "prompt": prompt,
        }
        if system_prompt:
            data["system_prompt"] = system_prompt

        # Create the inference client using the VLMModelInterface.
        nim_client = create_inference_client(
            model_interface=VLMModelInterface(),
            endpoints=(None, endpoint_url),
            auth_token=api_key,
            infer_protocol="http",
        )

        # Perform inference to generate captions.
        captions: List[str] = nim_client.infer(data, model_name=model_name, temperature=temperature)
        return captions

    except Exception as e:
        err_msg = f"_generate_captions: Error generating captions: {e}"
        logger.error(err_msg, exc_info=True)
        raise type(e)(err_msg) from e


[docs] @unified_exception_handler def transform_image_create_vlm_caption_internal( df_transform_ledger: pd.DataFrame, task_config: Union[BaseModel, Dict[str, Any]], transform_config: Any, execution_trace_log: Optional[Dict[str, Any]] = None, ) -> pd.DataFrame: """ Extracts and adds captions for image content in a DataFrame using the VLM model API. This function updates the 'metadata' column for rows where the content type is "image". It uses configuration values from task_config (or falls back to transform_config defaults) to determine the API key, prompt, endpoint URL, and model name for caption generation. The generated captions are added under the 'image_metadata.caption' key in the metadata. Parameters ---------- df_transform_ledger : pd.DataFrame The input DataFrame containing image data. Each row must have a 'metadata' column with at least the 'content' and 'content_metadata' keys. task_config : Union[BaseModel, Dict[str, Any]] Configuration parameters for caption extraction. If provided as a Pydantic model, it will be converted to a dictionary. Expected keys include "api_key", "prompt", "endpoint_url", and "model_name". transform_config : Any A configuration object providing default values for caption extraction. It should have attributes: api_key, prompt, endpoint_url, and model_name. execution_trace_log : Optional[Dict[str, Any]], default=None Optional trace information for debugging or logging purposes. Returns ------- pd.DataFrame The updated DataFrame with generated captions added to the 'image_metadata.caption' field within the 'metadata' column for each image row. Raises ------ Exception Propagates any exception encountered during the caption extraction process, with added context. """ _ = execution_trace_log # Unused variable; placeholder to prevent linter warnings. logger.debug("Attempting to caption image content") # Convert task_config to dictionary if it is a Pydantic model. if isinstance(task_config, BaseModel): task_config = task_config.model_dump() # Retrieve configuration values with fallback to transform_config defaults. api_key: str = task_config.get("api_key") or transform_config.api_key prompt: str = task_config.get("prompt") or transform_config.prompt system_prompt: str = task_config.get("system_prompt") or transform_config.system_prompt endpoint_url: str = task_config.get("endpoint_url") or transform_config.endpoint_url model_name: str = task_config.get("model_name") or transform_config.model_name # Context text: task config overrides pipeline default. context_text_max_chars: int = task_config.get("context_text_max_chars") or getattr( transform_config, "context_text_max_chars", 0 ) # Temperature: task config overrides pipeline default. temperature: float = task_config.get("temperature") or getattr(transform_config, "temperature", 1.0) # Create a mask for rows where the content type is "image". df_mask: pd.Series = df_transform_ledger["metadata"].apply( lambda meta: meta.get("content_metadata", {}).get("type") == "image" ) # If no image rows exist, return the original DataFrame. if not df_mask.any(): return df_transform_ledger if context_text_max_chars and context_text_max_chars > 0: page_text_map = _build_page_text_map(df_transform_ledger) for idx in df_transform_ledger.loc[df_mask].index: meta: Dict[str, Any] = df_transform_ledger.at[idx, "metadata"] base64_image: str = meta["content"] context_text = _gather_context_text_for_image(meta, page_text_map, context_text_max_chars) enriched_prompt = _build_prompt_with_context(prompt, context_text) captions: List[str] = _generate_captions( [base64_image], enriched_prompt, system_prompt, api_key, endpoint_url, model_name, temperature=temperature, ) image_meta: Dict[str, Any] = meta.get("image_metadata", {}) image_meta["caption"] = captions[0] if captions else "" meta["image_metadata"] = image_meta df_transform_ledger.at[idx, "metadata"] = meta else: base64_images: List[str] = ( df_transform_ledger.loc[df_mask, "metadata"].apply(lambda meta: meta["content"]).tolist() ) captions: List[str] = _generate_captions( base64_images, prompt, system_prompt, api_key, endpoint_url, model_name, temperature=temperature, ) for idx, caption in zip(df_transform_ledger.loc[df_mask].index, captions): meta: Dict[str, Any] = df_transform_ledger.at[idx, "metadata"] image_meta: Dict[str, Any] = meta.get("image_metadata", {}) image_meta["caption"] = caption meta["image_metadata"] = image_meta df_transform_ledger.at[idx, "metadata"] = meta logger.debug("Image content captioning complete") result, execution_trace_log = df_transform_ledger, {} _ = execution_trace_log # Unused variable; placeholder to prevent linter warnings. return result