Source code for nemo_deploy.deploy_pytriton

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import logging

from nemo_deploy.deploy_base import DeployBase
from nemo_export_deploy_common.import_utils import MISSING_TRITON_MSG, UnavailableError

LOGGER = logging.getLogger("NeMo")

try:
    from pytriton.model_config import ModelConfig
    from pytriton.triton import Triton, TritonConfig

    HAVE_TRITON = True
except (ImportError, ModuleNotFoundError):
    HAVE_TRITON = False


[docs] class DeployPyTriton(DeployBase): """Deploys any models to Triton Inference Server that implements ITritonDeployable interface in nemo_deploy. Example: from nemo_deploy import DeployPyTriton, NemoQueryLLM from nemo_export.tensorrt_llm import TensorRTLLM trt_llm_exporter = TensorRTLLM(model_dir="/path/for/model/files") trt_llm_exporter.export( nemo_checkpoint_path="/path/for/nemo/checkpoint", model_type="llama", tensor_parallelism_size=1, ) nm = DeployPyTriton(model=trt_llm_exporter, triton_model_name="model_name", http_port=8000) nm.deploy() nm.run() nq = NemoQueryLLM(url="localhost", model_name="model_name") prompts = ["hello, testing GPT inference", "another GPT inference test?"] output = nq.query_llm(prompts=prompts, max_output_len=100) print("prompts: ", prompts) print("") print("output: ", output) print("") prompts = ["Give me some info about Paris", "Do you think Londan is a good city to visit?", "What do you think about Rome?"] output = nq.query_llm(prompts=prompts, max_output_len=250) print("prompts: ", prompts) print("") print("output: ", output) print("") """ def __init__( self, triton_model_name: str, triton_model_version: int = 1, model=None, max_batch_size: int = 128, http_port: int = 8000, grpc_port: int = 8001, address="0.0.0.0", allow_grpc=True, allow_http=True, streaming=False, pytriton_log_verbose=0, ): """A nemo checkpoint or model is expected for serving on Triton Inference Server. Args: triton_model_name (str): Name for the service triton_model_version(int): Version for the service checkpoint_path (str): path of the nemo file model (ITritonDeployable): A model that implements the ITritonDeployable from nemo_deploy import ITritonDeployable max_batch_size (int): max batch size port (int) : port for the Triton server address (str): http address for Triton server to bind. """ super().__init__( triton_model_name=triton_model_name, triton_model_version=triton_model_version, model=model, max_batch_size=max_batch_size, http_port=http_port, grpc_port=grpc_port, address=address, allow_grpc=allow_grpc, allow_http=allow_http, streaming=streaming, ) self.pytriton_log_verbose = pytriton_log_verbose
[docs] def deploy(self): """Deploys any models to Triton Inference Server.""" if not HAVE_TRITON: raise UnavailableError(MISSING_TRITON_MSG) try: if self.streaming: triton_config = TritonConfig( log_verbose=self.pytriton_log_verbose, allow_grpc=self.allow_grpc, allow_http=self.allow_http, grpc_address=self.address, ) self.triton = Triton(config=triton_config) self.triton.bind( model_name=self.triton_model_name, model_version=self.triton_model_version, infer_func=self.model.triton_infer_fn_streaming, inputs=self.model.get_triton_input, outputs=self.model.get_triton_output, config=ModelConfig(decoupled=True), ) else: triton_config = TritonConfig( http_address=self.address, http_port=self.http_port, grpc_address=self.address, grpc_port=self.grpc_port, allow_grpc=self.allow_grpc, allow_http=self.allow_http, ) self.triton = Triton(config=triton_config) self.triton.bind( model_name=self.triton_model_name, model_version=self.triton_model_version, infer_func=self.model.triton_infer_fn, inputs=self.model.get_triton_input, outputs=self.model.get_triton_output, config=ModelConfig(max_batch_size=self.max_batch_size), ) except Exception as e: self.triton = None LOGGER.error(e)
[docs] def serve(self): """Starts serving the model and waits for the requests.""" if self.triton is None: raise Exception("deploy should be called first.") try: self.triton.serve() except Exception as e: self.triton = None LOGGER.error(e)
[docs] def run(self): """Starts serving the model asynchronously.""" if self.triton is None: raise Exception("deploy should be called first.") self.triton.run()
[docs] def stop(self): """Stops serving the model.""" if self.triton is None: raise Exception("deploy should be called first.") self.triton.stop()