core.inference.text_generation_controllers.mtp_utils_triton#

Module Contents#

Functions#

_rewind_kv_cache_kernel

Rewind KV-cache bookkeeping for one request after speculative verification.

rewind_kv_cache

Launch the KV-cache rewind Triton kernel.

_verify_speculative_tokens_kernel

Verify speculative tokens for one request.

verify_speculative_tokens

Launch the speculative-token verification Triton kernel.

_prepare_next_forward_pass_kernel

Gather final tokens and extract accepted speculative tokens per request.

prepare_next_forward_pass

Launch the prepare-next-forward-pass Triton kernel.

_mamba_state_selective_copy_kernel

Copy intermediate Mamba state to current state for decode requests.

mamba_state_selective_copy

Copy accepted intermediate Mamba states to current states in-place.

API#

core.inference.text_generation_controllers.mtp_utils_triton._rewind_kv_cache_kernel(
ACCEPTED_COUNTS_PTR,
PREFILL_STATUS_PTR,
LAST_KV_BLOCK_OFFSET_PTR,
KV_LENGTH_OFFSETS_PTR,
KV_BLOCK_COUNTS_PTR,
LAST_KV_BLOCK_ID_PTR,
KV_BLOCK_IDS_PTR,
BLOCKS_TO_RELEASE_PTR,
REMOVE_MASK_PTR,
kv_block_ids_stride,
max_blocks_minus_1,
num_active_requests,
NUM_SPEC_TOKENS: triton.language.constexpr,
BLOCK_SIZE_TOKENS: triton.language.constexpr,
)#

Rewind KV-cache bookkeeping for one request after speculative verification.

Grid: may be padded beyond active requests for CUDA-graph compatibility. Each program handles exactly one request. Programs with pid >= num_active_requests are padding and produce safe no-op outputs.

core.inference.text_generation_controllers.mtp_utils_triton.rewind_kv_cache(
accepted_counts,
prefill_status,
last_kv_block_offset,
kv_length_offsets,
kv_block_counts,
last_kv_block_id,
kv_block_ids,
num_speculative_tokens,
block_size_tokens,
num_active_requests=None,
)#

Launch the KV-cache rewind Triton kernel.

Parameters:

num_active_requests – Number of real (non-padding) requests. When the grid is padded beyond this count, the kernel skips padding programs so stale data in padding slots cannot corrupt bookkeeping. Defaults to accepted_counts.shape[0] (no padding).

Returns:

(blocks_to_release, remove_mask) — same semantics as the original torch.compile’d _rewind_kv_cache (KV-cache portion only; Mamba state updates are handled separately by the caller).

core.inference.text_generation_controllers.mtp_utils_triton._verify_speculative_tokens_kernel(
INPUT_TOKENS_PTR,
OUTPUT_TOKENS_PTR,
ACCEPTED_MASK_PTR,
LAST_ONE_INDICES_PTR,
num_decode_requests,
decode_len,
STRIDE: triton.language.constexpr,
BLOCK_SIZE: triton.language.constexpr,
)#

Verify speculative tokens for one request.

Grid: (active_request_count,) Programs 0..num_decode_requests-1 handle decode requests. Programs num_decode_requests..end handle prefill requests.

core.inference.text_generation_controllers.mtp_utils_triton.verify_speculative_tokens(
input_tokens,
output_tokens,
num_decode_requests,
num_prefill_requests,
num_speculative_tokens,
)#

Launch the speculative-token verification Triton kernel.

Returns:

(last_one_indices, accepted_tokens_mask, input_tokens) matching the original _verify_speculative_tokens signature.

core.inference.text_generation_controllers.mtp_utils_triton._prepare_next_forward_pass_kernel(
OUTPUT_TOKENS_PTR,
REQUIRED_LOGIT_INDICES_PTR,
LAST_ONE_INDICES_PTR,
INPUT_TOKENS_PTR,
ACCEPTED_MASK_PTR,
SAMPLED_TOKENS_OUT_PTR,
LAST_ACCEPTED_SEQ_OUT_PTR,
ACCEPTED_TOKENS_OUT_PTR,
ACCEPTED_COUNTS_OUT_PTR,
accepted_tokens_out_stride,
num_decode_requests,
STRIDE: triton.language.constexpr,
NUM_SPEC_TOKENS: triton.language.constexpr,
SPEC_BLOCK_SIZE: triton.language.constexpr,
)#

Gather final tokens and extract accepted speculative tokens per request.

Grid: (active_request_count,)

core.inference.text_generation_controllers.mtp_utils_triton.prepare_next_forward_pass(
num_decode_requests,
output_tokens,
required_logit_indices,
last_one_indices,
accepted_tokens_mask,
input_tokens,
sampled_tokens_buf,
last_accepted_seq_buf,
accepted_tokens_per_request,
accepted_token_counts,
num_speculative_tokens,
)#

Launch the prepare-next-forward-pass Triton kernel.

Writes results into the pre-allocated buffers provided by the caller.

core.inference.text_generation_controllers.mtp_utils_triton._mamba_state_selective_copy_kernel(
SRC_PTR,
DST_PTR,
PREFILL_STATUS_PTR,
STATE_IDX_PTR,
ACCEPTED_PTR,
src_stride_layer,
src_stride_slot,
src_stride_spec,
dst_stride_layer,
dst_stride_slot,
STATE_SIZE,
BLOCK_SIZE: triton.language.constexpr,
)#

Copy intermediate Mamba state to current state for decode requests.

Grid: (N, L, num_chunks)

  • dim 0: active request index

  • dim 1: mamba layer index

  • dim 2: chunk of the flattened state vector

No-op for prefill requests.

core.inference.text_generation_controllers.mtp_utils_triton.mamba_state_selective_copy(
intermediate_states,
current_states,
prefill_status,
state_idx,
accepted_counts,
num_layers,
)#

Copy accepted intermediate Mamba states to current states in-place.

For each decode request, copies intermediate[layer, slot, accepted_count, ...]current[layer, slot, ...] for every Mamba layer.

Parameters:
  • intermediate_states(L, M, S+1, *state_shape) — intermediate buffer.

  • current_states(L, M, *state_shape) — current state buffer (updated in-place).

  • prefill_status(N,) int tensor — 0 for decode, 1 for prefill.

  • state_idx(N,) int tensor — mamba state slot index per request.

  • accepted_counts(N,) int tensor — accepted token index per request.

  • num_layers – number of Mamba layers (first dim of the state tensors).