core.inference.text_generation_controllers.mtp_utils_pytorch#
Module Contents#
Functions#
Update the KV cache bookkeeping for speculative decoding. |
|
Verify speculative tokens against input tokens and compute acceptance. |
|
Prepare data for the next forward pass after speculative token verification. |
|
Mamba speculative rewind state update. |
API#
- core.inference.text_generation_controllers.mtp_utils_pytorch.rewind_kv_cache(
- accepted_counts,
- prefill_status,
- last_kv_block_offset,
- kv_length_offsets,
- kv_block_counts,
- last_kv_block_id,
- kv_block_ids,
- num_speculative_tokens,
- block_size_tokens,
- num_active_requests=None,
Update the KV cache bookkeeping for speculative decoding.
After forward pass with speculative tokens, some tokens may be rejected. This function “rewinds” the KV cache bookkeeping to reflect only the accepted tokens.
When speculative tokens are rejected, we need to:
Update kv_length_offsets (total sequence length)
Update last_kv_block_offset (position within last block)
If rewinding crosses a block boundary:
Reduce kv_block_counts
Update last_kv_block_id to point to the previous block
Clear the entry in kv_block_ids for the released block
Mutates the input tensors in-place.
Returns (blocks_to_release, remove_mask).
- core.inference.text_generation_controllers.mtp_utils_pytorch.verify_speculative_tokens(
- input_tokens,
- output_tokens,
- num_decode_requests,
- num_prefill_requests,
- num_speculative_tokens,
Verify speculative tokens against input tokens and compute acceptance.
Creates an accepted tokens mask where:
For prefill requests, the token is always accepted.
For decode requests, the first token (base token) is always accepted, then we compare sampled tokens with input tokens and accept consecutive matches. Then finds the index of the last accepted token per request.
Example (assume 1, 2, and 0 spec tokens are accepted in the first 3 decode requests): input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 Output tokens [ a6o a7o a8o | b40 b5o b6o | c7o c8o c9o | d3o | e5o ] Output tokens right shift [ d3o a6o a7o | a8o b40 b5o | b6o c7o c8o | c9o | d3o ] Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] Last one indices [ 1 | 5 | 6 | 9 | 10 ]
- Returns:
(last_one_indices, accepted_tokens_mask, input_tokens) where last_one_indices contains the index of the last accepted token per request.
- Return type:
tuple
- core.inference.text_generation_controllers.mtp_utils_pytorch.prepare_next_forward_pass(
- num_decode_requests,
- output_tokens,
- required_logit_indices,
- last_one_indices,
- accepted_tokens_mask,
- input_tokens,
- sampled_tokens_buf,
- last_accepted_seq_buf,
- accepted_tokens_per_request,
- accepted_token_counts,
- num_speculative_tokens,
Prepare data for the next forward pass after speculative token verification.
For each active request:
Store the final sampled tokens for the next forward pass.
Store the last accepted positions in the packed sequence for serial MTP computation after verification.
For decode requests, extract accepted tokens and counts: input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] Accepted tokens [ [a6s -1] | [b4s b5s] | [-1 -1] ] # Only decode requests (prefill defaults to -1) Accepted token counts [ 1 | 2 | 0 ] # Prefill defaults to 0
Writes results into the pre-allocated buffers provided by the caller.
- core.inference.text_generation_controllers.mtp_utils_pytorch.mamba_state_selective_copy(
- intermediate_states,
- current_states,
- prefill_status,
- state_idx,
- accepted_counts,
- num_layers,
Mamba speculative rewind state update.
For each decode request, copies
intermediate[layer, slot, accepted_count, ...]→current[layer, slot, ...]for every Mamba layer.