core.inference.text_generation_controllers.text_generation_controller#
Module Contents#
Classes#
The text generation controller (the main sampling loop) |
API#
- class core.inference.text_generation_controllers.text_generation_controller.TextGenerationController(
- inference_wrapped_model: megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper.AbstractModelInferenceWrapper,
- tokenizer,
The text generation controller (the main sampling loop)
This class tokenizes the input, runs inference, samples from logits, and detokenizes the output.
- Parameters:
inference_wrapped_model (AbstractModelInferenceWrapper) – A model that is wrapped using the specs given in the abstract_model_inference_wrapper.py
tokenizer (type) – Tokenizer used for tokenizing and detokenizing the prompts
Initialization
- set_stop_word_finished_ids_callback(callback)#
Set a callback to get request IDs that should be marked as finished due to stop words.
The callback should have signature: callback(active_request_ids: List[int]) -> Set[int] Returns a set of request IDs from active_request_ids that should be marked as finished.
- Parameters:
callback – Function that returns request IDs to mark as finished.
- _init_dynamic_sampling_tensors()#
Initialize tensors needed for dynamic sampling.
- tokenize_prompt(
- prompt: str,
- add_BOS: bool = False,
Utility to tokenize the input prompts.
- Parameters:
prompt (str) – The input prompt.
- Returns:
Returns the tokenized prompt.
- Return type:
List[int]
- _detokenize(
- tokens: List[int],
- skip_special_tokens: bool = True,
Detokenize a sequence of token IDs, handling skip_special_tokens for different tokenizer APIs.
On the first call, inspects
self.tokenizer.detokenizeto see if it accepts askip_special_tokenskeyword argument, and caches that result onself. Subsequent calls will use the cached flag to invokedetokenizewith the correct signature (with or withoutskip_special_tokens).- Parameters:
tokens (List[int]) – The token IDs to convert back to text.
skip_special_tokens (bool) – Whether to remove special tokens (e.g. BOS/EOS) during detokenization. Only passed through if the tokenizer supports it.
- Returns:
The detokenized string.
- Return type:
str
- detokenize_generations(
- tokens_gpu_tensor: torch.Tensor,
- lengths_gpu_tensor: torch.Tensor,
- detokenize_segments: bool,
- skip_special_tokens: bool = True,
Detokenize the generated tokens.
- Parameters:
tokens_gpu_tensor (torch.Tensor) – Tensor containing the tokens
lengths_gpu_tensor (torch.Tensor) – Tensor containing the lengths of each sequence
detokenize_segments (bool) – If True, returns individually detokenized tokens. If False,
in (returns None as second element. Helpful for understanding per-token boundaries)
text. (generated)
skip_special_tokens (bool) – If True removes special tokens like bos
detokenization. (during)
- Returns:
A tuple containing:
str: The complete detokenized text
List[str] | None: List of segmented tokens if detokenize_segments is True, else None
- Return type:
tuple[str, List[str] | None]
- _torch_sampling_func(
- last_token_logits: torch.Tensor,
- temperature: float,
- top_k: int,
- top_p: float,
- vocab_size: Optional[int] = None,
Samples the logits to generate outputs
Given the logits of the last token, this function samples it according to the parameters defined in sampling_params and returns the samples. If sampling parameters top_n_logprobs > 0 at each step it also updates the top_n_logprobs dict.
- Parameters:
last_token_logits (torch.Tensor) – The last token logits. A tensor of size [batch_size, vocab_size].
temperature (float) – The temperature to use for sampling.
top_k (int) – The top-k value to use for sampling.
top_p (float) – The top-p value to use for sampling.
vocab_size (int) – Obtained from the tokenizer. Defaults to None.
- Returns:
1D tensor with [batch_size] elements
- Return type:
sampled_logits (torch.Tensor)
- sample_from_logits(
- last_token_logits: torch.Tensor,
- sampling_params: Optional[megatron.core.inference.sampling_params.SamplingParams] = None,
- vocab_size: Optional[int] = None,
- generation_started: Optional[torch.Tensor] = None,
- top_n_logprobs_dict: Dict[int, List[Dict[str, float]]] = None,
- logits: Optional[torch.Tensor] = None,
- **kwargs,
Samples the logits to generate outputs
Given the logits of the last token, this function samples it according to the parameters defined in sampling_params and returns the samples. If sampling parameters top_n_logprobs > 0 at each step it also updates the top_n_logprobs dict.
- Parameters:
last_token_logits (torch.Tensor) – The last token logits. A tensor of size [batch_size, vocab_size]
sampling_params (SamplingParams) – The parameters to use for inference.
vocab_size (int) – Obtained from the tokenizer. Defaults to None
generation_started (torch.Tensor) – A boolean tensor of shape [batch_size]. True indicates the prompt at that index has started generating tokens.
top_n_logprobs_dict (top_n_logprobs_dict) – The dict to be updated
- Returns:
1D tensor with [batch_size] elements top_n_logprobs_this_step (torch.return_types.topk): a topk tensor with values as logits and indices as the top k elements. None if sampling params top_n_logprobs is 0.
- Return type:
sampled_logits (torch.Tensor)
- update_generation_status(
- updated_prompts_tokens: torch.Tensor,
- generation_started: torch.Tensor,
- current_context_end_position: int,
- is_generation_done_tensor: torch.Tensor,
- generated_sequence_lengths: torch.Tensor,
- termination_id: Optional[int] = None,
Checks which prompts have reached an end condition
We check which prompts have reached an end condition and set the corresponding flags of the is_generation_done_tensor to True. The generated sequence lengths increase as we keep generating, until that prompts hits an end condition. The generation_started tensor determines which prompts have started generating.
- Parameters:
updated_prompts_tokens (torch.Tensor) – The prompts tokens updated with the latest generated tokens. A tensor of shape [batch_size, max_seq_len] (i.e max_seq_len = max_prompt_len + tokens_to_generate)
generation_started (torch.Tensor) – A boolean tensor of shape [batch_size]. True indicates the prompt at that index has started generating tokens.
current_context_end_position (int) – An integer indicating which position to extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor) – A boolean tensor of shape [batch_size]. True indicates the prompt at that index has reached end condition.
generated_sequence_lengths (torch.Tensor) – A int tensor of shape [batch_size]. Each value represents the generated sequence lengths for that prompt.
- Returns:
Returns the boolean is_generation_done_tensor and the generated_sequence_lengths after updating it
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- pad_input_prompt_tokens(
- batch_prompt_tokens_list: List[List[int]],
- padded_batch_size: int,
- padded_sequence_length: int,
Method to pad input prompts
Given a list of prompts, pad them all to uniform length
- Parameters:
batch_prompt_tokens_list (List[List[int]]) – A list containing the prompt tokens
padded_batch_size (int) – The maximum number of requests for this batch
padded_sequence_length (int) – The maximum number of input + output tokens for this batch
- Returns:
A torch tensor of shape [padded_batch_size, padded_sequence_length]
- Return type:
torch.Tensor
- unpad_input_prompt_tokens(
- padded_batch_prompt_tokens: torch.Tensor,
- original_batch_size: int,
Truncates the given input tensor back to the original prompt size before padding.
- Parameters:
padded_batch_prompt_tokens (torch.Tensor) – The padded tokens tensor
original_batch_size (int) – The original batch size before padding
- _dynamic_step_context_init(
- construct_graph_dimensions: Optional[megatron.core.inference.batch_dimensions_utils.InferenceBatchDimensions] = None,
- is_dummy_forward: bool = False,
Initializes the inference context for dynamic batching.
- Parameters:
construct_graph_dimensions (Optional[InferenceBatchDimensions]) – The graph config to use for constructing the cuda graphs.
is_dummy_forward (bool) – Whether we are running an expert parallel dummy forward pass
- Returns:
The active input IDs. position_ids (Tensor): The active position IDs.
- Return type:
input_ids (Tensor)
- _dynamic_step_forward_logits(
- input_ids: torch.Tensor,
- position_ids: torch.Tensor,
Forward step the model to get logits for dynamic batching.
This also handles logits-broadcasting for pipeline parallelism.
- Parameters:
input_ids (Tensor) – The input token IDs.
position_ids (Tensor) – The position IDs.
- _dynamic_step_sample_bookkeeping()#
Perform bookkeeping necessary to sample logits for dynamic batching.
- _dynamic_step_sample_logits(logits: torch.Tensor)#
Sample tokens from logits for dynamic batching.
- Parameters:
logits (Tensor) – The logits from the forward pass.
- _dynamic_step_log_probs_bookkeeping() Tuple[bool, bool]#
Perform bookkeeping necessary to compute log probs for dynamic batching.
- Returns:
Whether to return the sampled log_probs.
- Return type:
return_log_probs (bool)
- _router_record_bookkeeping() Optional[Dict[int, torch.Tensor]]#
Collect and map routing indices per request for MoE router recording.
This method retrieves recorded routing decisions and maps them to individual requests using the context’s request_ids and query_lengths. Uses the context’s routing_metadata when available (which handles CUDA graph static buffers automatically). Must be called while context attributes are still valid (before request transitions).
- Returns:
A dictionary mapping request_id to a tensor of shape [num_tokens, num_layers, topk]. Returns None if routing replay is disabled or no routing data was recorded.
- Return type:
Optional[Dict[int, Tensor]]
- _dynamic_step_calculate_log_probs(
- logits: torch.Tensor,
Calculate log probs from logits.
- _dynamic_step_calculate_top_n_logprobs(
- logits: torch.Tensor,
- log_probs_tensor: Optional[torch.Tensor] = None,
Calculate top-n log probs from logits for dynamic batching.
- Parameters:
logits (Tensor) – The logits to compute top-n log probs from.
log_probs_tensor (Optional[Tensor]) – Pre-computed log probabilities tensor. If provided, avoids recomputing log_softmax. Should be the tensor returned by calculate_log_probs.
- Returns:
A dictionary mapping request_idx to list of (top_n_logprobs, top_n_indices) tuples. Each tuple in the list represents one token position.
- dummy_forward()#
Perform a dummy forward pass. This is used in expert model parallelism on ranks that do not have any real requests.
- _dynamic_step_context_bookkeeping() Dict[str, torch.Tensor]#
Update the dynamic inference context after sampling.
- Parameters:
new_sample (Tensor) – The newly sampled tokens.
request_metadata (Optional[Dict[str, Tensor]]) – An override for the tensors that manage request metadata, such as sampling parameters. By default, this metadata is retrieved from the context.
- Returns:
A dictionary containing: active_request_ids (Tensor): Current active request IDs. newly_paused_request_ids (Tensor): Newly paused request IDs. finished_request_ids (Tensor): Finished request IDs.
- Return type:
Dict [str, Tensor]
- async async_generate_output_tokens_dynamic_batch(
- skip_bookkeeping: Optional[bool] = False,
Forward step the model and update the inference context.
- Parameters:
skip_bookkeeping (Optional[bool]) – If true, skip the context bookkeeping step.
- Returns:
A dictionary containing: active_request_ids (Tensor): Current active request IDs. newly_paused_request_ids (Tensor): Newly paused request IDs. finished_request_ids (Tensor): Finished request IDs. sample (Tensor): New sample. log_probs (Optional[Tensor]): Log probabilities of the new sample, if requested. cuda_graph_request_count (Optional[int]): Size of cuda graph used for this step.
- Return type:
(Optional[Dict])
- generate_output_tokens_dynamic_batch(
- loop: Optional[asyncio.AbstractEventLoop] = None,
Synchronous wrapper for `self.async_generate_output_tokens_dynamic_batch.
- _update_top_n_logprobs_dict(
- top_n_logprobs_this_step: torch.Tensor,
- top_n_logprobs_indices: torch.Tensor,
- mask: torch.Tensor,
- top_n_logprobs_dict: Dict[int, List[Dict[str, float]]],
Function to update the top_n_logprobs at each step
This function goes through the topn logprobs generated for each, and for whichever batch has started generating tokens, it updates the top_n_logprobs_dict with the decoded token (string) as the key and the logit as the value. top_n_logprobs_dict has as keys the batch idx, the values is a list, where each element represents a dictionary of decoded token as key and logit as value generated at each step
- Parameters:
top_n_logprobs_this_step (torch.Tensor) – The top n logprob values
top_n_logprobs_indices (torch.Tensor) – The indices corresponding to the top n logprobs
mask (torch.Tensor) – A mask to indicate which requests should append to the dict
top_n_logprobs_dict (top_n_logprobs_dict) – The dict to be updated
- generate_all_output_tokens_static_batch(
- active_requests: OrderedDict[int, megatron.core.inference.inference_request.InferenceRequest],
- active_streams: Optional[OrderedDict[str, megatron.core.inference.async_stream.AsyncStream]] = None,
Utility to generate all the output tokens and probabilities for the prompts.
This utility generates the output tokens for a static batch. It runs the forward steps till all prompts complete generation, updates the status of these requests to completed, adds the generated result and returns these requests
- Parameters:
active_requests (OrderedDict[int, InferenceRequest]) – The input active requests.
- Returns:
The result for each of the incoming requests
- Return type:
OrderedDict[int, InferenceRequest]
- prep_inference_input(
- prompts_tokens: torch.Tensor,
- active_requests: OrderedDict[int, megatron.core.inference.inference_request.InferenceRequest],
- use_attention_mask: bool = False,
Preparing input data for inference, using respective wrapper’s prep_inference_input method # pylint: disable=line-too-long
- Parameters:
prompts_tokens (torch.Tensor) – A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[int, InferenceRequest]) – The input active requests
use_attention_mask (bool) – Whether to use an attention mask. Should be set to True only when exclusively doing prefill (no decode) with variable prompt lengths.
- Returns:
A dict of the inference input for the current batch.
- stream_tokens(
- sampling_params: megatron.core.inference.sampling_params.SamplingParams,
- request_ids: List[int],
- requests: List[megatron.core.inference.inference_request.InferenceRequest],
- streams: List[megatron.core.inference.async_stream.AsyncStream],
- generation_started: List[bool],
- is_generation_done: List[bool],
- tokens: torch.Tensor,
- prompt_lengths: List[int],
- generated_lengths: List[int],
- output_log_probs: Union[torch.Tensor, None],
Asynchronously streams tokens for the given requests.
- Parameters:
sampling_params (SamplingParams) – The sampling parameters.
request_ids (List[int]) – The request IDs.
request (List[InferenceRequest]) – The requests.
stream (List[AsyncStream]) – The streams over which to send tokens.
generation_started (List[bool]) – Whether the decode step has started.
is_generation_done (List[bool]) – Whether generation has completed.
tokens (torch.Tensor) – The tokens for this request.
prompt_lengths (List[int]) – The number of prompt tokens for each request.
generated_lengths (List[int]) – The number of output tokens for each request.
output_log_probs (torch.Tensor, optional) – The log probs for each request.