# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import numpy as np
from nemo_deploy import ITritonDeployable
from nemo_deploy.utils import cast_output, str_ndarray2list
from nemo_export_deploy_common.import_utils import MISSING_TRITON_MSG, MISSING_VLLM_MSG, UnavailableError
try:
from pytriton.decorators import batch, first_value
from pytriton.model_config import Tensor
HAVE_PYTRITON = True
except (ImportError, ModuleNotFoundError):
from unittest.mock import MagicMock
batch = MagicMock()
first_value = MagicMock()
HAVE_PYTRITON = False
try:
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
HAVE_VLLM = True
except (ImportError, ModuleNotFoundError):
HAVE_VLLM = False
[docs]
class vLLMHFExporter(ITritonDeployable):
"""The Exporter class uses vLLM APIs to convert a HF model to vLLM and makes the class, deployable with Triton server.
Example:
from nemo_export import vLLMHFExporter
from nemo_deploy import DeployPyTriton
exporter = vLLMHFExporter()
exporter.export(model="/path/to/model/")
server = DeployPyTriton(
model=exporter,
triton_model_name='model'
)
server.deploy()
server.serve()
server.stop()
"""
def __init__(self):
self.model = None
self.lora_models = None
if not HAVE_VLLM:
raise UnavailableError(MISSING_VLLM_MSG)
if not HAVE_PYTRITON:
raise UnavailableError(MISSING_TRITON_MSG)
[docs]
def export(self, model, enable_lora: bool = False):
"""Exports the HF checkpoint to vLLM and initializes the engine.
Args:
model (str): model name or the path
"""
self.model = LLM(model=model, enable_lora=enable_lora)
[docs]
def add_lora_models(self, lora_model_name, lora_model):
if self.lora_models is None:
self.lora_models = {}
self.lora_models[lora_model_name] = lora_model
@property
def get_triton_input(self):
inputs = (
Tensor(name="prompts", shape=(-1,), dtype=bytes),
Tensor(name="max_output_len", shape=(-1,), dtype=np.int_, optional=True),
Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True),
Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True),
Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True),
)
return inputs
@property
def get_triton_output(self):
outputs = (Tensor(name="outputs", shape=(-1,), dtype=bytes),)
return outputs
[docs]
@batch
@first_value("max_output_len", "top_k", "top_p", "temperature")
def triton_infer_fn(self, **inputs: np.ndarray): # pragma: no cover
try:
infer_input = {"input_texts": str_ndarray2list(inputs.pop("prompts"))}
if "max_output_len" in inputs:
infer_input["max_output_len"] = inputs.pop("max_output_len")
if "top_k" in inputs:
infer_input["top_k"] = inputs.pop("top_k")
if "top_p" in inputs:
infer_input["top_p"] = inputs.pop("top_p")
if "temperature" in inputs:
infer_input["temperature"] = inputs.pop("temperature")
output_texts = self.forward(**infer_input)
output = cast_output(output_texts, np.bytes_)
except Exception as error:
err_msg = "An error occurred: {0}".format(str(error))
output = cast_output([err_msg], np.bytes_)
return {"outputs": output}
[docs]
def forward(
self,
input_texts: List[str],
max_output_len: int = 64,
top_k: int = 1,
top_p: float = 0.1,
temperature: float = 1.0,
lora_model_name: str = None,
):
assert self.model is not None, "Model is not initialized."
lora_request = None
if lora_model_name is not None:
if self.lora_models is None:
raise Exception("No lora models are available.")
assert lora_model_name in self.lora_models.keys(), "Lora model was not added before"
lora_request = LoRARequest(lora_model_name, 1, self.lora_models[lora_model_name])
sampling_params = SamplingParams(
max_tokens=max_output_len,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
)
request_output = self.model.generate(input_texts, sampling_params, lora_request=lora_request)
output = []
for o in request_output:
output.append(o.outputs[0].text)
return output