Source code for nv_ingest_client.primitives.tasks.task_factory

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import inspect
from typing import Callable
from typing import Dict
from typing import Type
from typing import Union

from nv_ingest_client.primitives.tasks.task_base import Task, TaskType, is_valid_task_type
from nv_ingest_client.primitives.tasks.caption import CaptionTask
from nv_ingest_client.primitives.tasks.dedup import DedupTask
from nv_ingest_client.primitives.tasks.embed import EmbedTask
from nv_ingest_client.primitives.tasks.extract import ExtractTask
from nv_ingest_client.primitives.tasks.filter import FilterTask
from nv_ingest_client.primitives.tasks.split import SplitTask
from nv_ingest_client.primitives.tasks.store import StoreEmbedTask, StoreTask
from nv_ingest_client.primitives.tasks.udf import UDFTask


[docs] class TaskUnimplemented(Task): """ Placeholder for unimplemented tasks """ def __init__(self, **kwargs) -> None: super().__init__() raise NotImplementedError("Task type is not implemented")
# Mapping of TaskType to Task classes, arranged alphabetically by task type _TASK_MAP: Dict[TaskType, Callable] = { TaskType.CAPTION: CaptionTask, TaskType.DEDUP: DedupTask, TaskType.EMBED: EmbedTask, TaskType.EXTRACT: ExtractTask, TaskType.FILTER: FilterTask, TaskType.SPLIT: SplitTask, TaskType.STORE_EMBEDDING: StoreEmbedTask, TaskType.STORE: StoreTask, TaskType.TRANSFORM: TaskUnimplemented, TaskType.UDF: UDFTask, }
[docs] def task_factory(task_type: Union[TaskType, str], **kwargs) -> Task: """ Factory method for creating tasks based on the provided task type. Parameters ---------- task_type : TaskType The type of the task to create. **kwargs : dict Additional keyword arguments to pass to the task's constructor. Returns ------- Task An instance of the task corresponding to the given task type. Raises ------ ValueError If an invalid task type is provided. """ if isinstance(task_type, str): if is_valid_task_type(task_type): task_type = TaskType[task_type] else: raise ValueError(f"Invalid task type string: '{task_type}'") elif not isinstance(task_type, TaskType): raise ValueError("task_type must be a TaskType enum member or a valid task type string") task_class: Type[Task] = _TASK_MAP[task_type] # Inspect the constructor (__init__) of the task class to get its parameters sig = inspect.signature(task_class.__init__) params = sig.parameters # Exclude 'self' and positional-only parameters valid_kwargs = { name for name, param in params.items() if param.kind in [param.KEYWORD_ONLY, param.POSITIONAL_OR_KEYWORD] and name != "self" } # Check if provided kwargs match the task's constructor parameters for kwarg in kwargs: if kwarg not in valid_kwargs: raise ValueError(f"Unexpected keyword argument '{kwarg}' for task type '{task_type.name}'") # Create and return the task instance with the provided kwargs return task_class(**kwargs)