Source code for nv_ingest.modules.sinks.vdb_task_sink

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import logging
import os
import pickle
import time
from dataclasses import dataclass
from urllib.parse import urlparse

import mrc
from minio import Minio
from morpheus.utils.control_message_utils import cm_skip_processing_if_failed
from morpheus.utils.module_ids import WRITE_TO_VECTOR_DB
from morpheus.utils.module_utils import ModuleLoaderFactory
from morpheus.utils.module_utils import register_module
from morpheus_llm.service.vdb.milvus_client import DATA_TYPE_MAP
from morpheus_llm.service.vdb.utils import VectorDBServiceFactory
from morpheus_llm.service.vdb.vector_db_service import VectorDBService
from mrc.core import operators as ops
from pymilvus import BulkInsertState
from pymilvus import connections
from pymilvus import utility

import cudf

from nv_ingest.schemas.vdb_task_sink_schema import VdbTaskSinkSchema
from nv_ingest.util.exception_handlers.decorators import nv_ingest_node_failure_context_manager
from nv_ingest.util.flow_control import filter_by_task
from nv_ingest.util.modules.config_validator import fetch_and_validate_module_config
from nv_ingest.util.tracing import traceable
from nv_ingest_api.primitives.ingest_control_message import IngestControlMessage, remove_task_by_type

logger = logging.getLogger(__name__)

MODULE_NAME = "vdb_task_sink"
MODULE_NAMESPACE = "nv_ingest"

VDBTaskSinkLoaderFactory = ModuleLoaderFactory(MODULE_NAME, MODULE_NAMESPACE, VdbTaskSinkSchema)

_DEFAULT_ENDPOINT = os.environ.get("MINIO_INTERNAL_ADDRESS", "minio:9000")
_DEFAULT_BUCKET_NAME = os.environ.get("MINIO_BUCKET", "nv-ingest")


def _bulk_ingest(
    milvus_uri: str = None,
    collection_name: str = None,
    bucket_name: str = None,
    bulk_ingest_path: str = None,
    extra_params: dict = None,
):

    endpoint = extra_params.get("endpoint", _DEFAULT_ENDPOINT)
    access_key = extra_params.get("access_key", None)
    secret_key = extra_params.get("secret_key", None)

    client = Minio(
        endpoint,
        access_key=access_key,
        secret_key=secret_key,
        session_token=extra_params.get("session_token", None),
        secure=extra_params.get("secure", False),
        region=extra_params.get("region", None),
    )
    bucket_found = client.bucket_exists(bucket_name)
    if not bucket_found:
        raise ValueError(f"Could not find bucket {bucket_name}")
    batch_files = [
        [f"{file.object_name}"] for file in client.list_objects(bucket_name, prefix=bulk_ingest_path, recursive=True)
    ]

    uri_parsed = urlparse(milvus_uri)
    _ = connections.connect(host=uri_parsed.hostname, port=uri_parsed.port)

    task_ids = []
    for file in batch_files:
        task_id = utility.do_bulk_insert(collection_name=collection_name, files=file)
        task_ids.append(task_id)

    while len(task_ids) > 0:
        logger.debug("Wait 1 second to check bulkinsert tasks state...")
        time.sleep(1)
        for id in task_ids:
            state = utility.get_bulk_insert_state(task_id=id)
            if state.state == BulkInsertState.ImportFailed or state.state == BulkInsertState.ImportFailedAndCleaned:
                logger.error(f"The task {state.task_id} failed, reason: {state.failed_reason}")
                task_ids.remove(id)
            elif state.state == BulkInsertState.ImportCompleted:
                logger.debug(f"The task {state.task_id} completed")
                task_ids.remove(id)

    while True:
        progress = utility.index_building_progress(collection_name)
        logger.info(progress)
        if progress.get("total_rows") == progress.get("indexed_rows"):
            break
        time.sleep(5)


[docs] def preprocess_vdb_resources(service, recreate: bool, resource_schemas: dict): for resource_name, resource_schema_config in resource_schemas.items(): has_object = service.has_store_object(name=resource_name) if recreate and has_object: # Delete the existing resource service.drop(name=resource_name) has_object = False # Ensure that the resource exists if not has_object: # TODO(Devin) import pymilvus schema_fields = [] for field_data in resource_schema_config["schema_conf"]["schema_fields"]: if "dtype" in field_data: field_data["dtype"] = DATA_TYPE_MAP.get(field_data["dtype"]) field_schema = pymilvus.FieldSchema(**field_data) schema_fields.append(field_schema.to_dict()) else: schema_fields.append(field_data) resource_schema_config["schema_conf"]["schema_fields"] = schema_fields # function that we need to call first to turn resource_kwargs into a milvus config spec. service.create(name=resource_name, **resource_schema_config)
def _create_vdb_service( service: str, is_service_serialized: bool, service_kwargs: dict, recreate: bool, resource_schemas: dict ): """ A function to used to instantiate a `VectorDBService` if a running VDB is available and a connection can be established. Parameters ---------- service : str A string mapping to a supported `VectorDBService`. is_service_serialized : bool A flag to identify if the supplied service is serialized or needs to be instantiated. service_kwargs : dict Additional parameters needed to connect to the specificed `VectorDBService`. recreate : bool A flag specifying whether or not to re-instantate the VDB collection. resource_schemas : dict Defines the schemas of the VDB collection. Returns ------- VectorDBService or str If a connection is established, a `VectorDBService` instance is returned, otherwise a string representing a supported VDB service is returned to allow repeat connection attempts. bool A flag used to signify the successful instantiation of a `VectorDBService`. """ service_str = service try: service: VectorDBService = ( pickle.loads(bytes(service, "latin1")) if is_service_serialized else VectorDBServiceFactory.create_instance(service_name=service, **service_kwargs) ) preprocess_vdb_resources(service, recreate, resource_schemas) return service, True except Exception as e: logger.error(f"Failed to connect to {service_str}: {e}") return service_str, False
[docs] @dataclass class AccumulationStats: """ A data class to store accumulation statistics to support dynamic batching of database inserts. Attributes ---------- msg_count : int Total number of accumulated records. last_insert_time : float A value representing the time of the most recent database insert. data : list[cudf.DataFrame] A list containing accumulated batches since the last database insert. """ msg_count: int last_insert_time: float data: list[cudf.DataFrame]
@register_module(MODULE_NAME, MODULE_NAMESPACE) def _vdb_task_sink(builder: mrc.Builder): """ Receives incoming messages in IngestControlMessage format. Parameters ---------- builder : mrc.Builder The Morpheus builder instance to attach this module to. Notes ----- The `module_config` should contain: - 'recreate': bool, whether to recreate the resource if it already exists (default is False). - 'service': str, the name of the service or a serialized instance of VectorDBService. - 'is_service_serialized': bool, whether the provided service is serialized (default is False). - 'default_resource_name': str, the name of the collection resource (must not be None or empty). - 'resource_kwargs': dict, additional keyword arguments for resource creation. - 'resource_schemas': dict, additional keyword arguments for resource creation. - 'service_kwargs': dict, additional keyword arguments for VectorDBService creation. - 'batch_size': int, accumulates messages until reaching the specified batch size for writing to VDB. - 'write_time_interval': float, specifies the time interval (in seconds) for writing messages, or writing messages - 'retry_interval': float, specify the interval to retry connections to milvus when the accumulated batch size is reached. Raises ------ ValueError If 'resource_name' is None or empty. If 'service' is not provided or is not a valid service name or a serialized instance of VectorDBService. """ validated_config = fetch_and_validate_module_config(builder, VdbTaskSinkSchema) recreate = validated_config.recreate service = validated_config.service is_service_serialized = validated_config.is_service_serialized default_resource_name = validated_config.default_resource_name resource_kwargs = validated_config.resource_kwargs resource_schemas = validated_config.resource_schemas service_kwargs = validated_config.service_kwargs batch_size = validated_config.batch_size write_time_interval = validated_config.write_time_interval retry_interval = validated_config.retry_interval start_time = time.time() service, service_status = _create_vdb_service( service, is_service_serialized, service_kwargs, recreate, resource_schemas ) accumulator_dict = {default_resource_name: AccumulationStats(msg_count=0, last_insert_time=time.time(), data=[])} def on_completed(): final_df_references = [] # Pushing remaining messages for key, accum_stats in accumulator_dict.items(): try: if accum_stats.data: merged_df = cudf.concat(accum_stats.data) service.insert_dataframe(name=key, df=merged_df) final_df_references.append(accum_stats.data) except Exception as e: logger.error("Unable to upload dataframe entries to vector database: %s", e) # Close vector database service connection if isinstance(service, VectorDBService): service.close() def extract_df(ctrl_msg: IngestControlMessage, filter_errors: bool): df = None resource_name = None mdf = ctrl_msg.payload() if filter_errors: info_msg_mask = mdf["metadata"].struct.field("info_message_metadata").struct.field("filter") mdf = mdf.loc[~info_msg_mask].copy() mdf["embedding"] = mdf["metadata"].struct.field("embedding") mdf["_source_metadata"] = mdf["metadata"].struct.field("source_metadata") mdf["_content_metadata"] = mdf["metadata"].struct.field("content_metadata") df = mdf[mdf["_contains_embeddings"]].copy() df = df[ [ "embedding", "_content", "_source_metadata", "_content_metadata", ] ] df.columns = ["vector", "text", "source", "content_metadata"] return df, resource_name @filter_by_task(["vdb_upload"]) @traceable(MODULE_NAME) @cm_skip_processing_if_failed @nv_ingest_node_failure_context_manager( annotation_id=MODULE_NAME, raise_on_failure=validated_config.raise_on_failure, ) def on_data(ctrl_msg: IngestControlMessage): nonlocal service_status nonlocal start_time nonlocal service try: task_props = remove_task_by_type(ctrl_msg, "vdb_upload") bulk_ingest = task_props.get("bulk_ingest", False) bulk_ingest_path = task_props.get("bulk_ingest_path", None) bucket_name = task_props.get("bucket_name", _DEFAULT_BUCKET_NAME) extra_params = task_props.get("params", {}) filter_errors = task_props.get("filter_errors", True) if not service_status: curr_time = time.time() delta_t = curr_time - start_time if delta_t >= retry_interval: service, service_status = _create_vdb_service( service, is_service_serialized, service_kwargs, recreate, resource_schemas ) start_time = curr_time if not service_status: logger.error("Not connected to vector database.") raise ValueError("Not connected to vector database") if bulk_ingest: _bulk_ingest(service_kwargs["uri"], default_resource_name, bucket_name, bulk_ingest_path, extra_params) else: df, msg_resource_target = extract_df(ctrl_msg, filter_errors) if df is not None and not df.empty: if not isinstance(df, cudf.DataFrame): df = cudf.DataFrame(df) df_size = len(df) current_time = time.time() # Use default resource name if not msg_resource_target: msg_resource_target = default_resource_name if not service.has_store_object(msg_resource_target): logger.error("Resource not exists in the vector database: %s", msg_resource_target) raise ValueError(f"Resource not exists in the vector database: {msg_resource_target}") if msg_resource_target in accumulator_dict: accumulator: AccumulationStats = accumulator_dict[msg_resource_target] accumulator.msg_count += df_size accumulator.data.append(df) else: accumulator_dict[msg_resource_target] = AccumulationStats( msg_count=df_size, last_insert_time=time.time(), data=[df] ) for key, accum_stats in accumulator_dict.items(): if accum_stats.msg_count >= batch_size or ( accum_stats.last_insert_time != -1 and (current_time - accum_stats.last_insert_time) >= write_time_interval ): if accum_stats.data: merged_df = cudf.concat(accum_stats.data) # pylint: disable=not-a-mapping service.insert_dataframe(name=key, df=merged_df, **resource_kwargs) # Reset accumulator stats accum_stats.data.clear() accum_stats.last_insert_time = current_time accum_stats.msg_count = 0 if isinstance(ctrl_msg, IngestControlMessage): ctrl_msg.set_metadata( "insert_response", { "status": "inserted", "accum_count": 0, "insert_count": df_size, "succ_count": df_size, "err_count": 0, }, ) else: logger.debug("Accumulated %d rows for collection: %s", accum_stats.msg_count, key) if isinstance(ctrl_msg, IngestControlMessage): ctrl_msg.set_metadata( "insert_response", { "status": "accumulated", "accum_count": df_size, "insert_count": 0, "succ_count": 0, "err_count": 0, }, ) except Exception as e: raise ValueError(f"Failed to insert upload to vector database: {e}") return ctrl_msg node = builder.make_node( WRITE_TO_VECTOR_DB, ops.map(on_data), ops.filter(lambda val: val is not None), ops.on_completed(on_completed) ) node.launch_options.engines_per_pe = validated_config.progress_engines builder.register_module_input("input", node) builder.register_module_output("output", node)