core.inference.model_inference_wrappers.abstract_model_inference_wrapper#
Module Contents#
Classes#
Abstract inference wrapper |
Data#
API#
- core.inference.model_inference_wrappers.abstract_model_inference_wrapper.DEPRECATED_ARGS#
[‘inference_wrapper_config’, ‘pg_collection’]
- class core.inference.model_inference_wrappers.abstract_model_inference_wrapper.AbstractModelInferenceWrapper(
- model: Union[LegacyGPTModel, megatron.core.models.gpt.gpt_model.GPTModel],
- inference_context: megatron.core.inference.contexts.BaseInferenceContext,
Bases:
abc.ABCAbstract inference wrapper
Extend this to create a version for your model.
The wrapper prepares the model for inference, provides the required input data and runs the forward pass.
- Parameters:
model (Union[GPTModel, LegacyGPTModel]) – The actual GPT model (MCore or MLM).
inference_context (BaseInferenceContext) – Context for managing KV cache and other inference params.
Initialization
- prep_model_for_inference()#
A utility function for preparing model for inference
The function gets called once before the auto regressive inference loop. It puts the model in eval mode.
- abstractmethod prep_inference_input(prompt_tokens) Dict[str, Any]#
Prepares the inference input data.
- Parameters:
prompts_tokens (torch.Tensor) – A tensor of shape [batch_size, max_seq_len]
- Returns:
A dict with all the inference input needed for the batch.
- abstractmethod get_batch_for_context_window(
- *args,
- **kwargs,
Returns the input data for inference
This function gets called iteratively in the inference loop. It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.
- _forward(inference_input)#
Runs a forward pass of the model.
- Parameters:
inference_input (Dict[str, Any]) – The input data.
- Returns:
The model output logits.
- dummy_forward()#
Run a dummy forward pass through the model, with a single token. Use-case: Used in EP on ranks which do not have any work, but are needed for the all-to-all communication.
- _get_batch_size_and_seq_len(
- tokens: torch.Tensor,
- recv_buffer_seq_len: Optional[int] = None,
Returns the batch size and sequence length based on the tokens tensor and recv_buffer_seq_len.
- Parameters:
tokens (torch.Tensor) – The input tensor of shape (batch_size, seq_len).
recv_buffer_seq_len (int, optional) – An optional recv buffer sequence length.
- Returns:
A tuple (batch_size, seq_len), where batch_size is the first dimension of tokens and seq_len is either the second dimension or recv_buffer_seq_len.
- Return type:
tuple
- _allocate_recv_buffer(batch_size, seq_len)#
Receive happens between the layers with size [seq_len, batch_size, hidden_size].
- forward_pass_without_pipeline_parallel(
- inference_input: Dict[str, Any],
Utility to carry out simple forward pass for TP or no model parallel models
Runs a very simple forward pass for model. Used in the case of models without any parallelism or only tensor parallelism.
- Parameters:
inference_input (Dict[str, Any]) – A dict containg the inputs for the gpt model [tokens, position ids, attention mask]
- Returns:
The output logits of shape [batch_size, seq_len, padded_vocab_size]
- Return type:
torch.Tensor
- forward_pass_with_pipeline_parallel(
- inference_input: Dict[str, Any],
- recv_buffer_seq_len: Optional[int] = None,
Utility to carry out forward pass for PP models
TODO: Add support for asynchronous microbatches
- Parameters:
inference_input (Dict[str, Any]) – A dict containing the inputs for the gpt model [tokens, position ids, attention mask]
recv_buffer_seq_len (int) – An optional sequence length for the pipeline parallel recv buffer.
- Returns:
The output logits of shape [batch_size, seq_len, padded_vocab_size]
- Return type:
torch.Tensor
- run_one_forward_step(
- inference_input: Dict[str, Any],
- recv_buffer_seq_len: Optional[int] = None,
The forward pass of the model for inference
Appropriate utility is called for the forward pass depending on the type of model parallelism used
- Parameters:
inference_input (Dict[str, Any]) – A dict containing the inputs for the gpt model [tokens, position ids, attention mask]
recv_buffer_seq_len (int) – An optional sequence length for the pipeline parallel recv buffer.
- Returns:
The output logits of shape [batch_size, seq_len, padded_vocab_size]. The logits are returned only in the last pipeline stage for PP models.
- Return type:
torch.Tensor