KVCacheUpdate

Performs Key (K) / Value (V) cache update for attention computations.

Users provide the newly computed K/V values as inputs, and the layer will output the updated K/V cache. The writeIndices input specifies where to write K/V updates for each sequence in the batch.

Separate KVCacheUpdate layers should be used for K and V.

Attributes

cacheMode specifies the cache update mode:

  • LINEAR In linear mode, for each batch element i and sequence position s: output[i, :, writeIndices[i] + s, :] = update[i, :, s, :]

Inputs

cache: tensor of type T, the key/value cache tensor. Must be a network input and have a static sequence length dimension.

update: tensor of type T, the newly computed key/value tensor to write into the cache.

writeIndices: tensor of type M, specifies the write position index for each batch element i. Values must satisfy writeIndices[i] + sequenceLength <= maxSequenceLength.

Outputs

output: tensor of type T, the updated cache tensor. Must be a network output and shares the same device memory address with the cache input (in-place update).

Data Types

T: float32, float16, bfloat16

M: int32, int64

Shape Information

cache and output are tensors with the same shape of \([b, d, s_{max}, h]\)

update is a tensor with the shape of \([b, d, s, h]\) where \(s \leq s_{max}\)

writeIndices is a tensor with the shape of \([b]\)

Where:

  • \(b\) is the batch size

  • \(d\) is the number of heads

  • \(s_{max}\) is the maximum sequence length (must be static)

  • \(s\) is the update sequence length

  • \(h\) is the head size

DLA Support

Not supported.

Examples

KVCacheUpdate
network = get_runner.builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))
cache_shape = (4, 2, 8, 1)
update_shape = (4, 2, 4, 1)
write_indices_shape = (4,)

cache = network.add_input("cache", dtype=trt.float32, shape=cache_shape)
update = network.add_input("update", dtype=trt.float32, shape=update_shape)
write_indices = network.add_input("write_indices", dtype=trt.int32, shape=write_indices_shape)
layer = network.add_kv_cache_update(cache, update, write_indices, trt.KVCacheMode.LINEAR)
network.mark_output(layer.get_output(0))

cache_data = np.array(
    [
        [0.53, 0.88, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.41, 0.0,  0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.67, 0.0,  0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.32, 0.79, 0.64, 0.0, 0.0, 0.0, 0.0, 0.0],
    ],
    dtype=np.float32,
)
inputs[cache.name] = cache_data[:, None, :, None] + np.zeros((1, 2, 1, 1))

update_data = np.array(
    [
        [0.72, 0.0,  0.0, 0.0],
        [0.55, 0.94, 0.0, 0.0],
        [0.61, 0.28, 0.0, 0.0],
        [0.83, 0.0,  0.0, 0.0],
    ],
    dtype=np.float32,
)
inputs[update.name] = update_data[:, None, :, None] + np.zeros((1, 2, 1, 1))

write_indices_data = np.array([2, 1, 1, 3], dtype=np.int32)
inputs[write_indices.name] = write_indices_data

outputs[layer.get_output(0).name] = layer.get_output(0).shape

expected_data = np.array(
    [
        [0.53, 0.88, 0.72, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.41, 0.55, 0.94, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.67, 0.61, 0.28, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.32, 0.79, 0.64, 0.83, 0.0, 0.0, 0.0, 0.0],
    ],
    dtype=np.float32,
)

expected[layer.get_output(0).name] = expected_data[:, None, :, None] + np.zeros((1, 2, 1, 1))

# Set get_runner.network back to the new STRONGLY_TYPED network
get_runner.network = network

C++ API

For more information about the C++ IKVCacheUpdateLayer operator, refer to the C++ IKVCacheUpdateLayer documentation.

Python API

For more information about the Python IKVCacheUpdateLayer operator, refer to the Python IKVCacheUpdateLayer documentation.