core.inference.sampling.flashinfer_sampling#

Module Contents#

Classes#

FlashInferSampling

Fused FlashInfer sampling, with optional CUDA graph capture/replay.

API#

class core.inference.sampling.flashinfer_sampling.FlashInferSampling(
vocab_size: int,
rng: torch.Generator,
config=None,
enable_cuda_graph: bool = False,
)#

Bases: megatron.core.inference.sampling.base.Sampling

Fused FlashInfer sampling, with optional CUDA graph capture/replay.

Initialization

sample_kernel(
logits: torch.Tensor,
n: int,
context,
*,
gather_indices: Optional[torch.Tensor] = None,
token_to_request_index: Optional[torch.Tensor] = None,
eager: bool = False,
cache_key: Any = None,
) torch.Tensor#

FlashInfer fused top-k / top-p sampling kernel.

Parameters:
  • logits – Logits tensor of shape [>=n, vocab_size].

  • n – Number of rows to sample.

  • context – The active DynamicInferenceContext.

  • gather_indices – When set, sample from logits[gather_indices[:n], :].

  • token_to_request_index – When set, sampling parameters are gathered per-token rather than per-request (used by the speculative path).

  • eager – Consumed by CudaGraphManager when it wraps this kernel.

  • cache_key – Consumed by CudaGraphManager when it wraps this kernel.

Returns:

Sampled token ids of shape [n]. Under CUDA graph replay, this is a static buffer.