KVCacheUpdate¶
Performs Key (K) / Value (V) cache update for attention computations.
Users provide the newly computed K/V values as inputs, and the layer will output the updated K/V cache. The writeIndices input specifies where to write K/V updates for each sequence in the batch.
Separate KVCacheUpdate layers should be used for K and V.
Attributes¶
cacheMode specifies the cache update mode:
LINEARIn linear mode, for each batch element i and sequence position s:output[i, :, writeIndices[i] + s, :] = update[i, :, s, :]
Inputs¶
cache: tensor of type T, the key/value cache tensor. Must be a network input and have a static sequence length dimension.
update: tensor of type T, the newly computed key/value tensor to write into the cache.
writeIndices: tensor of type M, specifies the write position index for each batch element i. Values must satisfy writeIndices[i] + sequenceLength <= maxSequenceLength.
Outputs¶
output: tensor of type T, the updated cache tensor. Must be a network output and shares the same device memory address with the cache input (in-place update).
Data Types¶
T: float32, float16, bfloat16
M: int32, int64
Shape Information¶
cache and output are tensors with the same shape of \([b, d, s_{max}, h]\)
update is a tensor with the shape of \([b, d, s, h]\) where \(s \leq s_{max}\)
writeIndices is a tensor with the shape of \([b]\)
Where:
\(b\) is the batch size
\(d\) is the number of heads
\(s_{max}\) is the maximum sequence length (must be static)
\(s\) is the update sequence length
\(h\) is the head size
DLA Support¶
Not supported.
Examples¶
KVCacheUpdate
network = get_runner.builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))
cache_shape = (4, 2, 8, 1)
update_shape = (4, 2, 4, 1)
write_indices_shape = (4,)
cache = network.add_input("cache", dtype=trt.float32, shape=cache_shape)
update = network.add_input("update", dtype=trt.float32, shape=update_shape)
write_indices = network.add_input("write_indices", dtype=trt.int32, shape=write_indices_shape)
layer = network.add_kv_cache_update(cache, update, write_indices, trt.KVCacheMode.LINEAR)
network.mark_output(layer.get_output(0))
cache_data = np.array(
[
[0.53, 0.88, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.41, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.67, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.32, 0.79, 0.64, 0.0, 0.0, 0.0, 0.0, 0.0],
],
dtype=np.float32,
)
inputs[cache.name] = cache_data[:, None, :, None] + np.zeros((1, 2, 1, 1))
update_data = np.array(
[
[0.72, 0.0, 0.0, 0.0],
[0.55, 0.94, 0.0, 0.0],
[0.61, 0.28, 0.0, 0.0],
[0.83, 0.0, 0.0, 0.0],
],
dtype=np.float32,
)
inputs[update.name] = update_data[:, None, :, None] + np.zeros((1, 2, 1, 1))
write_indices_data = np.array([2, 1, 1, 3], dtype=np.int32)
inputs[write_indices.name] = write_indices_data
outputs[layer.get_output(0).name] = layer.get_output(0).shape
expected_data = np.array(
[
[0.53, 0.88, 0.72, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.41, 0.55, 0.94, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.67, 0.61, 0.28, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.32, 0.79, 0.64, 0.83, 0.0, 0.0, 0.0, 0.0],
],
dtype=np.float32,
)
expected[layer.get_output(0).name] = expected_data[:, None, :, None] + np.zeros((1, 2, 1, 1))
# Set get_runner.network back to the new STRONGLY_TYPED network
get_runner.network = network
C++ API¶
For more information about the C++ IKVCacheUpdateLayer operator, refer to the C++ IKVCacheUpdateLayer documentation.
Python API¶
For more information about the Python IKVCacheUpdateLayer operator, refer to the Python IKVCacheUpdateLayer documentation.