core.pipeline_parallel.schedules#

Module Contents#

Functions#

get_forward_backward_func

Retrieves the appropriate forward_backward function given the configuration of parallel_state.

deallocate_output_tensor

Pseudo-deallocate (i.e., set to scalar) the output tensor’s ‘.data’ field.

custom_backward

Directly call C++ autograd engine.

set_current_microbatch

Set the current microbatch.

forward_step_calc_loss

Calculate the loss and number of tokens for forward_step()

forward_step

Forward step for passed-in model.

backward_step

Backward step through passed-in output tensor.

check_first_val_step

Check if it is the first validation step.

forward_backward_no_pipelining

Run forward and backward passes with no pipeline parallelism

clear_embedding_activation_buffer

Clear embedding activation buffer.

finish_embedding_wgrad_compute

Finish embedding wgrad compute.

get_pp_rank_microbatches

Get the number of total, warmup, and remaining microbatches in PP scheduling.

get_schedule_table

Get the schedule table for PP scheduling.

convert_schedule_table_to_order

Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below: virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9 microbatch_id | 0 1 2 0 1 2 3 4 3 4 model_chunk_id | 0 0 0 1 1 1 0 0 1 1

forward_backward_pipelining_with_interleaving

Run interleaved 1F1B schedule (model split into model chunks), with communication between pipeline stages as needed.

get_tensor_shapes

Determine right tensor sizes (based on position of rank with respect to split rank) and model size.

forward_backward_pipelining_without_interleaving

Run non-interleaved 1F1B schedule, with communication between pipeline stages. Returns dictionary with losses if the last stage, empty dict otherwise.

Data#

API#

core.pipeline_parallel.schedules.Shape#

None

core.pipeline_parallel.schedules.get_forward_backward_func()#

Retrieves the appropriate forward_backward function given the configuration of parallel_state.

Returns a function that will perform all of the forward and backward passes of the model given the pipeline model parallel world size and virtual pipeline model parallel world size in the global parallel_state.

Note that if using sequence parallelism, the sequence length component of the tensor shape is updated to original_sequence_length / tensor_model_parallel_world_size.

The function returned takes the following arguments:

forward_step_func (required): A function that takes a data iterator and a model as its arguments and return the model’s forward output and the loss function. The loss function should take one torch.Tensor and return a torch.Tensor of loss and a dictionary of string -> torch.Tensor.

A third argument, checkpoint_activations_microbatch, indicates
that the activations for this microbatch should be
checkpointed. A None value for this argument indicates that
the default from the configuration should be used. This is
used when the
num_microbatches_with_partial_activation_checkpoints is used.

For example:

def loss_func(loss_mask, output_tensor):
    losses = output_tensor.float()
    loss_mask = loss_mask.view(-1).float()
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

    # Reduce loss for logging.
    averaged_loss = average_losses_across_data_parallel_group([loss])

    return loss, {'lm loss': averaged_loss[0]}

def forward_step(data_iterator, model):
    data, loss_mask = next(data_iterator)
    output = model(data)
    return output, partial(loss_func, loss_mask)


forward_backward_func(forward_step_func=forward_step, ...)

data_iterator (required): an iterator over the data, will be passed as is to forward_step_func. Expected to be a list of iterators in the case of interleaved pipeline parallelism.

model (required): the actual model. Expected to be a list of modules in the case of interleaved pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.

num_microbatches (int, required): The number of microbatches to go through

seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack transformer, this is the encoder’s sequence length. This is ignored if variable_seq_lengths in the config is True. Otherwise, each microbatch in the current global batch size must use this sequence length.

micro_batch_size (int, required): The number of sequences in a microbatch.

decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack transformer. This is ignored for a single-stack transformer.

forward_only (optional, default = False): Perform only the forward step

collect_non_loss_data (optional, bool, default=False): TODO

first_val_step (bool, optional): Is the first step of the validation phase. Used by Transformer Engine modules to only update their fp8 weights only on the first validation step.

adjust_tensor_shapes_fn (Callable, optional): A function that adjusts the receive and send tensor shapes. Only applicable in forward_backward_pipelining_without_interleaving for now. Takes in a list of receive shapes and a list of send shapes and returns the adjusted respective list of shapes. Thus it is not used in the other forward-backward functions which have different shape handling.

core.pipeline_parallel.schedules.deallocate_output_tensor(out, deallocate_pipeline_outputs=False)#

Pseudo-deallocate (i.e., set to scalar) the output tensor’s ‘.data’ field.

This method should be called right after the output tensor has been sent to the next pipeline stage. At this point, the output tensor is only useful for its ‘.grad_fn’ field, and not its ‘.data’.

core.pipeline_parallel.schedules.custom_backward(output, grad_output)#

Directly call C++ autograd engine.

To make the ‘deallocate_output_tensor’ (above) optimization work, the C++ autograd engine must be called directly, bypassing Pytorch’s torch.autograd.backward. Pytorch’s ‘backward’ checks that the output and grad have the same shape, while C++’s ‘backward’ does not.

core.pipeline_parallel.schedules.set_current_microbatch(model, microbatch_id)#

Set the current microbatch.

core.pipeline_parallel.schedules.forward_step_calc_loss(
model,
output_tensor,
loss_func,
config,
vp_stage,
collect_non_loss_data,
num_microbatches,
forward_data_store,
cp_group_size=None,
is_last_stage=None,
)#

Calculate the loss and number of tokens for forward_step()

core.pipeline_parallel.schedules.forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
cp_group_size,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
vp_stage=None,
is_last_stage=True,
)#

Forward step for passed-in model.

If it is the first stage, the input tensor is obtained from the data_iterator. Otherwise, the passed-in input_tensor is used.

Parameters:
  • forward_step_func (callable) –

    The forward step function for the model that takes the data iterator as the first argument, and model as the second. This user’s forward step is expected to output a tuple of two elements:

    1. The output object from the forward step. This output object needs to be a
        tensor or some kind of collection of tensors. The only hard requirement
        for this object is that it needs to be acceptible as input into the second
        function.
    2. A function to reduce (optionally) the output from the forward step. This
        could be a reduction over the loss from the model, it could be a function that
        grabs the output from the model and reformats, it could be a function that just
        passes through the model output. This function must have one of the following
        patterns, and depending on the pattern different things happen internally:
    
            a. A tuple of reduced loss and some other data. Note that in this case
                the first argument is divided by the number of global microbatches,
                assuming it is a loss, so that the loss is stable as a function of
                the number of devices the step is split across.
            b. A triple of reduced loss, number of tokens, and some other data. This
                is similar to case (a), but the loss is further averaged across the
                number of tokens in the batch. If the user is not already averaging
                across the number of tokens, this pattern is useful to use.
            c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
                of tensors, etc in the case of inference). To trigger case 3 you need
                to specify `collect_non_loss_data=True` and you may also want to
                specify `forward_only=True` in the call to the parent forward_backward
                function.
    

  • data_iterator (iterator) – The data iterator.

  • model (nn.Module) – The model to perform the forward step on.

  • num_microbatches (int) – The number of microbatches.

  • input_tensor (Tensor or list[Tensor]) – The input tensor(s) for the forward step.

  • forward_data_store (list) – The list to store the forward data. If you go down path 2.a or 2.b for the return of your forward reduction function then this will store only the final dimension of the output, for example the metadata output by the loss function. If you go down the path of 2.c then this will store the entire output of the forward reduction function applied to the model output.

  • config (object) – The configuration object.

  • collect_non_loss_data (bool, optional) – Whether to collect non-loss data. Defaults to False. This is the path to use if you want to collect arbitrary output from the model forward, such as with inference use cases. Defaults to False.

  • checkpoint_activations_microbatch (int, optional) – The microbatch to checkpoint activations. Defaults to None.

  • is_first_microbatch (bool, optional) – Whether it is the first microbatch. Defaults to False.

  • current_microbatch (int, optional) – The current microbatch. Defaults to None.

  • vp_stage (int, optional) – The virtual pipeline stage. Defaults to None.

  • is_last_stage (bool, optional) – Whether it is the last stage. Defaults to True. Also considering virtual stages. In case of PP/VPP, is_last_stage/is_vp_last_stage.

Returns:

The output object(s) from the forward step. Tensor: The number of tokens.

Return type:

Tensor or list[Tensor]

core.pipeline_parallel.schedules.backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type,
config,
)#

Backward step through passed-in output tensor.

If last stage, output_tensor_grad is None, otherwise gradient of loss with respect to stage’s output tensor.

Returns gradient of loss with respect to input tensor (None if first stage).

core.pipeline_parallel.schedules.check_first_val_step(first_val_step, forward_only, cond)#

Check if it is the first validation step.

core.pipeline_parallel.schedules.forward_backward_no_pipelining(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: Optional[int] = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: Optional[bool] = None,
adjust_tensor_shapes_fn: Optional[Callable] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Run forward and backward passes with no pipeline parallelism

core.pipeline_parallel.schedules.clear_embedding_activation_buffer(config, model, is_last_stage)#

Clear embedding activation buffer.

core.pipeline_parallel.schedules.finish_embedding_wgrad_compute(
config,
embedding_module,
is_last_stage,
tp_group,
)#

Finish embedding wgrad compute.

core.pipeline_parallel.schedules.get_pp_rank_microbatches(
num_microbatches,
num_model_chunks,
microbatch_group_size_per_vp_stage,
forward_only=False,
overlap_moe_expert_parallel_comm=False,
p2p_communicator: Optional[megatron.core.pipeline_parallel.p2p_communication.P2PCommunicator] = None,
)#

Get the number of total, warmup, and remaining microbatches in PP scheduling.

core.pipeline_parallel.schedules.get_schedule_table(
num_microbatches,
num_model_chunks,
microbatch_group_size_per_vp_stage,
)#

Get the schedule table for PP scheduling.

core.pipeline_parallel.schedules.convert_schedule_table_to_order(
num_warmup_microbatches,
num_model_chunks,
schedule_table,
)#

Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below: virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9 microbatch_id | 0 1 2 0 1 2 3 4 3 4 model_chunk_id | 0 0 0 1 1 1 0 0 1 1

Then the forward backward separated order is: forward | 1 1 1 2 2 2 1 1 2 2 backward | -2 -2 -2 -1 -1 -1 -2 -2 -1 -1

If num_warmup_microbatches is 5, the output order is: 1 1 1 2 2 2 -2 1 -2 1 -2 2 -1 2 -1 -1 -2 -2 -1 -1

core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: Optional[int] = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: Optional[bool] = None,
adjust_tensor_shapes_fn: Optional[Callable] = None,
p2p_communicator: Optional[megatron.core.pipeline_parallel.p2p_communication.P2PCommunicator] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Run interleaved 1F1B schedule (model split into model chunks), with communication between pipeline stages as needed.

Returns dictionary with losses if the last stage, empty dict otherwise.

core.pipeline_parallel.schedules.get_tensor_shapes(
*,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int,
config,
tp_group: torch.distributed.ProcessGroup,
cp_group: torch.distributed.ProcessGroup,
)#

Determine right tensor sizes (based on position of rank with respect to split rank) and model size.

core.pipeline_parallel.schedules.forward_backward_pipelining_without_interleaving(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: Optional[int] = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: Optional[bool] = None,
adjust_tensor_shapes_fn: Optional[Callable] = None,
p2p_communicator: Optional[megatron.core.pipeline_parallel.p2p_communication.P2PCommunicator] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Run non-interleaved 1F1B schedule, with communication between pipeline stages. Returns dictionary with losses if the last stage, empty dict otherwise.