Source code for nemo_deploy.service.fastapi_interface_to_pytriton

# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import numpy as np
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, model_validator
from pydantic_settings import BaseSettings

from nemo_deploy.nlp import NemoQueryLLMPyTorch

try:
    from nemo.utils import logging
except (ImportError, ModuleNotFoundError):
    import logging

    logging = logging.getLogger(__name__)


[docs] class TritonSettings(BaseSettings): """TritonSettings class that gets the values of TRITON_HTTP_ADDRESS and TRITON_PORT.""" _triton_service_port: int _triton_service_ip: str def __init__(self): super(TritonSettings, self).__init__() try: self._triton_service_port = int(os.environ.get("TRITON_PORT", 8000)) self._triton_service_ip = os.environ.get("TRITON_HTTP_ADDRESS", "0.0.0.0") except Exception as error: logging.error( "An exception occurred trying to retrieve set args in TritonSettings class. Error:", error, ) return @property def triton_service_port(self): """Returns the port number for the Triton service.""" return self._triton_service_port @property def triton_service_ip(self): """Returns the IP address for the Triton service.""" return self._triton_service_ip
app = FastAPI() triton_settings = TritonSettings()
[docs] class BaseRequest(BaseModel): """Common parameters for completions and chat requests for the server. Attributes: model (str): The name of the model to use for completion. max_tokens (int): The maximum number of tokens to generate in the response. temperature (float): Sampling temperature for randomness in generation. top_p (float): Cumulative probability for nucleus sampling. top_k (int): Number of highest-probability tokens to consider for sampling. """ model: str max_tokens: int = 512 temperature: float = 1.0 top_p: float = 0.0 top_k: int = 0
[docs] @model_validator(mode="after") def set_greedy_params(self): """Validate parameters for greedy decoding.""" if self.temperature == 0 and self.top_p == 0: logging.warning("Both temperature and top_p are 0. Setting top_k to 1 to ensure greedy sampling.") self.top_k = 1 return self
[docs] class CompletionRequest(BaseRequest): """Represents a request for text completion. Attributes: prompt (str): The input text to generate a response from. logprobs (int): Number of log probabilities to include in the response, if applicable. echo (bool): Whether to return the input text as part of the response. """ prompt: str logprobs: int = None echo: bool = False
[docs] class ChatCompletionRequest(BaseRequest): """Represents a request for chat completion. Attributes: messages (list[dict]): A list of message dictionaries for chat completion. logprobs (bool): Whether to return log probabilities for output tokens. top_logprobs (int): Number of log probabilities to include in the response, if applicable. logprobs must be set to true if this parameter is used. """ messages: list[dict]
[docs] @app.get("/v1/health") def health_check(): """Health check endpoint to verify that the API is running. Returns: dict: A dictionary indicating the status of the application. """ return {"status": "ok"}
[docs] @app.get("/v1/triton_health") async def check_triton_health(): """This method exposes endpoint "/triton_health". This can be used to verify if Triton server is accessible while running the REST or FastAPI application. Verify by running: curl http://service_http_address:service_port/v1/triton_health and the returned status should inform if the server is accessible. """ triton_url = ( f"http://{triton_settings.triton_service_ip}:{str(triton_settings.triton_service_port)}/v2/health/ready" ) logging.info(f"Attempting to connect to Triton server at: {triton_url}") try: response = requests.get(triton_url, timeout=5) if response.status_code == 200: return {"status": "Triton server is reachable and ready"} else: raise HTTPException(status_code=503, detail="Triton server is not ready") except requests.RequestException as e: raise HTTPException(status_code=503, detail=f"Cannot reach Triton server: {str(e)}")
[docs] def convert_numpy(obj): """Convert NumPy arrays in output to lists.""" if isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, dict): return {k: convert_numpy(v) for k, v in obj.items()} elif isinstance(obj, list): return [convert_numpy(i) for i in obj] else: return obj
[docs] def _helper_fun( url, model, prompts, temperature, top_k, top_p, compute_logprob, max_length, apply_chat_template, n_top_logprobs, echo, ): """run_in_executor doesn't allow to pass kwargs, so we have this helper function to pass args as a list.""" nq = NemoQueryLLMPyTorch(url=url, model_name=model) output = nq.query_llm( prompts=prompts, temperature=temperature, top_k=top_k, top_p=top_p, compute_logprob=compute_logprob, max_length=max_length, apply_chat_template=apply_chat_template, n_top_logprobs=n_top_logprobs, init_timeout=300, echo=echo, ) return output
[docs] async def query_llm_async( *, url, model, prompts, temperature, top_k, top_p, compute_logprob, max_length, apply_chat_template, n_top_logprobs, echo, ): """Sends requests to `NemoQueryLLMPyTorch.query_llm` in a non-blocking way. This allows the server to process concurrent requests. This way enables batching of requests in the underlying Triton server. """ import asyncio import concurrent loop = asyncio.get_event_loop() with concurrent.futures.ThreadPoolExecutor() as pool: result = await loop.run_in_executor( pool, _helper_fun, url, model, prompts, temperature, top_k, top_p, compute_logprob, max_length, apply_chat_template, n_top_logprobs, echo, ) return result
[docs] @app.post("/v1/completions/") async def completions_v1(request: CompletionRequest): """Defines the completions endpoint and queries the model deployed on PyTriton server.""" url = f"http://{triton_settings.triton_service_ip}:{triton_settings.triton_service_port}" logging.info(f"Request: {request}") prompts = request.prompt if not isinstance(request.prompt, list): prompts = [request.prompt] output = await query_llm_async( url=url, model=request.model, prompts=prompts, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, compute_logprob=(request.logprobs is not None and request.logprobs > 0), max_length=request.max_tokens, apply_chat_template=False, n_top_logprobs=request.logprobs, echo=request.echo, ) output_serializable = convert_numpy(output) output_serializable["choices"][0]["text"] = output_serializable["choices"][0]["text"][0][0] if request.logprobs is not None and request.logprobs > 0: output_serializable["choices"][0]["logprobs"]["token_logprobs"] = output_serializable["choices"][0]["logprobs"][ "token_logprobs" ][0] output_serializable["choices"][0]["logprobs"]["top_logprobs"] = output_serializable["choices"][0]["logprobs"][ "top_logprobs" ][0] if request.echo: # output format requires empty logprobs for the 1st token output_serializable["choices"][0]["logprobs"]["token_logprobs"].insert(0, None) else: output_serializable["choices"][0]["logprobs"] = None logging.info(f"Output: {output_serializable}") return output_serializable
[docs] def dict_to_str(messages): """Serializes dict to str.""" return json.dumps(messages)
[docs] @app.post("/v1/chat/completions/") async def chat_completions_v1(request: ChatCompletionRequest): """Defines the chat completions endpoint and queries the model deployed on PyTriton server.""" url = f"http://{triton_settings.triton_service_ip}:{triton_settings.triton_service_port}" logging.info(f"Request: {request}") prompts = request.messages if not isinstance(request.messages, list): prompts = [request.messages] # Serialize the dictionary to a JSON string represnetation to be able to convert to numpy array # (str_list2numpy) and back to list (str_ndarray2list) as required by PyTriton. Using the dictionaries directly # with these methods is not possible as they expect string type. json_prompts = [dict_to_str(prompts)] output = await query_llm_async( url=url, model=request.model, prompts=json_prompts, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, compute_logprob=False, # disable logprobs because we dont need them for any benchmark max_length=request.max_tokens, apply_chat_template=True, n_top_logprobs=None, echo=False, # chat request doesn't support echo ) # Add 'role' as 'assistant' key to the output dict output["choices"][0]["message"] = { "role": "assistant", "content": output["choices"][0]["text"], } output["object"] = "chat.completion" output["choices"][0]["logprobs"] = None del output["choices"][0]["text"] output_serializable = convert_numpy(output) output_serializable["choices"][0]["message"]["content"] = output_serializable["choices"][0]["message"]["content"][ 0 ][0] logging.info(f"Output: {output_serializable}") return output_serializable