core.inference.model_inference_wrappers.abstract_model_inference_wrapper#

Module Contents#

Classes#

AbstractModelInferenceWrapper

Abstract inference wrapper

Data#

API#

core.inference.model_inference_wrappers.abstract_model_inference_wrapper.DEPRECATED_ARGS#

[‘inference_wrapper_config’, ‘pg_collection’]

class core.inference.model_inference_wrappers.abstract_model_inference_wrapper.AbstractModelInferenceWrapper(
model: Union[LegacyGPTModel, megatron.core.models.gpt.gpt_model.GPTModel],
inference_context: megatron.core.inference.contexts.BaseInferenceContext,
)#

Bases: abc.ABC

Abstract inference wrapper

Extend this to create a version for your model.

The wrapper prepares the model for inference, provides the required input data and runs the forward pass.

Parameters:
  • model (Union[GPTModel, LegacyGPTModel]) – The actual GPT model (MCore or MLM).

  • inference_context (BaseInferenceContext) – Context for managing KV cache and other inference params.

Initialization

prep_model_for_inference()#

A utility function for preparing model for inference

The function gets called once before the auto regressive inference loop. It puts the model in eval mode.

abstractmethod prep_inference_input(prompt_tokens) Dict[str, Any]#

Prepares the inference input data.

Parameters:

prompts_tokens (torch.Tensor) – A tensor of shape [batch_size, max_seq_len]

Returns:

A dict with all the inference input needed for the batch.

abstractmethod get_batch_for_context_window(
*args,
**kwargs,
) Dict[str, Any]#

Returns the input data for inference

This function gets called iteratively in the inference loop. It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.

_forward(inference_input)#

Runs a forward pass of the model.

Parameters:

inference_input (Dict[str, Any]) – The input data.

Returns:

The model output logits.

dummy_forward()#

Run a dummy forward pass through the model, with a single token. Use-case: Used in EP on ranks which do not have any work, but are needed for the all-to-all communication.

_get_batch_size_and_seq_len(
tokens: torch.Tensor,
recv_buffer_seq_len: Optional[int] = None,
)#

Returns the batch size and sequence length based on the tokens tensor and recv_buffer_seq_len.

Parameters:
  • tokens (torch.Tensor) – The input tensor of shape (batch_size, seq_len).

  • recv_buffer_seq_len (int, optional) – An optional recv buffer sequence length.

Returns:

A tuple (batch_size, seq_len), where batch_size is the first dimension of tokens and seq_len is either the second dimension or recv_buffer_seq_len.

Return type:

tuple

_allocate_recv_buffer(batch_size, seq_len)#

Receive happens between the layers with size [seq_len, batch_size, hidden_size].

forward_pass_without_pipeline_parallel(
inference_input: Dict[str, Any],
) torch.Tensor#

Utility to carry out simple forward pass for TP or no model parallel models

Runs a very simple forward pass for model. Used in the case of models without any parallelism or only tensor parallelism.

Parameters:

inference_input (Dict[str, Any]) – A dict containg the inputs for the gpt model [tokens, position ids, attention mask]

Returns:

The output logits of shape [batch_size, seq_len, padded_vocab_size]

Return type:

torch.Tensor

forward_pass_with_pipeline_parallel(
inference_input: Dict[str, Any],
recv_buffer_seq_len: Optional[int] = None,
) torch.Tensor#

Utility to carry out forward pass for PP models

TODO: Add support for asynchronous microbatches

Parameters:
  • inference_input (Dict[str, Any]) – A dict containing the inputs for the gpt model [tokens, position ids, attention mask]

  • recv_buffer_seq_len (int) – An optional sequence length for the pipeline parallel recv buffer.

Returns:

The output logits of shape [batch_size, seq_len, padded_vocab_size]

Return type:

torch.Tensor

run_one_forward_step(
inference_input: Dict[str, Any],
recv_buffer_seq_len: Optional[int] = None,
) torch.Tensor#

The forward pass of the model for inference

Appropriate utility is called for the forward pass depending on the type of model parallelism used

Parameters:
  • inference_input (Dict[str, Any]) – A dict containing the inputs for the gpt model [tokens, position ids, attention mask]

  • recv_buffer_seq_len (int) – An optional sequence length for the pipeline parallel recv buffer.

Returns:

The output logits of shape [batch_size, seq_len, padded_vocab_size]. The logits are returned only in the last pipeline stage for PP models.

Return type:

torch.Tensor