core.inference.contexts.fused_kv_append_kernel#
Module Contents#
Functions#
Triton kernel to append key and value vectors to pre-sliced paged KV cache tensors. |
|
Append to KV cache using a high-performance, standalone Triton kernel. |
API#
- core.inference.contexts.fused_kv_append_kernel._append_kv_cache_kernel(
- key_ptr,
- value_ptr,
- key_cache_ptr,
- value_cache_ptr,
- block_idx_ptr,
- local_kv_seq_idx_ptr,
- stride_key_token,
- stride_key_head,
- stride_key_hdim,
- stride_value_token,
- stride_value_head,
- stride_value_hdim,
- stride_cache_block,
- stride_cache_pos,
- stride_cache_head,
- stride_cache_hdim,
- n_tokens: triton.language.int32,
- num_heads: triton.language.int32,
- H_DIM: triton.language.int32,
- BLOCK_SIZE_H: triton.language.constexpr,
Triton kernel to append key and value vectors to pre-sliced paged KV cache tensors.
Each program instance handles one head of one token. The grid is 2D: (n_tokens, num_heads).
It identifies which token and head it is responsible for using
tl.program_id.It loads the
block_idxandlocal_posfor that token.It loads the
h_dimvector for its assigned key/value head.It calculates the destination address in the 4D cache slices.
It writes (scatters) the head vector to its destination in the cache.
- core.inference.contexts.fused_kv_append_kernel.triton_append_key_value_cache(
- layer_number: int,
- key: torch.Tensor,
- value: torch.Tensor,
- memory_buffer: torch.Tensor,
- padded_active_token_count: int,
- token_to_block_idx: torch.Tensor,
- token_to_local_position_within_kv_block: torch.Tensor,
Append to KV cache using a high-performance, standalone Triton kernel.
- Parameters:
layer_number (int) – Layer number (1-based).
key (Tensor) – Key tensor of shape (batch_size, 1, num_heads, h_dim).
value (Tensor) – Value tensor of shape (batch_size, 1, num_heads, h_dim).
memory_buffer (Tensor) – The 6D KV cache tensor to write to.
padded_active_token_count (int) – The number of active tokens to process.
token_to_block_idx (Tensor) – Tensor mapping token index to its block index in
cache. (the)
token_to_local_position_within_kv_block (Tensor) – Tensor mapping token index
block. (to its position within a)