Source code for nv_ingest_client.primitives.tasks.embed

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=too-few-public-methods
# pylint: disable=too-many-arguments

import logging
from typing import Dict, Any, Type
from typing import Optional

from pydantic import BaseModel, ConfigDict, model_validator

from .task_base import Task

logger = logging.getLogger(__name__)


[docs] class EmbedTaskSchema(BaseModel): """ Schema for embed task configuration. This schema contains configuration details for an embedding task, including the endpoint URL, model name, API key, and error filtering flag. Attributes ---------- endpoint_url : Optional[str] URL of the embedding endpoint. Default is None. model_name : Optional[str] Name of the embedding model. Default is None. api_key : Optional[str] API key for authentication with the embedding service. Default is None. filter_errors : bool Flag to indicate whether errors should be filtered. Default is False. """ endpoint_url: Optional[str] = None model_name: Optional[str] = None api_key: Optional[str] = None filter_errors: bool = False
[docs] @model_validator(mode="before") def handle_deprecated_fields(cls: Type["EmbedTaskSchema"], values: Dict[str, Any]) -> Dict[str, Any]: """ Handle deprecated fields before model validation. This validator checks for the presence of deprecated keys ('text' and 'tables') in the input dictionary and removes them. Warnings are issued if these keys are found. Parameters ---------- values : Dict[str, Any] Input dictionary of model values. Returns ------- Dict[str, Any] The updated dictionary with deprecated fields removed. """ if "text" in values: logger.warning( "'text' parameter is deprecated and will be ignored. Future versions will remove this argument." ) values.pop("text") if "tables" in values: logger.warning( "'tables' parameter is deprecated and will be ignored. Future versions will remove this argument." ) values.pop("tables") return values
model_config = ConfigDict(extra="forbid")
[docs] class EmbedTask(Task): """ Object for document embedding tasks. This class encapsulates the configuration and runtime state for an embedding task, including details like the endpoint URL, model name, and API key. """ def __init__( self, endpoint_url: Optional[str] = None, model_name: Optional[str] = None, api_key: Optional[str] = None, text: Optional[bool] = None, tables: Optional[bool] = None, filter_errors: bool = False, ) -> None: """ Initialize the EmbedTask configuration. Parameters ---------- endpoint_url : Optional[str], optional URL of the embedding endpoint. Defaults to None. model_name : Optional[str], optional Name of the embedding model. Defaults to None. api_key : Optional[str], optional API key for the embedding service. Defaults to None. text : Optional[bool], optional Deprecated. This parameter is ignored if provided. tables : Optional[bool], optional Deprecated. This parameter is ignored if provided. filter_errors : bool, optional Flag indicating whether errors should be filtered. Defaults to False. """ super().__init__() if text is not None: logger.warning( "'text' parameter is deprecated and will be ignored. Future versions will remove this argument." ) if tables is not None: logger.warning( "'tables' parameter is deprecated and will be ignored. Future versions will remove this argument." ) self._endpoint_url: Optional[str] = endpoint_url self._model_name: Optional[str] = model_name self._api_key: Optional[str] = api_key self._filter_errors: bool = filter_errors def __str__(self) -> str: """ Return the string representation of the EmbedTask. The string includes the endpoint URL, model name, a redacted API key, and the error filtering flag. Returns ------- str A string representation of the EmbedTask configuration. """ info: str = "Embed Task:\n" if self._endpoint_url: info += f" endpoint_url: {self._endpoint_url}\n" if self._model_name: info += f" model_name: {self._model_name}\n" if self._api_key: info += " api_key: [redacted]\n" info += f" filter_errors: {self._filter_errors}\n" return info
[docs] def to_dict(self) -> Dict[str, Any]: """ Convert the EmbedTask configuration to a dictionary for submission. Returns ------- Dict[str, Any] A dictionary containing the task type and properties, suitable for submission (e.g., to a Redis database). """ task_properties: Dict[str, Any] = {"filter_errors": self._filter_errors} if self._endpoint_url: task_properties["endpoint_url"] = self._endpoint_url if self._model_name: task_properties["model_name"] = self._model_name if self._api_key: task_properties["api_key"] = self._api_key return {"type": "embed", "task_properties": task_properties}