Attention¶
Generates an attention.
See also
Attributes¶
normalizationOp activation function can be one of:
NONENo normalization is applied.SOFTMAXApply softmax normalization on the attention scores on the s_kv dimension.
causal boolean parameter determines whether the attention will run causal inference.
decomposable boolean parameter determines whether the attention can be decomposed to use multiple kernels if no fused kernel support found.
normalizationQuantizeToType optional parameter determines the datatype the attention normalization is quantized to.
DataType::kFP8: The attention normalization is quantized to FP8.DataType::kINT8: The attention normalization is quantized to INT8.
nbRanks parameter (default: 1) specifies the number of ranks for multi-device attention execution. When > 1, hints attention to perform context-parallel multi-device attention.
Inputs¶
query a tensor of type T1
key a tensor of type T1
value a tensor of type T1
mask optional: tensor of type T2. For a bool mask, a True value indicates that the corresponding position is allowed to attend. For a float mask, the mask values will be added to the BMM1 output, known as an add mask.
normalizationQuantizeScale optional: tensor of type T1. The quantization scale for the attention normalization output.
Outputs¶
outputs tensors of type T1
Data Types¶
T1: float32, float16, bfloat16
T2: float32, float16, bfloat16, bool. T2 must be the same as T1 if not bool
Shape Information¶
query and outputs are tensors with the same shape of \([b, d_q, s_q, h]\)
key and value are tensors with the same shape of \([b, d_{kv}, s_{kv}, h]\)
mask is a tensor with the same shape of \([a_0, a_1, s_q, s_{kv}]\) where \(a_0\) and \(a_1\) must be broadcastable to b and h.
normalizationQuantizeScale is a tensor with a shape of \([a_0,...,a_n], 0 \leq n \geq 1\)
DLA Support¶
Not supported.
Multi-Device Runtime Requirements¶
To use Attention with nbRanks > 1:
Multi-device attention requires GPUs with SM >= 100 (Blackwell or newer) and supports only
bfloat16andfloat16data types.The engine must be built with
PreviewFeature::kMULTIDEVICE_RUNTIME_10_16enabled in the builder config.An NCCL communicator must be initialized and set on the execution context via
IExecutionContext::setCommunicatorbefore inference.All participating ranks must execute the same network with synchronized execution calls.
Examples¶
Attention
network = get_runner.builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))
qkv_shape = (1, 8, 1, 16)
mask_shape = (1, 1, 1, 1)
query = network.add_input("query", dtype=trt.float16, shape=qkv_shape)
key = network.add_input("key", dtype=trt.float16, shape=qkv_shape)
value = network.add_input("value", dtype=trt.float16, shape=qkv_shape)
mask = network.add_input("mask", dtype=trt.bool, shape=mask_shape)
layer = network.add_attention(query, key, value, trt.AttentionNormalizationOp.SOFTMAX, False)
layer.mask = mask
network.mark_output(layer.get_output(0))
query_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
query_data[0, i, 0, :] = i + 1
inputs[query.name] = query_data
key_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
key_data[0, i, 0, :] = i + 1
inputs[key.name] = key_data
value_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
value_data[0, i, 0, :] = i + 1
inputs[value.name] = value_data
mask_data = np.ones(mask_shape, dtype=np.bool_)
inputs[mask.name] = mask_data
outputs[layer.get_output(0).name] = layer.get_output(0).shape
# With identical query/key/values, each position will attend to itself most strongly
expected_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
expected_data[0, i, 0, :] = i + 1
expected[layer.get_output(0).name] = expected_data
# Set get_runner.network back to the new STRONGLY_TYPED network
get_runner.network = network
C++ API¶
For more information about the C++ IAttention operator, refer to the C++ IAttention documentation.
Python API¶
For more information about the Python IAttention operator, refer to the Python IAttention documentation.