Source code for nv_ingest_client.cli.util.click

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import json
import logging
import os
import random
from enum import Enum
from pprint import pprint

import click
from nv_ingest_client.cli.util.processing import check_schema
from nv_ingest_client.primitives.tasks import CaptionTask
from nv_ingest_client.primitives.tasks import DedupTask
from nv_ingest_client.primitives.tasks import EmbedTask
from nv_ingest_client.primitives.tasks import ExtractTask
from nv_ingest_client.primitives.tasks import FilterTask
from nv_ingest_client.primitives.tasks import SplitTask
from nv_ingest_client.primitives.tasks import StoreEmbedTask
from nv_ingest_client.primitives.tasks import StoreTask
from nv_ingest_client.primitives.tasks.caption import CaptionTaskSchema
from nv_ingest_client.primitives.tasks.dedup import DedupTaskSchema
from nv_ingest_client.primitives.tasks.embed import EmbedTaskSchema
from nv_ingest_client.primitives.tasks.extract import ExtractTaskSchema
from nv_ingest_client.primitives.tasks.filter import FilterTaskSchema
from nv_ingest_client.primitives.tasks.split import SplitTaskSchema
from nv_ingest_client.primitives.tasks.store import StoreEmbedTaskSchema
from nv_ingest_client.primitives.tasks.store import StoreTaskSchema
from nv_ingest_client.util.util import generate_matching_files

logger = logging.getLogger(__name__)


[docs] class LogLevel(str, Enum): DEBUG = "DEBUG" INFO = "INFO" WARNING = "WARNING" ERROR = "ERROR" CRITICAL = "CRITICAL"
[docs] class ClientType(str, Enum): REST = "REST" REDIS = "REDIS" KAFKA = "KAFKA"
# Example TaskId validation set VALID_TASK_IDS = {"task1", "task2", "task3"} _MODULE_UNDER_TEST = "nv_ingest_client.cli.util.click"
[docs] def debug_print_click_options(ctx): """ Retrieves all options from the Click context and pretty prints them. Parameters ---------- ctx : click.Context The Click context object from which to retrieve the command options. """ click_options = {} for param in ctx.command.params: if isinstance(param, click.Option): value = ctx.params[param.name] click_options[param.name] = value pprint(click_options)
[docs] def click_validate_file_exists(ctx, param, value): if not value: return [] if isinstance(value, str): value = [value] else: value = list(value) for filepath in value: if not os.path.exists(filepath): raise click.BadParameter(f"File does not exist: {filepath}") return value
[docs] def click_validate_task(ctx, param, value): validated_tasks = {} validation_errors = [] for task_str in value: task_split = task_str.split(":", 1) if len(task_split) != 2: task_id, json_options = task_str, "{}" else: task_id, json_options = task_split try: options = json.loads(json_options) if task_id == "split": task_options = check_schema(SplitTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" new_task = [(new_task_id, SplitTask(**task_options.model_dump()))] elif task_id == "extract": task_options = check_schema(ExtractTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}_{task_options.document_type}" new_task = [(new_task_id, ExtractTask(**task_options.model_dump()))] elif task_id == "store": task_options = check_schema(StoreTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" new_task = [(new_task_id, StoreTask(**task_options.model_dump()))] elif task_id == "store_embedding": task_options = check_schema(StoreEmbedTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" new_task = [(new_task_id, StoreEmbedTask(**task_options.model_dump()))] elif task_id == "caption": task_options = check_schema(CaptionTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" new_task = [(new_task_id, CaptionTask(**task_options.model_dump()))] elif task_id == "dedup": task_options = check_schema(DedupTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" new_task = [(new_task_id, DedupTask(**task_options.model_dump()))] elif task_id == "filter": task_options = check_schema(FilterTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" new_task = [(new_task_id, FilterTask(**task_options.model_dump()))] elif task_id == "embed": task_options = check_schema(EmbedTaskSchema, options, task_id, json_options) new_task_id = f"{task_id}" new_task = [(new_task_id, EmbedTask(**task_options.model_dump()))] else: raise ValueError(f"Unsupported task type: {task_id}") if new_task_id in validated_tasks: raise ValueError(f"Duplicate task detected: {new_task_id}") logger.debug("Adding task: %s", new_task_id) for task_tuple in new_task: validated_tasks[task_tuple[0]] = task_tuple[1] except ValueError as e: validation_errors.append(str(e)) if validation_errors: # Aggregate error messages with original values highlighted error_message = "\n".join(validation_errors) raise click.BadParameter(error_message) return validated_tasks
[docs] def click_validate_batch_size(ctx, param, value): if value < 1: raise click.BadParameter("Batch size must be >= 1.") return value
[docs] def pre_process_dataset(dataset_json: str, shuffle_dataset: bool): """ Loads a dataset from a JSON file and optionally shuffles the list of files contained within. Parameters ---------- dataset_json : str The path to the dataset JSON file. shuffle_dataset : bool, optional Whether to shuffle the dataset before processing. Defaults to True. Returns ------- list The list of files from the dataset, possibly shuffled. """ try: with open(dataset_json, "r") as f: file_source = json.load(f) except FileNotFoundError: raise click.BadParameter(f"Dataset JSON file not found: {dataset_json}") except json.JSONDecodeError: raise click.BadParameter(f"Invalid JSON format in file: {dataset_json}") # Extract the list of files and optionally shuffle them file_source = file_source.get("sampled_files", []) if shuffle_dataset: random.shuffle(file_source) return file_source
[docs] def click_match_and_validate_files(ctx, param, value): """ Matches and validates files based on the provided file source patterns. Parameters ---------- value : list of str A list containing file source patterns to match against. Returns ------- list of str or None A list of matching file paths if any matches are found; otherwise, None. """ if not value: return [] matching_files = list(generate_matching_files(value)) if not matching_files: logger.warning("No files found matching the specified patterns.") return [] return matching_files