core.inference.model_inference_wrappers.t5.t5_inference_wrapper#
Module Contents#
Classes#
Inference wrapper for T5 model. |
API#
- class core.inference.model_inference_wrappers.t5.t5_inference_wrapper.T5InferenceWrapper(
- model: megatron.core.models.T5.T5Model,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- use_local: bool = False,
Bases:
megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper.AbstractModelInferenceWrapperInference wrapper for T5 model.
The wrapper prepares the model for inference, provides the required input data, and runs the forward pass
- Parameters:
model (T5Model) – The T5 model (MCore or legacy)
inference_context (BaseInferenceContext) – Manages KV cache, and tracks sequence/token/batch offsets.
use_local (bool) – Whether the T5 model’s transformer impl is local (vs transformer_engine)
Initialization
- prep_inference_input(
- prompts_tokens: torch.Tensor,
- encoder_prompts: Optional[List[str]] = None,
- tokenizer: Any = None,
Prepares the inference input data.
- Parameters:
prompts_tokens (torch.Tensor) – A tensor of shape [batch_size, max_seq_len]
encoder_prompts (dict) – List of string of encoder input prompts
tokenizer (type) – Tokenizer used for tokenizing and detokenizing text
- Returns:
A dict with all the inference input needed for the batch.
- tokenize_encoder_prompt(
- encoder_prompt: str,
- tokenizer,
Utility to tokenize the encoder_prompt
- Parameters:
encoder_prompt (str) – The encoder_prompt
tokenizer (type) – Tokenizer used for tokenizing and detokenizing string
- Returns:
Returns the tokenized prompt
- Return type:
torch.Tensor
- pad_encoder_prompts_tokens(
- encoder_prompts_tokens_list: List[List[int]],
- max_sequence_length: int,
- tokenizer,
Method to pad input prompts
Given a list of prompts, pad them all to uniform length
- Parameters:
encoder_prompts_tokens_list (List[List[int]]) – A list containing the encoder_input_tokens
max_sequence_length (int) – Maximum of the length of the encoder inputs tokens
tokenizer (type) – Tokenizer used for tokenizing and detokenizing text
- Returns:
A torch tensor of shape [bs, max_sequence_length]
- Return type:
torch.Tensor
- get_batch_for_context_window(
- inference_input: Dict[str, Any],
- context_start_position: int,
- context_end_position: int,
Returns the inference data given context window
This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data.
- Parameters:
inference_input (Dict[str, Any]) – The inference input for the batch.
context_start_position (int) – Start of the context window. During the first inference step it is mostly 0
context_end_position (int) – End of the context window. During the last inference step it will mostly be the max generated sequence length.
- Returns:
A dict of inputs that will be used by your model in the forward step
- Return type:
Dict
- forward_pass_without_pipeline_parallel(
- inference_input: Dict[str, Any],
Utility to carry out simple forward pass for TP or no model parallel models
Runs a very simple forward pass for model. Used in the case of models without any parallelism or only tensor parallelism.
- Parameters:
inference_input (Dict[str, Any]) – A dict containg the inputs for the gpt model [tokens, position ids, attention mask]
- Returns:
The output logits of shape [batch_size, seq_len, padded_vocab_size]
- Return type:
torch.Tensor