import datetime
import logging
import time
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
from urllib.parse import urlparse
from pathlib import Path
import pandas as pd
from functools import partial
import json
import os
import requests
from nv_ingest_client.util.process_json_files import ingest_json_results_to_blob
from nv_ingest_client.util.util import ClientConfigSchema
from pymilvus import AnnSearchRequest
from pymilvus import BulkInsertState
from pymilvus import Collection
from pymilvus import CollectionSchema
from pymilvus import DataType
from pymilvus import Function
from pymilvus import FunctionType
from pymilvus import MilvusClient
from pymilvus import RRFRanker
from pymilvus import connections
from pymilvus import utility
from pymilvus.bulk_writer import BulkFileType
from pymilvus.bulk_writer import RemoteBulkWriter
from pymilvus.milvus_client.index import IndexParams
from pymilvus.model.sparse import BM25EmbeddingFunction
from pymilvus.model.sparse.bm25.tokenizers import build_default_analyzer
from pymilvus.orm.types import CONSISTENCY_STRONG
from scipy.sparse import csr_array
from nv_ingest_client.util.transport import infer_microservice
from nv_ingest_client.util.vdb.adt_vdb import VDB
logger = logging.getLogger(__name__)
CONSISTENCY = CONSISTENCY_STRONG
pandas_reader_map = {
".json": pd.read_json,
".csv": partial(pd.read_csv, index_col=0),
".parquet": pd.read_parquet,
".pq": pd.read_parquet,
}
[docs]
def pandas_file_reader(input_file: str):
path_file = Path(input_file)
if not path_file.exists:
raise ValueError(f"File does not exist: {input_file}")
file_type = path_file.suffix
return pandas_reader_map[file_type](input_file)
def _dict_to_params(collections_dict: dict, write_params: dict):
params_tuple_list = []
for coll_name, data_type in collections_dict.items():
cp_write_params = write_params.copy()
enabled_dtypes = {
"enable_text": False,
"enable_charts": False,
"enable_tables": False,
"enable_images": False,
"enable_infographics": False,
}
if not isinstance(data_type, list):
data_type = [data_type]
for d_type in data_type:
enabled_dtypes[f"enable_{d_type}"] = True
cp_write_params.update(enabled_dtypes)
params_tuple_list.append((coll_name, cp_write_params))
return params_tuple_list
[docs]
def create_nvingest_schema(dense_dim: int = 1024, sparse: bool = False, local_index: bool = False) -> CollectionSchema:
"""
Creates a schema for the nv-ingest produced data. This is currently setup to follow
the default expected schema fields in nv-ingest. You can see more about the declared fields
in the `nv_ingest.schemas.vdb_task_sink_schema.build_default_milvus_config` function. This
schema should have the fields declared in that function, at a minimum. To ensure proper
data propagation to milvus.
Parameters
----------
dense_dim : int, optional
The size of the embedding dimension.
sparse : bool, optional
When set to true, this adds a Sparse field to the schema, usually activated for
hybrid search.
Returns
-------
CollectionSchema
Returns a milvus collection schema, with the minimum required nv-ingest fields
and extra fields (sparse), if specified by the user.
"""
schema = MilvusClient.create_schema(auto_id=True, enable_dynamic_field=True)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True, auto_id=True)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim)
schema.add_field(field_name="source", datatype=DataType.JSON)
schema.add_field(field_name="content_metadata", datatype=DataType.JSON)
if sparse and local_index:
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
elif sparse:
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
schema.add_field(
field_name="text",
datatype=DataType.VARCHAR,
max_length=65535,
enable_analyzer=True,
analyzer_params={"type": "english"},
enable_match=True,
)
schema.add_function(
Function(
name="bm25",
function_type=FunctionType.BM25,
input_field_names=["text"],
output_field_names="sparse",
)
)
else:
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
return schema
[docs]
def create_nvingest_index_params(
sparse: bool = False, gpu_index: bool = True, gpu_search: bool = True, local_index: bool = True
) -> IndexParams:
"""
Creates index params necessary to create an index for a collection. At a minimum,
this function will create a dense embedding index but can also create a sparse
embedding index (BM25) for hybrid search.
Parameters
----------
sparse : bool, optional
When set to true, this adds a Sparse index to the IndexParams, usually activated for
hybrid search.
gpu_index : bool, optional
When set to true, creates an index on the GPU. The index is GPU_CAGRA.
gpu_search : bool, optional
When set to true, if using a gpu index, the search will be conducted using the GPU.
Otherwise the search will be conducted on the CPU (index will be turned into HNSW).
Returns
-------
IndexParams
Returns index params setup for a dense embedding index and if specified, a sparse
embedding index.
"""
index_params = MilvusClient.prepare_index_params()
if local_index:
index_params.add_index(
field_name="vector",
index_name="dense_index",
index_type="FLAT",
metric_type="L2",
)
else:
if gpu_index:
index_params.add_index(
field_name="vector",
index_name="dense_index",
index_type="GPU_CAGRA",
metric_type="L2",
params={
"intermediate_graph_degree": 128,
"graph_degree": 100,
"build_algo": "NN_DESCENT",
"adapt_for_cpu": "false" if gpu_search else "true",
},
)
else:
index_params.add_index(
field_name="vector",
index_name="dense_index",
index_type="HNSW",
metric_type="L2",
params={"M": 64, "efConstruction": 512},
)
if sparse and local_index:
index_params.add_index(
field_name="sparse",
index_name="sparse_index",
index_type="SPARSE_INVERTED_INDEX", # Index type for sparse vectors
metric_type="IP", # Currently, only IP (Inner Product) is supported for sparse vectors
params={"drop_ratio_build": 0.2}, # The ratio of small vector values to be dropped during indexing
)
elif sparse:
index_params.add_index(
field_name="sparse",
index_name="sparse_index",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
return index_params
[docs]
def create_collection(
client: MilvusClient,
collection_name: str,
schema: CollectionSchema,
index_params: IndexParams = None,
recreate=True,
):
"""
Creates a milvus collection with the supplied name and schema. Within that collection,
this function ensures that the desired indexes are created based on the IndexParams
supplied.
Parameters
----------
client : MilvusClient
Client connected to mivlus instance.
collection_name : str
Name of the collection to be created.
schema : CollectionSchema,
Schema that identifies the fields of data that will be available in the collection.
index_params : IndexParams, optional
The parameters used to create the index(es) for the associated collection fields.
recreate : bool, optional
If true, and the collection is detected, it will be dropped before being created
again with the provided information (schema, index_params).
"""
if recreate and client.has_collection(collection_name):
client.drop_collection(collection_name)
if not client.has_collection(collection_name):
client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params,
consistency_level=CONSISTENCY,
)
[docs]
def create_nvingest_collection(
collection_name: str,
milvus_uri: str = "http://localhost:19530",
sparse: bool = False,
recreate: bool = True,
gpu_index: bool = True,
gpu_search: bool = True,
dense_dim: int = 2048,
recreate_meta: bool = False,
) -> CollectionSchema:
"""
Creates a milvus collection with an nv-ingest compatible schema under
the target name.
Parameters
----------
collection_name : str
Name of the collection to be created.
milvus_uri : str,
Milvus address with http(s) preffix and port. Can also be a file path, to activate
milvus-lite.
sparse : bool, optional
When set to true, this adds a Sparse index to the IndexParams, usually activated for
hybrid search.
recreate : bool, optional
If true, and the collection is detected, it will be dropped before being created
again with the provided information (schema, index_params).
gpu_cagra : bool, optional
If true, creates a GPU_CAGRA index for dense embeddings.
dense_dim : int, optional
Sets the dimension size for the dense embedding in the milvus schema.
Returns
-------
CollectionSchema
Returns a milvus collection schema, that represents the fields in the created
collection.
"""
local_index = False
if urlparse(milvus_uri).scheme:
connections.connect(uri=milvus_uri)
server_version = utility.get_server_version()
if "lite" in server_version:
gpu_index = False
else:
gpu_index = False
if milvus_uri.endswith(".db"):
local_index = True
client = MilvusClient(milvus_uri)
schema = create_nvingest_schema(dense_dim=dense_dim, sparse=sparse, local_index=local_index)
index_params = create_nvingest_index_params(
sparse=sparse, gpu_index=gpu_index, gpu_search=gpu_search, local_index=local_index
)
create_collection(client, collection_name, schema, index_params, recreate=recreate)
d_idx, s_idx = _get_index_types(index_params, sparse=sparse)
log_new_meta_collection(
collection_name,
fields=schema.fields,
milvus_uri=milvus_uri,
dense_index=str(d_idx),
dense_dim=dense_dim,
sparse_index=str(s_idx),
recreate=recreate_meta,
)
return schema
def _get_index_types(index_params: IndexParams, sparse: bool = False) -> Tuple[str, str]:
"""
Returns the dense and optional sparse index types from Milvus index_params,
handling both old (dict) and new (list) formats.
Parameters:
index_params: The index parameters object with a _indexes attribute.
sparse (bool): Whether to look for sparse_index as well.
Returns:
tuple: (dense_index_type, sparse_index_type or None)
"""
d_idx = None
s_idx = None
indexes = getattr(index_params, "_indexes", None)
if indexes is None:
indexes = {(idx, index_param.index_name): index_param for idx, index_param in enumerate(index_params)}
if isinstance(indexes, dict):
# Old Milvus behavior (< 2.5.6)
for k, v in indexes.items():
if k[1] == "dense_index" and hasattr(v, "_index_type"):
d_idx = v._index_type
if sparse and k[1] == "sparse_index" and hasattr(v, "_index_type"):
s_idx = v._index_type
elif isinstance(indexes, list):
# New Milvus behavior (>= 2.5.6)
for idx in indexes:
index_name = getattr(idx, "index_name", None)
index_type = getattr(idx, "index_type", None)
if index_name == "dense_index":
d_idx = index_type
if sparse and index_name == "sparse_index":
s_idx = index_type
else:
raise TypeError(f"Unexpected type for index_params._indexes: {type(indexes)}")
return str(d_idx), str(s_idx)
def _format_sparse_embedding(sparse_vector: csr_array):
sparse_embedding = {int(k[1]): float(v) for k, v in sparse_vector.todok()._dict.items()}
return sparse_embedding if len(sparse_embedding) > 0 else {int(0): float(0)}
def _record_dict(text, element, sparse_vector: csr_array = None):
record = {
"text": text,
"vector": element["metadata"]["embedding"],
"source": element["metadata"]["source_metadata"],
"content_metadata": element["metadata"]["content_metadata"],
}
if sparse_vector is not None:
record["sparse"] = _format_sparse_embedding(sparse_vector)
return record
[docs]
def verify_embedding(element):
if element["metadata"]["embedding"] is not None:
return True
return False
[docs]
def cleanup_records(
records,
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
enable_images: bool = True,
enable_infographics: bool = True,
enable_audio: bool = True,
meta_dataframe: pd.DataFrame = None,
meta_source_field: str = None,
meta_fields: list[str] = None,
record_func=_record_dict,
sparse_model=None,
):
cleaned_records = []
for result in records:
if result is not None:
if not isinstance(result, list):
result = [result]
for element in result:
text = _pull_text(
element, enable_text, enable_charts, enable_tables, enable_images, enable_infographics, enable_audio
)
_insert_location_into_content_metadata(
element, enable_charts, enable_tables, enable_images, enable_infographics
)
if meta_dataframe is not None and meta_source_field and meta_fields:
add_metadata(element, meta_dataframe, meta_source_field, meta_fields)
if text:
if sparse_model is not None:
element = record_func(text, element, sparse_model.encode_documents([text]))
else:
element = record_func(text, element)
cleaned_records.append(element)
return cleaned_records
def _pull_text(
element,
enable_text: bool,
enable_charts: bool,
enable_tables: bool,
enable_images: bool,
enable_infographics: bool,
enable_audio: bool,
):
text = None
if element["document_type"] == "text" and enable_text:
text = element["metadata"]["content"]
elif element["document_type"] == "structured":
text = element["metadata"]["table_metadata"]["table_content"]
if element["metadata"]["content_metadata"]["subtype"] == "chart" and not enable_charts:
text = None
elif element["metadata"]["content_metadata"]["subtype"] == "table" and not enable_tables:
text = None
elif element["metadata"]["content_metadata"]["subtype"] == "infographic" and not enable_infographics:
text = None
elif element["document_type"] == "image" and enable_images:
text = element["metadata"]["image_metadata"]["caption"]
elif element["document_type"] == "audio" and enable_audio:
text = element["metadata"]["audio_metadata"]["audio_transcript"]
verify_emb = verify_embedding(element)
if not text or not verify_emb:
source_name = element["metadata"]["source_metadata"]["source_name"]
pg_num = element["metadata"]["content_metadata"].get("page_number", None)
doc_type = element["document_type"]
if not verify_emb:
logger.debug(f"failed to find embedding for entity: {source_name} page: {pg_num} type: {doc_type}")
if not text:
logger.debug(f"failed to find text for entity: {source_name} page: {pg_num} type: {doc_type}")
# if we do find text but no embedding remove anyway
text = None
if text and len(text) > 65535:
logger.warning(
f"Text is too long, skipping. It is advised to use SplitTask, to make smaller chunk sizes."
f"text_length: {len(text)}, file_name: {element['metadata']['source_metadata'].get('source_name', None)} "
f"page_number: {element['metadata']['content_metadata'].get('page_number', None)}"
)
text = None
return text
def _insert_location_into_content_metadata(
element, enable_charts: bool, enable_tables: bool, enable_images: bool, enable_infographic: bool
):
location = max_dimensions = None
if element["document_type"] == "structured":
location = element["metadata"]["table_metadata"]["table_location"]
max_dimensions = element["metadata"]["table_metadata"]["table_location_max_dimensions"]
if element["metadata"]["content_metadata"]["subtype"] == "chart" and not enable_charts:
location = max_dimensions = None
elif element["metadata"]["content_metadata"]["subtype"] == "table" and not enable_tables:
location = max_dimensions = None
elif element["metadata"]["content_metadata"]["subtype"] == "infographic" and not enable_infographic:
location = max_dimensions = None
elif element["document_type"] == "image" and enable_images:
location = element["metadata"]["image_metadata"]["image_location"]
max_dimensions = element["metadata"]["image_metadata"]["image_location_max_dimensions"]
if (not location) and (element["document_type"] != "text"):
source_name = element["metadata"]["source_metadata"]["source_name"]
pg_num = element["metadata"]["content_metadata"].get("page_number")
doc_type = element["document_type"]
logger.info(f"failed to find location for entity: {source_name} page: {pg_num} type: {doc_type}")
location = max_dimensions = None
element["metadata"]["content_metadata"]["location"] = location
element["metadata"]["content_metadata"]["max_dimensions"] = max_dimensions
[docs]
def write_records_minio(records, writer: RemoteBulkWriter) -> RemoteBulkWriter:
"""
Writes the supplied records to milvus using the supplied writer.
If a sparse model is supplied, it will be used to generate sparse
embeddings to allow for hybrid search. Will filter records based on
type, depending on what types are enabled via the boolean parameters.
If the user sets the log level to info, any time a record fails
ingestion, it will be reported to the user.
Parameters
----------
records : List
List of chunks with attached metadata
writer : RemoteBulkWriter
The Milvus Remote BulkWriter instance that was created with necessary
params to access the minio instance corresponding to milvus.
sparse_model : model,
Sparse model used to generate sparse embedding in the form of
scipy.sparse.csr_array
enable_text : bool, optional
When true, ensure all text type records are used.
enable_charts : bool, optional
When true, ensure all chart type records are used.
enable_tables : bool, optional
When true, ensure all table type records are used.
enable_images : bool, optional
When true, ensure all image type records are used.
enable_infographics : bool, optional
When true, ensure all infographic type records are used.
enable_audio : bool, optional
When true, ensure all audio transcript type records are used.
record_func : function, optional
This function will be used to parse the records for necessary information.
Returns
-------
RemoteBulkWriter
Returns the writer supplied, with information related to minio records upload.
"""
for element in records:
writer.append_row(element)
writer.commit()
print(f"Wrote data to: {writer.batch_files}")
return writer
[docs]
def bulk_insert_milvus(collection_name: str, writer: RemoteBulkWriter, milvus_uri: str = "http://localhost:19530"):
"""
This function initialize the bulk ingest of all minio uploaded records, and checks for
milvus task completion. Once the function is complete all records have been uploaded
to the milvus collection.
Parameters
----------
collection_name : str
Name of the milvus collection.
writer : RemoteBulkWriter
The Milvus Remote BulkWriter instance that was created with necessary
params to access the minio instance corresponding to milvus.
milvus_uri : str,
Milvus address with http(s) preffix and port. Can also be a file path, to activate
milvus-lite.
"""
connections.connect(uri=milvus_uri)
t_bulk_start = time.time()
task_id = utility.do_bulk_insert(
collection_name=collection_name, files=writer.batch_files[0], consistency_level=CONSISTENCY
)
# list_bulk_insert_tasks = utility.list_bulk_insert_tasks(collection_name=collection_name)
state = "Pending"
while state != "Completed":
task = utility.get_bulk_insert_state(task_id=task_id)
state = task.state_name
if state == "Completed":
t_bulk_end = time.time()
print("Start time:", task.create_time_str)
print("Imported row count:", task.row_count)
print(f"Bulk {collection_name} upload took {t_bulk_end - t_bulk_start} s")
if task.state == BulkInsertState.ImportFailed:
print("Failed reason:", task.failed_reason)
time.sleep(1)
[docs]
def create_bm25_model(
records,
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
enable_images: bool = True,
enable_infographics: bool = True,
enable_audio: bool = True,
) -> BM25EmbeddingFunction:
"""
This function takes the input records and creates a corpus,
factoring in filters (i.e. texts, charts, tables) and fits
a BM25 model with that information. If the user sets the log
level to info, any time a record fails ingestion, it will be
reported to the user.
Parameters
----------
records : List
List of chunks with attached metadata
enable_text : bool, optional
When true, ensure all text type records are used.
enable_charts : bool, optional
When true, ensure all chart type records are used.
enable_tables : bool, optional
When true, ensure all table type records are used.
enable_images : bool, optional
When true, ensure all image type records are used.
enable_infographics : bool, optional
When true, ensure all infographic type records are used.
enable_audio : bool, optional
When true, ensure all audio transcript type records are used.
Returns
-------
BM25EmbeddingFunction
Returns the model fitted to the selected corpus.
"""
all_text = []
for result in records:
if not isinstance(result, list):
result = [result]
for element in result:
text = _pull_text(
element, enable_text, enable_charts, enable_tables, enable_images, enable_infographics, enable_audio
)
if text:
all_text.append(text)
analyzer = build_default_analyzer(language="en")
bm25_ef = BM25EmbeddingFunction(analyzer)
bm25_ef.fit(all_text)
return bm25_ef
[docs]
def stream_insert_milvus(records, client: MilvusClient, collection_name: str):
"""
This function takes the input records and creates a corpus,
factoring in filters (i.e. texts, charts, tables) and fits
a BM25 model with that information. If the user sets the log
level to info, any time a record fails ingestion, it will be
reported to the user.
Parameters
----------
records : List
List of chunks with attached metadata
client : MilvusClient
Milvus client instance
collection_name : str
Milvus Collection to search against
"""
count = 0
for element in records:
client.insert(collection_name=collection_name, data=[element])
count += 1
logger.info(f"streamed {count} records")
[docs]
def write_to_nvingest_collection(
records,
collection_name: str,
milvus_uri: str = "http://localhost:19530",
minio_endpoint: str = "localhost:9000",
sparse: bool = True,
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
enable_images: bool = True,
enable_infographics: bool = True,
bm25_save_path: str = "bm25_model.json",
compute_bm25_stats: bool = True,
access_key: str = "minioadmin",
secret_key: str = "minioadmin",
bucket_name: str = "a-bucket",
threshold: int = 1000,
meta_dataframe=None,
meta_source_field=None,
meta_fields=None,
stream: bool = False,
**kwargs,
):
"""
This function takes the input records and creates a corpus,
factoring in filters (i.e. texts, charts, tables) and fits
a BM25 model with that information.
Parameters
----------
records : List
List of chunks with attached metadata
collection_name : str
Milvus Collection to search against
milvus_uri : str,
Milvus address with http(s) preffix and port. Can also be a file path, to activate
milvus-lite.
minio_endpoint : str,
Endpoint for the minio instance attached to your milvus.
enable_text : bool, optional
When true, ensure all text type records are used.
enable_charts : bool, optional
When true, ensure all chart type records are used.
enable_tables : bool, optional
When true, ensure all table type records are used.
enable_images : bool, optional
When true, ensure all image type records are used.
enable_infographics : bool, optional
When true, ensure all infographic type records are used.
sparse : bool, optional
When true, incorporates sparse embedding representations for records.
bm25_save_path : str, optional
The desired filepath for the sparse model if sparse is True.
access_key : str, optional
Minio access key.
secret_key : str, optional
Minio secret key.
bucket_name : str, optional
Minio bucket name.
stream : bool, optional
When true, the records will be inserted into milvus using the stream insert method.
"""
local_index = False
connections.connect(uri=milvus_uri)
if urlparse(milvus_uri).scheme:
server_version = utility.get_server_version()
if "lite" in server_version:
stream = True
else:
stream = True
if milvus_uri.endswith(".db"):
local_index = True
bm25_ef = None
if local_index and sparse and compute_bm25_stats:
bm25_ef = create_bm25_model(
records,
enable_text=enable_text,
enable_charts=enable_charts,
enable_tables=enable_tables,
enable_images=enable_images,
enable_infographics=enable_infographics,
)
bm25_ef.save(bm25_save_path)
elif local_index and sparse:
bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en"))
bm25_ef.load(bm25_save_path)
client = MilvusClient(milvus_uri)
schema = Collection(collection_name).schema
if isinstance(meta_dataframe, str):
meta_dataframe = pandas_file_reader(meta_dataframe)
cleaned_records = cleanup_records(
records,
enable_text=enable_text,
enable_charts=enable_charts,
enable_tables=enable_tables,
enable_images=enable_images,
enable_infographics=enable_infographics,
meta_dataframe=meta_dataframe,
meta_source_field=meta_source_field,
meta_fields=meta_fields,
sparse_model=bm25_ef,
)
num_elements = len(cleaned_records)
if num_elements == 0:
raise ValueError("No records with Embeddings to insert detected.")
logger.info(f"{num_elements} elements to insert to milvus")
logger.info(f"threshold for streaming is {threshold}")
if num_elements < threshold:
stream = True
if stream:
stream_insert_milvus(
cleaned_records,
client,
collection_name,
)
else:
# Connections parameters to access the remote bucket
conn = RemoteBulkWriter.S3ConnectParam(
endpoint=minio_endpoint, # the default MinIO service started along with Milvus
access_key=access_key,
secret_key=secret_key,
bucket_name=bucket_name,
secure=False,
)
text_writer = RemoteBulkWriter(
schema=schema, remote_path="/", connect_param=conn, file_type=BulkFileType.PARQUET
)
writer = write_records_minio(
cleaned_records,
text_writer,
)
bulk_insert_milvus(collection_name, writer, milvus_uri)
# fixes bulk insert lag time https://github.com/milvus-io/milvus/issues/21746
client.refresh_load(collection_name)
[docs]
def dense_retrieval(
queries,
collection_name: str,
client: MilvusClient,
dense_model,
top_k: int,
dense_field: str = "vector",
output_fields: List[str] = ["text"],
_filter: str = "",
):
"""
This function takes the input queries and conducts a dense
embedding search against the dense vector and return the top_k
nearest records in the collection.
Parameters
----------
queries : List
List of queries
collection : Collection
Milvus Collection to search against
client : MilvusClient
Client connected to mivlus instance.
dense_model : NVIDIAEmbedding
Dense model to generate dense embeddings for queries.
top_k : int
Number of search results to return per query.
dense_field : str
The name of the anns_field that holds the dense embedding
vector the collection.
Returns
-------
List
Nested list of top_k results per query.
"""
dense_embeddings = []
for query in queries:
dense_embeddings.append(dense_model.get_query_embedding(query))
results = client.search(
collection_name=collection_name,
data=dense_embeddings,
anns_field=dense_field,
limit=top_k,
output_fields=output_fields,
filter=_filter,
consistency_level=CONSISTENCY,
)
return results
[docs]
def hybrid_retrieval(
queries,
collection_name: str,
client: MilvusClient,
dense_model,
sparse_model,
top_k: int,
dense_field: str = "vector",
sparse_field: str = "sparse",
output_fields: List[str] = ["text"],
gpu_search: bool = True,
local_index: bool = False,
_filter: str = "",
):
"""
This function takes the input queries and conducts a hybrid
embedding search against the dense and sparse vectors, returning
the top_k nearest records in the collection.
Parameters
----------
queries : List
List of queries
collection : Collection
Milvus Collection to search against
client : MilvusClient
Client connected to mivlus instance.
dense_model : NVIDIAEmbedding
Dense model to generate dense embeddings for queries.
sparse_model : model,
Sparse model used to generate sparse embedding in the form of
scipy.sparse.csr_array
top_k : int
Number of search results to return per query.
dense_field : str
The name of the anns_field that holds the dense embedding
vector the collection.
sparse_field : str
The name of the anns_field that holds the sparse embedding
vector the collection.
Returns
-------
List
Nested list of top_k results per query.
"""
dense_embeddings = []
sparse_embeddings = []
for query in queries:
dense_embeddings.append(dense_model.get_query_embedding(query))
if sparse_model:
sparse_embeddings.append(_format_sparse_embedding(sparse_model.encode_queries([query])))
else:
sparse_embeddings.append(query)
s_param_1 = {
"metric_type": "L2",
}
if not gpu_search and not local_index:
s_param_1["params"] = {"ef": top_k}
# Create search requests for both vector types
search_param_1 = {
"data": dense_embeddings,
"anns_field": dense_field,
"param": s_param_1,
"limit": top_k,
"expr": _filter,
}
dense_req = AnnSearchRequest(**search_param_1)
s_param_2 = {"metric_type": "BM25"}
if local_index:
s_param_2 = {"metric_type": "IP", "params": {"drop_ratio_build": 0.0}}
search_param_2 = {
"data": sparse_embeddings,
"anns_field": sparse_field,
"param": s_param_2,
"limit": top_k,
"expr": _filter,
}
sparse_req = AnnSearchRequest(**search_param_2)
results = client.hybrid_search(
collection_name,
[sparse_req, dense_req],
RRFRanker(),
limit=top_k,
output_fields=output_fields,
consistency_level=CONSISTENCY,
)
return results
[docs]
def nvingest_retrieval(
queries,
collection_name: str = None,
vdb_op: VDB = None,
milvus_uri: str = "http://localhost:19530",
top_k: int = 5,
hybrid: bool = False,
dense_field: str = "vector",
sparse_field: str = "sparse",
embedding_endpoint=None,
sparse_model_filepath: str = "bm25_model.json",
model_name: str = None,
output_fields: List[str] = ["text", "source", "content_metadata"],
gpu_search: bool = True,
nv_ranker: bool = False,
nv_ranker_endpoint: str = None,
nv_ranker_model_name: str = None,
nv_ranker_nvidia_api_key: str = None,
nv_ranker_truncate: str = "END",
nv_ranker_top_k: int = 50,
nv_ranker_max_batch_size: int = 64,
_filter: str = "",
**kwargs,
):
"""
This function takes the input queries and conducts a hybrid/dense
embedding search against the vectors, returning the top_k nearest
records in the collection.
Parameters
----------
queries : List
List of queries
collection : Collection
Milvus Collection to search against
milvus_uri : str,
Milvus address with http(s) preffix and port. Can also be a file path, to activate
milvus-lite.
top_k : int
Number of search results to return per query.
hybrid: bool, optional
If True, will calculate distances for both dense and sparse embeddings.
dense_field : str, optional
The name of the anns_field that holds the dense embedding
vector the collection.
sparse_field : str, optional
The name of the anns_field that holds the sparse embedding
vector the collection.
embedding_endpoint : str, optional
Number of search results to return per query.
sparse_model_filepath : str, optional
The path where the sparse model has been loaded.
model_name : str, optional
The name of the dense embedding model available in the NIM embedding endpoint.
nv_ranker : bool
Set to True to use the nvidia reranker.
nv_ranker_endpoint : str
The endpoint to the nvidia reranker
nv_ranker_model_name: str
The name of the model host in the nvidia reranker
nv_ranker_nvidia_api_key : str,
The nvidia reranker api key, necessary when using non-local asset
truncate : str [`END`, `NONE`]
Truncate the incoming texts if length is longer than the model allows.
nv_ranker_max_batch_size : int
Max size for the number of candidates to rerank.
nv_ranker_top_k : int,
The number of candidates to return after reranking.
Returns
-------
List
Nested list of top_k results per query.
"""
if vdb_op is not None and not isinstance(vdb_op, VDB):
raise ValueError("vdb_op must be a VDB object")
if isinstance(vdb_op, VDB):
kwargs = locals().copy()
kwargs.pop("vdb_op", None)
queries = kwargs.pop("queries", [])
return vdb_op.retrieval(queries, **kwargs)
from llama_index.embeddings.nvidia import NVIDIAEmbedding
client_config = ClientConfigSchema()
nvidia_api_key = client_config.nvidia_build_api_key
# required for NVIDIAEmbedding call if the endpoint is Nvidia build api.
embedding_endpoint = embedding_endpoint if embedding_endpoint else client_config.embedding_nim_endpoint
model_name = model_name if model_name else client_config.embedding_nim_model_name
local_index = False
embed_model = NVIDIAEmbedding(base_url=embedding_endpoint, model=model_name, nvidia_api_key=nvidia_api_key)
client = MilvusClient(milvus_uri)
final_top_k = top_k
if nv_ranker:
top_k = nv_ranker_top_k
if milvus_uri.endswith(".db"):
local_index = True
if hybrid:
bm25_ef = None
if local_index:
bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en"))
bm25_ef.load(sparse_model_filepath)
results = hybrid_retrieval(
queries,
collection_name,
client,
embed_model,
bm25_ef,
top_k,
output_fields=output_fields,
gpu_search=gpu_search,
local_index=local_index,
_filter=_filter,
)
else:
results = dense_retrieval(
queries, collection_name, client, embed_model, top_k, output_fields=output_fields, _filter=_filter
)
if nv_ranker:
rerank_results = []
for query, candidates in zip(queries, results):
rerank_results.append(
nv_rerank(
query,
candidates,
reranker_endpoint=nv_ranker_endpoint,
model_name=nv_ranker_model_name,
nvidia_api_key=nv_ranker_nvidia_api_key,
truncate=nv_ranker_truncate,
topk=final_top_k,
max_batch_size=nv_ranker_max_batch_size,
)
)
results = rerank_results
return results
[docs]
def remove_records(source_name: str, collection_name: str, milvus_uri: str = "http://localhost:19530"):
"""
This function allows a user to remove chunks associated with an ingested file.
Supply the full path of the file you would like to remove and this function will
remove all the chunks associated with that file in the target collection.
Parameters
----------
source_name : str
The full file path of the file you would like to remove from the collection.
collection_name : str
Milvus Collection to query against
milvus_uri : str,
Milvus address with http(s) preffix and port. Can also be a file path, to activate
milvus-lite.
Returns
-------
Dict
Dictionary with one key, `delete_cnt`. The value represents the number of entities
removed.
"""
client = MilvusClient(milvus_uri)
result_ids = client.delete(
collection_name=collection_name,
filter=f'(source["source_name"] == "{source_name}")',
)
return result_ids
[docs]
def nv_rerank(
query,
candidates,
reranker_endpoint: str = None,
model_name: str = None,
nvidia_api_key: str = None,
truncate: str = "END",
max_batch_size: int = 64,
topk: int = 5,
):
"""
This function allows a user to rerank a set of candidates using the nvidia reranker nim.
Parameters
----------
query : str
Query the candidates are supposed to answer.
candidates : list
List of the candidates to rerank.
reranker_endpoint : str
The endpoint to the nvidia reranker
model_name: str
The name of the model host in the nvidia reranker
nvidia_api_key : str,
The nvidia reranker api key, necessary when using non-local asset
truncate : str [`END`, `NONE`]
Truncate the incoming texts if length is longer than the model allows.
max_batch_size : int
Max size for the number of candidates to rerank.
topk : int,
The number of candidates to return after reranking.
Returns
-------
Dict
Dictionary with top_k reranked candidates.
"""
client_config = ClientConfigSchema()
# reranker = NVIDIARerank(base_url=reranker_endpoint, nvidia_api_key=nvidia_api_key, top_n=top_k)
reranker_endpoint = reranker_endpoint if reranker_endpoint else client_config.nv_ranker_nim_endpoint
model_name = model_name if model_name else client_config.nv_ranker_nim_model_name
nvidia_api_key = nvidia_api_key if nvidia_api_key else client_config.nvidia_build_api_key
headers = {"accept": "application/json", "Content-Type": "application/json"}
if nvidia_api_key:
headers["Authorization"] = f"Bearer {nvidia_api_key}"
texts = []
map_candidates = {}
for idx, candidate in enumerate(candidates):
map_candidates[idx] = candidate
texts.append({"text": candidate["entity"]["text"]})
payload = {"model": model_name, "query": {"text": query}, "passages": texts, "truncate": truncate}
start = time.time()
response = requests.post(f"{reranker_endpoint}", headers=headers, json=payload)
logger.debug(f"RERANKER time: {time.time() - start}")
if response.status_code != 200:
raise ValueError(f"Failed retrieving ranking results: {response.status_code} - {response.text}")
rank_results = []
for rank_vals in response.json()["rankings"]:
idx = rank_vals["index"]
rank_results.append(map_candidates[idx])
return rank_results
[docs]
def recreate_elements(data):
"""
This function takes the input data and creates a list of elements
with the necessary metadata for ingestion.
Parameters
----------
data : List
List of chunks with attached metadata
Returns
-------
List
List of elements with metadata.
"""
elements = []
for element in data:
element["metadata"] = {}
element["metadata"]["content_metadata"] = element.pop("content_metadata")
element["metadata"]["source_metadata"] = element.pop("source")
element["metadata"]["content"] = element.pop("text")
elements.append(element)
return elements
[docs]
def pull_all_milvus(
collection_name: str, milvus_uri: str = "http://localhost:19530", write_dir: str = None, batch_size: int = 1000
):
"""
This function takes the input collection name and pulls all the records
from the collection. It will either return the records as a list of
dictionaries or write them to a specified directory in JSON format.
Parameters
----------
collection_name : str
Milvus Collection to query against
milvus_uri : str,
Milvus address with http(s) preffix and port. Can also be a file path, to activate
milvus-lite.
write_dir : str, optional
Directory to write the records to. If None, the records will be returned as a list.
batch_size : int, optional
The number of records to pull in each batch. Defaults to 1000.
Returns
-------
List
List of records/files with records from the collection.
"""
client = MilvusClient(milvus_uri)
iterator = client.query_iterator(
collection_name=collection_name,
filter="pk >= 0",
output_fields=["source", "content_metadata", "text"],
batch_size=batch_size,
consistency_level=CONSISTENCY,
)
full_results = []
write_dir = Path(write_dir) if write_dir else None
batch_num = 0
while True:
results = recreate_elements(iterator.next())
if not results:
iterator.close()
break
if write_dir:
# write to disk
file_name = write_dir / f"milvus_data_{batch_num}.json"
full_results.append(file_name)
with open(file_name, "w") as outfile:
outfile.write(json.dumps(results))
else:
full_results += results
batch_num += 1
return full_results
[docs]
def get_embeddings(full_records, embedder, batch_size=256):
"""
This function takes the input records and creates a list of embeddings.
The default batch size is 256, but this can be adjusted based on the
available resources, to a maximum of 259. This is set by the NVIDIA embedding
microservice.
"""
embedded = []
embed_payload = [res["metadata"]["content"] for res in full_records]
for i in range(0, len(embed_payload), batch_size):
payload = embed_payload[i : i + batch_size]
embedded += embedder._get_text_embeddings(payload)
return embedded
[docs]
def embed_index_collection(
data,
collection_name,
batch_size: int = 256,
embedding_endpoint: str = None,
model_name: str = None,
nvidia_api_key: str = None,
milvus_uri: str = "http://localhost:19530",
sparse: bool = False,
recreate: bool = True,
gpu_index: bool = True,
gpu_search: bool = True,
dense_dim: int = 2048,
minio_endpoint: str = "localhost:9000",
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
enable_images: bool = True,
enable_infographics: bool = True,
bm25_save_path: str = "bm25_model.json",
compute_bm25_stats: bool = True,
access_key: str = "minioadmin",
secret_key: str = "minioadmin",
bucket_name: str = "a-bucket",
meta_dataframe: Union[str, pd.DataFrame] = None,
meta_source_field: str = None,
meta_fields: list[str] = None,
intput_type: str = "passage",
truncate: str = "END",
**kwargs,
):
"""
This function takes the input data and creates a collection in Milvus,
it will embed the records using the NVIDIA embedding model and store them in the collection.
After embedding the records, it will run the same ingestion process as the vdb_upload stage in
the Ingestor pipeline.
Args:
data (Union[str, List]): The data to be ingested. Can be a list of records or a file path.
collection_name (Union[str, Dict], optional): The name of the Milvus collection or a dictionary
containing collection configuration. Defaults to "nv_ingest_collection".
embedding_endpoint (str, optional): The endpoint for the NVIDIA embedding service. Defaults to None.
model_name (str, optional): The name of the embedding model. Defaults to None.
nvidia_api_key (str, optional): The API key for NVIDIA services. Defaults to None.
milvus_uri (str, optional): The URI of the Milvus server. Defaults to "http://localhost:19530".
sparse (bool, optional): Whether to use sparse indexing. Defaults to False.
recreate (bool, optional): Whether to recreate the collection if it already exists. Defaults to True.
gpu_index (bool, optional): Whether to use GPU for indexing. Defaults to True.
gpu_search (bool, optional): Whether to use GPU for search operations. Defaults to True.
dense_dim (int, optional): The dimensionality of dense vectors. Defaults to 2048.
minio_endpoint (str, optional): The endpoint for the MinIO server. Defaults to "localhost:9000".
enable_text (bool, optional): Whether to enable text data ingestion. Defaults to True.
enable_charts (bool, optional): Whether to enable chart data ingestion. Defaults to True.
enable_tables (bool, optional): Whether to enable table data ingestion. Defaults to True.
enable_images (bool, optional): Whether to enable image data ingestion. Defaults to True.
enable_infographics (bool, optional): Whether to enable infographic data ingestion. Defaults to True.
bm25_save_path (str, optional): The file path to save the BM25 model. Defaults to "bm25_model.json".
compute_bm25_stats (bool, optional): Whether to compute BM25 statistics. Defaults to True.
access_key (str, optional): The access key for MinIO authentication. Defaults to "minioadmin".
secret_key (str, optional): The secret key for MinIO authentication. Defaults to "minioadmin".
bucket_name (str, optional): The name of the MinIO bucket. Defaults to "a-bucket".
meta_dataframe (Union[str, pd.DataFrame], optional): A metadata DataFrame or the path to a CSV file
containing metadata. Defaults to None.
meta_source_field (str, optional): The field in the metadata that serves as the source identifier.
Defaults to None.
meta_fields (list[str], optional): A list of metadata fields to include. Defaults to None.
**kwargs: Additional keyword arguments for customization.
"""
client_config = ClientConfigSchema()
nvidia_api_key = nvidia_api_key if nvidia_api_key else client_config.nvidia_build_api_key
# required for NVIDIAEmbedding call if the endpoint is Nvidia build api.
embedding_endpoint = embedding_endpoint if embedding_endpoint else client_config.embedding_nim_endpoint
model_name = model_name if model_name else client_config.embedding_nim_model_name
# if not scheme we assume we are using grpc
grpc = "http" not in urlparse(embedding_endpoint).scheme
kwargs.pop("input_type", None)
kwargs.pop("truncate", None)
mil_op = Milvus(
collection_name=collection_name,
milvus_uri=milvus_uri,
sparse=sparse,
recreate=recreate,
gpu_index=gpu_index,
gpu_search=gpu_search,
dense_dim=dense_dim,
minio_endpoint=minio_endpoint,
enable_text=enable_text,
enable_charts=enable_charts,
enable_tables=enable_tables,
enable_images=enable_images,
enable_infographics=enable_infographics,
bm25_save_path=bm25_save_path,
compute_bm25_stats=compute_bm25_stats,
access_key=access_key,
secret_key=secret_key,
bucket_name=bucket_name,
meta_dataframe=meta_dataframe,
meta_source_field=meta_source_field,
meta_fields=meta_fields,
**kwargs,
)
# running in parts
if data is not None and isinstance(data[0], (str, os.PathLike)):
for results_file in data:
results = None
with open(results_file, "r") as infile:
results = json.loads(infile.read())
embeddings = infer_microservice(
results, model_name, embedding_endpoint, nvidia_api_key, intput_type, truncate, batch_size, grpc
)
for record, emb in zip(results, embeddings):
record["metadata"]["embedding"] = emb
record["document_type"] = "text"
if results is not None and len(results) > 0:
mil_op.run(results)
mil_op.milvus_kwargs["recreate"] = False
# running all at once
else:
embeddings = infer_microservice(
data, model_name, embedding_endpoint, nvidia_api_key, intput_type, truncate, batch_size, grpc
)
for record, emb in zip(data, embeddings):
record["metadata"]["embedding"] = emb
record["document_type"] = "text"
# this check ensures that we do not purge the current collection
# without having some actual data to insert.
if data is not None and len(data) > 0:
mil_op.run(data)
[docs]
def reindex_collection(
vdb_op: VDB = None,
current_collection_name: str = None,
new_collection_name: str = None,
write_dir: str = None,
embedding_endpoint: str = None,
model_name: str = None,
nvidia_api_key: str = None,
milvus_uri: str = "http://localhost:19530",
sparse: bool = False,
recreate: bool = True,
gpu_index: bool = True,
gpu_search: bool = True,
dense_dim: int = 2048,
minio_endpoint: str = "localhost:9000",
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
enable_images: bool = True,
enable_infographics: bool = True,
bm25_save_path: str = "bm25_model.json",
compute_bm25_stats: bool = True,
access_key: str = "minioadmin",
secret_key: str = "minioadmin",
bucket_name: str = "a-bucket",
meta_dataframe: Union[str, pd.DataFrame] = None,
meta_source_field: str = None,
meta_fields: list[str] = None,
embed_batch_size: int = 256,
query_batch_size: int = 1000,
input_type: str = "passage",
truncate: str = "END",
**kwargs,
):
"""
This function will reindex a collection in Milvus. It will pull all the records from the
current collection, embed them using the NVIDIA embedding model, and store them in a new
collection. After embedding the records, it will run the same ingestion process as the vdb_upload
stage in the Ingestor pipeline. This function will get embedding_endpoint, model_name and nvidia_api_key
defaults from the environment variables set in the environment if not explicitly set in the function call.
Parameters
----------
current_collection_name (str): The name of the current Milvus collection.
new_collection_name (str, optional): The name of the new Milvus collection. Defaults to None.
write_dir (str, optional): The directory to write the pulled records to. Defaults to None.
embedding_endpoint (str, optional): The endpoint for the NVIDIA embedding service. Defaults to None.
model_name (str, optional): The name of the embedding model. Defaults to None.
nvidia_api_key (str, optional): The API key for NVIDIA services. Defaults to None.
milvus_uri (str, optional): The URI of the Milvus server. Defaults to "http://localhost:19530".
sparse (bool, optional): Whether to use sparse indexing. Defaults to False.
recreate (bool, optional): Whether to recreate the collection if it already exists. Defaults to True.
gpu_index (bool, optional): Whether to use GPU for indexing. Defaults to True.
gpu_search (bool, optional): Whether to use GPU for search operations. Defaults to True.
dense_dim (int, optional): The dimensionality of dense vectors. Defaults to 2048.
minio_endpoint (str, optional): The endpoint for the MinIO server. Defaults to "localhost:9000".
enable_text (bool, optional): Whether to enable text data ingestion. Defaults to True.
enable_charts (bool, optional): Whether to enable chart data ingestion. Defaults to True.
enable_tables (bool, optional): Whether to enable table data ingestion. Defaults to True.
enable_images (bool, optional): Whether to enable image data ingestion. Defaults to True.
enable_infographics (bool, optional): Whether to enable infographic data ingestion. Defaults to True.
bm25_save_path (str, optional): The file path to save the BM25 model. Defaults to "bm25_model.json".
compute_bm25_stats (bool, optional): Whether to compute BM25 statistics. Defaults to True.
access_key (str, optional): The access key for MinIO authentication. Defaults to "minioadmin".
secret_key (str, optional): The secret key for MinIO authentication. Defaults to "minioadmin".
bucket_name (str, optional): The name of the MinIO bucket. Defaults to "a-bucket".
meta_dataframe (Union[str, pd.DataFrame], optional): A metadata DataFrame or the path to a CSV file
containing metadata. Defaults to None.
meta_source_field (str, optional): The field in the metadata that serves as the source identifier.
Defaults to None.
meta_fields (list[str], optional): A list of metadata fields to include. Defaults to None.
embed_batch_size (int, optional): The batch size for embedding. Defaults to 256.
query_batch_size (int, optional): The batch size for querying. Defaults to 1000.
**kwargs: Additional keyword arguments for customization.
"""
if vdb_op is not None and not isinstance(vdb_op, VDB):
raise ValueError("vdb_op must be a VDB object")
if isinstance(vdb_op, VDB):
kwargs = locals().copy()
kwargs.pop("vdb_op", None)
return vdb_op.reindex(**kwargs)
new_collection_name = new_collection_name if new_collection_name else current_collection_name
pull_results = pull_all_milvus(current_collection_name, milvus_uri, write_dir, query_batch_size)
embed_index_collection(
pull_results,
new_collection_name,
batch_size=embed_batch_size,
embedding_endpoint=embedding_endpoint,
model_name=model_name,
nvidia_api_key=nvidia_api_key,
milvus_uri=milvus_uri,
sparse=sparse,
recreate=recreate,
gpu_index=gpu_index,
gpu_search=gpu_search,
dense_dim=dense_dim,
minio_endpoint=minio_endpoint,
enable_text=enable_text,
enable_charts=enable_charts,
enable_tables=enable_tables,
enable_images=enable_images,
enable_infographics=enable_infographics,
bm25_save_path=bm25_save_path,
compute_bm25_stats=compute_bm25_stats,
access_key=access_key,
secret_key=secret_key,
bucket_name=bucket_name,
meta_dataframe=meta_dataframe,
meta_source_field=meta_source_field,
meta_fields=meta_fields,
input_type=input_type,
truncate=truncate,
**kwargs,
)
[docs]
def reconstruct_pages(anchor_record, records_list, page_signum: int = 0):
"""
This function allows a user reconstruct the pages for a retrieved chunk.
Parameters
----------
anchor_record : dict
Query the candidates are supposed to answer.
records_list : list
List of the candidates to rerank.
page_signum : int
The endpoint to the nvidia reranker
Returns
-------
String
Full page(s) corresponding to anchor record.
"""
source_file = anchor_record["entity"]["source"]["source_name"]
page_number = anchor_record["entity"]["content_metadata"]["page_number"]
min_page = page_number - page_signum
max_page = page_number + 1 + page_signum
page_numbers = list(range(min_page, max_page))
target_records = []
for sub_records in records_list:
for record in sub_records:
rec_src_file = record["metadata"]["source_metadata"]["source_name"]
rec_pg_num = record["metadata"]["content_metadata"]["page_number"]
if source_file == rec_src_file and rec_pg_num in page_numbers:
target_records.append(record)
return ingest_json_results_to_blob(target_records)
[docs]
class Milvus(VDB):
def __init__(
self,
collection_name: Union[str, Dict] = "nv_ingest_collection",
milvus_uri: str = "http://localhost:19530",
sparse: bool = False,
recreate: bool = True,
gpu_index: bool = True,
gpu_search: bool = True,
dense_dim: int = 2048,
minio_endpoint: str = "localhost:9000",
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
enable_images: bool = True,
enable_infographics: bool = True,
bm25_save_path: str = "bm25_model.json",
compute_bm25_stats: bool = True,
access_key: str = "minioadmin",
secret_key: str = "minioadmin",
bucket_name: str = "a-bucket",
meta_dataframe: Union[str, pd.DataFrame] = None,
meta_source_field: str = None,
meta_fields: list[str] = None,
stream: bool = False,
**kwargs,
):
"""
Initializes the Milvus operator class with the specified configuration parameters.
Args:
collection_name (Union[str, Dict], optional): The name of the Milvus collection or a dictionary
containing collection configuration. Defaults to "nv_ingest_collection".
milvus_uri (str, optional): The URI of the Milvus server. Defaults to "http://localhost:19530".
sparse (bool, optional): Whether to use sparse indexing. Defaults to False.
recreate (bool, optional): Whether to recreate the collection if it already exists. Defaults to True.
gpu_index (bool, optional): Whether to use GPU for indexing. Defaults to True.
gpu_search (bool, optional): Whether to use GPU for search operations. Defaults to True.
dense_dim (int, optional): The dimensionality of dense vectors. Defaults to 2048.
minio_endpoint (str, optional): The endpoint for the MinIO server. Defaults to "localhost:9000".
enable_text (bool, optional): Whether to enable text data ingestion. Defaults to True.
enable_charts (bool, optional): Whether to enable chart data ingestion. Defaults to True.
enable_tables (bool, optional): Whether to enable table data ingestion. Defaults to True.
enable_images (bool, optional): Whether to enable image data ingestion. Defaults to True.
enable_infographics (bool, optional): Whether to enable infographic data ingestion. Defaults to True.
bm25_save_path (str, optional): The file path to save the BM25 model. Defaults to "bm25_model.json".
compute_bm25_stats (bool, optional): Whether to compute BM25 statistics. Defaults to True.
access_key (str, optional): The access key for MinIO authentication. Defaults to "minioadmin".
secret_key (str, optional): The secret key for MinIO authentication. Defaults to "minioadmin".
bucket_name (str, optional): The name of the MinIO bucket. Defaults to "a-bucket".
meta_dataframe (Union[str, pd.DataFrame], optional): A metadata DataFrame or the path to a CSV file
containing metadata. Defaults to None.
meta_source_field (str, optional): The field in the metadata that serves as the source identifier.
Defaults to None.
meta_fields (list[str], optional): A list of metadata fields to include. Defaults to None.
**kwargs: Additional keyword arguments for customization.
stream (bool, optional): When true, the records will be inserted into milvus using the stream
insert method.
"""
kwargs = locals().copy()
kwargs.pop("self", None)
super().__init__(**kwargs)
[docs]
def create_index(self, **kwargs):
collection_name = kwargs.pop("collection_name")
return create_nvingest_collection(collection_name, **kwargs)
[docs]
def write_to_index(self, records, **kwargs):
collection_name = kwargs.pop("collection_name")
write_to_nvingest_collection(records, collection_name=collection_name, **kwargs)
[docs]
def retrieval(self, queries, **kwargs):
collection_name = kwargs.pop("collection_name")
return nvingest_retrieval(queries, collection_name=collection_name, **kwargs)
[docs]
def reindex(self, **kwargs):
collection_name = kwargs.pop("current_collection_name")
reindex_collection(current_collection_name=collection_name, **kwargs)
[docs]
def get_connection_params(self):
conn_dict = {
"milvus_uri": self.__dict__.get("milvus_uri", "http://localhost:19530"),
"sparse": self.__dict__.get("sparse", True),
"recreate": self.__dict__.get("recreate", True),
"gpu_index": self.__dict__.get("gpu_index", True),
"gpu_search": self.__dict__.get("gpu_search", True),
"dense_dim": self.__dict__.get("dense_dim", 2048),
}
return (self.collection_name, conn_dict)
[docs]
def get_write_params(self):
write_params = self.__dict__.copy()
write_params.pop("recreate", True)
write_params.pop("gpu_index", True)
write_params.pop("gpu_search", True)
write_params.pop("dense_dim", 2048)
return (self.collection_name, write_params)
[docs]
def run(self, records):
collection_name, create_params = self.get_connection_params()
_, write_params = self.get_write_params()
if isinstance(collection_name, str):
self.create_index(collection_name=collection_name, **create_params)
self.write_to_index(records, **write_params)
elif isinstance(collection_name, dict):
split_params_list = _dict_to_params(collection_name, write_params)
for sub_params in split_params_list:
coll_name, sub_write_params = sub_params
sub_write_params.pop("collection_name", None)
self.create_index(collection_name=coll_name, **create_params)
self.write_to_index(records, collection_name=coll_name, **sub_write_params)
else:
raise ValueError(f"Unsupported type for collection_name detected: {type(collection_name)}")