core.inference.text_generation_controllers.mtp_utils_triton#
Module Contents#
Functions#
Rewind KV-cache bookkeeping for one request after speculative verification. |
|
Launch the KV-cache rewind Triton kernel. |
|
Verify speculative tokens for one request. |
|
Launch the speculative-token verification Triton kernel. |
|
Gather final tokens and extract accepted speculative tokens per request. |
|
Launch the prepare-next-forward-pass Triton kernel. |
|
Copy intermediate Mamba state to current state for decode requests. |
|
Copy accepted intermediate Mamba states to current states in-place. |
API#
- core.inference.text_generation_controllers.mtp_utils_triton._rewind_kv_cache_kernel(
- ACCEPTED_COUNTS_PTR,
- PREFILL_STATUS_PTR,
- LAST_KV_BLOCK_OFFSET_PTR,
- KV_LENGTH_OFFSETS_PTR,
- KV_BLOCK_COUNTS_PTR,
- LAST_KV_BLOCK_ID_PTR,
- KV_BLOCK_IDS_PTR,
- BLOCKS_TO_RELEASE_PTR,
- REMOVE_MASK_PTR,
- kv_block_ids_stride,
- max_blocks_minus_1,
- num_active_requests,
- NUM_SPEC_TOKENS: triton.language.constexpr,
- BLOCK_SIZE_TOKENS: triton.language.constexpr,
Rewind KV-cache bookkeeping for one request after speculative verification.
Grid: may be padded beyond active requests for CUDA-graph compatibility. Each program handles exactly one request. Programs with
pid >= num_active_requestsare padding and produce safe no-op outputs.
- core.inference.text_generation_controllers.mtp_utils_triton.rewind_kv_cache(
- accepted_counts,
- prefill_status,
- last_kv_block_offset,
- kv_length_offsets,
- kv_block_counts,
- last_kv_block_id,
- kv_block_ids,
- num_speculative_tokens,
- block_size_tokens,
- num_active_requests=None,
Launch the KV-cache rewind Triton kernel.
- Parameters:
num_active_requests – Number of real (non-padding) requests. When the grid is padded beyond this count, the kernel skips padding programs so stale data in padding slots cannot corrupt bookkeeping. Defaults to
accepted_counts.shape[0](no padding).- Returns:
(blocks_to_release, remove_mask) — same semantics as the original torch.compile’d
_rewind_kv_cache(KV-cache portion only; Mamba state updates are handled separately by the caller).
- core.inference.text_generation_controllers.mtp_utils_triton._verify_speculative_tokens_kernel(
- INPUT_TOKENS_PTR,
- OUTPUT_TOKENS_PTR,
- ACCEPTED_MASK_PTR,
- LAST_ONE_INDICES_PTR,
- num_decode_requests,
- decode_len,
- STRIDE: triton.language.constexpr,
- BLOCK_SIZE: triton.language.constexpr,
Verify speculative tokens for one request.
Grid: (active_request_count,) Programs 0..num_decode_requests-1 handle decode requests. Programs num_decode_requests..end handle prefill requests.
- core.inference.text_generation_controllers.mtp_utils_triton.verify_speculative_tokens(
- input_tokens,
- output_tokens,
- num_decode_requests,
- num_prefill_requests,
- num_speculative_tokens,
Launch the speculative-token verification Triton kernel.
- Returns:
(last_one_indices, accepted_tokens_mask, input_tokens) matching the original
_verify_speculative_tokenssignature.
- core.inference.text_generation_controllers.mtp_utils_triton._prepare_next_forward_pass_kernel(
- OUTPUT_TOKENS_PTR,
- REQUIRED_LOGIT_INDICES_PTR,
- LAST_ONE_INDICES_PTR,
- INPUT_TOKENS_PTR,
- ACCEPTED_MASK_PTR,
- SAMPLED_TOKENS_OUT_PTR,
- LAST_ACCEPTED_SEQ_OUT_PTR,
- ACCEPTED_TOKENS_OUT_PTR,
- ACCEPTED_COUNTS_OUT_PTR,
- accepted_tokens_out_stride,
- num_decode_requests,
- STRIDE: triton.language.constexpr,
- NUM_SPEC_TOKENS: triton.language.constexpr,
- SPEC_BLOCK_SIZE: triton.language.constexpr,
Gather final tokens and extract accepted speculative tokens per request.
Grid: (active_request_count,)
- core.inference.text_generation_controllers.mtp_utils_triton.prepare_next_forward_pass(
- num_decode_requests,
- output_tokens,
- required_logit_indices,
- last_one_indices,
- accepted_tokens_mask,
- input_tokens,
- sampled_tokens_buf,
- last_accepted_seq_buf,
- accepted_tokens_per_request,
- accepted_token_counts,
- num_speculative_tokens,
Launch the prepare-next-forward-pass Triton kernel.
Writes results into the pre-allocated buffers provided by the caller.
- core.inference.text_generation_controllers.mtp_utils_triton._mamba_state_selective_copy_kernel(
- SRC_PTR,
- DST_PTR,
- PREFILL_STATUS_PTR,
- STATE_IDX_PTR,
- ACCEPTED_PTR,
- src_stride_layer,
- src_stride_slot,
- src_stride_spec,
- dst_stride_layer,
- dst_stride_slot,
- STATE_SIZE,
- BLOCK_SIZE: triton.language.constexpr,
Copy intermediate Mamba state to current state for decode requests.
Grid: (N, L, num_chunks)
dim 0: active request index
dim 1: mamba layer index
dim 2: chunk of the flattened state vector
No-op for prefill requests.
- core.inference.text_generation_controllers.mtp_utils_triton.mamba_state_selective_copy(
- intermediate_states,
- current_states,
- prefill_status,
- state_idx,
- accepted_counts,
- num_layers,
Copy accepted intermediate Mamba states to current states in-place.
For each decode request, copies
intermediate[layer, slot, accepted_count, ...]→current[layer, slot, ...]for every Mamba layer.- Parameters:
intermediate_states –
(L, M, S+1, *state_shape)— intermediate buffer.current_states –
(L, M, *state_shape)— current state buffer (updated in-place).prefill_status –
(N,)int tensor — 0 for decode, 1 for prefill.state_idx –
(N,)int tensor — mamba state slot index per request.accepted_counts –
(N,)int tensor — accepted token index per request.num_layers – number of Mamba layers (first dim of the state tensors).