TensorRT-LLM Multimodal#
This document provides a comprehensive guide for multimodal inference using TensorRT-LLM backend in Dynamo.
You can provide multimodal inputs in the following ways:
By sending image URLs
By providing paths to pre-computed embedding files
Note: You should provide either image URLs or embedding file paths in a single request.
Support Matrix#
Modality |
Input Format |
Aggregated |
Disaggregated |
Notes |
|---|---|---|---|---|
Image |
HTTP/HTTPS URL |
Yes |
Yes |
Full support for all image models |
Image |
Pre-computed Embeddings (.pt, .pth, .bin) |
Yes |
Yes |
Direct embedding files |
Video |
HTTP/HTTPS URL |
No |
No |
Not implemented |
Audio |
HTTP/HTTPS URL |
No |
No |
Not implemented |
Supported URL Formats#
Format |
Example |
Description |
|---|---|---|
HTTP/HTTPS |
|
Remote media files |
Pre-computed Embeddings |
|
Local embedding files (.pt, .pth, .bin) |
Deployment Patterns#
TRT-LLM supports aggregated and traditional disaggregated patterns. See Architecture Patterns for detailed explanations.
Pattern |
Supported |
Launch Script |
Notes |
|---|---|---|---|
EPD (Simple Aggregated) |
β |
|
Easiest setup |
E/PD (Encode Separate) |
β |
N/A |
Not supported |
E/P/D (Full Disaggregation) |
π§ WIP |
N/A |
PR #4668 in progress |
EP/D (Traditional Disaggregated) |
β |
|
Prefill handles encoding |
Component Flags#
Component |
Flag |
Purpose |
|---|---|---|
Worker |
|
Complete pipeline (aggregated) |
Prefill Worker |
|
Image processing + Prefill (multimodal tokenization happens here) |
Decode Worker |
|
Decode only |
Encode Worker (WIP) |
|
Image encoding (E/P/D flow) |
Aggregated Serving#
Quick steps to launch Llama-4 Maverick BF16 in aggregated mode:
cd $DYNAMO_HOME
export AGG_ENGINE_ARGS=./examples/backends/trtllm/engine_configs/llama4/multimodal/agg.yaml
export SERVED_MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
export MODEL_PATH="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
./examples/backends/trtllm/launch/agg.sh
Client:
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the image"
},
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
}
}
]
}
],
"stream": false,
"max_tokens": 160
}'
Disaggregated Serving#
Example using Qwen/Qwen2-VL-7B-Instruct:
cd $DYNAMO_HOME
export MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
export SERVED_MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
export PREFILL_ENGINE_ARGS="examples/backends/trtllm/engine_configs/qwen2-vl-7b-instruct/prefill.yaml"
export DECODE_ENGINE_ARGS="examples/backends/trtllm/engine_configs/qwen2-vl-7b-instruct/decode.yaml"
export MODALITY="multimodal"
./examples/backends/trtllm/launch/disagg.sh
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "Qwen/Qwen2-VL-7B-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the image"
},
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
}
}
]
}
],
"stream": false,
"max_tokens": 160
}'
For a large model like meta-llama/Llama-4-Maverick-17B-128E-Instruct, a multi-node setup is required for disaggregated serving (see Multi-node Deployment below), while aggregated serving can run on a single node. This is because the model with a disaggregated configuration is too large to fit on a single nodeβs GPUs. For instance, running this model in disaggregated mode requires 2 nodes with 8xH200 GPUs or 4 nodes with 4xGB200 GPUs.
Pre-computed Embeddings with E/P/D Flow#
For high-performance multimodal inference, Dynamo supports pre-computed embeddings with an Encode-Prefill-Decode (E/P/D) flow using NIXL (RDMA) for zero-copy tensor transfer.
Supported File Types#
.pt- PyTorch tensor files.pth- PyTorch checkpoint files.bin- Binary tensor files
Embedding File Formats#
TRT-LLM supports two formats for embedding files:
1. Simple Tensor Format
Direct tensor saved as .pt file containing only the embedding tensor:
embedding_tensor = torch.rand(1, 576, 4096) # [batch, seq_len, hidden_dim]
torch.save(embedding_tensor, "embedding.pt")
2. Dictionary Format with Auxiliary Data
Dictionary containing multiple keys, used by models like Llama-4 that require additional metadata:
embedding_dict = {
"mm_embeddings": torch.rand(1, 576, 4096),
"special_tokens": [128256, 128257],
"image_token_offsets": [[0, 576]],
# ... other model-specific metadata
}
torch.save(embedding_dict, "llama4_embedding.pt")
Simple tensors: Loaded directly and passed to
mm_embeddingsparameterDictionary format:
mm_embeddingskey extracted as main tensor, other keys preserved as auxiliary data
How to Launch#
cd $DYNAMO_HOME/examples/backends/trtllm
# Launch 3-worker E/P/D flow with NIXL
./launch/epd_disagg.sh
Note: This script is designed for 8-node H200 with
Llama-4-Scout-17B-16E-Instructmodel and assumes you have a model-specific embedding file ready.
Configuration#
# Encode endpoint for Prefill β Encode communication
export ENCODE_ENDPOINT="dyn://dynamo.tensorrt_llm_encode.generate"
# Security: Allowed directory for embedding files (default: /tmp)
export ALLOWED_LOCAL_MEDIA_PATH="/tmp"
# Security: Max file size to prevent DoS attacks (default: 50MB)
export MAX_FILE_SIZE_MB=50
Example Request with Pre-computed Embeddings#
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Describe the image"},
{"type": "image_url", "image_url": {"url": "/path/to/embedding.pt"}}
]
}
],
"max_tokens": 160
}'
E/P/D Architecture#
The E/P/D flow implements a 3-worker architecture:
Encode Worker: Loads pre-computed embeddings, transfers via NIXL
Prefill Worker: Receives embeddings, handles context processing and KV-cache generation
Decode Worker: Performs streaming token generation
sequenceDiagram
participant Client
participant Frontend
participant PrefillWorker as "Prefill Worker"
participant EncodeWorker as "Encode Worker"
participant DecodeWorker as "Decode Worker"
participant NIXL as "NIXL (RDMA)"
Client->>Frontend: POST /v1/chat/completions
Frontend->>PrefillWorker: Route to prefill worker
PrefillWorker->>EncodeWorker: Send request (embedding paths)
EncodeWorker->>NIXL: Create readable operation
EncodeWorker->>PrefillWorker: Send metadata + NIXL info
PrefillWorker->>NIXL: Begin read operation
NIXL-->>PrefillWorker: Zero-copy transfer complete
PrefillWorker->>Frontend: Return prefill response
Frontend->>DecodeWorker: Route to decode worker
DecodeWorker->>Frontend: Stream response chunks
Frontend->>Client: Stream response
Multi-node Deployment (Slurm)#
This section demonstrates how to deploy large multimodal models that require a multi-node setup using Slurm.
Note: The scripts referenced in this section can be found in
examples/basics/multinode/trtllm/.
Environment Setup#
Assuming you have allocated your nodes via salloc and are inside an interactive shell:
# Container image (build using docs/backends/trtllm/README.md#build-container)
export IMAGE="<dynamo_trtllm_image>"
# Host:container path pairs for mounting
export MOUNTS="${PWD}/../../../../:/mnt"
# Model configuration
export MODEL_PATH="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
export SERVED_MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
export MODALITY=${MODALITY:-"multimodal"}
Multi-node Disaggregated Launch#
For 4 4xGB200 nodes (2 for prefill, 2 for decode):
# Customize parallelism to match your engine configs
# export PREFILL_ENGINE_CONFIG="/mnt/examples/backends/trtllm/engine_configs/llama4/multimodal/prefill.yaml"
# export DECODE_ENGINE_CONFIG="/mnt/examples/backends/trtllm/engine_configs/llama4/multimodal/decode.yaml"
# export NUM_PREFILL_NODES=2
# export NUM_DECODE_NODES=2
# export NUM_GPUS_PER_NODE=4
# Launches frontend + etcd/nats on head node, plus prefill and decode workers
./srun_disaggregated.sh
Understanding the Output#
srun_disaggregated.shlaunches three srun jobs: frontend, prefill worker, and decode workerThe OpenAI frontend will dynamically discover workers as they register:
INFO dynamo_run::input::http: Watching for remote model at models INFO dynamo_llm::http::service::service_v2: Starting HTTP service on: 0.0.0.0:8000
TRT-LLM workers output progress from each MPI rank while loading
When ready, the frontend logs:
INFO dynamo_llm::discovery::watcher: added model model_name="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
Cleanup#
pkill srun
NIXL Usage#
Use Case |
Script |
NIXL Used? |
Data Transfer |
|---|---|---|---|
EPD (Simple Aggregated) |
|
No |
All in one worker |
EP/D (Traditional Disaggregated) |
|
Optional |
Prefill β Decode (KV cache via UCX or NIXL) |
E/P/D (pre-computed embeddings) |
|
Yes |
Encoder β Prefill (embeddings via NIXL) |
E/P/D (WIP) |
N/A |
No |
Encoder β Prefill (handles via params), Prefill β Decode (KV cache) |
Note: NIXL for KV cache transfer is currently beta and only supported on AMD64 (x86_64) architecture.
ModelInput Types and Registration#
TRT-LLM workers register with Dynamo using:
ModelInput Type |
Preprocessing |
Use Case |
|---|---|---|
|
Rust frontend may tokenize, but multimodal flows re-tokenize and build inputs in the Python worker; Rust token_ids are ignored |
All TRT-LLM workers |
# TRT-LLM Worker - Register with Tokens
await register_llm(
ModelInput.Tokens, # Rust does minimal preprocessing
model_type, # ModelType.Chat or ModelType.Prefill
generate_endpoint,
model_name,
...
)
Inter-Component Communication#
Transfer Stage |
Message |
NIXL Transfer |
|---|---|---|
Frontend β Prefill |
Request with image URL or embedding path |
No |
Encode β Prefill (pre-computed) |
NIXL metadata |
Yes (Embeddings tensor) |
Encode β Prefill (Image URL) (WIP) |
Disaggregated params with multimodal handles |
No |
Prefill β Decode |
Disaggregated params |
Configurable (KV cache: NIXL default, UCX optional) |
Known Limitations#
No Data URL support - Only HTTP/HTTPS URLs supported;
data:image/...base64 URLs not supportedNo video support - No video encoder implementation
No audio support - No audio encoder implementation
Multimodal preprocessing/tokenization happens in Python - Rust may forward token_ids, but multimodal requests are parsed and re-tokenized in the Python worker
E/P/D mode is WIP - Full E/P/D with image URLs under development
Multi-node H100 limitation - Loading
meta-llama/Llama-4-Maverick-17B-128E-Instructwith 8 nodes of H100 with TP=16 is not possible due to head count divisibility (num_attention_heads: 40not divisible bytp_size: 16)
Supported Models#
Multimodal models listed in TensorRT-LLM supported models are supported by Dynamo.
Common examples:
Llama 4 Vision models (Maverick, Scout)
Qwen2-VL models
Other vision-language models with TRT-LLM support
Key Files#
File |
Description |
|---|---|
|
Worker initialization and setup |
|
Command-line argument parsing |
|
Multimodal request processing |
|
Request handler factory |
|
Base handler and disaggregation modes |