core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper#
Module Contents#
Classes#
Inference wrapper for GPT model. |
Data#
API#
- core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper.DEPRECATED_ARGS#
[‘inference_wrapper_config’, ‘pg_collection’]
- class core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper.GPTInferenceWrapper(
- model: megatron.core.models.gpt.GPTModel,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Bases:
megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper.AbstractModelInferenceWrapperInference wrapper for GPT model.
The wrapper prepares the model for inference, provides the required input data, and runs the forward pass
- Parameters:
model (GPTModel) – The GPT model (MCore or legacy)
inference_context (BaseInferenceContext) – Manages KV cache, and tracks sequence/token/batch offsets.
Initialization
- prep_inference_input(
- prompts_tokens: torch.Tensor,
Prepares the inference input data.
- Parameters:
prompts_tokens (torch.Tensor) – A tensor of shape [batch_size, max_seq_len]
- Returns:
A dict with all the inference input needed for the batch.
- _build_attention_mask_and_position_ids(
- prompts_tokens: torch.Tensor,
Builds the full attention mask and position ids for the input tokens
- Parameters:
prompts_tokens (torch.Tensor) – A tensor of shape [batch_size, max_seq_len]
- Returns:
The attention mask of shape [1, 1, max_seq_len, max_seq_len] and position ids of shape [batch_size, max_seq_len]
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- get_batch_for_context_window(
- inference_input: Dict[str, Any],
- context_start_position: int,
- context_end_position: int,
Returns the inference data given context window
This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data.
- Parameters:
inference_input (Dict[str, Any]) – The inference input for the batch.
context_start_position (int) – Start of the context window. During the first inference step it is mostly 0
context_end_position (int) – End of the context window. During the last inference step it will mostly be the max generated sequence length.
- Returns:
A dict of inputs that will be used by your model in the forward step
- Return type:
Dict[str, Any]