Source code for nv_ingest.framework.orchestration.ray.examples.pipeline_test_harness

# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import os
import ray
import logging
import time
from typing import Dict, Any

# Import our new pipeline class.
from nv_ingest.framework.orchestration.ray.primitives.ray_pipeline import RayPipeline
from nv_ingest.framework.orchestration.ray.stages.extractors.audio_extractor import AudioExtractorStage
from nv_ingest.framework.orchestration.ray.stages.extractors.chart_extractor import ChartExtractorStage
from nv_ingest.framework.orchestration.ray.stages.extractors.docx_extractor import DocxExtractorStage
from nv_ingest.framework.orchestration.ray.stages.extractors.image_extractor import ImageExtractorStage
from nv_ingest.framework.orchestration.ray.stages.extractors.pdf_extractor import PDFExtractorStage
from nv_ingest.framework.orchestration.ray.stages.extractors.table_extractor import TableExtractorStage

# Import stage implementations and configuration models.
from nv_ingest.framework.orchestration.ray.stages.injectors.metadata_injector import MetadataInjectionStage
from nv_ingest.framework.orchestration.ray.stages.mutate.image_dedup import ImageDedupStage
from nv_ingest.framework.orchestration.ray.stages.mutate.image_filter import ImageFilterStage
from nv_ingest.framework.orchestration.ray.stages.sinks.message_broker_task_sink import (
    MessageBrokerTaskSinkStage,
    MessageBrokerTaskSinkConfig,
)
from nv_ingest.framework.orchestration.ray.stages.sources.message_broker_task_source import (
    MessageBrokerTaskSourceStage,
    MessageBrokerTaskSourceConfig,
    start_simple_message_broker,
)
from nv_ingest.framework.orchestration.ray.stages.storage.image_storage import ImageStorageStage
from nv_ingest.framework.orchestration.ray.stages.storage.store_embeddings import EmbeddingStorageStage
from nv_ingest.framework.orchestration.ray.stages.transforms.image_caption import ImageCaptionTransformStage
from nv_ingest.framework.orchestration.ray.stages.transforms.text_embed import TextEmbeddingTransformStage
from nv_ingest.framework.orchestration.ray.stages.transforms.text_splitter import TextSplitterStage
from nv_ingest.framework.schemas.framework_metadata_injector_schema import MetadataInjectorSchema
from nv_ingest_api.internal.schemas.extract.extract_audio_schema import AudioExtractorSchema
from nv_ingest_api.internal.schemas.extract.extract_chart_schema import ChartExtractorSchema
from nv_ingest_api.internal.schemas.extract.extract_docx_schema import DocxExtractorSchema
from nv_ingest_api.internal.schemas.extract.extract_image_schema import ImageExtractorSchema
from nv_ingest_api.internal.schemas.extract.extract_pdf_schema import PDFExtractorSchema
from nv_ingest_api.internal.schemas.extract.extract_table_schema import TableExtractorSchema
from nv_ingest_api.internal.schemas.mutate.mutate_image_dedup_schema import ImageDedupSchema
from nv_ingest_api.internal.schemas.store.store_embedding_schema import EmbeddingStorageSchema
from nv_ingest_api.internal.schemas.store.store_image_schema import ImageStorageModuleSchema
from nv_ingest_api.internal.schemas.transform.transform_image_caption_schema import ImageCaptionExtractionSchema
from nv_ingest_api.internal.schemas.transform.transform_image_filter_schema import ImageFilterSchema
from nv_ingest_api.internal.schemas.transform.transform_text_embedding_schema import TextEmbeddingSchema
from nv_ingest_api.internal.schemas.transform.transform_text_splitter_schema import TextSplitterSchema


[docs] def get_nim_service(env_var_prefix): prefix = env_var_prefix.upper() grpc_endpoint = os.environ.get( f"{prefix}_GRPC_ENDPOINT", "", ) http_endpoint = os.environ.get( f"{prefix}_HTTP_ENDPOINT", "", ) auth_token = os.environ.get( "NVIDIA_API_KEY", "", ) or os.environ.get( "NGC_API_KEY", "", ) infer_protocol = os.environ.get( f"{prefix}_INFER_PROTOCOL", "http" if http_endpoint else "grpc" if grpc_endpoint else "", ) logger.info(f"{prefix}_GRPC_ENDPOINT: {grpc_endpoint}") logger.info(f"{prefix}_HTTP_ENDPOINT: {http_endpoint}") logger.info(f"{prefix}_INFER_PROTOCOL: {infer_protocol}") return grpc_endpoint, http_endpoint, auth_token, infer_protocol
# Broker configuration – using a simple client on a fixed port. simple_config: Dict[str, Any] = { "client_type": "simple", "host": "localhost", "port": 7671, "max_retries": 3, "max_backoff": 2, "connection_timeout": 5, "broker_params": {"max_queue_size": 1000}, } if __name__ == "__main__": ray.init( ignore_reinit_error=True, _system_config={ "local_fs_capacity_threshold": 0.9, "object_spilling_config": json.dumps( { "type": "filesystem", "params": { "directory_path": [ "/tmp/ray_spill_testing_0", "/tmp/ray_spill_testing_1", "/tmp/ray_spill_testing_2", "/tmp/ray_spill_testing_3", ], "buffer_size": 100_000_000, }, }, ), }, ) logging.basicConfig(level=logging.INFO) logger = logging.getLogger("RayPipelineHarness") logger.info("Starting multi-stage pipeline test.") # Start the SimpleMessageBroker server externally. logger.info("Starting SimpleMessageBroker server.") broker_process = start_simple_message_broker(simple_config) logger.info("SimpleMessageBroker server started.") # Build the pipeline. pipeline = RayPipeline() logger.info("Created RayPipeline instance.") # Create configuration instances for the source and sink stages. source_config = MessageBrokerTaskSourceConfig( broker_client=simple_config, task_queue="ingest_task_queue", poll_interval=0.1, ) sink_config = MessageBrokerTaskSinkConfig( broker_client=simple_config, poll_interval=0.1, ) logger.info("Source and sink configurations created.") # Set environment variables for various services. os.environ["YOLOX_GRPC_ENDPOINT"] = "localhost:8001" os.environ["YOLOX_INFER_PROTOCOL"] = "grpc" os.environ["YOLOX_TABLE_STRUCTURE_GRPC_ENDPOINT"] = "127.0.0.1:8007" os.environ["YOLOX_TABLE_STRUCTURE_INFER_PROTOCOL"] = "grpc" os.environ["YOLOX_GRAPHIC_ELEMENTS_GRPC_ENDPOINT"] = "127.0.0.1:8004" os.environ["YOLOX_GRAPHIC_ELEMENTS_HTTP_ENDPOINT"] = "http://localhost:8003/v1/infer" os.environ["YOLOX_GRAPHIC_ELEMENTS_INFER_PROTOCOL"] = "http" os.environ["OCR_GRPC_ENDPOINT"] = "localhost:8010" os.environ["OCR_INFER_PROTOCOL"] = "grpc" os.environ["OCR_MODEL_NAME"] = "paddle" os.environ["NEMORETRIEVER_PARSE_HTTP_ENDPOINT"] = "https://integrate.api.nvidia.com/v1/chat/completions" os.environ["VLM_CAPTION_ENDPOINT"] = "https://integrate.api.nvidia.com/v1/chat/completions" os.environ["VLM_CAPTION_MODEL_NAME"] = "nvidia/llama-3.1-nemotron-nano-vl-8b-v1" logger.info("Environment variables set.") image_caption_endpoint_url = "https://integrate.api.nvidia.com/v1/chat/completions" model_name = "nvidia/llama-3.1-nemotron-nano-vl-8b-v1" yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_nim_service("yolox") ( yolox_table_structure_grpc, yolox_table_structure_http, yolox_table_structure_auth, yolox_table_structure_protocol, ) = get_nim_service("yolox_table_structure") ( yolox_graphic_elements_grpc, yolox_graphic_elements_http, yolox_graphic_elements_auth, yolox_graphic_elements_protocol, ) = get_nim_service("yolox_graphic_elements") nemoretriever_parse_grpc, nemoretriever_parse_http, nemoretriever_parse_auth, nemoretriever_parse_protocol = ( get_nim_service("nemoretriever_parse") ) ocr_grpc, ocr_http, ocr_auth, ocr_protocol = get_nim_service("ocr") model_name = os.environ.get("NEMORETRIEVER_PARSE_MODEL_NAME", "nvidia/nemoretriever-parse") pdf_extractor_config = { "pdfium_config": { "auth_token": yolox_auth, # All auth tokens are the same for the moment "yolox_endpoints": (yolox_grpc, yolox_http), "yolox_infer_protocol": yolox_protocol, }, "nemoretriever_parse_config": { "auth_token": nemoretriever_parse_auth, "nemoretriever_parse_endpoints": (nemoretriever_parse_grpc, nemoretriever_parse_http), "nemoretriever_parse_infer_protocol": nemoretriever_parse_protocol, "nemoretriever_parse_model_name": model_name, "yolox_endpoints": (yolox_grpc, yolox_http), "yolox_infer_protocol": yolox_protocol, }, } docx_extractor_config = { "docx_extraction_config": { "yolox_endpoints": (yolox_grpc, yolox_http), "yolox_infer_protocol": yolox_protocol, "auth_token": yolox_auth, } } chart_extractor_config = { "endpoint_config": { "yolox_endpoints": (yolox_graphic_elements_grpc, yolox_graphic_elements_http), "yolox_infer_protocol": yolox_graphic_elements_protocol, "ocr_endpoints": (ocr_grpc, ocr_http), "ocr_infer_protocol": ocr_protocol, "auth_token": yolox_auth, } } table_extractor_config = { "endpoint_config": { "yolox_endpoints": (yolox_table_structure_grpc, yolox_table_structure_http), "yolox_infer_protocol": yolox_table_structure_protocol, "ocr_endpoints": (ocr_grpc, ocr_http), "ocr_infer_protocol": ocr_protocol, "auth_token": yolox_auth, } } text_embedding_config = { "api_key": yolox_auth, "embedding_nim_endpoint": "http://localhost:8012/v1", "embedding_model": "nvidia/llama-3.2-nv-embedqa-1b-v2", } image_extraction_config = { "yolox_endpoints": (yolox_grpc, yolox_http), "yolox_infer_protocol": yolox_protocol, "auth_token": yolox_auth, # All auth tokens are the same for the moment } image_caption_config = { "api_key": yolox_auth, "endpoint_url": image_caption_endpoint_url, "model_name": model_name, "prompt": "Caption the content of this image:", } logger.info("Service configuration retrieved from get_nim_service and environment variables.") # Add stages: pipeline.add_source( name="source", source_actor=MessageBrokerTaskSourceStage, config=source_config, ) # TODO(Job_Counter): Utilizes a global that isn't compatible with Ray, will need to make it a shared object # pipeline.add_stage( # name="job_counter", # stage_actor=JobCounterStage, # config=JobCounterSchema(), # min_replicas=1, # max_replicas=1, # ) pipeline.add_stage( name="metadata_injection", stage_actor=MetadataInjectionStage, config=MetadataInjectorSchema(), # Use stage-specific config if needed. min_replicas=0, max_replicas=2, ) pipeline.add_stage( name="pdf_extractor", stage_actor=PDFExtractorStage, config=PDFExtractorSchema(**pdf_extractor_config), min_replicas=0, max_replicas=16, ) pipeline.add_stage( name="docx_extractor", stage_actor=DocxExtractorStage, config=DocxExtractorSchema(**docx_extractor_config), min_replicas=0, max_replicas=8, ) pipeline.add_stage( name="audio_extractor", stage_actor=AudioExtractorStage, config=AudioExtractorSchema(), min_replicas=0, max_replicas=8, ) pipeline.add_stage( name="image_extractor", stage_actor=ImageExtractorStage, config=ImageExtractorSchema(**image_extraction_config), min_replicas=0, max_replicas=8, ) pipeline.add_stage( name="table_extractor", stage_actor=TableExtractorStage, config=TableExtractorSchema(**table_extractor_config), min_replicas=0, max_replicas=8, ) pipeline.add_stage( name="chart_extractor", stage_actor=ChartExtractorStage, config=ChartExtractorSchema(**chart_extractor_config), min_replicas=0, max_replicas=8, ) pipeline.add_stage( name="text_embedding", stage_actor=TextEmbeddingTransformStage, config=TextEmbeddingSchema(**text_embedding_config), min_replicas=0, max_replicas=8, ) pipeline.add_stage( name="image_filter", stage_actor=ImageFilterStage, config=ImageFilterSchema(), min_replicas=0, max_replicas=4, ) pipeline.add_stage( name="image_dedup", stage_actor=ImageDedupStage, config=ImageDedupSchema(), min_replicas=0, max_replicas=4, ) pipeline.add_stage( name="image_storage", stage_actor=ImageStorageStage, config=ImageStorageModuleSchema(), min_replicas=0, max_replicas=4, ) pipeline.add_stage( name="embedding_storage", stage_actor=EmbeddingStorageStage, config=EmbeddingStorageSchema(), min_replicas=0, max_replicas=4, ) pipeline.add_stage( name="text_splitter", stage_actor=TextSplitterStage, config=TextSplitterSchema(), min_replicas=0, max_replicas=4, ) pipeline.add_stage( name="image_caption", stage_actor=ImageCaptionTransformStage, config=ImageCaptionExtractionSchema(**image_caption_config), min_replicas=0, max_replicas=4, ) pipeline.add_sink( name="sink", sink_actor=MessageBrokerTaskSinkStage, config=sink_config, min_replicas=0, max_replicas=2, ) logger.info("Added sink stage to pipeline.") # Wire the stages together via ThreadedQueueEdge actors. ###### INTAKE STAGES ######## pipeline.make_edge("source", "metadata_injection", queue_size=16) # pipeline.make_edge("job_counter", "metadata_injection", queue_size=16) pipeline.make_edge("metadata_injection", "pdf_extractor", queue_size=128) # to limit memory pressure ###### Document Extractors ######## pipeline.make_edge("pdf_extractor", "audio_extractor", queue_size=16) pipeline.make_edge("audio_extractor", "docx_extractor", queue_size=16) pipeline.make_edge("docx_extractor", "image_extractor", queue_size=16) pipeline.make_edge("image_extractor", "table_extractor", queue_size=16) ###### Primitive Extractors ######## pipeline.make_edge("table_extractor", "chart_extractor", queue_size=16) pipeline.make_edge("chart_extractor", "image_filter", queue_size=16) ###### Primitive Mutators ######## pipeline.make_edge("image_filter", "image_dedup", queue_size=16) pipeline.make_edge("image_dedup", "text_splitter", queue_size=16) ###### Primitive Transforms ######## pipeline.make_edge("text_splitter", "text_embedding", queue_size=16) pipeline.make_edge("text_embedding", "image_caption", queue_size=16) pipeline.make_edge("image_caption", "image_storage", queue_size=16) ###### Primitive Storage ######## pipeline.make_edge("image_storage", "embedding_storage", queue_size=16) pipeline.make_edge("embedding_storage", "sink", queue_size=16) logger.info("Completed wiring of pipeline edges.") # Build the pipeline (this instantiates actors and wires edges). logger.info("Building pipeline...") pipeline.build() logger.info("Pipeline build complete.") # Optionally, visualize the pipeline graph. # pipeline.visualize(mode="text", verbose=True, max_width=120) # Start the pipeline. logger.info("Starting pipeline...") pipeline.start() logger.info("Pipeline started successfully.") try: while True: time.sleep(5) except KeyboardInterrupt: logger.info("Interrupt received, shutting down pipeline.") pipeline.stop() ray.shutdown() logger.info("Ray shutdown complete.")