core.inference.sampling.torch_sampling#

Module Contents#

Classes#

TorchSampling

Sampling via bucketed torch.multinomial.

API#

class core.inference.sampling.torch_sampling.TorchSampling(rng: torch.Generator, vocab_size: int)#

Bases: megatron.core.inference.sampling.base.Sampling

Sampling via bucketed torch.multinomial.

Groups requests into unique buckets by (temperature, top_k, top_p) for separate launches.

Initialization

static sample_from_logits(
last_token_logits: torch.Tensor,
temperature: float,
top_k: int,
top_p: float,
*,
generator: torch.Generator,
vocab_size: Optional[int] = None,
) torch.Tensor#

Sample tokens from logits with temperature, top-k, and top-p filtering.

Shared between dynamic batching and static batching.

Parameters:
  • last_token_logits – Logits of shape [batch_size, vocab_size].

  • temperature – Temperature scaling factor.

  • top_k – Top-k filtering value (0 = disabled).

  • top_p – Top-p (nucleus) filtering value (0.0 = disabled).

  • generator – RNG used by torch.multinomial.

  • vocab_size – When provided, asserts top_k < vocab_size and clamps the sampled ids to [0, vocab_size - 1].

Returns:

Sampled token ids of shape [batch_size].

sample_kernel(
logits: torch.Tensor,
n: int,
context,
*,
gather_indices: Optional[torch.Tensor] = None,
token_to_request_index: Optional[torch.Tensor] = None,
eager: bool = False,
cache_key: Any = None,
) torch.Tensor#

Bucket active requests by (temperature, top_k, top_p) and sample each bucket.

Parameters:
  • logits – Logits tensor of shape [>=n, vocab_size].

  • n – Number of rows to sample.

  • context – The active DynamicInferenceContext.

  • gather_indices – When set, sample from logits[gather_indices[:n], :].

  • token_to_request_index – When set, the loop dispatches per-token rather than per-request (used by the speculative path).

  • eager – Accepted for API symmetry; ignored (TorchSampling has no graph wrapper).

  • cache_key – Accepted for API symmetry; ignored.

Returns:

Sampled token ids of shape [n].