Source code for nv_ingest_client.primitives.tasks.caption

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0


# pylint: disable=too-few-public-methods
# pylint: disable=too-many-arguments

import logging
from typing import Dict


from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskCaptionSchema
from .task_base import Task

logger = logging.getLogger(__name__)


[docs] class CaptionTask(Task): def __init__( self, api_key: str = None, endpoint_url: str = None, prompt: str = None, system_prompt: str = None, model_name: str = None, context_text_max_chars: int = None, temperature: float = None, ) -> None: super().__init__() # Use the API schema for validation validated_data = IngestTaskCaptionSchema( api_key=api_key, endpoint_url=endpoint_url, prompt=prompt, system_prompt=system_prompt, model_name=model_name, context_text_max_chars=context_text_max_chars, temperature=temperature, ) self._api_key = validated_data.api_key self._endpoint_url = validated_data.endpoint_url self._prompt = validated_data.prompt self._system_prompt = validated_data.system_prompt self._model_name = validated_data.model_name self._context_text_max_chars = validated_data.context_text_max_chars self._temperature = validated_data.temperature def __str__(self) -> str: """ Returns a string with the object's config and run time state """ info = "" info += "Image Caption Task:\n" if self._api_key: info += " api_key: [redacted]\n" if self._endpoint_url: info += f" endpoint_url: {self._endpoint_url}\n" if self._prompt: info += f" prompt: {self._prompt}\n" if self._system_prompt: info += f" system_prompt: {self._system_prompt}\n" if self._model_name: info += f" model_name: {self._model_name}\n" if self._context_text_max_chars: info += f" context_text_max_chars: {self._context_text_max_chars}\n" if self._temperature is not None: info += f" temperature: {self._temperature}\n" return info
[docs] def to_dict(self) -> Dict: """ Convert to a dict for submission to redis """ task_properties = {} if self._api_key: task_properties["api_key"] = self._api_key if self._endpoint_url: task_properties["endpoint_url"] = self._endpoint_url if self._prompt: task_properties["prompt"] = self._prompt if self._system_prompt: task_properties["system_prompt"] = self._system_prompt if self._model_name: task_properties["model_name"] = self._model_name if self._context_text_max_chars: task_properties["context_text_max_chars"] = self._context_text_max_chars if self._temperature is not None: task_properties["temperature"] = self._temperature return {"type": "caption", "task_properties": task_properties}