Source code for nv_ingest_client.primitives.tasks.embed
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=too-few-public-methods
# pylint: disable=too-many-arguments
import logging
from typing import Dict
from typing import Optional
from pydantic import BaseModel, root_validator
from .task_base import Task
logger = logging.getLogger(__name__)
[docs]
class EmbedTaskSchema(BaseModel):
endpoint_url: Optional[str] = None
model_name: Optional[str] = None
api_key: Optional[str] = None
filter_errors: bool = False
[docs]
@root_validator(pre=True)
def handle_deprecated_fields(cls, values):
if "text" in values:
logger.warning(
"'text' parameter is deprecated and will be ignored. Future versions will remove this argument."
)
values.pop("text")
if "tables" in values:
logger.warning(
"'tables' parameter is deprecated and will be ignored. Future versions will remove this argument."
)
values.pop("tables")
return values
[docs]
class Config:
extra = "forbid"
[docs]
class EmbedTask(Task):
"""
Object for document embedding task
"""
def __init__(
self,
endpoint_url: str = None,
model_name: str = None,
api_key: str = None,
text: bool = None,
tables: bool = None,
filter_errors: bool = False,
) -> None:
"""
Setup Embed Task Config
"""
super().__init__()
if text is not None:
logger.warning(
"'text' parameter is deprecated and will be ignored. Future versions will remove this argument."
)
if tables is not None:
logger.warning(
"'tables' parameter is deprecated and will be ignored. Future versions will remove this argument."
)
self._endpoint_url = endpoint_url
self._model_name = model_name
self._api_key = api_key
self._filter_errors = filter_errors
def __str__(self) -> str:
"""
Returns a string with the object's config and run time state
"""
info = "Embed Task:\n"
if self._endpoint_url:
info += f" endpoint_url: {self._endpoint_url}\n"
if self._model_name:
info += f" model_name: {self._model_name}\n"
if self._api_key:
info += " api_key: [redacted]\n"
info += f" filter_errors: {self._filter_errors}\n"
return info
[docs]
def to_dict(self) -> Dict:
"""
Convert to a dict for submission to redis
"""
task_properties = {
"filter_errors": self._filter_errors,
}
if self._endpoint_url:
task_properties["endpoint_url"] = self._endpoint_url
if self._model_name:
task_properties["model_name"] = self._model_name
if self._api_key:
task_properties["api_key"] = self._api_key
return {"type": "embed", "task_properties": task_properties}