pipeline_parallel package

This package contains implementations for two different pipeline parallelism schedules (one without interleaving and one with interleaving, see Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM for details), and a default no-pipelining schedule. It also contains methods for the point-to-point communication that is needed between pipeline stages.

Contains implementations for the various point-to-point communication needed (e.g., recv_forward and recv_backward) in the different pipeline parallelism schedules.

core.pipeline_parallel.p2p_communication.recv_backward(tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig) → torch.Tensor

Receive tensor from next rank in pipeline (backward receive).

See _communicate for argument details.

core.pipeline_parallel.p2p_communication.recv_forward(tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig) → torch.Tensor

Receive tensor from previous rank in pipeline (forward receive).

See _communicate for argument details.

core.pipeline_parallel.p2p_communication.send_backward(input_tensor_grad: torch.Tensor, config: megatron.core.ModelParallelConfig) → None

Send tensor to previous rank in pipeline (backward send).

See _communicate for argument details.

core.pipeline_parallel.p2p_communication.send_backward_recv_backward(input_tensor_grad: torch.Tensor, recv_next: bool, tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig, overlap_p2p_comm: bool = False) → torch.Tensor

Batched recv from next rank and send to previous rank in pipeline.

See _communicate for argument details.

core.pipeline_parallel.p2p_communication.send_backward_recv_forward(input_tensor_grad: torch.Tensor, tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig) → torch.Tensor

Batched send and recv with previous rank in pipeline.

See _communicate for argument details.

core.pipeline_parallel.p2p_communication.send_forward(output_tensor: torch.Tensor, config: megatron.core.ModelParallelConfig) → None

Send tensor to next rank in pipeline (forward send).

See _communicate for argument details.

core.pipeline_parallel.p2p_communication.send_forward_backward_recv_forward_backward(output_tensor: torch.Tensor, input_tensor_grad: torch.Tensor, recv_prev: bool, recv_next: bool, tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig) → torch.Tensor

Batched send and recv with previous and next ranks in pipeline.

See _communicate for argument details.

core.pipeline_parallel.p2p_communication.send_forward_recv_backward(output_tensor: torch.Tensor, tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig) → torch.Tensor

Batched send and recv with next rank in pipeline.

See _communicate for argument details.

core.pipeline_parallel.p2p_communication.send_forward_recv_forward(output_tensor: torch.Tensor, recv_prev: bool, tensor_shape: Union[List[int], torch.Size], config: megatron.core.ModelParallelConfig, overlap_p2p_comm: bool = False) → torch.Tensor

Batched recv from previous rank and send to next rank in pipeline.

See _communicate for argument details.

Contains implementations for two pipeline parallelism schedules (forward_backward_pipelining_with_interleaving`for pipeline parallelism with interleaving, `forward_backward_pipelining_without_interleaving for pipeline parallelism without interleaving) and a default no-pipelining schedule (forward_backward_no_pipelining). get_forward_backward_func returns the right scheduling function to use based on the configuration being trained (e.g., if pipeline-parallel size is 1, use forward_backward_no_pipelining).

core.pipeline_parallel.schedules.backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)

Backward step through passed-in output tensor.

If last stage, output_tensor_grad is None, otherwise gradient of loss with respect to stage’s output tensor.

Returns gradient of loss with respect to input tensor (None if first stage).

core.pipeline_parallel.schedules.check_first_val_step(first_val_step, forward_only, cond)

core.pipeline_parallel.schedules.custom_backward(output, grad_output)

Directly call C++ autograd engine.

To make the ‘deallocate_output_tensor’ (above) optimization work, the C++ autograd engine must be called directly, bypassing Pytorch’s torch.autograd.backward. Pytorch’s ‘backward’ checks that the output and grad have the same shape, while C++’s ‘backward’ does not.

core.pipeline_parallel.schedules.deallocate_output_tensor(out, deallocate_pipeline_outputs=False)

Pseudo-deallocate (i.e., set to scalar) the output tensor’s ‘.data’ field.

This method should be called right after the output tensor has been sent to the next pipeline stage. At this point, the output tensor is only useful for its ‘.grad_fn’ field, and not its ‘.data’.

core.pipeline_parallel.schedules.forward_backward_no_pipelining(*, forward_step_func, data_iterator: Union[Iterator, List[Iterator]], model: Union[torch.nn.Module, List[torch.nn.Module]], num_microbatches: int, seq_length: int, micro_batch_size: int, decoder_seq_length: Optional[int] = None, forward_only: bool = False, collect_non_loss_data: bool = False, first_val_step: Optional[bool] = None)

Run forward and backward passes with no pipeline parallelism (no inter-stage communication).

Returns dictionary with losses.

See get_forward_backward_func() for argument details

core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving(*, forward_step_func, data_iterator: Union[Iterator, List[Iterator]], model: Union[torch.nn.Module, List[torch.nn.Module]], num_microbatches: int, seq_length: int, micro_batch_size: int, decoder_seq_length: Optional[int] = None, forward_only: bool = False, collect_non_loss_data: bool = False, first_val_step: Optional[bool] = None)

Run interleaved 1F1B schedule (model split into model chunks), with communication between pipeline stages as needed.

Returns dictionary with losses if the last stage, empty dict otherwise.

core.pipeline_parallel.schedules.forward_backward_pipelining_without_interleaving(*, forward_step_func, data_iterator: Union[Iterator, List[Iterator]], model: Union[torch.nn.Module, List[torch.nn.Module]], num_microbatches: int, seq_length: int, micro_batch_size: int, decoder_seq_length: Optional[int] = None, forward_only: bool = False, collect_non_loss_data: bool = False, first_val_step: Optional[bool] = None)

Run non-interleaved 1F1B schedule, with communication between pipeline stages.

Returns dictionary with losses if the last stage, empty dict otherwise.

core.pipeline_parallel.schedules.forward_step(forward_step_func, data_iterator, model, num_microbatches, input_tensor, forward_data_store, config, collect_non_loss_data=False, checkpoint_activations_microbatch=None, is_first_microbatch=False)

Forward step for passed-in model.

If first stage, input tensor is obtained from data_iterator, otherwise passed-in input_tensor is used.

Returns output tensor.

core.pipeline_parallel.schedules.get_forward_backward_func()

Retrieves the appropriate forward_backward function given the configuration of parallel_state.

Returns a function that will perform all of the forward and backward passes of the model given the pipeline model parallel world size and virtual pipeline model parallel world size in the global parallel_state.

Note that if using sequence parallelism, the sequence length component of the tensor shape is updated to original_sequence_length / tensor_model_parallel_world_size.

The function returned takes the following arguments:

forward_step_func (required): A function that takes a data

iterator and a model as its arguments and return the model’s forward output and the loss function. The loss function should take one torch.Tensor and return a torch.Tensor of loss and a dictionary of string -> torch.Tensor.

A third argument, checkpoint_activations_microbatch, indicates that the activations for this microbatch should be checkpointed. A None value for this argument indicates that the default from the configuration should be used. This is used when the num_microbatches_with_partial_activation_checkpoints is used.

For example:

def loss_func(loss_mask, output_tensor):

losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

# Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss])

return loss, {‘lm loss’: averaged_loss[0]}

def forward_step(data_iterator, model):

data, loss_mask = next(data_iterator) output = model(data) return output, partial(loss_func, loss_mask)

forward_backward_func(forward_step_func=forward_step, …)

data_iterator (required): an iterator over the data, will be

passed as is to forward_step_func. Expected to be a list of iterators in the case of interleaved pipeline parallelism.

model (required): the actual model. Expected to be a list of modules in the case of interleaved

pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.

num_microbatches (int, required):

The number of microbatches to go through

seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack

transformer, this is the encoder’s sequence length. This is ignored if variable_seq_lengths in the config is True. Otherwise, each microbatch in the current global batch size must use this sequence length.

micro_batch_size (int, required): The number of sequences in a microbatch.

decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack

transformer. This is ignored for a single-stack transformer.

forward_only (optional, default = False): Perform only the forward step

collect_non_loss_data (optional, bool, default=False): TODO

first_val_step (bool, optional): Is the first step of the validation phase. Used by

Transformer Engine modules to only update their fp8 weights only on the first validation step.

core.pipeline_parallel.schedules.get_tensor_shapes(*, rank: int, model_type: megatron.core.enums.ModelType, seq_length: int, micro_batch_size: int, decoder_seq_length: int, config)

core.pipeline_parallel.schedules.recv_backward(tensor_shapes, config)

core.pipeline_parallel.schedules.recv_forward(tensor_shapes, config)

core.pipeline_parallel.schedules.send_backward(input_tensor_grads, tensor_shapes, config)

core.pipeline_parallel.schedules.send_backward_recv_forward(input_tensor_grads, tensor_shapes, config)

core.pipeline_parallel.schedules.send_forward(output_tensors, tensor_shapes, config)

core.pipeline_parallel.schedules.send_forward_recv_backward(output_tensors, tensor_shapes, config)

Previous Context parallelism overview
Next fusions package
© Copyright 2022-2024, NVIDIA. Last updated on Mar 16, 2024.