core.inference.sampling.flashinfer_sampling#
Module Contents#
Classes#
Fused FlashInfer sampling, with optional CUDA graph capture/replay. |
API#
- class core.inference.sampling.flashinfer_sampling.FlashInferSampling(
- vocab_size: int,
- rng: torch.Generator,
- config=None,
- enable_cuda_graph: bool = False,
Bases:
megatron.core.inference.sampling.base.SamplingFused FlashInfer sampling, with optional CUDA graph capture/replay.
Initialization
- sample_kernel(
- logits: torch.Tensor,
- n: int,
- context,
- *,
- gather_indices: Optional[torch.Tensor] = None,
- token_to_request_index: Optional[torch.Tensor] = None,
- eager: bool = False,
- cache_key: Any = None,
FlashInfer fused top-k / top-p sampling kernel.
- Parameters:
logits – Logits tensor of shape
[>=n, vocab_size].n – Number of rows to sample.
context – The active DynamicInferenceContext.
gather_indices – When set, sample from
logits[gather_indices[:n], :].token_to_request_index – When set, sampling parameters are gathered per-token rather than per-request (used by the speculative path).
eager – Consumed by
CudaGraphManagerwhen it wraps this kernel.cache_key – Consumed by
CudaGraphManagerwhen it wraps this kernel.
- Returns:
Sampled token ids of shape
[n]. Under CUDA graph replay, this is a static buffer.