core.inference.text_generation_controllers.text_generation_controller#
Module Contents#
Classes#
The text generation controller (the main sampling loop) |
API#
- class core.inference.text_generation_controllers.text_generation_controller.TextGenerationController(
- inference_wrapped_model: megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper.AbstractModelInferenceWrapper,
- tokenizer,
The text generation controller (the main sampling loop)
This class tokenizes the input, runs inference, samples from logits, and detokenizes the output.
- Parameters:
inference_wrapped_model (AbstractModelInferenceWrapper) – A model that is wrapped using the specs given in the abstract_model_inference_wrapper.py
tokenizer (type) – Tokenizer used for tokenizing and detokenizing the prompts
Initialization
- _get_mtp_num_heads() int#
Get the number of MTP layers from the model config.
- set_stop_word_finished_ids_callback(callback)#
Set a callback to get request IDs that should be marked as finished due to stop words.
The callback should have signature: callback(active_request_ids: List[int]) -> Set[int] Returns a set of request IDs from active_request_ids that should be marked as finished.
- Parameters:
callback – Function that returns request IDs to mark as finished.
- _init_dynamic_sampling_tensors()#
Initialize tensors needed for dynamic sampling.
- _init_mtp_sampling_tensor()#
Initialize the MTP sampling tensor after num_speculative_tokens is set.
- static tokenize_prompt(
- tokenizer,
- prompt: str,
- add_BOS: bool = False,
Utility to tokenize the input prompts.
- Parameters:
tokenizer – The tokenizer to use.
prompt (str) – The input prompt.
add_BOS (bool) – Whether to add a BOS token.
- Returns:
Returns the tokenized prompt.
- Return type:
List[int]
- static detokenize(
- tokenizer,
- tokens: List[int],
- remove_EOD: bool = True,
- skip_special_tokens: bool = True,
Detokenize a sequence of token IDs, optionally removing trailing EOD tokens and handling skip_special_tokens for different tokenizer APIs.
- Parameters:
tokenizer – The tokenizer to use for detokenization.
tokens (List[int]) – The token IDs to convert back to text.
remove_EOD (bool) – Whether to remove trailing EOD tokens before detokenization. Defaults to True.
skip_special_tokens (bool) – Whether to remove special tokens (e.g. BOS/EOS) during detokenization. Only passed through if the tokenizer supports it.
- Returns:
The detokenized string.
- Return type:
str
- detokenize_generations(
- tokens_gpu_tensor: torch.Tensor,
- lengths_gpu_tensor: torch.Tensor,
- detokenize_segments: bool,
- skip_special_tokens: bool = True,
Detokenize the generated tokens.
- Parameters:
tokens_gpu_tensor (torch.Tensor) – Tensor containing the tokens
lengths_gpu_tensor (torch.Tensor) – Tensor containing the lengths of each sequence
detokenize_segments (bool) – If True, returns individually detokenized tokens. If False,
in (returns None as second element. Helpful for understanding per-token boundaries)
text. (generated)
skip_special_tokens (bool) – If True removes special tokens like bos
detokenization. (during)
- Returns:
A tuple containing:
str: The complete detokenized text
List[str] | None: List of segmented tokens if detokenize_segments is True, else None
- Return type:
tuple[str, List[str] | None]
- _torch_sampling_func(
- last_token_logits: torch.Tensor,
- temperature: float,
- top_k: int,
- top_p: float,
- vocab_size: Optional[int] = None,
Samples the logits to generate outputs
Given the logits of the last token, this function samples it according to the parameters defined in sampling_params and returns the samples. If sampling parameters top_n_logprobs > 0 at each step it also updates the top_n_logprobs dict.
- Parameters:
last_token_logits (torch.Tensor) – The last token logits. A tensor of size [batch_size, vocab_size].
temperature (float) – The temperature to use for sampling.
top_k (int) – The top-k value to use for sampling.
top_p (float) – The top-p value to use for sampling.
vocab_size (int) – Obtained from the tokenizer. Defaults to None.
- Returns:
1D tensor with [batch_size] elements
- Return type:
sampled_logits (torch.Tensor)
- sample_from_logits(
- last_token_logits: torch.Tensor,
- sampling_params: Optional[megatron.core.inference.sampling_params.SamplingParams] = None,
- vocab_size: Optional[int] = None,
- generation_started: Optional[torch.Tensor] = None,
- top_n_logprobs_dict: Dict[int, List[Dict[str, float]]] = None,
- logits: Optional[torch.Tensor] = None,
- **kwargs,
Samples the logits to generate outputs
Given the logits of the last token, this function samples it according to the parameters defined in sampling_params and returns the samples. If sampling parameters top_n_logprobs > 0 at each step it also updates the top_n_logprobs dict.
- Parameters:
last_token_logits (torch.Tensor) – The last token logits. A tensor of size [batch_size, vocab_size]
sampling_params (SamplingParams) – The parameters to use for inference.
vocab_size (int) – Obtained from the tokenizer. Defaults to None
generation_started (torch.Tensor) – A boolean tensor of shape [batch_size]. True indicates the prompt at that index has started generating tokens.
top_n_logprobs_dict (top_n_logprobs_dict) – The dict to be updated
- Returns:
1D tensor with [batch_size] elements top_n_logprobs_this_step (torch.return_types.topk): a topk tensor with values as logits and indices as the top k elements. None if sampling params top_n_logprobs is 0.
- Return type:
sampled_logits (torch.Tensor)
- update_generation_status(
- updated_prompts_tokens: torch.Tensor,
- generation_started: torch.Tensor,
- current_context_end_position: int,
- is_generation_done_tensor: torch.Tensor,
- generated_sequence_lengths: torch.Tensor,
- termination_id: Optional[int] = None,
Checks which prompts have reached an end condition
We check which prompts have reached an end condition and set the corresponding flags of the is_generation_done_tensor to True. The generated sequence lengths increase as we keep generating, until that prompts hits an end condition. The generation_started tensor determines which prompts have started generating.
- Parameters:
updated_prompts_tokens (torch.Tensor) – The prompts tokens updated with the latest generated tokens. A tensor of shape [batch_size, max_seq_len] (i.e max_seq_len = max_prompt_len + tokens_to_generate)
generation_started (torch.Tensor) – A boolean tensor of shape [batch_size]. True indicates the prompt at that index has started generating tokens.
current_context_end_position (int) – An integer indicating which position to extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor) – A boolean tensor of shape [batch_size]. True indicates the prompt at that index has reached end condition.
generated_sequence_lengths (torch.Tensor) – A int tensor of shape [batch_size]. Each value represents the generated sequence lengths for that prompt.
- Returns:
Returns the boolean is_generation_done_tensor and the generated_sequence_lengths after updating it
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- pad_input_prompt_tokens(
- batch_prompt_tokens_list: List[List[int]],
- padded_batch_size: int,
- padded_sequence_length: int,
Method to pad input prompts
Given a list of prompts, pad them all to uniform length
- Parameters:
batch_prompt_tokens_list (List[List[int]]) – A list containing the prompt tokens
padded_batch_size (int) – The maximum number of requests for this batch
padded_sequence_length (int) – The maximum number of input + output tokens for this batch
- Returns:
A torch tensor of shape [padded_batch_size, padded_sequence_length]
- Return type:
torch.Tensor
- unpad_input_prompt_tokens(
- padded_batch_prompt_tokens: torch.Tensor,
- original_batch_size: int,
Truncates the given input tensor back to the original prompt size before padding.
- Parameters:
padded_batch_prompt_tokens (torch.Tensor) – The padded tokens tensor
original_batch_size (int) – The original batch size before padding
- _dynamic_step_context_init(
- construct_graph_dimensions: Optional[megatron.core.inference.batch_dimensions_utils.InferenceBatchDimensions] = None,
- is_dummy_forward: bool = False,
Initializes the inference context for dynamic batching.
- Parameters:
construct_graph_dimensions (Optional[InferenceBatchDimensions]) – The graph config to use for constructing the cuda graphs.
is_dummy_forward (bool) – Whether we are running an expert parallel dummy forward pass
- Returns:
The active input IDs. position_ids (Tensor): The active position IDs.
- Return type:
input_ids (Tensor)
- _dynamic_step_forward_logits(
- input_ids: torch.Tensor,
- position_ids: torch.Tensor,
Forward step the model to get logits for dynamic batching.
This also handles logits-broadcasting for pipeline parallelism.
- Parameters:
input_ids (Tensor) – The input token IDs.
position_ids (Tensor) – The position IDs.
- _dynamic_step_sample_bookkeeping()#
Perform bookkeeping necessary to sample logits for dynamic batching.
- _rewind_kv_cache()#
Update the KV cache bookkeeping for speculative decoding.
After forward pass with speculative tokens, some tokens may be rejected. This function “rewinds” the KV cache bookkeeping to reflect only the accepted tokens.
When speculative tokens are rejected, we need to:
Update request_kv_length_offsets (total sequence length)
Update request_last_kv_block_offset (position within last block)
If rewinding crosses a block boundary:
Reduce request_kv_block_counts
Update request_last_kv_block_id to point to the previous block
Clear the entry in request_to_kv_block_ids for the released block
Release the block back to the allocator
- _sample_from_logits_2d(logits_2d: torch.Tensor) torch.Tensor#
Sample tokens from 2D logits using existing sampling parameters.
- Parameters:
logits_2d (Tensor) – Logits of shape [num_requests, vocab_size].
- Returns:
Sampled tokens of shape [num_requests].
- Return type:
Tensor
- _compute_serial_mtp_and_sample()#
Compute MTP logits serially after verification and sample speculative tokens.
This ensures that MTP predictions are always conditioned on verified tokens. Each MTP depth receives the correctly sampled token from the previous depth (or the base token for depth 0) rather than stale speculative tokens from the previous step.
- _get_required_logit_indices(
- request_in_prefill_status_tensor: torch.Tensor,
- request_query_lengths: torch.Tensor,
- num_decode_requests: int,
- num_prefill_requests: int,
- device: torch.device,
Get indices into the logits tensor for tokens that need sampling.
For decode requests, all tokens (base + speculative) are needed. For prefill requests, only the last token logits are needed. Decode requests will always be on the left, followed by prefill requests.
Example with 5 requests (2 spec tokens): Assume input ids : [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d1 d2 | e1 e2 e3 e4] Request to prefill [ 0 | 0 | 0 | 1 | 1 ] Request query lengths [ 3 | 3 | 3 | 2 | 4 ] OUTPUT : required_logit_indices [ 0 1 2 | 3 4 5 | 6 7 8 | 10 | 14 ]
- Returns:
Indices into the sequence dimension of the logits tensor.
- Return type:
Tensor
- _sample_speculative_logits(
- required_logits: torch.Tensor,
- request_in_prefill_status_tensor: torch.Tensor,
Sample tokens from logits using sampling buckets.
For torch sampling buckets: [request_indices, temp, top_k, top_p]
Example with 5 requests: token_to_request_idx : [ 0 0 0 | 1 1 1 | 2 2 2 | 3 | 4 ] required_logits : [ a5l a6l a7l | b3l b4l b5l | c6l c7l c8l | d2l | e4l ] # Shape [11, vocab_size]
Sampling buckets: [[[0,2], temp1, top_k1, top_p1], [[1], temp3, top_k3, top_p3], [[3, 4], temp2, top_k2, top_p2]] Final output tokens : [a5s a6s a7s c6s c7s c8s b3s b4s b5s d2s e4s] # Shape [11] (Rearranged from sampling bucket order back to input order using token_order)
- Returns:
(output_tokens, repeats) where output_tokens has shape [total_required_tokens]
- Return type:
tuple
- _verify_speculative_tokens(
- output_tokens: torch.Tensor,
- input_tokens_required: torch.Tensor,
- request_in_prefill_status_tensor: torch.Tensor,
- repeats: torch.Tensor,
- num_decode_requests: int,
- num_prefill_requests: int,
- active_request_count: int,
Verify speculative tokens against input tokens and compute acceptance.
Creates an accepted tokens mask where:
For prefill requests, the token is always accepted.
For decode requests, the first token (base token) is always accepted, then we compare sampled tokens with input tokens and accept consecutive matches. Then finds the index of the last accepted token per request.
Example (assume 1, 2, and 0 spec tokens are accepted in the first 3 decode requests): input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 Output tokens [ a6o a7o a8o | b40 b5o b6o | c7o c8o c9o | d3o | e5o ] Output tokens right shift [ d3o a6o a7o | a8o b40 b5o | b6o c7o c8o | c9o | d3o ] Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] Last one indices [ 1 | 5 | 6 | 9 | 10 ]
- Returns:
(last_one_indices, accepted_tokens_mask, input_tokens_required) where last_one_indices contains the index of the last accepted token per request.
- Return type:
tuple
- _dynamic_step_sample_logits_and_verify_tokens(
- logits: torch.Tensor,
- input_ids: torch.Tensor,
Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens.
- _dynamic_step_sample_logits(logits: torch.Tensor)#
Sample tokens from logits for dynamic batching.
- Parameters:
logits (Tensor) – The logits from the forward pass.
- _dynamic_step_log_probs_bookkeeping() Tuple[bool, bool]#
Perform bookkeeping necessary to compute log probs for dynamic batching.
- Returns:
Whether to return the sampled log_probs.
- Return type:
return_log_probs (bool)
- _router_record_bookkeeping() Optional[Dict[int, torch.Tensor]]#
Collect and map routing indices per request for MoE router recording.
This method retrieves recorded routing decisions and maps them to individual requests using the context’s request_ids and query_lengths. Uses the context’s routing_metadata when available (which handles CUDA graph static buffers automatically). Must be called while context attributes are still valid (before request transitions).
- Returns:
A dictionary mapping request_id to a tensor of shape [num_tokens, num_layers, topk]. Returns None if routing replay is disabled or no routing data was recorded.
- Return type:
Optional[Dict[int, Tensor]]
- _dynamic_step_calculate_log_probs(
- logits: torch.Tensor,
Calculate log probs from logits.
- _dynamic_step_calculate_log_probs_speculative(
- logits: torch.Tensor,
Calculate log probs from logits for speculative decoding.
For decode requests, computes log probs for each accepted speculative token and the newly sampled token using the main model logits. For prefill requests, handles prompt log probs the same way as non-speculative decoding.
The main model logits at position j predict the token at position j+1. So:
log_prob(accepted_token[j]) comes from logits at position j
log_prob(newly_sampled_token) comes from logits at position accepted_count
- Parameters:
logits (Tensor) – The main model logits [1, seq_len, vocab_size].
- Returns:
log_probs_list: List of lists, one per active request, containing log probs for the tokens emitted in this step. log_probs_tensor: Full log_softmax tensor for top-n computation.- Return type:
Tuple of (log_probs_list, log_probs_tensor)
- _dynamic_step_calculate_top_n_logprobs_speculative(
- log_probs_tensor: torch.Tensor,
Calculate top-n log probs for speculative decoding.
For decode requests, computes top-n at each position that produced an emitted token (accepted speculative positions + the newly sampled position). For prefill requests, behaves identically to the non-speculative path.
- Parameters:
log_probs_tensor (Tensor) – Pre-computed log_softmax tensor from _dynamic_step_calculate_log_probs_speculative.
- Returns:
A dictionary mapping request_idx to list of (top_n_values, top_n_indices) tuples, one per emitted token position.
- _dynamic_step_calculate_top_n_logprobs(
- logits: torch.Tensor,
- log_probs_tensor: Optional[torch.Tensor] = None,
Calculate top-n log probs from logits for dynamic batching.
- Parameters:
logits (Tensor) – The logits to compute top-n log probs from.
log_probs_tensor (Optional[Tensor]) – Pre-computed log probabilities tensor. If provided, avoids recomputing log_softmax. Should be the tensor returned by calculate_log_probs.
- Returns:
A dictionary mapping request_idx to list of (top_n_logprobs, top_n_indices) tuples. Each tuple in the list represents one token position.
- dummy_forward()#
Perform a dummy forward pass. This is used in expert model parallelism on ranks that do not have any real requests. It may run in eager mode.
- _dummy_serial_mtp_forward()#
Run dummy MTP forward passes to participate in EP collectives.
When speculative decoding is active and MTP layers contain MoE sublayers (inherited from the decoder layer spec), each serial MTP step triggers EP all-to-all collectives. The dummy EP rank must issue matching collective calls so the real ranks do not hang.
This mirrors the structure of
_compute_serial_mtp_and_sample:On the last PP stage (where MTP resides): run
compute_mtp_single_stepwith dummy tensors so the MoE all-to-all is executed.When PP > 1: participate in the
broadcast_from_last_pipeline_stagethat the real ranks also perform.
- _dynamic_step_context_bookkeeping() Dict[str, torch.Tensor]#
Update the dynamic inference context after sampling.
- Parameters:
new_sample (Tensor) – The newly sampled tokens.
request_metadata (Optional[Dict[str, Tensor]]) – An override for the tensors that manage request metadata, such as sampling parameters. By default, this metadata is retrieved from the context.
- Returns:
A dictionary containing: active_request_ids (Tensor): Current active request IDs. newly_paused_request_ids (Tensor): Newly paused request IDs. finished_request_ids (Tensor): Finished request IDs.
- Return type:
Dict [str, Tensor]
- async async_generate_output_tokens_dynamic_batch(
- skip_bookkeeping: Optional[bool] = False,
Forward step the model and update the inference context.
- Parameters:
skip_bookkeeping (Optional[bool]) – If true, skip the context bookkeeping step.
- Returns:
A dictionary containing: active_request_ids (Tensor): Current active request IDs. newly_paused_request_ids (Tensor): Newly paused request IDs. finished_request_ids (Tensor): Finished request IDs. sample (Tensor): New sample. log_probs (Optional[Tensor]): Log probabilities of the new sample, if requested. cuda_graph_request_count (Optional[int]): Size of cuda graph used for this step.
- Return type:
(Optional[Dict])
- generate_output_tokens_dynamic_batch(
- loop: Optional[asyncio.AbstractEventLoop] = None,
Synchronous wrapper for `self.async_generate_output_tokens_dynamic_batch.
- _update_top_n_logprobs_dict(
- top_n_logprobs_this_step: torch.Tensor,
- top_n_logprobs_indices: torch.Tensor,
- mask: torch.Tensor,
- top_n_logprobs_dict: Dict[int, List[Dict[str, float]]],
Function to update the top_n_logprobs at each step
This function goes through the topn logprobs generated for each, and for whichever batch has started generating tokens, it updates the top_n_logprobs_dict with the decoded token (string) as the key and the logit as the value. top_n_logprobs_dict has as keys the batch idx, the values is a list, where each element represents a dictionary of decoded token as key and logit as value generated at each step
- Parameters:
top_n_logprobs_this_step (torch.Tensor) – The top n logprob values
top_n_logprobs_indices (torch.Tensor) – The indices corresponding to the top n logprobs
mask (torch.Tensor) – A mask to indicate which requests should append to the dict
top_n_logprobs_dict (top_n_logprobs_dict) – The dict to be updated
- generate_all_output_tokens_static_batch(
- active_requests: OrderedDict[int, megatron.core.inference.inference_request.InferenceRequest],
- active_streams: Optional[OrderedDict[str, megatron.core.inference.async_stream.AsyncStream]] = None,
Utility to generate all the output tokens and probabilities for the prompts.
This utility generates the output tokens for a static batch. It runs the forward steps till all prompts complete generation, updates the status of these requests to completed, adds the generated result and returns these requests
- Parameters:
active_requests (OrderedDict[int, InferenceRequest]) – The input active requests.
- Returns:
The result for each of the incoming requests
- Return type:
OrderedDict[int, InferenceRequest]
- prep_inference_input(
- prompts_tokens: torch.Tensor,
- active_requests: OrderedDict[int, megatron.core.inference.inference_request.InferenceRequest],
- use_attention_mask: bool = False,
Preparing input data for inference, using respective wrapper’s prep_inference_input method # pylint: disable=line-too-long
- Parameters:
prompts_tokens (torch.Tensor) – A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[int, InferenceRequest]) – The input active requests
use_attention_mask (bool) – Whether to use an attention mask. Should be set to True only when exclusively doing prefill (no decode) with variable prompt lengths.
- Returns:
A dict of the inference input for the current batch.
- stream_tokens(
- sampling_params: megatron.core.inference.sampling_params.SamplingParams,
- request_ids: List[int],
- requests: List[megatron.core.inference.inference_request.InferenceRequest],
- streams: List[megatron.core.inference.async_stream.AsyncStream],
- generation_started: List[bool],
- is_generation_done: List[bool],
- tokens: torch.Tensor,
- prompt_lengths: List[int],
- generated_lengths: List[int],
- output_log_probs: Union[torch.Tensor, None],
Asynchronously streams tokens for the given requests.
- Parameters:
sampling_params (SamplingParams) – The sampling parameters.
request_ids (List[int]) – The request IDs.
request (List[InferenceRequest]) – The requests.
stream (List[AsyncStream]) – The streams over which to send tokens.
generation_started (List[bool]) – Whether the decode step has started.
is_generation_done (List[bool]) – Whether generation has completed.
tokens (torch.Tensor) – The tokens for this request.
prompt_lengths (List[int]) – The number of prompt tokens for each request.
generated_lengths (List[int]) – The number of output tokens for each request.
output_log_probs (torch.Tensor, optional) – The log probs for each request.