core.inference.model_inference_wrappers.abstract_model_inference_wrapper#

Module Contents#

Classes#

AbstractModelInferenceWrapper

Abstract inference wrapper

API#

class core.inference.model_inference_wrappers.abstract_model_inference_wrapper.AbstractModelInferenceWrapper(
model: Union[LegacyGPTModel, megatron.core.models.gpt.gpt_model.GPTModel],
inference_wrapper_config: megatron.core.inference.model_inference_wrappers.inference_wrapper_config.InferenceWrapperConfig,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Bases: abc.ABC

Abstract inference wrapper

Extend this to create a version for your model.

The wrapper prepares the model for inference, provides the required input data and runs the forward pass.

Parameters:
  • model (Union[GPTModel, LegacyGPTModel]) – The actual GPT model (MCore or MLM).

  • inference_wrapper_config (InferenceWrapperConfig) – Has info like hidden size, vocab size etc.

  • inference_context (BaseInferenceContext) – Context for managing KV cache and other inference params.

  • pg_collection (ProcessGroupCollection) – Process groups for model communication.

Initialization

property inference_params#

Getter for deprecated inference_params.

prep_model_for_inference(
prompts_tokens: Optional[torch.Tensor] = None,
)#

A utility function for preparing model for inference

The function gets called once before the auto regressive inference loop. It puts the model in eval mode.

Parameters:

prompts_tokens (torch.Tensor, optional) – Deprecated, will be removed in megatron-core 0.13

abstractmethod prep_inference_input(prompt_tokens) Dict[str, Any]#

Prepares the inference input data.

Parameters:

prompts_tokens (torch.Tensor) – A tensor of shape [batch_size, max_seq_len]

Returns:

A dict with all the inference input needed for the batch.

abstractmethod get_batch_for_context_window(
*args,
**kwargs,
) Dict[str, Any]#

Returns the input data for inference

This function gets called iteratively in the inference loop . It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.

_forward(inference_input)#

Runs a forward pass of the model.

Parameters:

inference_input (Dict[str, Any]) – The input data.

Returns:

The model output logits.

dummy_forward()#

Run a dummy forward pass through the model, with a single token. Use-case: Used in EP on ranks which do not have any work, but are needed for the all-to-all communication.

_get_batch_size_and_seq_len(
tokens: torch.Tensor,
recv_buffer_seq_len: Optional[int] = None,
)#

Returns the batch size and sequence length based on the tokens tensor and recv_buffer_seq_len.

Parameters:
  • tokens (torch.Tensor) – The input tensor of shape (batch_size, seq_len).

  • recv_buffer_seq_len (int, optional) – An optional recv buffer sequence length.

Returns:

A tuple (batch_size, seq_len), where batch_size is the first dimension of tokens and seq_len is either the second dimension or recv_buffer_seq_len.

Return type:

tuple

_allocate_recv_buffer(batch_size, seq_len)#

Receive happens between the layers with size [seq_len, batch_size, hidden_size].

forward_pass_without_pipeline_parallel(
inference_input: Dict[str, Any],
) torch.Tensor#

Utility to carry out simple forward pass for TP or no model parallel models

Runs a very simple forward pass for model. Used in the case of models without any parallelism or only tensor parallelism.

Parameters:

inference_input (Dict[str, Any]) – A dict containg the inputs for the gpt model [tokens, position ids, attention mask]

Returns:

The output logits of shape [batch_size, seq_len, padded_vocab_size]

Return type:

torch.Tensor

forward_pass_with_pipeline_parallel_small_input_batch(
inference_input: Dict[str, Any],
recv_buffer_seq_len: Optional[int] = None,
) torch.Tensor#

Utility to carry out forward pass for PP models with very small inputs

If a model is pipeline parallel, yet, the input global batch is very small, we compute a foward pass on the entire global batch, rather than splitting it up into micro batches and doing something more complex as in the forward_pass_with_pipeline_parallel_large_input_batch method

Parameters:
  • inference_input (Dict[str, Any]) – A dict containing the inputs for the gpt model [tokens, position ids, attention mask]

  • recv_buffer_seq_len (int) – An optional sequence length for the pipeline parallel recv buffer.

Returns:

The output logits of shape [batch_size, seq_len, padded_vocab_size]

Return type:

torch.Tensor

forward_pass_with_pipeline_parallel_large_input_batch(
inference_input: Dict[str, Any],
recv_buffer_seq_len=None,
) torch.Tensor#

Utility to carry out forward pass PP models.

Runs the forward pass for models which are pipeline parallel. This is more complex than forward_pass_with_pipeline_parallel_small_input_batch because this splits the global batch into small micro batches and runs them through the model.

Parameters:
  • inference_input (Dict[str, Any]) – A dict containg the inputs for the gpt model [tokens, position ids, attention mask]

  • recv_buffer_seq_len (int) – An optional sequence length for the pipeline parallel recv buffer.

Returns:

The output logits of shape [batch_size, seq_len, padded_vocab_size]

Return type:

torch.Tensor

run_one_forward_step(
inference_input: Dict[str, Any],
recv_buffer_seq_len: Optional[int] = None,
) torch.Tensor#

The forward pass of the model for inference

Appropriate utility is called for the forward pass depending on the type of model parallelism used

Parameters:
  • inference_input (Dict[str, Any]) – A dict containing the inputs for the gpt model [tokens, position ids, attention mask]

  • recv_buffer_seq_len (int) – An optional sequence length for the pipeline parallel recv buffer.

Returns:

The output logits of shape [batch_size, seq_len, padded_vocab_size]. The logits are returned only in the last pipeline stage for PP models.

Return type:

torch.Tensor