Multi-Device Attention#

TensorRT supports multi-device attention through context parallelism, splitting the key-value sequence across multiple GPUs. Each GPU processes a portion of the sequence, and NCCL handles inter-device communication. Refer to the attention-mdtrt sample as an example.

  • Supported data types: BF16 and FP16

  • Platform support: Blackwell (SM 100) and later

For more information, refer to the GPU Architecture and Precision Support section.

Note

There is no ONNX compliant operator for multi-device attention. However, standard ONNX files generated by a sharding tool are supported when the Attention op includes an nbRanks field. You can also use the C++ or Python API directly.

The IAttention layer runs in multi-device mode when nbRanks > 1. Use IAttention::setNbRanks (C++) or IAttention.num_ranks (Python) to set the number of devices. Once set to a value greater than 1, this value cannot be changed.

Build-time configuration

  1. Add the attention layer with the IAttention API.

  2. Call IAttention::setNbRanks(nbRanks) (C++) or set attention.num_ranks = nbRanks (Python).

Runtime requirements

  1. Initialize an NCCL communicator with ncclCommInitRank or ncclCommInitAll.

  2. Call IExecutionContext::setCommunicator(ncclComm_t) before inference.

  3. All participating ranks must execute the same network with synchronized execution calls.

1// Add attention layer and set number of ranks
2auto attention = network->addAttentionV2(query, key, value, AttentionNormalizationOp::kSOFTMAX, CausalMaskKind::kNONE);
3
4// At runtime, on each rank, set communicator and run inference
5context->setCommunicator(ncclComm);
6context->enqueueV3(stream);
1# Add attention layer and set number of ranks
2attention = network.add_attention_v2(query, key, value, trt.AttentionNormalizationOp.SOFTMAX, trt.CausalMaskKind.NONE)
3attention.num_ranks = num_gpus
4
5# At runtime, on each rank, set communicator and run inference
6context.set_communicator(nccl_comm)
7context.execute_async_v3(stream)

See also

Working with Loops

Loop constructs for autoregressive decoding patterns.

Working with Conditionals

Conditional execution for dynamic decoder logic.