core.inference.contexts.fused_kv_append_kernel#

Module Contents#

Functions#

_append_kv_cache_kernel

Triton kernel to append key and value vectors to pre-sliced paged KV cache tensors.

triton_append_key_value_cache

Append to KV cache using a high-performance, standalone Triton kernel.

API#

core.inference.contexts.fused_kv_append_kernel._append_kv_cache_kernel(
key_ptr,
value_ptr,
key_cache_ptr,
value_cache_ptr,
block_idx_ptr,
local_kv_seq_idx_ptr,
stride_key_token,
stride_key_head,
stride_key_hdim,
stride_value_token,
stride_value_head,
stride_value_hdim,
stride_cache_block,
stride_cache_pos,
stride_cache_head,
stride_cache_hdim,
n_tokens: triton.language.int32,
num_heads: triton.language.int32,
H_DIM: triton.language.int32,
BLOCK_SIZE_H: triton.language.constexpr,
)#

Triton kernel to append key and value vectors to pre-sliced paged KV cache tensors.

Each program instance handles one head of one token. The grid is 2D: (n_tokens, num_heads).

  1. It identifies which token and head it is responsible for using tl.program_id.

  2. It loads the block_idx and local_pos for that token.

  3. It loads the h_dim vector for its assigned key/value head.

  4. It calculates the destination address in the 4D cache slices.

  5. It writes (scatters) the head vector to its destination in the cache.

core.inference.contexts.fused_kv_append_kernel.triton_append_key_value_cache(
layer_number: int,
key: torch.Tensor,
value: torch.Tensor,
memory_buffer: torch.Tensor,
padded_active_token_count: int,
token_to_block_idx: torch.Tensor,
token_to_local_position_within_kv_block: torch.Tensor,
) None#

Append to KV cache using a high-performance, standalone Triton kernel.

Parameters:
  • layer_number (int) – Layer number (1-based).

  • key (Tensor) – Key tensor of shape (batch_size, 1, num_heads, h_dim).

  • value (Tensor) – Value tensor of shape (batch_size, 1, num_heads, h_dim).

  • memory_buffer (Tensor) – The 6D KV cache tensor to write to.

  • padded_active_token_count (int) – The number of active tokens to process.

  • token_to_block_idx (Tensor) – Tensor mapping token index to its block index in

  • cache. (the)

  • token_to_local_position_within_kv_block (Tensor) – Tensor mapping token index

  • block. (to its position within a)