core.inference.text_generation_controllers.text_generation_controller#

Module Contents#

Classes#

TextGenerationController

The text generation controller (the main sampling loop)

API#

class core.inference.text_generation_controllers.text_generation_controller.TextGenerationController(
inference_wrapped_model: megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper.AbstractModelInferenceWrapper,
tokenizer,
)#

The text generation controller (the main sampling loop)

This class tokenizes the input, runs inference, samples from logits, and detokenizes the output.

Parameters:
  • inference_wrapped_model (AbstractModelInferenceWrapper) – A model that is wrapped using the specs given in the abstract_model_inference_wrapper.py

  • tokenizer (type) – Tokenizer used for tokenizing and detokenizing the prompts

Initialization

set_stop_word_finished_ids_callback(callback)#

Set a callback to get request IDs that should be marked as finished due to stop words.

The callback should have signature: callback(active_request_ids: List[int]) -> Set[int] Returns a set of request IDs from active_request_ids that should be marked as finished.

Parameters:

callback – Function that returns request IDs to mark as finished.

_init_dynamic_sampling_tensors()#

Initialize tensors needed for dynamic sampling.

tokenize_prompt(
prompt: str,
add_BOS: bool = False,
) List[int]#

Utility to tokenize the input prompts.

Parameters:

prompt (str) – The input prompt.

Returns:

Returns the tokenized prompt.

Return type:

List[int]

_detokenize(
tokens: List[int],
skip_special_tokens: bool = True,
) str#

Detokenize a sequence of token IDs, handling skip_special_tokens for different tokenizer APIs.

On the first call, inspects self.tokenizer.detokenize to see if it accepts a skip_special_tokens keyword argument, and caches that result on self. Subsequent calls will use the cached flag to invoke detokenize with the correct signature (with or without skip_special_tokens).

Parameters:
  • tokens (List[int]) – The token IDs to convert back to text.

  • skip_special_tokens (bool) – Whether to remove special tokens (e.g. BOS/EOS) during detokenization. Only passed through if the tokenizer supports it.

Returns:

The detokenized string.

Return type:

str

detokenize_generations(
tokens_gpu_tensor: torch.Tensor,
lengths_gpu_tensor: torch.Tensor,
detokenize_segments: bool,
skip_special_tokens: bool = True,
) tuple[str, Optional[List[List[str]]]]#

Detokenize the generated tokens.

Parameters:
  • tokens_gpu_tensor (torch.Tensor) – Tensor containing the tokens

  • lengths_gpu_tensor (torch.Tensor) – Tensor containing the lengths of each sequence

  • detokenize_segments (bool) – If True, returns individually detokenized tokens. If False,

  • in (returns None as second element. Helpful for understanding per-token boundaries)

  • text. (generated)

  • skip_special_tokens (bool) – If True removes special tokens like bos

  • detokenization. (during)

Returns:

A tuple containing:

  • str: The complete detokenized text

  • List[str] | None: List of segmented tokens if detokenize_segments is True, else None

Return type:

tuple[str, List[str] | None]

_torch_sampling_func(
last_token_logits: torch.Tensor,
temperature: float,
top_k: int,
top_p: float,
vocab_size: Optional[int] = None,
)#

Samples the logits to generate outputs

Given the logits of the last token, this function samples it according to the parameters defined in sampling_params and returns the samples. If sampling parameters top_n_logprobs > 0 at each step it also updates the top_n_logprobs dict.

Parameters:
  • last_token_logits (torch.Tensor) – The last token logits. A tensor of size [batch_size, vocab_size].

  • temperature (float) – The temperature to use for sampling.

  • top_k (int) – The top-k value to use for sampling.

  • top_p (float) – The top-p value to use for sampling.

  • vocab_size (int) – Obtained from the tokenizer. Defaults to None.

Returns:

1D tensor with [batch_size] elements

Return type:

sampled_logits (torch.Tensor)

sample_from_logits(
last_token_logits: torch.Tensor,
sampling_params: Optional[megatron.core.inference.sampling_params.SamplingParams] = None,
vocab_size: Optional[int] = None,
generation_started: Optional[torch.Tensor] = None,
top_n_logprobs_dict: Dict[int, List[Dict[str, float]]] = None,
logits: Optional[torch.Tensor] = None,
**kwargs,
) torch.Tensor#

Samples the logits to generate outputs

Given the logits of the last token, this function samples it according to the parameters defined in sampling_params and returns the samples. If sampling parameters top_n_logprobs > 0 at each step it also updates the top_n_logprobs dict.

Parameters:
  • last_token_logits (torch.Tensor) – The last token logits. A tensor of size [batch_size, vocab_size]

  • sampling_params (SamplingParams) – The parameters to use for inference.

  • vocab_size (int) – Obtained from the tokenizer. Defaults to None

  • generation_started (torch.Tensor) – A boolean tensor of shape [batch_size]. True indicates the prompt at that index has started generating tokens.

  • top_n_logprobs_dict (top_n_logprobs_dict) – The dict to be updated

Returns:

1D tensor with [batch_size] elements top_n_logprobs_this_step (torch.return_types.topk): a topk tensor with values as logits and indices as the top k elements. None if sampling params top_n_logprobs is 0.

Return type:

sampled_logits (torch.Tensor)

update_generation_status(
updated_prompts_tokens: torch.Tensor,
generation_started: torch.Tensor,
current_context_end_position: int,
is_generation_done_tensor: torch.Tensor,
generated_sequence_lengths: torch.Tensor,
termination_id: Optional[int] = None,
) Tuple[torch.Tensor, torch.Tensor]#

Checks which prompts have reached an end condition

We check which prompts have reached an end condition and set the corresponding flags of the is_generation_done_tensor to True. The generated sequence lengths increase as we keep generating, until that prompts hits an end condition. The generation_started tensor determines which prompts have started generating.

Parameters:
  • updated_prompts_tokens (torch.Tensor) – The prompts tokens updated with the latest generated tokens. A tensor of shape [batch_size, max_seq_len] (i.e max_seq_len = max_prompt_len + tokens_to_generate)

  • generation_started (torch.Tensor) – A boolean tensor of shape [batch_size]. True indicates the prompt at that index has started generating tokens.

  • current_context_end_position (int) – An integer indicating which position to extract from the prompts tokens to get the latest generated tokens.

  • is_generation_done_tensor (torch.Tensor) – A boolean tensor of shape [batch_size]. True indicates the prompt at that index has reached end condition.

  • generated_sequence_lengths (torch.Tensor) – A int tensor of shape [batch_size]. Each value represents the generated sequence lengths for that prompt.

Returns:

Returns the boolean is_generation_done_tensor and the generated_sequence_lengths after updating it

Return type:

Tuple[torch.Tensor, torch.Tensor]

pad_input_prompt_tokens(
batch_prompt_tokens_list: List[List[int]],
padded_batch_size: int,
padded_sequence_length: int,
) torch.Tensor#

Method to pad input prompts

Given a list of prompts, pad them all to uniform length

Parameters:
  • batch_prompt_tokens_list (List[List[int]]) – A list containing the prompt tokens

  • padded_batch_size (int) – The maximum number of requests for this batch

  • padded_sequence_length (int) – The maximum number of input + output tokens for this batch

Returns:

A torch tensor of shape [padded_batch_size, padded_sequence_length]

Return type:

torch.Tensor

unpad_input_prompt_tokens(
padded_batch_prompt_tokens: torch.Tensor,
original_batch_size: int,
)#

Truncates the given input tensor back to the original prompt size before padding.

Parameters:
  • padded_batch_prompt_tokens (torch.Tensor) – The padded tokens tensor

  • original_batch_size (int) – The original batch size before padding

_dynamic_step_context_init(
construct_graph_dimensions: Optional[megatron.core.inference.batch_dimensions_utils.InferenceBatchDimensions] = None,
is_dummy_forward: bool = False,
)#

Initializes the inference context for dynamic batching.

Parameters:
  • construct_graph_dimensions (Optional[InferenceBatchDimensions]) – The graph config to use for constructing the cuda graphs.

  • is_dummy_forward (bool) – Whether we are running an expert parallel dummy forward pass

Returns:

The active input IDs. position_ids (Tensor): The active position IDs.

Return type:

input_ids (Tensor)

_dynamic_step_forward_logits(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
) torch.Tensor#

Forward step the model to get logits for dynamic batching.

This also handles logits-broadcasting for pipeline parallelism.

Parameters:
  • input_ids (Tensor) – The input token IDs.

  • position_ids (Tensor) – The position IDs.

_dynamic_step_sample_bookkeeping()#

Perform bookkeeping necessary to sample logits for dynamic batching.

_dynamic_step_sample_logits(logits: torch.Tensor)#

Sample tokens from logits for dynamic batching.

Parameters:

logits (Tensor) – The logits from the forward pass.

_dynamic_step_log_probs_bookkeeping() Tuple[bool, bool]#

Perform bookkeeping necessary to compute log probs for dynamic batching.

Returns:

Whether to return the sampled log_probs.

Return type:

return_log_probs (bool)

_router_record_bookkeeping() Optional[Dict[int, torch.Tensor]]#

Collect and map routing indices per request for MoE router recording.

This method retrieves recorded routing decisions and maps them to individual requests using the context’s request_ids and query_lengths. Uses the context’s routing_metadata when available (which handles CUDA graph static buffers automatically). Must be called while context attributes are still valid (before request transitions).

Returns:

A dictionary mapping request_id to a tensor of shape [num_tokens, num_layers, topk]. Returns None if routing replay is disabled or no routing data was recorded.

Return type:

Optional[Dict[int, Tensor]]

_dynamic_step_calculate_log_probs(
logits: torch.Tensor,
) Optional[torch.Tensor]#

Calculate log probs from logits.

_dynamic_step_calculate_top_n_logprobs(
logits: torch.Tensor,
log_probs_tensor: Optional[torch.Tensor] = None,
) Optional[Dict[int, List[Tuple[torch.Tensor, torch.Tensor]]]]#

Calculate top-n log probs from logits for dynamic batching.

Parameters:
  • logits (Tensor) – The logits to compute top-n log probs from.

  • log_probs_tensor (Optional[Tensor]) – Pre-computed log probabilities tensor. If provided, avoids recomputing log_softmax. Should be the tensor returned by calculate_log_probs.

Returns:

A dictionary mapping request_idx to list of (top_n_logprobs, top_n_indices) tuples. Each tuple in the list represents one token position.

dummy_forward()#

Perform a dummy forward pass. This is used in expert model parallelism on ranks that do not have any real requests.

_dynamic_step_context_bookkeeping() Dict[str, torch.Tensor]#

Update the dynamic inference context after sampling.

Parameters:
  • new_sample (Tensor) – The newly sampled tokens.

  • request_metadata (Optional[Dict[str, Tensor]]) – An override for the tensors that manage request metadata, such as sampling parameters. By default, this metadata is retrieved from the context.

Returns:

A dictionary containing: active_request_ids (Tensor): Current active request IDs. newly_paused_request_ids (Tensor): Newly paused request IDs. finished_request_ids (Tensor): Finished request IDs.

Return type:

Dict [str, Tensor]

async async_generate_output_tokens_dynamic_batch(
skip_bookkeeping: Optional[bool] = False,
) Optional[Dict]#

Forward step the model and update the inference context.

Parameters:

skip_bookkeeping (Optional[bool]) – If true, skip the context bookkeeping step.

Returns:

A dictionary containing: active_request_ids (Tensor): Current active request IDs. newly_paused_request_ids (Tensor): Newly paused request IDs. finished_request_ids (Tensor): Finished request IDs. sample (Tensor): New sample. log_probs (Optional[Tensor]): Log probabilities of the new sample, if requested. cuda_graph_request_count (Optional[int]): Size of cuda graph used for this step.

Return type:

(Optional[Dict])

generate_output_tokens_dynamic_batch(
loop: Optional[asyncio.AbstractEventLoop] = None,
) Optional[Dict]#

Synchronous wrapper for `self.async_generate_output_tokens_dynamic_batch.

_update_top_n_logprobs_dict(
top_n_logprobs_this_step: torch.Tensor,
top_n_logprobs_indices: torch.Tensor,
mask: torch.Tensor,
top_n_logprobs_dict: Dict[int, List[Dict[str, float]]],
)#

Function to update the top_n_logprobs at each step

This function goes through the topn logprobs generated for each, and for whichever batch has started generating tokens, it updates the top_n_logprobs_dict with the decoded token (string) as the key and the logit as the value. top_n_logprobs_dict has as keys the batch idx, the values is a list, where each element represents a dictionary of decoded token as key and logit as value generated at each step

Parameters:
  • top_n_logprobs_this_step (torch.Tensor) – The top n logprob values

  • top_n_logprobs_indices (torch.Tensor) – The indices corresponding to the top n logprobs

  • mask (torch.Tensor) – A mask to indicate which requests should append to the dict

  • top_n_logprobs_dict (top_n_logprobs_dict) – The dict to be updated

generate_all_output_tokens_static_batch(
active_requests: OrderedDict[int, megatron.core.inference.inference_request.InferenceRequest],
active_streams: Optional[OrderedDict[str, megatron.core.inference.async_stream.AsyncStream]] = None,
) OrderedDict[int, megatron.core.inference.inference_request.InferenceRequest]#

Utility to generate all the output tokens and probabilities for the prompts.

This utility generates the output tokens for a static batch. It runs the forward steps till all prompts complete generation, updates the status of these requests to completed, adds the generated result and returns these requests

Parameters:

active_requests (OrderedDict[int, InferenceRequest]) – The input active requests.

Returns:

The result for each of the incoming requests

Return type:

OrderedDict[int, InferenceRequest]

prep_inference_input(
prompts_tokens: torch.Tensor,
active_requests: OrderedDict[int, megatron.core.inference.inference_request.InferenceRequest],
use_attention_mask: bool = False,
) Dict[str, Any]#

Preparing input data for inference, using respective wrapper’s prep_inference_input method # pylint: disable=line-too-long

Parameters:
  • prompts_tokens (torch.Tensor) – A tensor of shape [batch_size, max_sequence_length]

  • active_requests (OrderedDict[int, InferenceRequest]) – The input active requests

  • use_attention_mask (bool) – Whether to use an attention mask. Should be set to True only when exclusively doing prefill (no decode) with variable prompt lengths.

Returns:

A dict of the inference input for the current batch.

stream_tokens(
sampling_params: megatron.core.inference.sampling_params.SamplingParams,
request_ids: List[int],
requests: List[megatron.core.inference.inference_request.InferenceRequest],
streams: List[megatron.core.inference.async_stream.AsyncStream],
generation_started: List[bool],
is_generation_done: List[bool],
tokens: torch.Tensor,
prompt_lengths: List[int],
generated_lengths: List[int],
output_log_probs: Union[torch.Tensor, None],
)#

Asynchronously streams tokens for the given requests.

Parameters:
  • sampling_params (SamplingParams) – The sampling parameters.

  • request_ids (List[int]) – The request IDs.

  • request (List[InferenceRequest]) – The requests.

  • stream (List[AsyncStream]) – The streams over which to send tokens.

  • generation_started (List[bool]) – Whether the decode step has started.

  • is_generation_done (List[bool]) – Whether generation has completed.

  • tokens (torch.Tensor) – The tokens for this request.

  • prompt_lengths (List[int]) – The number of prompt tokens for each request.

  • generated_lengths (List[int]) – The number of output tokens for each request.

  • output_log_probs (torch.Tensor, optional) – The log probs for each request.