Deploying the NeMo Models in the NeMo Framework Inference Container

NeMo Framework Inference Container contains modules and scripts to help exporting nemo LLM models to TensorRT-LLM and deploying nemo LLM models to Triton Inference Server with easy-to-use APIs. In this section, we will show you how to deploy a nemo checkpoint with TensorRT-LLM.

  • Supported GPUs:
    • A100

    • H100

  • Supported models with different number of parameters in distributed NeMo checkpoint format are listed below.

Model Name

Model Parameters

NeMo Precision

TensorRT-LLM Precision

Fine Tuning

GPT 2B, 8B, 43B bfloat16 bfloat16 SFT, RLHF, SteerLM
LLAMA2 7B, 13B, 70B bfloat16 bfloat16 SFT, RLHF, SteerLM
Gemma* 2B, 7B bfloat16 bfloat16 SFT
StarCoder2** 3B, 7B, 15B bfloat16 bfloat16 SFT

*Gemma models are supported in a dedicated container hosted on NGC called nvcr.io/nvidia/nemo:24.01.gemma.

**StarCoder2 models are supported in a dedicated container hosted on NGC called nvcr.io/nvidia/nemo:24.01.starcoder2.

You can find the supported NeMo model and TensorRT-LLM model precisions, and fine-tuned variants in the list above. Please note that only NeMo models with the distributed checkpoint format are supported.

The following steps will work with all NVIDIA GPT based models. As an example, we will deploy the GPT-2B bfloat16 precision checkpoint.

First, run the following command to download a NeMo checkpoint:

Copy
Copied!
            

wget https://huggingface.co/nvidia/GPT-2B-001/resolve/main/GPT-2B-001_bf16_tp1.nemo

GPT-2B-001_bf16_tp1.nemo checkpoint contains a trained GPT model and it will be used as an example in this document. Pull down and run the container as below. Please change the vr below to the version of the container you would like to use:

Copy
Copied!
            

docker pull nvcr.io/ea-bignlp/ga-participants/nemofw-inference:vr docker run --gpus all -it --rm --shm-size=4g -p 8000:8000 -v ${PWD}/GPT-2B-001_bf16_tp1.nemo:/opt/checkpoints/GPT-2B.nemo -w /opt/NeMo nvcr.io/ea-bignlp/ga-participants/nemofw-inference:vr

And run the following pytest to see if everything is working:

Copy
Copied!
            

py.test -s tests/deploy/test_nemo_deploy.py::test_GPT_2B_001_bf16_tp1_1gpu

This pytest will export the downloaded NeMo checkpoint to TensorRT-LLM. And then, it will start serving it on Triton and run a simple query to see if the service is working.

Please change shared memory size using --shm-size if the test gives you a shared memory related error.

Users can use this container for mainly the three points below.

  • Run a script and start serving a LLM model with a given NeMo checkpoints on Triton. Model inference can be run on TensorRT-LLM.

  • Export a NeMo checkpoint to TensorRT LLM and run inference.

  • Import the export and the deploy modules in your code and use them based on your need.

Serve in-framework or TensorRT-LLM model on Triton

A LLM model in a NeMo checkpoint can be served on Triton using the following script. Script allows deployment of the models for TensorRT-LLM based inference. Once the script is executed, it will export the model to the TensorRT-LLM, and then start the service on the Triton.

Assuming the container has already been started using the steps above, please run the following script to start serving the downloaded model:

Copy
Copied!
            

python scripts/deploy/deploy_triton.py --nemo_checkpoint /opt/checkpoints/GPT-2B.nemo --model_type="gptnext" --triton_model_name GPT-2B

Parameters of the deploy_triton.py script:

  • nemo_checkpoint - path of the NeMo checkpoint.

  • ptuning_nemo_checkpoint - source .nemo file for prompt embeddings table.

  • model_type - type of the model. choices=[“gptnext”, “llama”].

  • triton_model_name name of the model on Triton.

  • triton_model_version - version of the model. Default is 1.

  • triton_port - port for the Triton server to listen for requests. Default is 8000.

  • triton_http_address - HTTP address for the Triton server. Default is 0.0.0.0

  • triton_model_repository - TensorRT temp folder. Default is /tmp/trt_llm_model_dir/.

  • num_gpus - number of GPUs to use for inference. Large models require multi-gpu export.

  • dtype - data type of the model on TensorRT-LLM. Default is “bf16”. Currently only “bf16” is supported.

  • max_input_len - maximum input length of the model.

  • max_output_len - maximum output length of the model.

  • max_batch_size - maximum batch size of the model.

Please note that although the parameters described here are generalized and should work with any NeMo checkpoint, the code is only tested with the GPT and LLaMA v2 NeMo checkpoints. We are currently working on supporting the other checkpoints.

Each time the script is executed, it starts the service after exporting the NeMo checkpoint to the TensorRT-LLM. If you would like to skip the exporting step, please provide an empty directory as below:

Copy
Copied!
            

mkdir tmp_triton_model_repository docker run --gpus all -it --rm --shm-size=4g -p 8000:8000 -v ${PWD}:/opt/checkpoints/ -w /opt/NeMo nvcr.io/ea-bignlp/ga-participants/nemofw-inference:vr python scripts/deploy/deploy_triton.py --nemo_checkpoint /opt/checkpoints/GPT-2B-001_bf16_tp1.nemo --model_type="gptnext" --triton_model_name GPT-2B --triton_model_repository /opt/checkpoints/tmp_triton_model_repository

The checkpoint will be exported to the provided folder after running the script above. Then, in order to just load the exported model, please run the following script in the container:

Copy
Copied!
            

python scripts/deploy/deploy_triton.py --triton_model_name GPT-2B --triton_model_repository /opt/checkpoints/tmp_triton_model_repository --model_type="gptnext"

If you are interested in only exporting a NeMo checkpoint to TensorRT-LLM, please use the scripts/export/export_to_trt.py. The parameters of the script is almost identical to the scripts/deploy/deploy_triton.py script. It just doesn’t include the deployment part.

Sending a Query using NeMo APIs

Once the service is started using the scripts above, it will wait for any request. One way to send a query to this service is to use the NeMo classes as shown in the following example in the currently running container or in another container (and in another machine). Another way is to use PyTriton, just to send the request. Or, you can make a HTTP request with different tools/libraries.

Below is a request example using NeMo APIs. You can put in a python file (or in CLI) and run:

Copy
Copied!
            

from nemo.deploy import NemoQuery nq = NemoQuery(url="localhost:8000", model_name="GPT-2B") output = nq.query_llm(prompts=["What is the capital of United States?"], max_output_token=15, top_k=1, top_p=0.0, temperature=1.0) print(output)

Please change the url and the model_name based on your server and the model name of your service. Please check the NeMoQuery docstrings for details.

Sending a Query using PyTriton

You can also install the PyTriton app with pip as shown here https://github.com/triton-inference-server/pytriton in your environment, and send a query to the service. Below is an code example that you can use to send a query with PyTriton:

Copy
Copied!
            

from pytriton.client import ModelClient import numpy as np def query_llm(url, model_name, prompts, max_output_token=128, top_k=1, top_p=0.0, temperature=1.0, init_timeout=600.0): str_ndarray = np.array(prompts)[..., np.newaxis] prompts = np.char.encode(str_ndarray, "utf-8") max_output_token = np.full(prompts.shape, max_output_token, dtype=np.int_) top_k = np.full(prompts.shape, top_k, dtype=np.int_) top_p = np.full(prompts.shape, top_p, dtype=np.single) temperature = np.full(prompts.shape, temperature, dtype=np.single) with ModelClient(url, model_name, init_timeout_s=init_timeout) as client: result_dict = client.infer_batch( prompts=prompts, max_output_token=max_output_token, top_k=top_k, top_p=top_p, temperature=temperature, ) output_type = client.model_config.outputs[0].dtype if output_type == np.bytes_: sentences = np.char.decode(result_dict["outputs"].astype("bytes"), "utf-8") return sentences else: return result_dict["outputs"] output = query_llm( url="localhost:8000", model_name="GPT-2B", prompts=["What color is a banana?"], max_output_token=10, top_k=1, top_p=0.0, temperature=1.0, ) print(output)

So far, we only used scripts to export and deploy LLM models. NeMo Deploy and Export modules provide easy-to-use APIs for deploying models to Triton and exporting nemo checkpoints to TensorRT-LLM.

Export a LLM model to TensorRT-LLM with NeMo APIs

You can use the APIs in the export module to export a NeMo checkpoint to TensorRT-LLM. Please see the following code example. Code assumes the GPT-2B-001_bf16_tp1.nemo checkpoint has already been downloaded and mounted to the /opt/checkpoints/ path. And, the /opt/checkpoints/tmp_trt_llm path is also assumed to exist:

Copy
Copied!
            

from nemo.export import TensorRTLLM trt_llm_exporter = TensorRTLLM(model_dir="/opt/checkpoints/tmp_triton_model_repository/") trt_llm_exporter.export(nemo_checkpoint_path="/opt/checkpoints/GPT-2B-001_bf16_tp1.nemo", model_type="gptnext", n_gpus=1) output = trt_llm_exporter.forward(["What is the best city in the world?"], max_output_token=17, top_k=1, top_p=0.0, temperature=1.0) print("output: ", output)

Please check the TensorRTLLM class docstrings for details.

Deploy a LLM model with NeMo APIs

You can use the APIs in the deploy module to deploy a TensorRT-LLM model to Triton. Please see the following code example. Code assumes the GPT-2B-001_bf16_tp1.nemo checkpoint has already been downloaded and mounted to the /opt/checkpoints/ path. And, the /opt/checkpoints/tmp_trt_llm path is also assumed to exist:

Copy
Copied!
            

from nemo.export import TensorRTLLM from nemo.deploy import DeployPyTriton trt_llm_exporter = TensorRTLLM(model_dir="/opt/checkpoints/tmp_triton_model_repository/") trt_llm_exporter.export(nemo_checkpoint_path="/opt/checkpoints/GPT-2B-001_bf16_tp1.nemo", model_type="gptnext", n_gpus=1) nm = DeployPyTriton(model=trt_llm_exporter, triton_model_name="GPT-2B", port=8000) nm.deploy() nm.serve()

Performance Measurements

In order to take performance measurements, you can use the following benchmark script below:

Copy
Copied!
            

python scripts/deploy/benchmark.py --nemo_checkpoint /opt/checkpoints/GPT-2B-001_bf16_tp1.nemo --model_type="gptnext" --triton_model_name GPT-2B -ng 1 -mil 150 -mol 20 -mbs 10 -nr 50

Parameters of the benchmark.py script:

  • nemo_checkpoint (nc) - path of the NeMo checkpoint.

  • model_type (mt) - type of the model. choices=[“gptnext”, “llama”].

  • triton_model_name (tmn) name of the model on Triton.

  • triton_model_version (tmv) - version of the model. Default is 1.

  • triton_port (tp) - port for the Triton server to listen for requests. Default is 8000.

  • triton_http_address (tha) - HTTP address for the Triton server. Default is 0.0.0.0

  • trt_llm_folder (tlf) - TensorRT temp folder. Default is /tmp/trt_llm_model_dir/.

  • num_gpus (ng) - number of GPUs to use for inference. Large models require multi-gpu export.

  • dtype (dt) - data type of the model on TensorRT-LLM. Default is “bf16”. Currently only “bf16” is supported.

  • max_input_len (mil) - maximum input length of the model.

  • max_output_len (mol) - maximum output length of the model.

  • max_batch_size (mbs) - maximum batch size of the model.

  • batch_size (bs) - batch sizes that will be tested for inference. Default=[“1”, “2”, “4”, “8”].

  • out_lens - output token sizes for testing. Default=[20, 100, 200, 300].

  • top_k - top k parameter for the sampler.

  • top_p - top p parameter for the sampler.

  • num_runs - number of times to take measurements to calculate the average of the measurements.

  • run_trt_llm - run inference only on TensorRT-LLM and do not start Triton server.

For the GPT and LLAMA2 inference performance numbers, please check the Performance section of the documentation.

Measuring Model Accuracy using PyTests for Deploy and Export Modules

Please check the pytests in the container listed as below for more examples on how to use NeMo APIs for export and deploy operations.

  • Deploy test: /opt/NeMo/tests/deploy/test_nemo_deploy.py

  • Export test: /opt/NeMo/tests/export/test_nemo_export.py

Accuracy measurements are also included in these two tests above. Please make sure to mount your nemo checkpoints as expected in the tests.

Previous Performance
Next Libraries
© Copyright 2023-2024, NVIDIA. Last updated on Apr 25, 2024.