Attention

Generates an attention.

Attributes

normalizationOp activation function can be one of:

  • NONE No normalization is applied.

  • SOFTMAX Apply softmax normalization on the attention scores on the s_kv dimension.

causal (deprecated) boolean parameter determines whether the attention will run causal inference. Superseded by causalKind.

causalKind parameter determines the causal mask alignment orientation. Possible values are:

  • NONE No causal masking applied.

  • UPPER_LEFT Diagonal anchored at top-left corner (legacy default when causal=true).

  • LOWER_RIGHT Diagonal anchored at bottom-right corner (decode-aligned semantics for LLM generation).

decomposable boolean parameter determines whether the attention can be decomposed to use multiple kernels if no fused kernel support found.

normalizationQuantizeToType optional parameter determines the datatype the attention normalization is quantized to.

  • DataType::kFP8: The attention normalization is quantized to FP8.

  • DataType::kINT8: The attention normalization is quantized to INT8.

nbRanks parameter (default: 1) specifies the number of ranks for multi-device attention execution. When > 1, hints attention to perform context-parallel multi-device attention.

queryForm and keyValueForm parameters determine the query/key/value tensor layouts. Possible values are:

  • PADDED_BHND Dense padded layout with shape \([b, d, s, h]\).

  • PACKED_NHD Packed ragged layout with shape \([t, d, h]\), where \(t\) is the total number of packed tokens.

Inputs

query a tensor of type T1

key a tensor of type T1

value a tensor of type T1

mask optional: tensor of type T2. For a bool mask, a True value indicates that the corresponding position is allowed to attend. For a float mask, the mask values will be added to the BMM1 output, known as an add mask.

normalizationQuantizeScale optional: tensor of type T1. The quantization scale for the attention normalization output.

queryLengths optional: tensor of type int32. Required when queryForm is PACKED_NHD. Contains cumulative packed-token offsets with shape \([b + 1]\), starting with 0.

keyValueLengths optional: tensor of type int32. Required when keyValueForm is PACKED_NHD. Contains cumulative packed-token offsets with shape \([b + 1]\), starting with 0. When keyValueForm is PADDED_BHND, this tensor can optionally contain per-batch key-value lengths with shape \([b]\).

Outputs

outputs tensors of type T1

Data Types

T1: float32, float16, bfloat16

T2: float32, float16, bfloat16, bool. T2 must be the same as T1 if not bool

Shape Information

For PADDED_BHND query form, query and outputs are tensors with the same shape of \([b, d_q, s_q, h]\).

For PADDED_BHND key-value form, key and value are tensors with the same shape of \([b, d_{kv}, s_{kv}, h]\).

For PACKED_NHD query form, query and outputs are tensors with the same shape of \([t_q, d_q, h]\), and queryLengths contains cumulative packed-token offsets for \(b\) sequences with shape of \([b + 1]\), starting with 0.

For PACKED_NHD key-value form, key and value are tensors with the same shape of \([t_{kv}, d_{kv}, h]\), and keyValueLengths contains cumulative packed-token offsets for \(b\) sequences with shape of \([b + 1]\), starting with 0.

mask is a tensor with the same shape of \([a_0, a_1, s_q, s_{kv}]\) where \(a_0\) and \(a_1\) must be broadcastable to b and h.

normalizationQuantizeScale is a tensor with a shape of \([a_0,...,a_n], 0 \leq n \geq 1\)

DLA Support

Not supported.

Multi-Device Runtime Requirements

To use Attention with nbRanks > 1:

  • Multi-device attention requires GPUs with SM >= 100 (Blackwell or newer) and supports only bfloat16 and float16 data types.

  • An NCCL communicator must be initialized and set on the execution context via IExecutionContext::setCommunicator before inference.

  • All participating ranks must execute the same network with synchronized execution calls.

Examples

Attention
network = get_runner.builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))
qkv_shape = (1, 8, 1, 16)
mask_shape = (1, 1, 1, 1)
query = network.add_input("query", dtype=trt.float16, shape=qkv_shape)
key = network.add_input("key", dtype=trt.float16, shape=qkv_shape)
value = network.add_input("value", dtype=trt.float16, shape=qkv_shape)
mask = network.add_input("mask", dtype=trt.bool, shape=mask_shape)
layer = network.add_attention_v2(query, key, value, trt.AttentionNormalizationOp.SOFTMAX, trt.CausalMaskKind.NONE)
layer.mask = mask
network.mark_output(layer.get_output(0))

query_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
    query_data[0, i, 0, :] = i + 1
inputs[query.name] = query_data

key_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
    key_data[0, i, 0, :] = i + 1
inputs[key.name] = key_data

value_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
    value_data[0, i, 0, :] = i + 1
inputs[value.name] = value_data

mask_data = np.ones(mask_shape, dtype=np.bool_)
inputs[mask.name] = mask_data

outputs[layer.get_output(0).name] = layer.get_output(0).shape

# With identical query/key/values, each position will attend to itself most strongly
expected_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
    expected_data[0, i, 0, :] = i + 1
expected[layer.get_output(0).name] = expected_data

# Set get_runner.network back to the new STRONGLY_TYPED network
get_runner.network = network
network = get_runner.builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))
qkv_shape = (1, 8, 1, 32)
mask_shape = (1, 1, 1, 1)

query = network.add_input("query", dtype=trt.float16, shape=qkv_shape)
q_scale = network.add_constant(shape=(), weights=np.array(2.0, dtype=np.float16))
query_fp8 = network.add_quantize(query, q_scale.get_output(0), trt.fp8)
query_fp16 = network.add_dequantize(query_fp8.get_output(0), q_scale.get_output(0), trt.float16)

key = network.add_input("key", dtype=trt.float16, shape=qkv_shape)
k_scale = network.add_constant(shape=(), weights=np.array(3.0, dtype=np.float16))
key_fp8 = network.add_quantize(key, k_scale.get_output(0), trt.fp8)
key_fp16 = network.add_dequantize(key_fp8.get_output(0), k_scale.get_output(0), trt.float16)

value = network.add_input("value", dtype=trt.float16, shape=qkv_shape)
v_scale = network.add_constant(shape=(), weights=np.array(4.0, dtype=np.float16))
value_fp8 = network.add_quantize(value, v_scale.get_output(0), trt.fp8)
value_fp16 = network.add_dequantize(value_fp8.get_output(0), v_scale.get_output(0), trt.float16)

mask = network.add_input("mask", dtype=trt.bool, shape=mask_shape)

normalization_quantize_scale = network.add_constant(shape=(), weights=np.array(5.0, dtype=np.float16))
layer = network.add_attention_v2(query_fp16.get_output(0), key_fp16.get_output(0), value_fp16.get_output(0), trt.AttentionNormalizationOp.SOFTMAX, trt.CausalMaskKind.NONE)

layer.mask = mask
layer.normalization_quantize_scale = normalization_quantize_scale.get_output(0)
layer.normalization_quantize_to_type = trt.fp8

output = layer.get_output(0)
output_scale = network.add_constant(shape=(), weights=np.array(6.0, dtype=np.float16))
output_fp8 = network.add_quantize(output, output_scale.get_output(0), trt.fp8)
output_fp16 = network.add_dequantize(output_fp8.get_output(0), output_scale.get_output(0), trt.float16)
network.mark_output(output_fp16.get_output(0))

query_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
    query_data[0, i, 0, :] = i + 1
inputs[query.name] = query_data

key_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
    key_data[0, i, 0, :] = i + 1
inputs[key.name] = key_data

value_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
    value_data[0, i, 0, :] = i + 1
inputs[value.name] = value_data

mask_data = np.ones(mask_shape, dtype=np.bool_)
inputs[mask.name] = mask_data

outputs[output_fp16.get_output(0).name] = output_fp16.get_output(0).shape

# With identical query/key/values, each position will attend to itself most strongly
expected_data = np.ones(qkv_shape, dtype=np.float16)
for i in range(qkv_shape[1]):
    expected_data[0, i, 0, :] = i + 1
expected[output_fp16.get_output(0).name] = expected_data

# Set get_runner.network back to the new STRONGLY_TYPED network
get_runner.network = network

# Adjust tolerance for FP8 quantization precision loss
get_runner.atol = 0.3
network = get_runner.builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))

# s_q=4, s_kv=8 to test the difference between upper-left and lower-right
batch, heads, s_q, s_kv, head_dim = 1, 1, 4, 8, 16
query_shape = (batch, heads, s_q, head_dim)
kv_shape = (batch, heads, s_kv, head_dim)

query = network.add_input("query", dtype=trt.float16, shape=query_shape)
key = network.add_input("key", dtype=trt.float16, shape=kv_shape)
value = network.add_input("value", dtype=trt.float16, shape=kv_shape)

# Use CausalMaskKind.UPPER_LEFT
layer = network.add_attention_v2(query, key, value, trt.AttentionNormalizationOp.SOFTMAX, trt.CausalMaskKind.UPPER_LEFT)
layer.decomposable = True
network.mark_output(layer.get_output(0))

# Input data: all ones
query_data = np.ones(query_shape, dtype=np.float16)
key_data = np.ones(kv_shape, dtype=np.float16)
# Value data: each key position has a distinct value
value_data = np.zeros(kv_shape, dtype=np.float16)
for j in range(s_kv):
    value_data[0, 0, j, :] = j + 1  # key pos j has value (j+1)

inputs[query.name] = query_data
inputs[key.name] = key_data
inputs[value.name] = value_data

outputs[layer.get_output(0).name] = layer.get_output(0).shape

# Expected output for upper-left causal:
# Query pos 0: attends to key[0] only -> output = value[0] = 1
# Query pos 1: attends to key[0:1] uniformly -> output = avg(1, 2) = 1.5
# Query pos 2: attends to key[0:2] uniformly -> output = avg(1, 2, 3) = 2
# Query pos 3: attends to key[0:3] uniformly -> output = avg(1, 2, 3, 4) = 2.5
expected_data = np.zeros(query_shape, dtype=np.float16)
for i in range(s_q):
    # Upper-left: attend to positions 0..i
    attended_values = [j + 1 for j in range(i + 1)]
    expected_data[0, 0, i, :] = np.mean(attended_values)

expected[layer.get_output(0).name] = expected_data

get_runner.network = network
get_runner.atol = 0.1
network = get_runner.builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))

# s_q=4, s_kv=8 to test the difference between upper-left and lower-right
batch, heads, s_q, s_kv, head_dim = 1, 1, 4, 8, 16
query_shape = (batch, heads, s_q, head_dim)
kv_shape = (batch, heads, s_kv, head_dim)

query = network.add_input("query", dtype=trt.float16, shape=query_shape)
key = network.add_input("key", dtype=trt.float16, shape=kv_shape)
value = network.add_input("value", dtype=trt.float16, shape=kv_shape)

# Use CausalMaskKind.LOWER_RIGHT
layer = network.add_attention_v2(query, key, value, trt.AttentionNormalizationOp.SOFTMAX, trt.CausalMaskKind.LOWER_RIGHT)
layer.decomposable = True
network.mark_output(layer.get_output(0))

# Input data: all ones for Q and K
query_data = np.ones(query_shape, dtype=np.float16)
key_data = np.ones(kv_shape, dtype=np.float16)
# Value data: each key position has a distinct value
value_data = np.zeros(kv_shape, dtype=np.float16)
for j in range(s_kv):
    value_data[0, 0, j, :] = j + 1  # key pos j has value (j+1)

inputs[query.name] = query_data
inputs[key.name] = key_data
inputs[value.name] = value_data

outputs[layer.get_output(0).name] = layer.get_output(0).shape

# Expected output for lower-right causal:
# offset = s_kv - s_q = 4
# Query pos 0: attends to key[0:4] -> output = avg(1,2,3,4,5) = 3
# Query pos 1: attends to key[0:5] -> output = avg(1,2,3,4,5,6) = 3.5
# Query pos 2: attends to key[0:6] -> output = avg(1,2,3,4,5,6,7) = 4
# Query pos 3: attends to key[0:7] -> output = avg(1,2,3,4,5,6,7,8) = 4.5
offset = s_kv - s_q
expected_data = np.zeros(query_shape, dtype=np.float16)
for i in range(s_q):
    # Lower-right: attend to positions 0..(i + offset)
    attended_values = [j + 1 for j in range(i + offset + 1)]
    expected_data[0, 0, i, :] = np.mean(attended_values)

expected[layer.get_output(0).name] = expected_data

get_runner.network = network
get_runner.atol = 0.1
# Packed attention with key-value lengths currently has fused support on SM100, SM103, and SM110.
network = get_runner.builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))

total_tokens, num_heads, head_dim = 3, 1, 128
qkv_shape = (total_tokens, num_heads, head_dim)
lengths_shape = (3,)

query = network.add_input("query", dtype=trt.float16, shape=qkv_shape)
key = network.add_input("key", dtype=trt.float16, shape=qkv_shape)
value = network.add_input("value", dtype=trt.float16, shape=qkv_shape)
query_lengths = network.add_input("query_lengths", dtype=trt.int32, shape=lengths_shape)
key_value_lengths = network.add_input("key_value_lengths", dtype=trt.int32, shape=lengths_shape)

layer = network.add_attention_v2(query, key, value, trt.AttentionNormalizationOp.SOFTMAX, trt.CausalMaskKind.NONE)
layer.query_form = trt.AttentionIOForm.PACKED_NHD
layer.key_value_form = trt.AttentionIOForm.PACKED_NHD
layer.query_lengths = query_lengths
layer.key_value_lengths = key_value_lengths
network.mark_output(layer.get_output(0))

inputs[query.name] = np.zeros(qkv_shape, dtype=np.float16)
inputs[key.name] = np.zeros(qkv_shape, dtype=np.float16)

value_data = np.zeros(qkv_shape, dtype=np.float16)
value_data[0, :, :] = 1
value_data[1, :, :] = 2
value_data[2, :, :] = 4
inputs[value.name] = value_data

sequence_lengths = np.array([1, 2], dtype=np.int32)
cumulative_lengths = np.concatenate(([0], np.cumsum(sequence_lengths))).astype(np.int32)
inputs[query_lengths.name] = cumulative_lengths
inputs[key_value_lengths.name] = cumulative_lengths

expected_data = np.zeros(qkv_shape, dtype=np.float16)
# Zero Q/K scores produce uniform attention within each packed sequence. The first sequence
# contains token 0, and the second sequence contains tokens 1 and 2.
expected_data[0, :, :] = 1
expected_data[1, :, :] = 3
expected_data[2, :, :] = 3

outputs[layer.get_output(0).name] = layer.get_output(0).shape
expected[layer.get_output(0).name] = expected_data

get_runner.network = network
get_runner.atol = 0.1

C++ API

For more information about the C++ IAttention operator, refer to the C++ IAttention documentation.

Python API

For more information about the Python IAttention operator, refer to the Python IAttention documentation.