Multi-Device Attention#
TensorRT supports multi-device attention through context parallelism, splitting the key-value sequence across multiple GPUs. Each GPU processes a portion of the sequence, and NCCL handles inter-device communication. Refer to the attention-mdtrt sample as an example.
Supported data types: BF16 and FP16
Platform support: Blackwell (SM 100) and later
For more information, refer to the GPU Architecture and Precision Support section.
Note
There is no ONNX compliant operator for multi-device attention. However, standard ONNX files generated by a sharding tool are supported when the Attention op includes an nbRanks field. You can also use the C++ or Python API directly.
The IAttention layer runs in multi-device mode when nbRanks > 1. Use IAttention::setNbRanks (C++) or IAttention.num_ranks (Python) to set the number of devices. Once set to a value greater than 1, this value cannot be changed.
Build-time configuration
Add the attention layer with the
IAttentionAPI.Call
IAttention::setNbRanks(nbRanks)(C++) or setattention.num_ranks = nbRanks(Python).
Runtime requirements
Initialize an NCCL communicator with
ncclCommInitRankorncclCommInitAll.Call
IExecutionContext::setCommunicator(ncclComm_t)before inference.All participating ranks must execute the same network with synchronized execution calls.
1// Add attention layer and set number of ranks
2auto attention = network->addAttentionV2(query, key, value, AttentionNormalizationOp::kSOFTMAX, CausalMaskKind::kNONE);
3
4// At runtime, on each rank, set communicator and run inference
5context->setCommunicator(ncclComm);
6context->enqueueV3(stream);
1# Add attention layer and set number of ranks
2attention = network.add_attention_v2(query, key, value, trt.AttentionNormalizationOp.SOFTMAX, trt.CausalMaskKind.NONE)
3attention.num_ranks = num_gpus
4
5# At runtime, on each rank, set communicator and run inference
6context.set_communicator(nccl_comm)
7context.execute_async_v3(stream)
See also
- Working with Loops
Loop constructs for autoregressive decoding patterns.
- Working with Conditionals
Conditional execution for dynamic decoder logic.