Source code for nv_ingest_api.internal.extract.audio.audio_extraction
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import pandas as pd
import functools
import uuid
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple
from nv_ingest_api.internal.enums.common import ContentTypeEnum
from nv_ingest_api.internal.primitives.nim.model_interface.parakeet import create_audio_inference_client
from nv_ingest_api.internal.schemas.extract.extract_audio_schema import AudioExtractorSchema
from nv_ingest_api.internal.schemas.meta.metadata_schema import MetadataSchema, AudioMetadataSchema
from nv_ingest_api.util.exception_handlers.decorators import unified_exception_handler
from nv_ingest_api.util.schema.schema_validator import validate_schema
logger = logging.getLogger(__name__)
@unified_exception_handler
def _extract_from_audio(row: pd.Series, audio_client: Any, trace_info: Dict, segment_audio: bool = False) -> Dict:
"""
Modifies the metadata of a row if the conditions for table extraction are met.
Parameters
----------
row : pd.Series
A row from the DataFrame containing metadata for the audio extraction.
audio_client : Any
The client used to call the audio inference model.
trace_info : Dict
Trace information used for logging or debugging.
Returns
-------
Dict
The modified metadata if conditions are met, otherwise the original metadata.
Raises
------
ValueError
If critical information (such as metadata) is missing from the row.
"""
metadata = row.get("metadata")
if metadata is None:
logger.error("Row does not contain 'metadata'.")
raise ValueError("Row does not contain 'metadata'.")
base64_audio = metadata.pop("content")
content_metadata = metadata.get("content_metadata", {})
# Only extract transcript if content type is audio
if (content_metadata.get("type") != ContentTypeEnum.AUDIO) or (base64_audio in (None, "")):
return [row.to_list()]
# Get the result from the inference model
segments, transcript = audio_client.infer(
base64_audio,
model_name="parakeet",
trace_info=trace_info, # traceable_func arg
stage_name="audio_extraction",
)
extracted_data = []
if segment_audio:
for segment in segments:
segment_metadata = metadata.copy()
audio_metadata = {"audio_transcript": segment["text"]}
segment_metadata["audio_metadata"] = validate_schema(audio_metadata, AudioMetadataSchema).model_dump()
segment_metadata["content_metadata"]["start_time"] = segment["start"]
segment_metadata["content_metadata"]["end_time"] = segment["end"]
extracted_data.append(
[
ContentTypeEnum.AUDIO,
validate_schema(segment_metadata, MetadataSchema).model_dump(),
str(uuid.uuid4()),
]
)
else:
audio_metadata = {"audio_transcript": transcript}
metadata["audio_metadata"] = validate_schema(audio_metadata, AudioMetadataSchema).model_dump()
extracted_data.append(
[ContentTypeEnum.AUDIO, validate_schema(metadata, MetadataSchema).model_dump(), str(uuid.uuid4())]
)
return extracted_data
[docs]
def extract_text_from_audio_internal(
df_extraction_ledger: pd.DataFrame,
task_config: Dict[str, Any],
extraction_config: AudioExtractorSchema,
execution_trace_log: Optional[Dict] = None,
) -> Tuple[pd.DataFrame, Dict]:
"""
Extracts audio data from a DataFrame.
Parameters
----------
df_extraction_ledger : pd.DataFrame
DataFrame containing the content from which audio data is to be extracted.
task_config : Dict[str, Any]
Dictionary containing task properties and configurations.
extraction_config : Any
The validated configuration object for audio extraction.
execution_trace_log : Optional[Dict], optional
Optional trace information for debugging or logging. Defaults to None.
Returns
-------
Tuple[pd.DataFrame, Dict]
A tuple containing the updated DataFrame and the trace information.
Raises
------
Exception
If any error occurs during the audio data extraction process.
"""
logger.debug(f"Entering audio extraction stage with {len(df_extraction_ledger)} rows.")
extract_params = task_config.get("params", {}).get("extract_audio_params", {})
audio_extraction_config = extraction_config.audio_extraction_config
grpc_endpoint = extract_params.get("grpc_endpoint") or audio_extraction_config.audio_endpoints[0]
http_endpoint = extract_params.get("http_endpoint") or audio_extraction_config.audio_endpoints[1]
infer_protocol = extract_params.get("infer_protocol") or audio_extraction_config.audio_infer_protocol
auth_token = extract_params.get("auth_token") or audio_extraction_config.auth_token
function_id = extract_params.get("function_id") or audio_extraction_config.function_id
use_ssl = extract_params.get("use_ssl") or audio_extraction_config.use_ssl
ssl_cert = extract_params.get("ssl_cert") or audio_extraction_config.ssl_cert
segment_audio = extract_params.get("segment_audio") or audio_extraction_config.segment_audio
parakeet_client = create_audio_inference_client(
(grpc_endpoint, http_endpoint),
infer_protocol=infer_protocol,
auth_token=auth_token,
function_id=function_id,
use_ssl=use_ssl,
ssl_cert=ssl_cert,
)
if execution_trace_log is None:
execution_trace_log = {}
logger.debug("No trace_info provided. Initialized empty trace_info dictionary.")
try:
# Create a partial function to extract using the provided configurations.
_extract_from_audio_partial = functools.partial(
_extract_from_audio,
audio_client=parakeet_client,
trace_info=execution_trace_log,
segment_audio=segment_audio,
)
# Apply the _extract_from_audio_partial function to each row in the DataFrame
extraction_series = df_extraction_ledger.apply(_extract_from_audio_partial, axis=1)
# Explode the results if the extraction returns lists.
extraction_series = extraction_series.explode().dropna()
# Convert the extracted results into a DataFrame.
if not extraction_series.empty:
extracted_df = pd.DataFrame(extraction_series.to_list(), columns=["document_type", "metadata", "uuid"])
else:
extracted_df = pd.DataFrame({"document_type": [], "metadata": [], "uuid": []})
return extracted_df, execution_trace_log
except Exception as e:
logger.exception(f"Error occurred while extracting audio data: {e}", exc_info=True)
raise