Source code for nv_ingest_client.primitives.tasks.embed
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=too-few-public-methods
# pylint: disable=too-many-arguments
import logging
from typing import Any
from typing import Dict
from typing import Literal
from typing import Optional
from typing import Type
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import model_validator
from .task_base import Task
logger = logging.getLogger(__name__)
[docs]
class EmbedTaskSchema(BaseModel):
    """
    Schema for embed task configuration.
    This schema contains configuration details for an embedding task,
    including the endpoint URL, model name, API key, and error filtering flag.
    Attributes
    ----------
    endpoint_url : Optional[str]
        URL of the embedding endpoint. Default is None.
    model_name : Optional[str]
        Name of the embedding model. Default is None.
    api_key : Optional[str]
        API key for authentication with the embedding service. Default is None.
    filter_errors : bool
        Flag to indicate whether errors should be filtered. Default is False.
    """
    endpoint_url: Optional[str] = None
    model_name: Optional[str] = None
    api_key: Optional[str] = None
    filter_errors: bool = False
    text_elements_modality: Optional[Literal["text", "image", "text_image"]] = None
    image_elements_modality: Optional[Literal["text", "image", "text_image"]] = None
    structured_elements_modality: Optional[Literal["text", "image", "text_image"]] = None
    audio_elements_modality: Optional[Literal["text"]] = None
[docs]
    @model_validator(mode="before")
    def handle_deprecated_fields(cls: Type["EmbedTaskSchema"], values: Dict[str, Any]) -> Dict[str, Any]:
        """
        Handle deprecated fields before model validation.
        This validator checks for the presence of deprecated keys ('text' and 'tables')
        in the input dictionary and removes them. Warnings are issued if these keys are found.
        Parameters
        ----------
        values : Dict[str, Any]
            Input dictionary of model values.
        Returns
        -------
        Dict[str, Any]
            The updated dictionary with deprecated fields removed.
        """
        if "text" in values:
            logger.warning(
                "'text' parameter is deprecated and will be ignored. Future versions will remove this argument."
            )
            values.pop("text")
        if "tables" in values:
            logger.warning(
                "'tables' parameter is deprecated and will be ignored. Future versions will remove this argument."
            )
            values.pop("tables")
        return values 
    model_config = ConfigDict(extra="forbid")
    model_config["protected_namespaces"] = () 
[docs]
class EmbedTask(Task):
    """
    Object for document embedding tasks.
    This class encapsulates the configuration and runtime state for an embedding task,
    including details like the endpoint URL, model name, and API key.
    """
    def __init__(
        self,
        endpoint_url: Optional[str] = None,
        model_name: Optional[str] = None,
        api_key: Optional[str] = None,
        text: Optional[bool] = None,
        tables: Optional[bool] = None,
        filter_errors: bool = False,
        text_elements_modality: Optional[str] = None,
        image_elements_modality: Optional[str] = None,
        structured_elements_modality: Optional[str] = None,
        audio_elements_modality: Optional[str] = None,
    ) -> None:
        """
        Initialize the EmbedTask configuration.
        Parameters
        ----------
        endpoint_url : Optional[str], optional
            URL of the embedding endpoint. Defaults to None.
        model_name : Optional[str], optional
            Name of the embedding model. Defaults to None.
        api_key : Optional[str], optional
            API key for the embedding service. Defaults to None.
        text : Optional[bool], optional
            Deprecated. This parameter is ignored if provided.
        tables : Optional[bool], optional
            Deprecated. This parameter is ignored if provided.
        filter_errors : bool, optional
            Flag indicating whether errors should be filtered. Defaults to False.
        """
        super().__init__()
        if text is not None:
            logger.warning(
                "'text' parameter is deprecated and will be ignored. Future versions will remove this argument."
            )
        if tables is not None:
            logger.warning(
                "'tables' parameter is deprecated and will be ignored. Future versions will remove this argument."
            )
        self._endpoint_url: Optional[str] = endpoint_url
        self._model_name: Optional[str] = model_name
        self._api_key: Optional[str] = api_key
        self._filter_errors: bool = filter_errors
        self._text_elements_modality: Optional[bool] = text_elements_modality
        self._image_elements_modality: Optional[bool] = image_elements_modality
        self._structured_elements_modality: Optional[bool] = structured_elements_modality
        self._audio_elements_modality: Optional[bool] = audio_elements_modality
    def __str__(self) -> str:
        """
        Return the string representation of the EmbedTask.
        The string includes the endpoint URL, model name, a redacted API key, and the error filtering flag.
        Returns
        -------
        str
            A string representation of the EmbedTask configuration.
        """
        info: str = "Embed Task:\n"
        if self._endpoint_url:
            info += f"  endpoint_url: {self._endpoint_url}\n"
        if self._model_name:
            info += f"  model_name: {self._model_name}\n"
        if self._api_key:
            info += "  api_key: [redacted]\n"
        info += f"  filter_errors: {self._filter_errors}\n"
        if self._text_elements_modality:
            info += f"  text_elements_modality: {self._text_elements_modality}\n"
        if self._image_elements_modality:
            info += f"  image_elements_modality: {self._image_elements_modality}\n"
        if self._structured_elements_modality:
            info += f"  structured_elements_modality: {self._structured_elements_modality}\n"
        if self._audio_elements_modality:
            info += f"  audio_elements_modality: {self._audio_elements_modality}\n"
        return info
[docs]
    def to_dict(self) -> Dict[str, Any]:
        """
        Convert the EmbedTask configuration to a dictionary for submission.
        Returns
        -------
        Dict[str, Any]
            A dictionary containing the task type and properties, suitable for submission
            (e.g., to a Redis database).
        """
        task_properties: Dict[str, Any] = {"filter_errors": self._filter_errors}
        if self._endpoint_url:
            task_properties["endpoint_url"] = self._endpoint_url
        if self._model_name:
            task_properties["model_name"] = self._model_name
        if self._api_key:
            task_properties["api_key"] = self._api_key
        if self._text_elements_modality:
            task_properties["text_elements_modality"] = self._text_elements_modality
        if self._image_elements_modality:
            task_properties["image_elements_modality"] = self._image_elements_modality
        if self._structured_elements_modality:
            task_properties["structured_elements_modality"] = self._structured_elements_modality
        if self._audio_elements_modality:
            task_properties["audio_elements_modality"] = self._audio_elements_modality
        return {"type": "embed", "task_properties": task_properties}