Source code for nv_ingest_api.internal.transform.embed_text

# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, Tuple, Optional, Iterable, List

import pandas as pd
from openai import OpenAI

from nv_ingest_api.internal.enums.common import ContentTypeEnum, StatusEnum, TaskTypeEnum
from nv_ingest_api.internal.schemas.meta.metadata_schema import (
    InfoMessageMetadataSchema,
)
from nv_ingest_api.internal.schemas.transform.transform_text_embedding_schema import TextEmbeddingSchema
from nv_ingest_api.util.schema.schema_validator import validate_schema

logger = logging.getLogger(__name__)


# ------------------------------------------------------------------------------
# Asynchronous Embedding Requests
# ------------------------------------------------------------------------------


def _make_async_request(
    prompts: List[str],
    api_key: str,
    embedding_nim_endpoint: str,
    embedding_model: str,
    encoding_format: str,
    input_type: str,
    truncate: str,
    filter_errors: bool,
) -> list:
    """
    Interacts directly with the NIM embedding service to calculate embeddings for a batch of prompts.

    Parameters
    ----------
    prompts : List[str]
        A list of prompt strings for which embeddings are to be calculated.
    api_key : str
        API key for authentication with the embedding service.
    embedding_nim_endpoint : str
        Base URL for the NIM embedding service.
    embedding_model : str
        The model to use for generating embeddings.
    encoding_format : str
        The desired encoding format.
    input_type : str
        The type of input data.
    truncate : str
        Truncation setting for the input data.
    filter_errors : bool
        Flag indicating whether to filter errors in the response.

    Returns
    -------
    list
        A dictionary with keys "embedding" (the embedding results) and "info_msg" (any error info).

    Raises
    ------
    RuntimeError
        If an error occurs during the embedding request, with an info message attached.
    """
    response = {}

    try:
        client = OpenAI(
            api_key=api_key,
            base_url=embedding_nim_endpoint,
        )

        resp = client.embeddings.create(
            input=prompts,
            model=embedding_model,
            encoding_format=encoding_format,
            extra_body={"input_type": input_type, "truncate": truncate},
        )

        response["embedding"] = resp.data
        response["info_msg"] = None

    except Exception as err:
        info_msg = {
            "task": TaskTypeEnum.EMBED.value,
            "status": StatusEnum.ERROR.value,
            "message": f"Embedding error: {err}",
            "filter": filter_errors,
        }
        validated_info_msg = validate_schema(info_msg, InfoMessageMetadataSchema).model_dump()

        response["embedding"] = [None] * len(prompts)
        response["info_msg"] = validated_info_msg

        raise RuntimeError(f"Embedding error occurred. Info message: {validated_info_msg}") from err

    return response


def _async_request_handler(
    prompts: List[str],
    api_key: str,
    embedding_nim_endpoint: str,
    embedding_model: str,
    encoding_format: str,
    input_type: str,
    truncate: str,
    filter_errors: bool,
) -> List[dict]:
    """
    Gathers calculated embedding results from the NIM embedding service concurrently.

    Parameters
    ----------
    prompts : List[str]
        A list of prompt batches.
    api_key : str
        API key for authentication.
    embedding_nim_endpoint : str
        Base URL for the NIM embedding service.
    embedding_model : str
        The model to use for generating embeddings.
    encoding_format : str
        The desired encoding format.
    input_type : str
        The type of input data.
    truncate : str
        Truncation setting for the input data.
    filter_errors : bool
        Flag indicating whether to filter errors in the response.

    Returns
    -------
    List[dict]
        A list of response dictionaries from the embedding service.
    """
    with ThreadPoolExecutor() as executor:
        futures = [
            executor.submit(
                _make_async_request,
                prompts=prompt_batch,
                api_key=api_key,
                embedding_nim_endpoint=embedding_nim_endpoint,
                embedding_model=embedding_model,
                encoding_format=encoding_format,
                input_type=input_type,
                truncate=truncate,
                filter_errors=filter_errors,
            )
            for prompt_batch in prompts
        ]
        results = [future.result() for future in futures]

    return results


def _async_runner(
    prompts: List[str],
    api_key: str,
    embedding_nim_endpoint: str,
    embedding_model: str,
    encoding_format: str,
    input_type: str,
    truncate: str,
    filter_errors: bool,
) -> dict:
    """
    Concurrently launches all NIM embedding requests and flattens the results.

    Parameters
    ----------
    prompts : List[str]
        A list of prompt batches.
    api_key : str
        API key for authentication.
    embedding_nim_endpoint : str
        Base URL for the NIM embedding service.
    embedding_model : str
        The model to use for generating embeddings.
    encoding_format : str
        The desired encoding format.
    input_type : str
        The type of input data.
    truncate : str
        Truncation setting for the input data.
    filter_errors : bool
        Flag indicating whether to filter errors in the response.

    Returns
    -------
    dict
        A dictionary with keys "embeddings" (flattened embedding results) and "info_msgs" (error messages).
    """
    results = _async_request_handler(
        prompts,
        api_key,
        embedding_nim_endpoint,
        embedding_model,
        encoding_format,
        input_type,
        truncate,
        filter_errors,
    )

    flat_results = {"embeddings": [], "info_msgs": []}
    for batch_dict in results:
        info_msg = batch_dict["info_msg"]
        for embedding in batch_dict["embedding"]:
            if not isinstance(embedding, list):
                if embedding is not None:
                    flat_results["embeddings"].append(embedding.embedding)
                else:
                    flat_results["embeddings"].append(embedding)
            else:
                flat_results["embeddings"].append(embedding)
            flat_results["info_msgs"].append(info_msg)

    return flat_results


# ------------------------------------------------------------------------------
# Pandas UDFs for Content Extraction
# ------------------------------------------------------------------------------


def _add_embeddings(row, embeddings, info_msgs):
    """
    Updates a DataFrame row with embedding data and associated error info.
    Ensures the 'embedding' field is always present, even if None.

    Parameters
    ----------
    row : pandas.Series
        A row of the DataFrame.
    embeddings : dict
        Dictionary mapping row indices to embeddings.
    info_msgs : dict
        Dictionary mapping row indices to info message dicts.

    Returns
    -------
    pandas.Series
        The updated row with 'embedding', 'info_message_metadata', and
        '_contains_embeddings' appropriately set.
    """
    embedding = embeddings.get(row.name, None)
    info_msg = info_msgs.get(row.name, None)

    # Always set embedding, even if None
    row["metadata"]["embedding"] = embedding

    if info_msg:
        row["metadata"]["info_message_metadata"] = info_msg
        row["document_type"] = ContentTypeEnum.INFO_MSG
        row["_contains_embeddings"] = False
    else:
        row["_contains_embeddings"] = embedding is not None

    return row


def _get_pandas_text_content(row):
    """
    Extracts text content from a DataFrame row.

    Parameters
    ----------
    row : pandas.Series
        A row containing the 'content' key.

    Returns
    -------
    str
        The text content from the row.
    """
    return row["content"]


def _get_pandas_table_content(row):
    """
    Extracts table/chart content from a DataFrame row.

    Parameters
    ----------
    row : pandas.Series
        A row containing 'table_metadata' with 'table_content'.

    Returns
    -------
    str
        The table/chart content from the row.
    """
    return row.get("table_metadata", {}).get("table_content")


def _get_pandas_image_content(row):
    """
    Extracts image caption content from a DataFrame row.

    Parameters
    ----------
    row : pandas.Series
        A row containing 'image_metadata' with 'caption'.

    Returns
    -------
    str
        The image caption from the row.
    """
    return row.get("image_metadata", {}).get("caption")


def _get_pandas_audio_content(row):
    """
    A pandas UDF used to select extracted audio transcription to be used to create embeddings.
    """
    return row.get("audio_metadata", {}).get("audio_transcript")


# ------------------------------------------------------------------------------
# Batch Processing Utilities
# ------------------------------------------------------------------------------


def _batch_generator(iterable: Iterable, batch_size: int = 10):
    """
    Yields batches of a specified size from an iterable.

    Parameters
    ----------
    iterable : Iterable
        The iterable to batch.
    batch_size : int, optional
        The size of each batch (default is 10).

    Yields
    ------
    list
        A batch of items from the iterable.
    """
    iter_len = len(iterable)
    for idx in range(0, iter_len, batch_size):
        yield iterable[idx : min(idx + batch_size, iter_len)]


def _generate_batches(prompts: List[str], batch_size: int = 100) -> List[str]:
    """
    Splits a list of prompts into batches.

    Parameters
    ----------
    prompts : List[str]
        The list of prompt strings.
    batch_size : int, optional
        The desired batch size (default is 100).

    Returns
    -------
    List[List[str]]
        A list of batches, each containing a subset of the prompts.
    """
    return [batch for batch in _batch_generator(prompts, batch_size)]


# ------------------------------------------------------------------------------
# DataFrame Concatenation Utility
# ------------------------------------------------------------------------------


def _concatenate_extractions_pandas(
    base_df: pd.DataFrame, dataframes: List[pd.DataFrame], masks: List[pd.Series]
) -> pd.DataFrame:
    """
    Concatenates processed DataFrame rows (with embeddings) with unprocessed rows from the base DataFrame.

    Parameters
    ----------
    base_df : pd.DataFrame
        The original DataFrame.
    dataframes : List[pd.DataFrame]
        List of DataFrames that have been enriched with embeddings.
    masks : List[pd.Series]
        List of boolean masks indicating the rows that were processed.

    Returns
    -------
    pd.DataFrame
        The concatenated DataFrame with embeddings applied where available.
    """
    unified_mask = pd.Series(False, index=base_df.index)
    for mask in masks:
        unified_mask = unified_mask | mask

    df_no_text = base_df.loc[~unified_mask].copy()
    df_no_text["_contains_embeddings"] = False

    dataframes.append(df_no_text)
    combined_df = pd.concat(dataframes, axis=0, ignore_index=True).reset_index(drop=True)
    return combined_df


# ------------------------------------------------------------------------------
# Embedding Extraction Pipeline
# ------------------------------------------------------------------------------


[docs] def transform_create_text_embeddings_internal( df_transform_ledger: pd.DataFrame, task_config: Dict[str, Any], transform_config: TextEmbeddingSchema = TextEmbeddingSchema(), execution_trace_log: Optional[Dict] = None, ) -> Tuple[pd.DataFrame, Dict]: """ Generates text embeddings for supported content types (TEXT, STRUCTURED, IMAGE, AUDIO) from a pandas DataFrame using asynchronous requests. This function ensures that even if the extracted content is empty or None, the embedding field is explicitly created and set to None. Parameters ---------- df_transform_ledger : pd.DataFrame The DataFrame containing content for embedding extraction. task_config : Dict[str, Any] Dictionary containing task properties (e.g., filter error flag). transform_config : TextEmbeddingSchema, optional Validated configuration for text embedding extraction. execution_trace_log : Optional[Dict], optional Optional trace information for debugging or logging (default is None). Returns ------- Tuple[pd.DataFrame, Dict] A tuple containing: - The updated DataFrame with embeddings applied. - A dictionary with trace information. """ api_key = task_config.get("api_key") or transform_config.api_key endpoint_url = task_config.get("endpoint_url") or transform_config.embedding_nim_endpoint model_name = task_config.get("model_name") or transform_config.embedding_model if execution_trace_log is None: execution_trace_log = {} logger.debug("No trace_info provided. Initialized empty trace_info dictionary.") if df_transform_ledger.empty: return df_transform_ledger, {"trace_info": execution_trace_log} embedding_dataframes = [] content_masks = [] pandas_content_extractor = { ContentTypeEnum.TEXT: _get_pandas_text_content, ContentTypeEnum.STRUCTURED: _get_pandas_table_content, ContentTypeEnum.IMAGE: _get_pandas_image_content, ContentTypeEnum.AUDIO: _get_pandas_audio_content, ContentTypeEnum.VIDEO: lambda x: None, # Not supported yet. } def _content_type_getter(row): return row["content_metadata"]["type"] for content_type, content_getter in pandas_content_extractor.items(): if not content_getter: logger.debug(f"Skipping unsupported content type: {content_type}") continue # Get rows matching the content type content_mask = df_transform_ledger["metadata"].apply(_content_type_getter) == content_type.value if not content_mask.any(): continue # Always include all content_mask rows and prepare them df_content = df_transform_ledger.loc[content_mask].copy().reset_index(drop=True) # Extract content and normalize empty or non-str to None extracted_content = ( df_content["metadata"] .apply(content_getter) .apply(lambda x: x.strip() if isinstance(x, str) and x.strip() else None) ) df_content["_content"] = extracted_content # Prepare batches for only valid (non-None) content valid_content_mask = df_content["_content"].notna() if valid_content_mask.any(): filtered_content_batches = _generate_batches( df_content.loc[valid_content_mask, "_content"].tolist(), batch_size=transform_config.batch_size ) content_embeddings = _async_runner( filtered_content_batches, api_key, endpoint_url, model_name, transform_config.encoding_format, transform_config.input_type, transform_config.truncate, False, ) # Build a simple row index -> embedding map embeddings_dict = dict( zip(df_content.loc[valid_content_mask].index, content_embeddings.get("embeddings", [])) ) info_msgs_dict = dict( zip(df_content.loc[valid_content_mask].index, content_embeddings.get("info_msgs", [])) ) else: embeddings_dict = {} info_msgs_dict = {} # Apply embeddings or None to all rows df_content = df_content.apply(_add_embeddings, embeddings=embeddings_dict, info_msgs=info_msgs_dict, axis=1) embedding_dataframes.append(df_content) content_masks.append(content_mask) combined_df = _concatenate_extractions_pandas(df_transform_ledger, embedding_dataframes, content_masks) return combined_df, {"trace_info": execution_trace_log}