core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper#

Module Contents#

Classes#

GPTInferenceWrapper

Inference wrapper for GPT model.

Data#

API#

core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper.DEPRECATED_ARGS#

[‘inference_wrapper_config’, ‘pg_collection’]

class core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper.GPTInferenceWrapper(
model: megatron.core.models.gpt.GPTModel,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
)#

Bases: megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper.AbstractModelInferenceWrapper

Inference wrapper for GPT model.

The wrapper prepares the model for inference, provides the required input data, and runs the forward pass

Parameters:
  • model (GPTModel) – The GPT model (MCore or legacy)

  • inference_context (BaseInferenceContext) – Manages KV cache, and tracks sequence/token/batch offsets.

Initialization

prep_inference_input(
prompts_tokens: torch.Tensor,
) Dict[str, Any]#

Prepares the inference input data.

Parameters:

prompts_tokens (torch.Tensor) – A tensor of shape [batch_size, max_seq_len]

Returns:

A dict with all the inference input needed for the batch.

_build_attention_mask_and_position_ids(
prompts_tokens: torch.Tensor,
) Tuple[torch.Tensor, torch.Tensor]#

Builds the full attention mask and position ids for the input tokens

Parameters:

prompts_tokens (torch.Tensor) – A tensor of shape [batch_size, max_seq_len]

Returns:

The attention mask of shape [1, 1, max_seq_len, max_seq_len] and position ids of shape [batch_size, max_seq_len]

Return type:

Tuple[torch.Tensor, torch.Tensor]

get_batch_for_context_window(
inference_input: Dict[str, Any],
context_start_position: int,
context_end_position: int,
) Dict[str, Any]#

Returns the inference data given context window

This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data.

Parameters:
  • inference_input (Dict[str, Any]) – The inference input for the batch.

  • context_start_position (int) – Start of the context window. During the first inference step it is mostly 0

  • context_end_position (int) – End of the context window. During the last inference step it will mostly be the max generated sequence length.

Returns:

A dict of inputs that will be used by your model in the forward step

Return type:

Dict[str, Any]