core.inference.model_inference_wrappers.abstract_model_inference_wrapper#
Module Contents#
Classes#
Abstract inference wrapper |
API#
- class core.inference.model_inference_wrappers.abstract_model_inference_wrapper.AbstractModelInferenceWrapper(
- model: Union[LegacyGPTModel, megatron.core.models.gpt.gpt_model.GPTModel],
- inference_wrapper_config: megatron.core.inference.model_inference_wrappers.inference_wrapper_config.InferenceWrapperConfig,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Bases:
abc.ABCAbstract inference wrapper
Extend this to create a version for your model.
The wrapper prepares the model for inference, provides the required input data and runs the forward pass.
- Parameters:
model (Union[GPTModel, LegacyGPTModel]) – The actual GPT model (MCore or MLM).
inference_wrapper_config (InferenceWrapperConfig) – Has info like hidden size, vocab size etc.
inference_context (BaseInferenceContext) – Context for managing KV cache and other inference params.
pg_collection (ProcessGroupCollection) – Process groups for model communication.
Initialization
- property inference_params#
Getter for deprecated
inference_params.
- prep_model_for_inference(
- prompts_tokens: Optional[torch.Tensor] = None,
A utility function for preparing model for inference
The function gets called once before the auto regressive inference loop. It puts the model in eval mode.
- Parameters:
prompts_tokens (torch.Tensor, optional) – Deprecated, will be removed in
megatron-core0.13
- abstractmethod prep_inference_input(prompt_tokens) Dict[str, Any]#
Prepares the inference input data.
- Parameters:
prompts_tokens (torch.Tensor) – A tensor of shape [batch_size, max_seq_len]
- Returns:
A dict with all the inference input needed for the batch.
- abstractmethod get_batch_for_context_window(
- *args,
- **kwargs,
Returns the input data for inference
This function gets called iteratively in the inference loop . It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.
- _forward(inference_input)#
Runs a forward pass of the model.
- Parameters:
inference_input (Dict[str, Any]) – The input data.
- Returns:
The model output logits.
- dummy_forward()#
Run a dummy forward pass through the model, with a single token. Use-case: Used in EP on ranks which do not have any work, but are needed for the all-to-all communication.
- _get_batch_size_and_seq_len(
- tokens: torch.Tensor,
- recv_buffer_seq_len: Optional[int] = None,
Returns the batch size and sequence length based on the tokens tensor and recv_buffer_seq_len.
- Parameters:
tokens (torch.Tensor) – The input tensor of shape (batch_size, seq_len).
recv_buffer_seq_len (int, optional) – An optional recv buffer sequence length.
- Returns:
A tuple (batch_size, seq_len), where batch_size is the first dimension of tokens and seq_len is either the second dimension or recv_buffer_seq_len.
- Return type:
tuple
- _allocate_recv_buffer(batch_size, seq_len)#
Receive happens between the layers with size [seq_len, batch_size, hidden_size].
- forward_pass_without_pipeline_parallel(
- inference_input: Dict[str, Any],
Utility to carry out simple forward pass for TP or no model parallel models
Runs a very simple forward pass for model. Used in the case of models without any parallelism or only tensor parallelism.
- Parameters:
inference_input (Dict[str, Any]) – A dict containg the inputs for the gpt model [tokens, position ids, attention mask]
- Returns:
The output logits of shape [batch_size, seq_len, padded_vocab_size]
- Return type:
torch.Tensor
- forward_pass_with_pipeline_parallel_small_input_batch(
- inference_input: Dict[str, Any],
- recv_buffer_seq_len: Optional[int] = None,
Utility to carry out forward pass for PP models with very small inputs
If a model is pipeline parallel, yet, the input global batch is very small, we compute a foward pass on the entire global batch, rather than splitting it up into micro batches and doing something more complex as in the forward_pass_with_pipeline_parallel_large_input_batch method
- Parameters:
inference_input (Dict[str, Any]) – A dict containing the inputs for the gpt model [tokens, position ids, attention mask]
recv_buffer_seq_len (int) – An optional sequence length for the pipeline parallel recv buffer.
- Returns:
The output logits of shape [batch_size, seq_len, padded_vocab_size]
- Return type:
torch.Tensor
- forward_pass_with_pipeline_parallel_large_input_batch(
- inference_input: Dict[str, Any],
- recv_buffer_seq_len=None,
Utility to carry out forward pass PP models.
Runs the forward pass for models which are pipeline parallel. This is more complex than forward_pass_with_pipeline_parallel_small_input_batch because this splits the global batch into small micro batches and runs them through the model.
- Parameters:
inference_input (Dict[str, Any]) – A dict containg the inputs for the gpt model [tokens, position ids, attention mask]
recv_buffer_seq_len (int) – An optional sequence length for the pipeline parallel recv buffer.
- Returns:
The output logits of shape [batch_size, seq_len, padded_vocab_size]
- Return type:
torch.Tensor
- run_one_forward_step(
- inference_input: Dict[str, Any],
- recv_buffer_seq_len: Optional[int] = None,
The forward pass of the model for inference
Appropriate utility is called for the forward pass depending on the type of model parallelism used
- Parameters:
inference_input (Dict[str, Any]) – A dict containing the inputs for the gpt model [tokens, position ids, attention mask]
recv_buffer_seq_len (int) – An optional sequence length for the pipeline parallel recv buffer.
- Returns:
The output logits of shape [batch_size, seq_len, padded_vocab_size]. The logits are returned only in the last pipeline stage for PP models.
- Return type:
torch.Tensor