core.inference.sampling.torch_sampling#
Module Contents#
Classes#
Sampling via bucketed |
API#
- class core.inference.sampling.torch_sampling.TorchSampling(rng: torch.Generator, vocab_size: int)#
Bases:
megatron.core.inference.sampling.base.SamplingSampling via bucketed
torch.multinomial.Groups requests into unique buckets by
(temperature, top_k, top_p)for separate launches.Initialization
- static sample_from_logits(
- last_token_logits: torch.Tensor,
- temperature: float,
- top_k: int,
- top_p: float,
- *,
- generator: torch.Generator,
- vocab_size: Optional[int] = None,
Sample tokens from logits with temperature, top-k, and top-p filtering.
Shared between dynamic batching and static batching.
- Parameters:
last_token_logits – Logits of shape
[batch_size, vocab_size].temperature – Temperature scaling factor.
top_k – Top-k filtering value (0 = disabled).
top_p – Top-p (nucleus) filtering value (0.0 = disabled).
generator – RNG used by
torch.multinomial.vocab_size – When provided, asserts
top_k < vocab_sizeand clamps the sampled ids to[0, vocab_size - 1].
- Returns:
Sampled token ids of shape
[batch_size].
- sample_kernel(
- logits: torch.Tensor,
- n: int,
- context,
- *,
- gather_indices: Optional[torch.Tensor] = None,
- token_to_request_index: Optional[torch.Tensor] = None,
- eager: bool = False,
- cache_key: Any = None,
Bucket active requests by
(temperature, top_k, top_p)and sample each bucket.- Parameters:
logits – Logits tensor of shape
[>=n, vocab_size].n – Number of rows to sample.
context – The active DynamicInferenceContext.
gather_indices – When set, sample from
logits[gather_indices[:n], :].token_to_request_index – When set, the loop dispatches per-token rather than per-request (used by the speculative path).
eager – Accepted for API symmetry; ignored (TorchSampling has no graph wrapper).
cache_key – Accepted for API symmetry; ignored.
- Returns:
Sampled token ids of shape
[n].