core.inference.model_inference_wrappers.t5.t5_inference_wrapper#

Module Contents#

Classes#

T5InferenceWrapper

Inference wrapper for T5 model.

API#

class core.inference.model_inference_wrappers.t5.t5_inference_wrapper.T5InferenceWrapper(
model: megatron.core.models.T5.T5Model,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
use_local: bool = False,
)#

Bases: megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper.AbstractModelInferenceWrapper

Inference wrapper for T5 model.

The wrapper prepares the model for inference, provides the required input data, and runs the forward pass

Parameters:
  • model (T5Model) – The T5 model (MCore or legacy)

  • inference_context (BaseInferenceContext) – Manages KV cache, and tracks sequence/token/batch offsets.

  • use_local (bool) – Whether the T5 model’s transformer impl is local (vs transformer_engine)

Initialization

prep_inference_input(
prompts_tokens: torch.Tensor,
encoder_prompts: Optional[List[str]] = None,
tokenizer: Any = None,
) Dict[str, Any]#

Prepares the inference input data.

Parameters:
  • prompts_tokens (torch.Tensor) – A tensor of shape [batch_size, max_seq_len]

  • encoder_prompts (dict) – List of string of encoder input prompts

  • tokenizer (type) – Tokenizer used for tokenizing and detokenizing text

Returns:

A dict with all the inference input needed for the batch.

tokenize_encoder_prompt(
encoder_prompt: str,
tokenizer,
) torch.Tensor#

Utility to tokenize the encoder_prompt

Parameters:
  • encoder_prompt (str) – The encoder_prompt

  • tokenizer (type) – Tokenizer used for tokenizing and detokenizing string

Returns:

Returns the tokenized prompt

Return type:

torch.Tensor

pad_encoder_prompts_tokens(
encoder_prompts_tokens_list: List[List[int]],
max_sequence_length: int,
tokenizer,
) torch.Tensor#

Method to pad input prompts

Given a list of prompts, pad them all to uniform length

Parameters:
  • encoder_prompts_tokens_list (List[List[int]]) – A list containing the encoder_input_tokens

  • max_sequence_length (int) – Maximum of the length of the encoder inputs tokens

  • tokenizer (type) – Tokenizer used for tokenizing and detokenizing text

Returns:

A torch tensor of shape [bs, max_sequence_length]

Return type:

torch.Tensor

get_batch_for_context_window(
inference_input: Dict[str, Any],
context_start_position: int,
context_end_position: int,
) Dict[str, Any]#

Returns the inference data given context window

This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data.

Parameters:
  • inference_input (Dict[str, Any]) – The inference input for the batch.

  • context_start_position (int) – Start of the context window. During the first inference step it is mostly 0

  • context_end_position (int) – End of the context window. During the last inference step it will mostly be the max generated sequence length.

Returns:

A dict of inputs that will be used by your model in the forward step

Return type:

Dict

forward_pass_without_pipeline_parallel(
inference_input: Dict[str, Any],
) torch.Tensor#

Utility to carry out simple forward pass for TP or no model parallel models

Runs a very simple forward pass for model. Used in the case of models without any parallelism or only tensor parallelism.

Parameters:

inference_input (Dict[str, Any]) – A dict containg the inputs for the gpt model [tokens, position ids, attention mask]

Returns:

The output logits of shape [batch_size, seq_len, padded_vocab_size]

Return type:

torch.Tensor