Source code for nv_ingest.framework.orchestration.morpheus.stages.transforms.embed_text_stage

# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import functools
import logging
from typing import Dict, Any

from nv_ingest.framework.orchestration.morpheus.stages.meta.multiprocessing_stage import MultiProcessingBaseStage
from nv_ingest_api.internal.schemas.transform.transform_text_embedding_schema import TextEmbeddingSchema
from nv_ingest_api.internal.transform.embed_text import transform_create_text_embeddings_internal

logger = logging.getLogger(__name__)


# ------------------------------------------------------------------------------
# Stage Generation
# ------------------------------------------------------------------------------


[docs] def generate_text_embed_extractor_stage( c: Any, transform_config: Dict[str, Any], task: str = "embed", task_desc: str = "text_embed_extraction", pe_count: int = 1, ): """ Generates a multiprocessing stage to perform text embedding extraction from a pandas DataFrame. Parameters ---------- c : Any Global configuration object. transform_config : Dict[str, Any] Configuration parameters for the text embedding extractor, validated against EmbedExtractionsSchema. task : str, optional The task name for the stage worker function (default is "embed"). task_desc : str, optional A descriptor used for latency tracing and logging (default is "text_embed_extraction"). pe_count : int, optional Number of process engines to use concurrently (default is 1). Returns ------- MultiProcessingBaseStage A configured stage that processes a pandas DataFrame and returns a tuple of (DataFrame, trace_info dict). """ validated_config = TextEmbeddingSchema(**transform_config) _wrapped_process_fn = functools.partial( transform_create_text_embeddings_internal, transform_config=validated_config ) return MultiProcessingBaseStage( c=c, pe_count=pe_count, task=task, task_desc=task_desc, process_fn=_wrapped_process_fn )