core.inference.contexts.mamba_slot_allocator#
Module Contents#
Classes#
Manages Mamba state caching for prefix caching in hybrid models. |
Data#
API#
- core.inference.contexts.mamba_slot_allocator.MAX_INTERMEDIATE_OFFSETS_PER_REQUEST#
3
- class core.inference.contexts.mamba_slot_allocator.MambaSlotAllocator(
- context: core.inference.contexts.dynamic_context.DynamicInferenceContext,
- max_slots: int,
- num_mamba_layers: int,
- conv_states_shape: tuple,
- ssm_states_shape: tuple,
- conv_states_dtype: torch.dtype,
- ssm_states_dtype: torch.dtype,
Manages Mamba state caching for prefix caching in hybrid models.
Owns the Mamba cache slot pool, block-to-slot mappings, hash-to-block mapping, and intermediate state tracking. Accesses KV allocator state (ref counts, timestamps, block hashes) via the parent context.
- Parameters:
context – The DynamicInferenceContext that owns this allocator.
max_slots – Maximum number of cache slots.
num_mamba_layers – Number of Mamba layers in the model.
conv_states_shape – Shape of per-slot conv state (excluding layer/slot dims).
ssm_states_shape – Shape of per-slot SSM state (excluding layer/slot dims).
conv_states_dtype – Dtype for conv state tensors.
ssm_states_dtype – Dtype for SSM state tensors.
Initialization
- allocate_slots_batch(block_ids: list) list#
Get free Mamba cache slots for multiple blocks, evicting if necessary.
Handles deduplication: if the same block_id appears multiple times, only one slot is allocated and all occurrences get the same slot.
- Parameters:
block_ids – List of KV block IDs to associate with slots.
- Returns:
List of allocated slot indices (same length as block_ids).
- _evict_lru_slots_batch(num_needed: int) list#
Evict the least recently used Mamba cache slots.
Does NOT return slots to the free pool — caller takes ownership.
- Parameters:
num_needed – Number of slots to evict.
- Returns:
List of freed slot indices.
- get_slot(block_id: int) int#
Return the cache slot for a block, or -1 if none.
- Parameters:
block_id – The KV block ID.
- Returns:
Slot index or -1.
- has_state(block_id: int) bool#
Check if a block has cached Mamba state.
- invalidate_block(block_id: int) None#
Free cache slot and clear mappings for a block.
- Parameters:
block_id – The KV block ID.
- _invalidate_blocks_batch(block_ids_list: list) None#
Free cache slots and clear mappings for multiple blocks at once.
Vectorized version of invalidate_block that avoids per-block .item() GPU syncs. Used by on_kv_blocks_deregistered for bulk eviction.
- Parameters:
block_ids_list – List of block IDs to invalidate.
- on_kv_blocks_deregistered(
- block_ids_list: list,
- hashes_to_delete: set,
Handle KV block deregistration by cleaning up Mamba state.
Called by KVBlockAllocator._deregister_blocks via callback.
- Parameters:
block_ids_list – List of deregistered block IDs.
hashes_to_delete – Set of hashes being deregistered (excludes -1).
- store_from_tensors(
- block_id: int,
- layer_idx: int,
- ssm_state: torch.Tensor,
- conv_state: torch.Tensor,
Write provided state tensors to a cache slot for a specific layer.
- Parameters:
block_id – The KV block ID.
layer_idx – The Mamba layer index.
ssm_state – SSM state tensor to store.
conv_state – Conv state tensor to store.
- store_from_live_batch(slots: list, request_indices: list) None#
Copy all layers from live per-request buffers to cache slots.
- Parameters:
slots – List of cache slot indices.
request_indices – List of context request indices.
- restore_to_live(request_idx: int, block_id: int) bool#
Copy all layers from cache slot to live request state.
- Parameters:
request_idx – The context request index.
block_id – The KV block ID.
- Returns:
True if state was restored, False if block has no cached state.
- register_block_hashes_batch(block_ids: list, hashes: list) None#
Register multiple blocks as having cached Mamba state.
Only registers entries where hash > 0.
- Parameters:
block_ids – List of block IDs.
hashes – List of hash values (same length as block_ids).
- compute_and_store_offsets(
- req,
- current_id: int,
- skip_tokens: int,
- prefill_chunk_length: int,
- num_matched_blocks: int,
- matched_block_ids: list,
- overall_required_blocks: int,
Compute intermediate state extraction offsets and store per-request.
- Parameters:
req – The inference request.
current_id – Context request index.
skip_tokens – Number of tokens being skipped (mamba match).
prefill_chunk_length – Total prefill chunk length before skipping.
num_matched_blocks – Number of KV-matched blocks.
matched_block_ids – List of matched KV block IDs.
overall_required_blocks – Total blocks needed for this request.
- get_intermediate_gpu_data()#
Get intermediate offsets and counts as GPU tensor slices for current prefill batch.
- Returns:
offsets_gpu: [prefill_count, 3] int32 GPU tensor counts_gpu: [prefill_count] int32 GPU tensor Returns (None, None) if no prefill requests or no intermediates.
- Return type:
Tuple of (offsets_gpu, counts_gpu) where
- commit_intermediate_states() None#
Commit intermediate states from pre-allocated output buffers to cache.
Called after the forward pass (including CUDA graph replay) completes. Batched pipeline: collect data, allocate slots, copy states, register hashes.
- _collect_commit_data()#
Extract commit data from GPU intermediate state tracking.
- Returns:
Tuple of (intermediate_bids, src_offsets, eos_bids, eos_ctx_indices, all_hashes) or None if nothing to commit. all_hashes covers intermediate_bids + eos_bids in that order.
- _copy_intermediate_to_cache(src_offsets: list, slots: list) None#
Copy intermediate states from output buffers to cache slots.
Uses fancy-indexed GPU D2D copy (2 kernel launches instead of 2N).
- Parameters:
src_offsets – Source indices into intermediate_ssm_out/intermediate_conv_out.
slots – Destination cache slot indices.
- _clear_intermediate_state() None#
Clear all per-request intermediate state tracking.
- reset() None#
Reset all state (mappings, free pool, cache, intermediate tracking).