Source code for nv_ingest_client.primitives.tasks.task_factory
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Callable
from typing import Dict
from typing import Type
from typing import Union
from .caption import CaptionTask
from .dedup import DedupTask
from .embed import EmbedTask
from .extract import ExtractTask
from .filter import FilterTask
from .split import SplitTask
from .store import StoreEmbedTask
from .store import StoreTask
from .task_base import Task
from .task_base import TaskType
from .task_base import is_valid_task_type
[docs]
class TaskUnimplemented(Task):
"""
Placeholder for unimplemented tasks
"""
def __init__(self, **kwargs) -> None:
super().__init__()
raise NotImplementedError("Task type is not implemented")
# Mapping of TaskType to Task classes, arranged alphabetically by task type
_TASK_MAP: Dict[TaskType, Callable] = {
TaskType.CAPTION: CaptionTask,
TaskType.DEDUP: DedupTask,
TaskType.EMBED: EmbedTask,
TaskType.EXTRACT: ExtractTask,
TaskType.FILTER: FilterTask,
TaskType.SPLIT: SplitTask,
TaskType.STORE_EMBEDDING: StoreEmbedTask,
TaskType.STORE: StoreTask,
TaskType.TRANSFORM: TaskUnimplemented,
}
[docs]
def task_factory(task_type: Union[TaskType, str], **kwargs) -> Task:
"""
Factory method for creating tasks based on the provided task type.
Parameters
----------
task_type : TaskType
The type of the task to create.
**kwargs : dict
Additional keyword arguments to pass to the task's constructor.
Returns
-------
Task
An instance of the task corresponding to the given task type.
Raises
------
ValueError
If an invalid task type is provided.
"""
if isinstance(task_type, str):
if is_valid_task_type(task_type):
task_type = TaskType[task_type]
else:
raise ValueError(f"Invalid task type string: '{task_type}'")
elif not isinstance(task_type, TaskType):
raise ValueError("task_type must be a TaskType enum member or a valid task type string")
task_class: Type[Task] = _TASK_MAP[task_type]
# Inspect the constructor (__init__) of the task class to get its parameters
sig = inspect.signature(task_class.__init__)
params = sig.parameters
# Exclude 'self' and positional-only parameters
valid_kwargs = {
name
for name, param in params.items()
if param.kind in [param.KEYWORD_ONLY, param.POSITIONAL_OR_KEYWORD] and name != "self"
}
# Check if provided kwargs match the task's constructor parameters
for kwarg in kwargs:
if kwarg not in valid_kwargs:
raise ValueError(f"Unexpected keyword argument '{kwarg}' for task type '{task_type.name}'")
# Create and return the task instance with the provided kwargs
return task_class(**kwargs)