pipeline_parallel package
This package contains implementations for two different pipeline parallelism schedules (one without interleaving and one with interleaving, see Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM for details), and a default no-pipelining schedule. It also contains methods for the point-to-point communication that is needed between pipeline stages.
Contains implementations for the various point-to-point communication needed (e.g., recv_forward and recv_backward) in the different pipeline parallelism schedules.
- core.pipeline_parallel.p2p_communication.recv_backward(tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig) → torch.Tensor
Receive tensor from next rank in pipeline (backward receive).
See _communicate for argument details.
- core.pipeline_parallel.p2p_communication.recv_forward(tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig) → torch.Tensor
Receive tensor from previous rank in pipeline (forward receive).
See _communicate for argument details.
- core.pipeline_parallel.p2p_communication.send_backward(input_tensor_grad: torch.Tensor, config: megatron.core.ModelParallelConfig) → None
Send tensor to previous rank in pipeline (backward send).
See _communicate for argument details.
- core.pipeline_parallel.p2p_communication.send_backward_recv_backward(input_tensor_grad: torch.Tensor, recv_next: bool, tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig, overlap_p2p_comm: bool = False) → torch.Tensor
Batched recv from next rank and send to previous rank in pipeline.
See _communicate for argument details.
- core.pipeline_parallel.p2p_communication.send_backward_recv_forward(input_tensor_grad: torch.Tensor, tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig) → torch.Tensor
Batched send and recv with previous rank in pipeline.
See _communicate for argument details.
- core.pipeline_parallel.p2p_communication.send_forward(output_tensor: torch.Tensor, config: megatron.core.ModelParallelConfig) → None
Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
- core.pipeline_parallel.p2p_communication.send_forward_backward_recv_forward_backward(output_tensor: torch.Tensor, input_tensor_grad: torch.Tensor, recv_prev: bool, recv_next: bool, tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig) → torch.Tensor
Batched send and recv with previous and next ranks in pipeline.
See _communicate for argument details.
- core.pipeline_parallel.p2p_communication.send_forward_recv_backward(output_tensor: torch.Tensor, tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig) → torch.Tensor
Batched send and recv with next rank in pipeline.
See _communicate for argument details.
- core.pipeline_parallel.p2p_communication.send_forward_recv_forward(output_tensor: torch.Tensor, recv_prev: bool, tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig, overlap_p2p_comm: bool = False) → torch.Tensor
Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
Contains implementations for two pipeline parallelism schedules (forward_backward_pipelining_with_interleaving`for pipeline parallelism with interleaving, `forward_backward_pipelining_without_interleaving for pipeline parallelism without interleaving) and a default no-pipelining schedule (forward_backward_no_pipelining). get_forward_backward_func returns the right scheduling function to use based on the configuration being trained (e.g., if pipeline-parallel size is 1, use forward_backward_no_pipelining).
- core.pipeline_parallel.schedules.backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss with respect to stage’s output tensor.
Returns gradient of loss with respect to input tensor (None if first stage).
- core.pipeline_parallel.schedules.check_first_val_step(first_val_step, forward_only, cond)
Check if it is the first validation step.
- core.pipeline_parallel.schedules.clear_embedding_activation_buffer(config, model)
Clear embedding activation buffer.
- core.pipeline_parallel.schedules.custom_backward(output, grad_output)
Directly call C++ autograd engine.
To make the ‘deallocate_output_tensor’ (above) optimization work, the C++ autograd engine must be called directly, bypassing Pytorch’s torch.autograd.backward. Pytorch’s ‘backward’ checks that the output and grad have the same shape, while C++’s ‘backward’ does not.
- core.pipeline_parallel.schedules.deallocate_output_tensor(out, deallocate_pipeline_outputs=False)
Pseudo-deallocate (i.e., set to scalar) the output tensor’s ‘.data’ field.
This method should be called right after the output tensor has been sent to the next pipeline stage. At this point, the output tensor is only useful for its ‘.grad_fn’ field, and not its ‘.data’.
- core.pipeline_parallel.schedules.finish_embedding_wgrad_compute(config, embedding_module)
Finish embedding wgrad compute.
- core.pipeline_parallel.schedules.forward_backward_no_pipelining(*, forward_step_func, data_iterator: Union[Iterator, List[Iterator]], model: Union[torch.nn.Module, List[torch.nn.Module]], num_microbatches: int, seq_length: int, micro_batch_size: int, decoder_seq_length: Optional[int] = None, forward_only: bool = False, collect_non_loss_data: bool = False, first_val_step: Optional[bool] = None)
Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
Returns dictionary with losses.
See get_forward_backward_func() for argument details
- core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving(*, forward_step_func, data_iterator: Union[Iterator, List[Iterator]], model: Union[torch.nn.Module, List[torch.nn.Module]], num_microbatches: int, seq_length: int, micro_batch_size: int, decoder_seq_length: Optional[int] = None, forward_only: bool = False, collect_non_loss_data: bool = False, first_val_step: Optional[bool] = None)
Run interleaved 1F1B schedule (model split into model chunks), with communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise.
- core.pipeline_parallel.schedules.forward_backward_pipelining_without_interleaving(*, forward_step_func, data_iterator: Union[Iterator, List[Iterator]], model: Union[torch.nn.Module, List[torch.nn.Module]], num_microbatches: int, seq_length: int, micro_batch_size: int, decoder_seq_length: Optional[int] = None, forward_only: bool = False, collect_non_loss_data: bool = False, first_val_step: Optional[bool] = None)
Run non-interleaved 1F1B schedule, with communication between pipeline stages. Returns dictionary with losses if the last stage, empty dict otherwise.
- core.pipeline_parallel.schedules.forward_step(forward_step_func, data_iterator, model, num_microbatches, input_tensor, forward_data_store, config, collect_non_loss_data=False, checkpoint_activations_microbatch=None, is_first_microbatch=False, current_microbatch=None, encoder_decoder_xattn=False)
Forward step for passed-in model.
If it is the first stage, the input tensor is obtained from the data_iterator. Otherwise, the passed-in input_tensor is used.
- Parameters
forward_step_func (callable) –
The forward step function for the model that takes the data iterator as the first argument, and model as the second. This user’s forward step is expected to output a tuple of two elements:
-
- The output object from the forward step. This output object needs to be a
-
tensor or some kind of collection of tensors. The only hard requirement for this object is that it needs to be acceptible as input into the second function.
-
- A function to reduce (optionally) the output from the forward step. This
-
could be a reduction over the loss from the model, it could be a function that grabs the output from the model and reformats, it could be a function that just passes through the model output. This function must have one of the following patterns, and depending on the pattern different things happen internally:
-
- A tuple of reduced loss and some other data. Note that in this case
-
the first argument is divided by the number of global microbatches, assuming it is a loss, so that the loss is stable as a function of the number of devices the step is split across.
-
- A triple of reduced loss, number of tokens, and some other data. This
-
is similar to case (a), but the loss is further averaged across the number of tokens in the batch. If the user is not already averaging across the number of tokens, this pattern is useful to use.
-
- Any arbitrary data the user wants (eg a dictionary of tensors, a list
-
of tensors, etc in the case of inference). To trigger case 3 you need to specify collect_non_loss_data=True and you may also want to specify forward_only=True in the call to the parent forward_backward function.
-
-
data_iterator (iterator) – The data iterator.
model (nn.Module) – The model to perform the forward step on.
num_microbatches (int) – The number of microbatches.
input_tensor (Tensor or list[Tensor]) – The input tensor(s) for the forward step.
forward_data_store (list) – The list to store the forward data. If you go down path 2.a or 2.b for the return of your forward reduction function then this will store only the final dimension of the output, for example the metadata output by the loss function. If you go down the path of 2.c then this will store the entire output of the forward reduction function applied to the model output.
config (object) – The configuration object.
collect_non_loss_data (bool, optional) – Whether to collect non-loss data. Defaults to False. This is the path to use if you want to collect arbitrary output from the model forward, such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional) – The microbatch to checkpoint activations. Defaults to None.
is_first_microbatch (bool, optional) – Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional) – The current microbatch. Defaults to None.
- Returns
The output object(s) from the forward step. Tensor: The number of tokens.
- Return type
Tensor or list[Tensor]
- core.pipeline_parallel.schedules.get_forward_backward_func()
Retrieves the appropriate forward_backward function given the configuration of parallel_state.
Returns a function that will perform all of the forward and backward passes of the model given the pipeline model parallel world size and virtual pipeline model parallel world size in the global parallel_state.
Note that if using sequence parallelism, the sequence length component of the tensor shape is updated to original_sequence_length / tensor_model_parallel_world_size.
The function returned takes the following arguments:
- forward_step_func (required): A function that takes a data
iterator and a model as its arguments and return the model’s forward output and the loss function. The loss function should take one torch.Tensor and return a torch.Tensor of loss and a dictionary of string -> torch.Tensor.
A third argument, checkpoint_activations_microbatch, indicates that the activations for this microbatch should be checkpointed. A None value for this argument indicates that the default from the configuration should be used. This is used when the num_microbatches_with_partial_activation_checkpoints is used.
For example:
- def loss_func(loss_mask, output_tensor):
losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {‘lm loss’: averaged_loss[0]}
- def forward_step(data_iterator, model):
data, loss_mask = next(data_iterator) output = model(data) return output, partial(loss_func, loss_mask)
forward_backward_func(forward_step_func=forward_step, …)
- data_iterator (required): an iterator over the data, will be
passed as is to forward_step_func. Expected to be a list of iterators in the case of interleaved pipeline parallelism.
- model (required): the actual model. Expected to be a list of modules in the case of interleaved
pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.
- num_microbatches (int, required):
The number of microbatches to go through
- seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
transformer, this is the encoder’s sequence length. This is ignored if variable_seq_lengths in the config is True. Otherwise, each microbatch in the current global batch size must use this sequence length.
micro_batch_size (int, required): The number of sequences in a microbatch.
- decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack
transformer. This is ignored for a single-stack transformer.
forward_only (optional, default = False): Perform only the forward step
collect_non_loss_data (optional, bool, default=False): TODO
- first_val_step (bool, optional): Is the first step of the validation phase. Used by
Transformer Engine modules to only update their fp8 weights only on the first validation step.
- core.pipeline_parallel.schedules.get_tensor_shapes(*, rank: int, model_type: megatron.core.enums.ModelType, seq_length: int, micro_batch_size: int, decoder_seq_length: int, config, encoder_decoder_xattn: bool)
Determine right tensor sizes (based on position of rank with respect to split rank) and model size. Send two tensors if model decoder requires the encoder’s output (via cross-attention) and rank is in decoder stage. First tensor is decoder. Second tensor is encoder. If model has an encoder & decoder and rank is at the boundary, send one tensor. Otherwise, send one tensor.
- core.pipeline_parallel.schedules.recv_backward(tensor_shapes, config)
Wrapper for p2p_communication.recv_backward used with non-interleaving schedule.
- core.pipeline_parallel.schedules.recv_forward(tensor_shapes, config)
Wrapper for p2p_communication.recv_forward used with non-interleaving schedule.
- core.pipeline_parallel.schedules.send_backward(input_tensor_grads, tensor_shapes, config)
Wrapper for p2p_communication.send_backward used with non-interleaving schedule.
- core.pipeline_parallel.schedules.send_backward_recv_forward(input_tensor_grads, tensor_shapes, config)
Wrapper for p2p_communication.send_backward_recv_forward used with non-interleaving schedule.
- core.pipeline_parallel.schedules.send_forward(output_tensors, tensor_shapes, config)
Wrapper for p2p_communication.send_forward used with non-interleaving schedule.
- core.pipeline_parallel.schedules.send_forward_recv_backward(output_tensors, tensor_shapes, config)
Wrapper for p2p_communication.send_forward_recv_backward used with non-interleaving schedule.
- core.pipeline_parallel.schedules.set_current_microbatch(model, microbatch_id)
Set the current microbatch.