Source code for nv_ingest_client.primitives.tasks.caption

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0


# pylint: disable=too-few-public-methods
# pylint: disable=too-many-arguments

import logging
from typing import Dict
from typing import Optional

from pydantic import ConfigDict, BaseModel

from .task_base import Task

logger = logging.getLogger(__name__)


[docs] class CaptionTaskSchema(BaseModel): api_key: Optional[str] = None endpoint_url: Optional[str] = None prompt: Optional[str] = None model_name: Optional[str] = None model_config = ConfigDict(extra="forbid") model_config["protected_namespaces"] = ()
[docs] class CaptionTask(Task): def __init__( self, api_key: str = None, endpoint_url: str = None, prompt: str = None, model_name: str = None, ) -> None: super().__init__() self._api_key = api_key self._endpoint_url = endpoint_url self._prompt = prompt self._model_name = model_name def __str__(self) -> str: """ Returns a string with the object's config and run time state """ info = "" info += "Image Caption Task:\n" if self._api_key: info += " api_key: [redacted]\n" if self._endpoint_url: info += f" endpoint_url: {self._endpoint_url}\n" if self._prompt: info += f" prompt: {self._prompt}\n" if self._model_name: info += f" model_name: {self._model_name}\n" return info
[docs] def to_dict(self) -> Dict: """ Convert to a dict for submission to redis """ task_properties = {} if self._api_key: task_properties["api_key"] = self._api_key if self._endpoint_url: task_properties["endpoint_url"] = self._endpoint_url if self._prompt: task_properties["prompt"] = self._prompt if self._model_name: task_properties["model_name"] = self._model_name return {"type": "caption", "task_properties": task_properties}