Source code for nv_ingest_api.internal.transform.caption_image

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import Any, Dict, List, Optional, Tuple, Union

import pandas as pd
from pydantic import BaseModel

from nv_ingest_api.internal.primitives.nim.model_interface.vlm import VLMModelInterface
from nv_ingest_api.internal.enums.common import ContentTypeEnum
from nv_ingest_api.util.exception_handlers.decorators import unified_exception_handler
from nv_ingest_api.util.image_processing import scale_image_to_encoding_size
from nv_ingest_api.util.nim import create_inference_client

logger = logging.getLogger(__name__)


def _prepare_dataframes_mod(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series]:
    """
    Prepares and returns three DataFrame-related objects from the input DataFrame.

    The function performs the following:
      1. Checks if the DataFrame is empty or if the "document_type" column is missing.
         In such a case, returns the original DataFrame, an empty DataFrame, and an empty boolean Series.
      2. Otherwise, it creates a boolean Series identifying rows where "document_type" equals IMAGE.
      3. Extracts a DataFrame containing only those rows.

    Parameters
    ----------
    df : pd.DataFrame
        The input DataFrame that should contain a "document_type" column.

    Returns
    -------
    Tuple[pd.DataFrame, pd.DataFrame, pd.Series]
        A tuple containing:
          - The original DataFrame.
          - A DataFrame filtered to rows where "document_type" is IMAGE.
          - A boolean Series indicating which rows in the original DataFrame are IMAGE rows.
    """
    try:
        if df.empty or "document_type" not in df.columns:
            return df, pd.DataFrame(), pd.Series(dtype=bool)

        bool_index: pd.Series = df["document_type"] == ContentTypeEnum.IMAGE
        df_matched: pd.DataFrame = df.loc[bool_index]

        return df, df_matched, bool_index

    except Exception as e:
        err_msg = f"_prepare_dataframes_mod: Error preparing dataframes. Original error: {e}"
        logger.error(err_msg, exc_info=True)
        raise type(e)(err_msg) from e


def _generate_captions(
    base64_images: List[str], prompt: str, api_key: str, endpoint_url: str, model_name: str
) -> List[str]:
    """
    Generates captions for a list of base64-encoded PNG images using the VLM model API.

    This function performs the following steps:
      1. Scales each image to meet encoding size requirements using `scale_image_to_encoding_size`.
      2. Constructs the input payload containing the scaled images and the provided prompt.
      3. Creates an inference client using the VLMModelInterface.
      4. Calls the client's infer method to obtain a list of captions corresponding to the images.

    Parameters
    ----------
    base64_images : List[str]
        List of base64-encoded PNG image strings.
    prompt : str
        Text prompt to guide caption generation.
    api_key : str
        API key for authenticating with the VLM endpoint.
    endpoint_url : str
        URL of the VLM model HTTP endpoint.
    model_name : str
        The name of the model to use for inference.

    Returns
    -------
    List[str]
        A list of generated captions, each corresponding to an input image.

    Raises
    ------
    Exception
        Propagates any exception encountered during caption generation, with added context.
    """
    try:
        # Scale each image to ensure it meets encoding size requirements.
        scaled_images: List[str] = [scale_image_to_encoding_size(b64)[0] for b64 in base64_images]

        # Build the input payload for the VLM model.
        data: Dict[str, Any] = {
            "base64_images": scaled_images,
            "prompt": prompt,
        }

        # Create the inference client using the VLMModelInterface.
        nim_client = create_inference_client(
            model_interface=VLMModelInterface(),
            endpoints=(None, endpoint_url),
            auth_token=api_key,
            infer_protocol="http",
        )

        logger.debug(f"Calling VLM endpoint: {endpoint_url} with model: {model_name}")
        # Perform inference to generate captions.
        captions: List[str] = nim_client.infer(data, model_name=model_name)
        return captions

    except Exception as e:
        err_msg = f"_generate_captions: Error generating captions: {e}"
        logger.error(err_msg, exc_info=True)
        raise type(e)(err_msg) from e


[docs] @unified_exception_handler def transform_image_create_vlm_caption_internal( df_transform_ledger: pd.DataFrame, task_config: Union[BaseModel, Dict[str, Any]], transform_config: Any, execution_trace_log: Optional[Dict[str, Any]] = None, ) -> pd.DataFrame: """ Extracts and adds captions for image content in a DataFrame using the VLM model API. This function updates the 'metadata' column for rows where the content type is "image". It uses configuration values from task_config (or falls back to transform_config defaults) to determine the API key, prompt, endpoint URL, and model name for caption generation. The generated captions are added under the 'image_metadata.caption' key in the metadata. Parameters ---------- df_transform_ledger : pd.DataFrame The input DataFrame containing image data. Each row must have a 'metadata' column with at least the 'content' and 'content_metadata' keys. task_config : Union[BaseModel, Dict[str, Any]] Configuration parameters for caption extraction. If provided as a Pydantic model, it will be converted to a dictionary. Expected keys include "api_key", "prompt", "endpoint_url", and "model_name". transform_config : Any A configuration object providing default values for caption extraction. It should have attributes: api_key, prompt, endpoint_url, and model_name. execution_trace_log : Optional[Dict[str, Any]], default=None Optional trace information for debugging or logging purposes. Returns ------- pd.DataFrame The updated DataFrame with generated captions added to the 'image_metadata.caption' field within the 'metadata' column for each image row. Raises ------ Exception Propagates any exception encountered during the caption extraction process, with added context. """ _ = execution_trace_log # Unused variable; placeholder to prevent linter warnings. logger.debug("Attempting to caption image content") # Convert task_config to dictionary if it is a Pydantic model. if isinstance(task_config, BaseModel): task_config = task_config.model_dump() # Retrieve configuration values with fallback to transform_config defaults. api_key: str = task_config.get("api_key") or transform_config.api_key prompt: str = task_config.get("prompt") or transform_config.prompt endpoint_url: str = task_config.get("endpoint_url") or transform_config.endpoint_url model_name: str = task_config.get("model_name") or transform_config.model_name # Create a mask for rows where the content type is "image". df_mask: pd.Series = df_transform_ledger["metadata"].apply( lambda meta: meta.get("content_metadata", {}).get("type") == "image" ) # If no image rows exist, return the original DataFrame. if not df_mask.any(): return df_transform_ledger # Collect base64-encoded images from the rows where the content type is "image". base64_images: List[str] = df_transform_ledger.loc[df_mask, "metadata"].apply(lambda meta: meta["content"]).tolist() # Generate captions for the collected images. captions: List[str] = _generate_captions(base64_images, prompt, api_key, endpoint_url, model_name) # Update the DataFrame: assign each generated caption to the corresponding row. for idx, caption in zip(df_transform_ledger.loc[df_mask].index, captions): meta: Dict[str, Any] = df_transform_ledger.at[idx, "metadata"] image_meta: Dict[str, Any] = meta.get("image_metadata", {}) image_meta["caption"] = caption meta["image_metadata"] = image_meta df_transform_ledger.at[idx, "metadata"] = meta logger.debug("Image content captioning complete") result, execution_trace_log = df_transform_ledger, {} _ = execution_trace_log # Unused variable; placeholder to prevent linter warnings. return result