Source code for nv_ingest_api.internal.extract.audio.audio_extraction

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging

import pandas as pd
import functools
import uuid
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple
import base64
from pathlib import Path

from nv_ingest_api.internal.enums.common import ContentTypeEnum
from nv_ingest_api.internal.primitives.nim.model_interface.parakeet import create_audio_inference_client
from nv_ingest_api.internal.schemas.extract.extract_audio_schema import AudioExtractorSchema
from nv_ingest_api.internal.schemas.meta.metadata_schema import MetadataSchema, AudioMetadataSchema
from nv_ingest_api.util.exception_handlers.decorators import unified_exception_handler
from nv_ingest_api.util.schema.schema_validator import validate_schema
from nv_ingest_api.interface.utility import read_file_as_base64

logger = logging.getLogger(__name__)


@unified_exception_handler
def _extract_from_audio(row: pd.Series, audio_client: Any, trace_info: Dict, segment_audio: bool = False) -> Dict:
    """
    Modifies the metadata of a row if the conditions for table extraction are met.

    Parameters
    ----------
    row : pd.Series
        A row from the DataFrame containing metadata for the audio extraction.

    audio_client : Any
        The client used to call the audio inference model.

    trace_info : Dict
        Trace information used for logging or debugging.

    Returns
    -------
    Dict
        The modified metadata if conditions are met, otherwise the original metadata.

    Raises
    ------
    ValueError
        If critical information (such as metadata) is missing from the row.
    """

    metadata = row.get("metadata")

    if metadata is None:
        logger.error("Row does not contain 'metadata'.")
        raise ValueError("Row does not contain 'metadata'.")

    base64_audio = metadata.pop("content")
    try:
        base64_file_path = base64_audio
        if not base64_file_path:
            return [row.to_list()]
        base64_file_path = base64.b64decode(base64_file_path).decode("utf-8")
        if not base64_file_path:
            return [row.to_list()]
        if Path(base64_file_path).exists():
            base64_audio = read_file_as_base64(base64_file_path)
    except (UnicodeDecodeError, base64.binascii.Error):
        pass
    content_metadata = metadata.get("content_metadata", {})

    # Only extract transcript if content type is audio
    if (content_metadata.get("type") != ContentTypeEnum.AUDIO) or (base64_audio in (None, "")):
        return [row.to_list()]

    logger.debug(f"Removing file {base64_file_path}")
    Path(base64_file_path).unlink(missing_ok=True)

    # Get the result from the inference model
    segments, transcript = audio_client.infer(
        base64_audio,
        model_name="parakeet",
        trace_info=trace_info,  # traceable_func arg
        stage_name="audio_extraction",
    )

    extracted_data = []
    if segment_audio:
        for segment in segments:
            segment_metadata = metadata.copy()
            audio_metadata = {"audio_transcript": segment["text"]}
            segment_metadata["audio_metadata"] = validate_schema(audio_metadata, AudioMetadataSchema).model_dump()
            segment_metadata["content_metadata"]["start_time"] = segment["start"]
            segment_metadata["content_metadata"]["end_time"] = segment["end"]

            extracted_data.append(
                [
                    ContentTypeEnum.AUDIO,
                    validate_schema(segment_metadata, MetadataSchema).model_dump(),
                    str(uuid.uuid4()),
                ]
            )
    else:
        audio_metadata = {"audio_transcript": transcript}
        metadata["audio_metadata"] = validate_schema(audio_metadata, AudioMetadataSchema).model_dump()
        extracted_data.append(
            [ContentTypeEnum.AUDIO, validate_schema(metadata, MetadataSchema).model_dump(), str(uuid.uuid4())]
        )

    return extracted_data


[docs] def extract_text_from_audio_internal( df_extraction_ledger: pd.DataFrame, task_config: Dict[str, Any], extraction_config: AudioExtractorSchema, execution_trace_log: Optional[Dict] = None, ) -> Tuple[pd.DataFrame, Dict]: """ Extracts audio data from a DataFrame. Parameters ---------- df_extraction_ledger : pd.DataFrame DataFrame containing the content from which audio data is to be extracted. task_config : Dict[str, Any] Dictionary containing task properties and configurations. extraction_config : Any The validated configuration object for audio extraction. execution_trace_log : Optional[Dict], optional Optional trace information for debugging or logging. Defaults to None. Returns ------- Tuple[pd.DataFrame, Dict] A tuple containing the updated DataFrame and the trace information. Raises ------ Exception If any error occurs during the audio data extraction process. """ logger.debug(f"Entering audio extraction stage with {len(df_extraction_ledger)} rows.") extract_params = task_config.get("params", {}).get("extract_audio_params", {}) audio_extraction_config = extraction_config.audio_extraction_config grpc_endpoint = extract_params.get("grpc_endpoint") or audio_extraction_config.audio_endpoints[0] http_endpoint = extract_params.get("http_endpoint") or audio_extraction_config.audio_endpoints[1] infer_protocol = extract_params.get("infer_protocol") or audio_extraction_config.audio_infer_protocol auth_token = extract_params.get("auth_token") or audio_extraction_config.auth_token function_id = extract_params.get("function_id") or audio_extraction_config.function_id use_ssl = extract_params.get("use_ssl") or audio_extraction_config.use_ssl ssl_cert = extract_params.get("ssl_cert") or audio_extraction_config.ssl_cert segment_audio = extract_params.get("segment_audio") or audio_extraction_config.segment_audio parakeet_client = create_audio_inference_client( (grpc_endpoint, http_endpoint), infer_protocol=infer_protocol, auth_token=auth_token, function_id=function_id, use_ssl=use_ssl, ssl_cert=ssl_cert, ) if execution_trace_log is None: execution_trace_log = {} logger.debug("No trace_info provided. Initialized empty trace_info dictionary.") try: # Create a partial function to extract using the provided configurations. _extract_from_audio_partial = functools.partial( _extract_from_audio, audio_client=parakeet_client, trace_info=execution_trace_log, segment_audio=segment_audio, ) # Apply the _extract_from_audio_partial function to each row in the DataFrame extraction_series = df_extraction_ledger.apply(_extract_from_audio_partial, axis=1) # Explode the results if the extraction returns lists. extraction_series = extraction_series.explode().dropna() # Convert the extracted results into a DataFrame. if not extraction_series.empty: extracted_df = pd.DataFrame(extraction_series.to_list(), columns=["document_type", "metadata", "uuid"]) else: extracted_df = pd.DataFrame({"document_type": [], "metadata": [], "uuid": []}) return extracted_df, execution_trace_log except Exception as e: logger.exception(f"Error occurred while extracting audio data: {e}", exc_info=True) raise