core.inference.sampling.base#

Module Contents#

Classes#

Sampling

Abstract base for inference sampling backends.

API#

class core.inference.sampling.base.Sampling#

Bases: abc.ABC

Abstract base for inference sampling backends.

Subclasses implement sample_kernel. CUDA graphs are added via CudaGraphManager.

abstractmethod sample_kernel(
logits: torch.Tensor,
n: int,
context,
*,
gather_indices: Optional[torch.Tensor] = None,
token_to_request_index: Optional[torch.Tensor] = None,
eager: bool = False,
cache_key: Any = None,
) torch.Tensor#

Sample n tokens from logits and return them.

Parameters:
  • logits – Logits tensor of shape [>=n, vocab_size].

  • n – Number of rows to sample.

  • context – The active DynamicInferenceContext.

  • gather_indices – If provided, only sample from logits[gather_indices[:n], :].

  • token_to_request_index – Per-token request mapping; when set, sampling parameters are gathered per-token instead of per-request.

  • eager – Consumed by CudaGraphManager when it wraps this kernel.

  • cache_key – Consumed by CudaGraphManager when it wraps this kernel.

Returns:

Sampled token ids of shape [n]. Under CUDA graph replay, this is a static buffer.

sample_speculative(
required_logits: torch.Tensor,
num_decode: int,
num_prefill: int,
num_speculative_tokens: int,
context,
*,
gather_indices: Optional[torch.Tensor] = None,
eager: bool = False,
cache_key: Any = None,
) torch.Tensor#

Sample tokens for the speculative-verify path.

Decode requests contribute 1 + num_speculative_tokens rows; prefill requests contribute 1. Builds the per-token request mapping and dispatches to sample_kernel. The sample_kernel is forced eager so its own CudaGraphManager wrapper does not fire.

When gather_indices is supplied, the kernel selects via logits[gather_indices[:n], :]. When gather_indices is None, required_logits is expected to be already pre-gathered to the layout described above (e.g. when materialize_only_last_token_logits=True upstream).