TensorRT-LLM Multimodal#

This document provides a comprehensive guide for multimodal inference using TensorRT-LLM backend in Dynamo.

You can provide multimodal inputs in the following ways:

  • By sending image URLs

  • By providing paths to pre-computed embedding files

Note: You should provide either image URLs or embedding file paths in a single request.

Support Matrix#

Modality

Input Format

Aggregated

Disaggregated

Notes

Image

HTTP/HTTPS URL

Yes

Yes

Full support for all image models

Image

Pre-computed Embeddings (.pt, .pth, .bin)

Yes

Yes

Direct embedding files

Video

HTTP/HTTPS URL

No

No

Not implemented

Audio

HTTP/HTTPS URL

No

No

Not implemented

Supported URL Formats#

Format

Example

Description

HTTP/HTTPS

http://example.com/image.jpg

Remote media files

Pre-computed Embeddings

/path/to/embedding.pt

Local embedding files (.pt, .pth, .bin)

Deployment Patterns#

TRT-LLM supports aggregated and traditional disaggregated patterns. See Architecture Patterns for detailed explanations.

Pattern

Supported

Launch Script

Notes

EPD (Simple Aggregated)

βœ…

agg.sh

Easiest setup

E/PD (Encode Separate)

❌

N/A

Not supported

E/P/D (Full Disaggregation)

🚧 WIP

N/A

PR #4668 in progress

EP/D (Traditional Disaggregated)

βœ…

disagg_multimodal.sh

Prefill handles encoding

Component Flags#

Component

Flag

Purpose

Worker

--modality multimodal

Complete pipeline (aggregated)

Prefill Worker

--disaggregation-mode prefill

Image processing + Prefill (multimodal tokenization happens here)

Decode Worker

--disaggregation-mode decode

Decode only

Encode Worker (WIP)

--disaggregation-mode encode

Image encoding (E/P/D flow)

Aggregated Serving#

Quick steps to launch Llama-4 Maverick BF16 in aggregated mode:

cd $DYNAMO_HOME

export AGG_ENGINE_ARGS=./examples/backends/trtllm/engine_configs/llama4/multimodal/agg.yaml
export SERVED_MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
export MODEL_PATH="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
./examples/backends/trtllm/launch/agg.sh

Client:

curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
    "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Describe the image"
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
                    }
                }
            ]
        }
    ],
    "stream": false,
    "max_tokens": 160
}'

Disaggregated Serving#

Example using Qwen/Qwen2-VL-7B-Instruct:

cd $DYNAMO_HOME

export MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
export SERVED_MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
export PREFILL_ENGINE_ARGS="examples/backends/trtllm/engine_configs/qwen2-vl-7b-instruct/prefill.yaml"
export DECODE_ENGINE_ARGS="examples/backends/trtllm/engine_configs/qwen2-vl-7b-instruct/decode.yaml"
export MODALITY="multimodal"

./examples/backends/trtllm/launch/disagg.sh
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
    "model": "Qwen/Qwen2-VL-7B-Instruct",
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Describe the image"
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
                    }
                }
            ]
        }
    ],
    "stream": false,
    "max_tokens": 160
}'

For a large model like meta-llama/Llama-4-Maverick-17B-128E-Instruct, a multi-node setup is required for disaggregated serving (see Multi-node Deployment below), while aggregated serving can run on a single node. This is because the model with a disaggregated configuration is too large to fit on a single node’s GPUs. For instance, running this model in disaggregated mode requires 2 nodes with 8xH200 GPUs or 4 nodes with 4xGB200 GPUs.

Pre-computed Embeddings with E/P/D Flow#

For high-performance multimodal inference, Dynamo supports pre-computed embeddings with an Encode-Prefill-Decode (E/P/D) flow using NIXL (RDMA) for zero-copy tensor transfer.

Supported File Types#

  • .pt - PyTorch tensor files

  • .pth - PyTorch checkpoint files

  • .bin - Binary tensor files

Embedding File Formats#

TRT-LLM supports two formats for embedding files:

1. Simple Tensor Format

Direct tensor saved as .pt file containing only the embedding tensor:

embedding_tensor = torch.rand(1, 576, 4096)  # [batch, seq_len, hidden_dim]
torch.save(embedding_tensor, "embedding.pt")

2. Dictionary Format with Auxiliary Data

Dictionary containing multiple keys, used by models like Llama-4 that require additional metadata:

embedding_dict = {
    "mm_embeddings": torch.rand(1, 576, 4096),
    "special_tokens": [128256, 128257],
    "image_token_offsets": [[0, 576]],
    # ... other model-specific metadata
}
torch.save(embedding_dict, "llama4_embedding.pt")
  • Simple tensors: Loaded directly and passed to mm_embeddings parameter

  • Dictionary format: mm_embeddings key extracted as main tensor, other keys preserved as auxiliary data

How to Launch#

cd $DYNAMO_HOME/examples/backends/trtllm

# Launch 3-worker E/P/D flow with NIXL
./launch/epd_disagg.sh

Note: This script is designed for 8-node H200 with Llama-4-Scout-17B-16E-Instruct model and assumes you have a model-specific embedding file ready.

Configuration#

# Encode endpoint for Prefill β†’ Encode communication
export ENCODE_ENDPOINT="dyn://dynamo.tensorrt_llm_encode.generate"

# Security: Allowed directory for embedding files (default: /tmp)
export ALLOWED_LOCAL_MEDIA_PATH="/tmp"

# Security: Max file size to prevent DoS attacks (default: 50MB)
export MAX_FILE_SIZE_MB=50

Example Request with Pre-computed Embeddings#

curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
    "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
    "messages": [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "Describe the image"},
                {"type": "image_url", "image_url": {"url": "/path/to/embedding.pt"}}
            ]
        }
    ],
    "max_tokens": 160
}'

E/P/D Architecture#

The E/P/D flow implements a 3-worker architecture:

  • Encode Worker: Loads pre-computed embeddings, transfers via NIXL

  • Prefill Worker: Receives embeddings, handles context processing and KV-cache generation

  • Decode Worker: Performs streaming token generation

        sequenceDiagram
    participant Client
    participant Frontend
    participant PrefillWorker as "Prefill Worker"
    participant EncodeWorker as "Encode Worker"
    participant DecodeWorker as "Decode Worker"
    participant NIXL as "NIXL (RDMA)"

    Client->>Frontend: POST /v1/chat/completions
    Frontend->>PrefillWorker: Route to prefill worker
    PrefillWorker->>EncodeWorker: Send request (embedding paths)
    EncodeWorker->>NIXL: Create readable operation
    EncodeWorker->>PrefillWorker: Send metadata + NIXL info
    PrefillWorker->>NIXL: Begin read operation
    NIXL-->>PrefillWorker: Zero-copy transfer complete
    PrefillWorker->>Frontend: Return prefill response
    Frontend->>DecodeWorker: Route to decode worker
    DecodeWorker->>Frontend: Stream response chunks
    Frontend->>Client: Stream response
    

Multi-node Deployment (Slurm)#

This section demonstrates how to deploy large multimodal models that require a multi-node setup using Slurm.

Note: The scripts referenced in this section can be found in examples/basics/multinode/trtllm/.

Environment Setup#

Assuming you have allocated your nodes via salloc and are inside an interactive shell:

# Container image (build using docs/backends/trtllm/README.md#build-container)
export IMAGE="<dynamo_trtllm_image>"

# Host:container path pairs for mounting
export MOUNTS="${PWD}/../../../../:/mnt"

# Model configuration
export MODEL_PATH="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
export SERVED_MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
export MODALITY=${MODALITY:-"multimodal"}

Multi-node Disaggregated Launch#

For 4 4xGB200 nodes (2 for prefill, 2 for decode):

# Customize parallelism to match your engine configs
# export PREFILL_ENGINE_CONFIG="/mnt/examples/backends/trtllm/engine_configs/llama4/multimodal/prefill.yaml"
# export DECODE_ENGINE_CONFIG="/mnt/examples/backends/trtllm/engine_configs/llama4/multimodal/decode.yaml"
# export NUM_PREFILL_NODES=2
# export NUM_DECODE_NODES=2
# export NUM_GPUS_PER_NODE=4

# Launches frontend + etcd/nats on head node, plus prefill and decode workers
./srun_disaggregated.sh

Understanding the Output#

  1. srun_disaggregated.sh launches three srun jobs: frontend, prefill worker, and decode worker

  2. The OpenAI frontend will dynamically discover workers as they register:

    INFO dynamo_run::input::http: Watching for remote model at models
    INFO dynamo_llm::http::service::service_v2: Starting HTTP service on: 0.0.0.0:8000
    
  3. TRT-LLM workers output progress from each MPI rank while loading

  4. When ready, the frontend logs:

    INFO dynamo_llm::discovery::watcher: added model model_name="meta-llama/Llama-4-Maverick-17B-128E-Instruct"
    

Cleanup#

pkill srun

NIXL Usage#

Use Case

Script

NIXL Used?

Data Transfer

EPD (Simple Aggregated)

agg.sh

No

All in one worker

EP/D (Traditional Disaggregated)

disagg_multimodal.sh

Optional

Prefill β†’ Decode (KV cache via UCX or NIXL)

E/P/D (pre-computed embeddings)

epd_disagg.sh

Yes

Encoder β†’ Prefill (embeddings via NIXL)

E/P/D (WIP)

N/A

No

Encoder β†’ Prefill (handles via params), Prefill β†’ Decode (KV cache)

Note: NIXL for KV cache transfer is currently beta and only supported on AMD64 (x86_64) architecture.

ModelInput Types and Registration#

TRT-LLM workers register with Dynamo using:

ModelInput Type

Preprocessing

Use Case

ModelInput.Tokens

Rust frontend may tokenize, but multimodal flows re-tokenize and build inputs in the Python worker; Rust token_ids are ignored

All TRT-LLM workers

# TRT-LLM Worker - Register with Tokens
await register_llm(
    ModelInput.Tokens,      # Rust does minimal preprocessing
    model_type,             # ModelType.Chat or ModelType.Prefill
    generate_endpoint,
    model_name,
    ...
)

Inter-Component Communication#

Transfer Stage

Message

NIXL Transfer

Frontend β†’ Prefill

Request with image URL or embedding path

No

Encode β†’ Prefill (pre-computed)

NIXL metadata

Yes (Embeddings tensor)

Encode β†’ Prefill (Image URL) (WIP)

Disaggregated params with multimodal handles

No

Prefill β†’ Decode

Disaggregated params

Configurable (KV cache: NIXL default, UCX optional)

Known Limitations#

  • No Data URL support - Only HTTP/HTTPS URLs supported; data:image/... base64 URLs not supported

  • No video support - No video encoder implementation

  • No audio support - No audio encoder implementation

  • Multimodal preprocessing/tokenization happens in Python - Rust may forward token_ids, but multimodal requests are parsed and re-tokenized in the Python worker

  • E/P/D mode is WIP - Full E/P/D with image URLs under development

  • Multi-node H100 limitation - Loading meta-llama/Llama-4-Maverick-17B-128E-Instruct with 8 nodes of H100 with TP=16 is not possible due to head count divisibility (num_attention_heads: 40 not divisible by tp_size: 16)

Supported Models#

Multimodal models listed in TensorRT-LLM supported models are supported by Dynamo.

Common examples:

  • Llama 4 Vision models (Maverick, Scout)

  • Qwen2-VL models

  • Other vision-language models with TRT-LLM support

Key Files#

File

Description

components/src/dynamo/trtllm/main.py

Worker initialization and setup

components/src/dynamo/trtllm/utils/trtllm_utils.py

Command-line argument parsing

components/src/dynamo/trtllm/multimodal_processor.py

Multimodal request processing

components/src/dynamo/trtllm/request_handlers/handlers.py

Request handler factory

components/src/dynamo/trtllm/request_handlers/handler_base.py

Base handler and disaggregation modes