Source code for nv_ingest.framework.orchestration.morpheus.modules.sinks.vdb_task_sink

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import logging
import os
import pickle
import time
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
from urllib.parse import urlparse

import mrc
import pandas as pd
from minio import Minio
from morpheus.utils.control_message_utils import cm_skip_processing_if_failed
from morpheus.utils.module_ids import WRITE_TO_VECTOR_DB
from morpheus.utils.module_utils import ModuleLoaderFactory
from morpheus.utils.module_utils import register_module
from morpheus_llm.service.vdb.milvus_client import DATA_TYPE_MAP
from morpheus_llm.service.vdb.utils import VectorDBServiceFactory
from morpheus_llm.service.vdb.vector_db_service import VectorDBService
from mrc.core import operators as ops
from pymilvus import BulkInsertState
from pymilvus import connections
from pymilvus import utility

from nv_ingest.framework.schemas.framework_vdb_task_sink_schema import VdbTaskSinkSchema
from nv_ingest_api.util.exception_handlers.decorators import nv_ingest_node_failure_context_manager
from nv_ingest.framework.util.flow_control import filter_by_task
from nv_ingest.framework.orchestration.morpheus.util.modules.config_validator import (
    fetch_and_validate_module_config,
)
from nv_ingest_api.internal.primitives.tracing.tagging import traceable
from nv_ingest_api.internal.primitives.ingest_control_message import IngestControlMessage, remove_task_by_type

logger = logging.getLogger(__name__)

MODULE_NAME = "vdb_task_sink"
MODULE_NAMESPACE = "nv_ingest"

VDBTaskSinkLoaderFactory = ModuleLoaderFactory(MODULE_NAME, MODULE_NAMESPACE, VdbTaskSinkSchema)

_DEFAULT_ENDPOINT = os.environ.get("MINIO_INTERNAL_ADDRESS", "minio:9000")
_DEFAULT_BUCKET_NAME = os.environ.get("MINIO_BUCKET", "nv-ingest")


def _bulk_ingest(
    milvus_uri: str = None,
    collection_name: str = None,
    bucket_name: str = None,
    bulk_ingest_path: str = None,
    extra_params: dict = None,
):
    endpoint = extra_params.get("endpoint", _DEFAULT_ENDPOINT)
    access_key = extra_params.get("access_key", None)
    secret_key = extra_params.get("secret_key", None)

    client = Minio(
        endpoint,
        access_key=access_key,
        secret_key=secret_key,
        session_token=extra_params.get("session_token", None),
        secure=extra_params.get("secure", False),
        region=extra_params.get("region", None),
    )
    bucket_found = client.bucket_exists(bucket_name)
    if not bucket_found:
        raise ValueError(f"Could not find bucket {bucket_name}")
    batch_files = [
        [f"{file.object_name}"] for file in client.list_objects(bucket_name, prefix=bulk_ingest_path, recursive=True)
    ]

    uri_parsed = urlparse(milvus_uri)
    _ = connections.connect(host=uri_parsed.hostname, port=uri_parsed.port)

    task_ids = []
    for file in batch_files:
        task_id = utility.do_bulk_insert(collection_name=collection_name, files=file)
        task_ids.append(task_id)

    while len(task_ids) > 0:
        logger.debug("Wait 1 second to check bulkinsert tasks state...")
        time.sleep(1)
        for id in task_ids:
            state = utility.get_bulk_insert_state(task_id=id)
            if state.state == BulkInsertState.ImportFailed or state.state == BulkInsertState.ImportFailedAndCleaned:
                logger.error(f"The task {state.task_id} failed, reason: {state.failed_reason}")
                task_ids.remove(id)
            elif state.state == BulkInsertState.ImportCompleted:
                logger.debug(f"The task {state.task_id} completed")
                task_ids.remove(id)

    while True:
        progress = utility.index_building_progress(collection_name)
        logger.info(progress)
        if progress.get("total_rows") == progress.get("indexed_rows"):
            break
        time.sleep(5)


def _preprocess_vdb_resources(service, recreate: bool, resource_schemas: dict):
    for resource_name, resource_schema_config in resource_schemas.items():
        has_object = service.has_store_object(name=resource_name)

        if recreate and has_object:
            # Delete the existing resource
            service.drop(name=resource_name)
            has_object = False

        # Ensure that the resource exists
        if not has_object:
            # TODO(Devin)
            import pymilvus

            schema_fields = []
            for field_data in resource_schema_config["schema_conf"]["schema_fields"]:
                if "dtype" in field_data:
                    field_data["dtype"] = DATA_TYPE_MAP.get(field_data["dtype"])
                    field_schema = pymilvus.FieldSchema(**field_data)
                    schema_fields.append(field_schema.to_dict())
                else:
                    schema_fields.append(field_data)

            resource_schema_config["schema_conf"]["schema_fields"] = schema_fields
            # function that we need to call first to turn resource_kwargs into a milvus config spec.

            service.create(name=resource_name, **resource_schema_config)


def _create_vdb_service(
    service: str, is_service_serialized: bool, service_kwargs: dict, recreate: bool, resource_schemas: dict
):
    """
    A function to used to instantiate a `VectorDBService` if a running VDB is available and a connection can be
    established.

    Parameters
    ----------
    service : str
        A string mapping to a supported `VectorDBService`.
    is_service_serialized : bool
        A flag to identify if the supplied service is serialized or needs to be instantiated.
    service_kwargs : dict
        Additional parameters needed to connect to the specificed `VectorDBService`.
    recreate : bool
        A flag specifying whether or not to re-instantate the VDB collection.
    resource_schemas : dict
        Defines the schemas of the VDB collection.

    Returns
    -------
    VectorDBService or str
        If a connection is established, a `VectorDBService` instance is returned, otherwise a string representing
        a supported VDB service is returned to allow repeat connection attempts.
    bool
        A flag used to signify the successful instantiation of a `VectorDBService`.
    """

    service_str = service

    try:
        service: VectorDBService = (
            pickle.loads(bytes(service, "latin1"))
            if is_service_serialized
            else VectorDBServiceFactory.create_instance(service_name=service, **service_kwargs)
        )
        _preprocess_vdb_resources(service, recreate, resource_schemas)

        return service, True

    except Exception as e:
        logger.error(f"Failed to connect to {service_str}: {e}")

        return service_str, False


[docs] @dataclass class AccumulationStats: """ A data class to store accumulation statistics to support dynamic batching of database inserts. Attributes ---------- msg_count : int Total number of accumulated records. last_insert_time : float A value representing the time of the most recent database insert. data : list[pd.DataFrame] A list containing accumulated batches since the last database insert. """ msg_count: int last_insert_time: float data: list[pd.DataFrame]
def _extract_dataframe_from_control_message( ctrl_msg: IngestControlMessage, filter_errors: bool ) -> Tuple[pd.DataFrame, Optional[str]]: """ Extracts a DataFrame from the control message and applies filtering to remove error messages. Returns a tuple of the processed DataFrame and an optional resource name (always None in this case). """ df_payload = ctrl_msg.payload() if filter_errors: info_msg_mask = df_payload["metadata"].struct.field("info_message_metadata").struct.field("filter") df_payload = df_payload.loc[~info_msg_mask].copy() # Extract necessary fields from metadata. df_payload["embedding"] = df_payload["metadata"].struct.field("embedding") df_payload["_source_metadata"] = df_payload["metadata"].struct.field("source_metadata") df_payload["_content_metadata"] = df_payload["metadata"].struct.field("content_metadata") # Filter rows that contain embeddings and select columns. df = df_payload[df_payload["_contains_embeddings"]].copy() df = df[ [ "embedding", "_content", "_source_metadata", "_content_metadata", ] ] df.columns = ["vector", "text", "source", "content_metadata"] return df, None def _update_accumulator_and_flush( df: pd.DataFrame, resource_name: Optional[str], accumulator_dict: Dict[str, AccumulationStats], batch_size: int, write_time_interval: float, service: VectorDBService, resource_kwargs: dict, ctrl_msg: IngestControlMessage, default_resource_name: str, ) -> None: """ Updates the accumulator for the given resource with the new DataFrame. Flushes (inserts) data if the batch size or time interval criteria are met. """ if resource_name is None: resource_name = default_resource_name if not service.has_store_object(resource_name): logger.error("Resource not exists in the vector database: %s", resource_name) raise ValueError(f"Resource not exists in the vector database: {resource_name}") if resource_name in accumulator_dict: accumulator = accumulator_dict[resource_name] accumulator.msg_count += len(df) accumulator.data.append(df) else: accumulator_dict[resource_name] = AccumulationStats(msg_count=len(df), last_insert_time=time.time(), data=[df]) current_time = time.time() for key, accum_stats in accumulator_dict.items(): if accum_stats.msg_count >= batch_size or ( accum_stats.last_insert_time != -1 and (current_time - accum_stats.last_insert_time) >= write_time_interval ): if accum_stats.data: merged_df = pd.concat(accum_stats.data, ignore_index=True) service.insert_dataframe(name=key, df=merged_df, **resource_kwargs) accum_stats.data.clear() accum_stats.last_insert_time = current_time accum_stats.msg_count = 0 ctrl_msg.set_metadata( "insert_response", { "status": "inserted", "accum_count": 0, "insert_count": len(df), "succ_count": len(df), "err_count": 0, }, ) else: logger.debug("Accumulated %d rows for collection: %s", accum_stats.msg_count, key) ctrl_msg.set_metadata( "insert_response", { "status": "accumulated", "accum_count": len(df), "insert_count": 0, "succ_count": 0, "err_count": 0, }, ) def _finalize_vector_db_service(accumulator_dict: Dict[str, AccumulationStats], service: VectorDBService) -> None: """ Flushes any remaining accumulated data to the vector database and closes the service connection. """ for key, accum_stats in accumulator_dict.items(): try: if accum_stats.data: merged_df = pd.concat(accum_stats.data, ignore_index=True) service.insert_dataframe(name=key, df=merged_df) except Exception as e: logger.error("Unable to upload dataframe entries to vector database: %s", e) if isinstance(service, VectorDBService): service.close() def _process_control_message_data( ctrl_msg: IngestControlMessage, service: VectorDBService, accumulator_dict: Dict[str, AccumulationStats], default_resource_name: str, batch_size: int, write_time_interval: float, resource_kwargs: dict, service_kwargs: dict, filter_errors: bool, ) -> IngestControlMessage: """ Processes the control message for data ingestion. If bulk ingestion is enabled, delegates to the bulk ingest function. Otherwise, it extracts the DataFrame from the message and updates/flushed the accumulator accordingly. """ task_props = remove_task_by_type(ctrl_msg, "vdb_upload") bulk_ingest = task_props.get("bulk_ingest", False) bulk_ingest_path = task_props.get("bulk_ingest_path", None) bucket_name = task_props.get("bucket_name", _DEFAULT_BUCKET_NAME) extra_params = task_props.get("params", {}) filter_errors = task_props.get("filter_errors", filter_errors) if bulk_ingest: _bulk_ingest(service_kwargs["uri"], default_resource_name, bucket_name, bulk_ingest_path, extra_params) return ctrl_msg else: df, msg_resource_target = _extract_dataframe_from_control_message(ctrl_msg, filter_errors) if df is not None and not df.empty: # Ensure that df is a pandas DataFrame. if not isinstance(df, pd.DataFrame): df = pd.DataFrame(df) _update_accumulator_and_flush( df, msg_resource_target, accumulator_dict, batch_size, write_time_interval, service, resource_kwargs, ctrl_msg, default_resource_name, ) return ctrl_msg @register_module(MODULE_NAME, MODULE_NAMESPACE) def _vdb_task_sink(builder: mrc.Builder): """ Receives incoming messages in IngestControlMessage format and writes data to a vector database. The module configuration (validated using VdbTaskSinkSchema) should include various parameters controlling resource creation, batching, write intervals, and retry intervals. """ validated_config = fetch_and_validate_module_config(builder, VdbTaskSinkSchema) recreate = validated_config.recreate service = validated_config.service is_service_serialized = validated_config.is_service_serialized default_resource_name = validated_config.default_resource_name resource_kwargs = validated_config.resource_kwargs resource_schemas = validated_config.resource_schemas service_kwargs = validated_config.service_kwargs batch_size = validated_config.batch_size write_time_interval = validated_config.write_time_interval retry_interval = validated_config.retry_interval start_time = time.time() service, service_status = _create_vdb_service( service, is_service_serialized, service_kwargs, recreate, resource_schemas ) accumulator_dict = {default_resource_name: AccumulationStats(msg_count=0, last_insert_time=time.time(), data=[])} # on_completed callback def on_completed(): _finalize_vector_db_service(accumulator_dict, service) @filter_by_task(["vdb_upload"]) @traceable(MODULE_NAME) @cm_skip_processing_if_failed @nv_ingest_node_failure_context_manager( annotation_id=MODULE_NAME, raise_on_failure=validated_config.raise_on_failure, ) def on_data(ctrl_msg: IngestControlMessage): nonlocal service_status, start_time, service try: task_props = remove_task_by_type(ctrl_msg, "vdb_upload") bulk_ingest = task_props.get("bulk_ingest", False) _ = bulk_ingest # Reconnect service if necessary. if not service_status: curr_time = time.time() if curr_time - start_time >= retry_interval: service, service_status = _create_vdb_service( service, is_service_serialized, service_kwargs, recreate, resource_schemas ) start_time = curr_time if not service_status: logger.error("Not connected to vector database.") raise ValueError("Not connected to vector database") ctrl_msg = _process_control_message_data( ctrl_msg, service, accumulator_dict, default_resource_name, batch_size, write_time_interval, resource_kwargs, service_kwargs, filter_errors=True, ) except Exception as e: raise ValueError(f"Failed to insert upload to vector database: {e}") return ctrl_msg node = builder.make_node( WRITE_TO_VECTOR_DB, ops.map(on_data), ops.filter(lambda val: val is not None), ops.on_completed(on_completed), ) node.launch_options.engines_per_pe = validated_config.progress_engines builder.register_module_input("input", node) builder.register_module_output("output", node)