core.inference.contexts.mamba_slot_allocator#

Module Contents#

Classes#

MambaSlotAllocator

Manages Mamba state caching for prefix caching in hybrid models.

Data#

API#

core.inference.contexts.mamba_slot_allocator.MAX_INTERMEDIATE_OFFSETS_PER_REQUEST#

3

class core.inference.contexts.mamba_slot_allocator.MambaSlotAllocator(
context: core.inference.contexts.dynamic_context.DynamicInferenceContext,
max_slots: int,
num_mamba_layers: int,
conv_states_shape: tuple,
ssm_states_shape: tuple,
conv_states_dtype: torch.dtype,
ssm_states_dtype: torch.dtype,
)#

Manages Mamba state caching for prefix caching in hybrid models.

Owns the Mamba cache slot pool, block-to-slot mappings, hash-to-block mapping, and intermediate state tracking. Accesses KV allocator state (ref counts, timestamps, block hashes) via the parent context.

Parameters:
  • context – The DynamicInferenceContext that owns this allocator.

  • max_slots – Maximum number of cache slots.

  • num_mamba_layers – Number of Mamba layers in the model.

  • conv_states_shape – Shape of per-slot conv state (excluding layer/slot dims).

  • ssm_states_shape – Shape of per-slot SSM state (excluding layer/slot dims).

  • conv_states_dtype – Dtype for conv state tensors.

  • ssm_states_dtype – Dtype for SSM state tensors.

Initialization

allocate_slots_batch(block_ids: list) list#

Get free Mamba cache slots for multiple blocks, evicting if necessary.

Handles deduplication: if the same block_id appears multiple times, only one slot is allocated and all occurrences get the same slot.

Parameters:

block_ids – List of KV block IDs to associate with slots.

Returns:

List of allocated slot indices (same length as block_ids).

_evict_lru_slots_batch(num_needed: int) list#

Evict the least recently used Mamba cache slots.

Does NOT return slots to the free pool — caller takes ownership.

Parameters:

num_needed – Number of slots to evict.

Returns:

List of freed slot indices.

get_slot(block_id: int) int#

Return the cache slot for a block, or -1 if none.

Parameters:

block_id – The KV block ID.

Returns:

Slot index or -1.

has_state(block_id: int) bool#

Check if a block has cached Mamba state.

invalidate_block(block_id: int) None#

Free cache slot and clear mappings for a block.

Parameters:

block_id – The KV block ID.

_invalidate_blocks_batch(block_ids_list: list) None#

Free cache slots and clear mappings for multiple blocks at once.

Vectorized version of invalidate_block that avoids per-block .item() GPU syncs. Used by on_kv_blocks_deregistered for bulk eviction.

Parameters:

block_ids_list – List of block IDs to invalidate.

on_kv_blocks_deregistered(
block_ids_list: list,
hashes_to_delete: set,
) None#

Handle KV block deregistration by cleaning up Mamba state.

Called by KVBlockAllocator._deregister_blocks via callback.

Parameters:
  • block_ids_list – List of deregistered block IDs.

  • hashes_to_delete – Set of hashes being deregistered (excludes -1).

store_from_tensors(
block_id: int,
layer_idx: int,
ssm_state: torch.Tensor,
conv_state: torch.Tensor,
) None#

Write provided state tensors to a cache slot for a specific layer.

Parameters:
  • block_id – The KV block ID.

  • layer_idx – The Mamba layer index.

  • ssm_state – SSM state tensor to store.

  • conv_state – Conv state tensor to store.

store_from_live_batch(slots: list, request_indices: list) None#

Copy all layers from live per-request buffers to cache slots.

Parameters:
  • slots – List of cache slot indices.

  • request_indices – List of context request indices.

restore_to_live(request_idx: int, block_id: int) bool#

Copy all layers from cache slot to live request state.

Parameters:
  • request_idx – The context request index.

  • block_id – The KV block ID.

Returns:

True if state was restored, False if block has no cached state.

register_block_hashes_batch(block_ids: list, hashes: list) None#

Register multiple blocks as having cached Mamba state.

Only registers entries where hash > 0.

Parameters:
  • block_ids – List of block IDs.

  • hashes – List of hash values (same length as block_ids).

compute_and_store_offsets(
req,
current_id: int,
skip_tokens: int,
prefill_chunk_length: int,
num_matched_blocks: int,
matched_block_ids: list,
overall_required_blocks: int,
) None#

Compute intermediate state extraction offsets and store per-request.

Parameters:
  • req – The inference request.

  • current_id – Context request index.

  • skip_tokens – Number of tokens being skipped (mamba match).

  • prefill_chunk_length – Total prefill chunk length before skipping.

  • num_matched_blocks – Number of KV-matched blocks.

  • matched_block_ids – List of matched KV block IDs.

  • overall_required_blocks – Total blocks needed for this request.

get_intermediate_gpu_data()#

Get intermediate offsets and counts as GPU tensor slices for current prefill batch.

Returns:

offsets_gpu: [prefill_count, 3] int32 GPU tensor counts_gpu: [prefill_count] int32 GPU tensor Returns (None, None) if no prefill requests or no intermediates.

Return type:

Tuple of (offsets_gpu, counts_gpu) where

commit_intermediate_states() None#

Commit intermediate states from pre-allocated output buffers to cache.

Called after the forward pass (including CUDA graph replay) completes. Batched pipeline: collect data, allocate slots, copy states, register hashes.

_collect_commit_data()#

Extract commit data from GPU intermediate state tracking.

Returns:

Tuple of (intermediate_bids, src_offsets, eos_bids, eos_ctx_indices, all_hashes) or None if nothing to commit. all_hashes covers intermediate_bids + eos_bids in that order.

_copy_intermediate_to_cache(src_offsets: list, slots: list) None#

Copy intermediate states from output buffers to cache slots.

Uses fancy-indexed GPU D2D copy (2 kernel launches instead of 2N).

Parameters:
  • src_offsets – Source indices into intermediate_ssm_out/intermediate_conv_out.

  • slots – Destination cache slot indices.

_clear_intermediate_state() None#

Clear all per-request intermediate state tracking.

reset() None#

Reset all state (mappings, free pool, cache, intermediate tracking).