core.pipeline_parallel.combined_1f1b#

Module Contents#

Functions#

combined_1f1b_schedule_for_no_pipelining

Scheduler for 1f1b with no pipelining.

combined_1f1b_schedule_for_interleaved_pipelining

Helper method to run combined forward and backward step for A2A communication hiding. This method merges the functionality of forward_step_helper and backward_step_helper and eventually calls combined_forward_backward_step method defined in combined_1f1b.py. This method is called only if overlap_moe_expert_parallel_comm is true.

combined_forward_backward_step

Merged forward and backward step for combined 1f1b scheduler.

Data#

API#

core.pipeline_parallel.combined_1f1b.Shape#

None

core.pipeline_parallel.combined_1f1b.combined_1f1b_schedule_for_no_pipelining(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
output_tensor_grad,
forward_data_store,
config,
collect_non_loss_data,
first_val_step,
forward_only,
no_sync_func,
total_num_tokens,
check_first_val_step,
)#

Scheduler for 1f1b with no pipelining.

This function schedules micro-batches in a way that the forward pass of Transformer layers for one micro-batch runs in parallel with the backward pass of another. Each layer’s forward and backward operations are co-scheduled to maximize the overlap of their computations and communications. EP A2A in forward step is hidden by the attention/mlp computation in the backward step, and vice versa. Assuming we have 4 microbatches, the schedule is as follows: Phases 0: 1st microbatch forward Phases 1: 1st microbatch backward + 2nd microbatch forward Phases 2: 2nd microbatch backward + 3rd microbatch forward Phases 3: 3rd microbatch backward + 4th microbatch forward Phases 4: 4th microbatch backward

core.pipeline_parallel.combined_1f1b.combined_1f1b_schedule_for_interleaved_pipelining(
config,
forward_step_func,
data_iterator,
model,
num_microbatches,
forward_data_store,
forward_step_helper_preprocess,
forward_step_helper_postprocess,
backward_step_helper_preprocess,
backward_step_helper_postprocess,
get_microbatch_id_in_model_chunk,
get_model_chunk_id,
check_first_val_step,
is_first_microbatch_for_model_chunk,
collect_non_loss_data,
f_virtual_microbatch_id=None,
b_virtual_microbatch_id=None,
pre_forward=None,
pre_backward=None,
post_forward=None,
post_backward=None,
)#

Helper method to run combined forward and backward step for A2A communication hiding. This method merges the functionality of forward_step_helper and backward_step_helper and eventually calls combined_forward_backward_step method defined in combined_1f1b.py. This method is called only if overlap_moe_expert_parallel_comm is true.

Parameters:
  • groups (The arguments could be categorized into 2)

  • arguments (- Common) –

    • f_virtual_microbatch_id, b_virtual_microbatch_id,

  • combined_forward_backward_step() (- Arguments for) –

    • config, forward_step_func, data_iterator, model, num_microbatches, forward_data_store

    • check_first_val_step, is_first_microbatch_for_model_chunk, collect_non_loss_data

    • pre_forward, pre_backward, post_forward, post_backward

  • forward_step_helper (- Callables for the) –

    • forward_step_helper_preprocess, forward_step_helper_postprocess

    • backward_step_helper_preprocess, backward_step_helper_postprocess

    • get_microbatch_id_in_model_chunk, get_model_chunk_id

Returns:

The output object(s) from the forward step. input_tensor_grad (Tensor): The grad of the input tensor.

Return type:

output_tensor (Tensor or list[Tensor])

Descriptions: This method merges the forward_step_helper() and backward_step_helper() in schedules.py. Assuming that: def forward_step_helper(): # forward_step_helper_preprocess() # forward_step() # forward_step_helper_postprocess() def backward_step_helper(): # backward_step_helper_preprocess() # backward_step() # backward_step_helper_postprocess() Then the combined_1f1b_schedule_for_interleaved_pipelining() method will be: def combined_1f1b_schedule_for_interleaved_pipelining(): # forward_step_helper_preprocess() # backward_step_helper_preprocess() # combined_forward_backward_step() // merged forward_step() and backward_step() # forward_step_helper_postprocess() # backward_step_helper_postprocess()

core.pipeline_parallel.combined_1f1b.combined_forward_backward_step(
forward_step_func,
data_iterator,
f_model,
num_microbatches,
input_tensor,
forward_data_store,
b_model,
b_input_tensor,
b_output_tensor,
b_output_tensor_grad,
config,
f_model_chunk_id=None,
pre_forward=None,
pre_backward=None,
post_forward=None,
post_backward=None,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
encoder_decoder_xattn=False,
)#

Merged forward and backward step for combined 1f1b scheduler.

Parameters:
  • forward_step (Need to accept the argument of both)

  • forward_step_func (callable) – A function returning a forward schedule plan which is an input of schedule_chunk_1f1b function.

  • overlap. (Only exists in 1f1b steady state with p2p) – pre_forward (callable): The function to call before the forward_step. pre_backward (callable): The function to call before the backward_step. post_forward (callable): The function to call after the forward_step. post_backward (callable): The function to call after the backward_step.

Returns:

The output object(s) from the forward step. forward_num_tokens (Tensor): The number of tokens. backward_input_tensor_grad (Tensor): The grad of the input tensor.

Return type:

forward_output_tensor (Tensor or list[Tensor])

Descriptions: This method merges the forward_step() and backward_step() methods in the schedules.py file. Assuming that: def forward_step(): # forward_preprocess() # forward_compute() # forward_postprocess() def backward_step(): # backward_preprocess() # backward_compute() # backward_postprocess() Then the forward_backward_step() method will be: def forward_backward_step(): # forward_preprocess() // the same as the forward_step() # GENERATE f_schedule_plan // schedule happens in schedule_chunk_1f1b() # backward_preprocess() // the same as the backward_step() # COMBINED_FORWARD_BACKWARD_COMPUTE() // by calling schedule_chunk_1f1b() # forward_postprocess() // the same as the forward_step() # backward_postprocess() // the same as the backward_step()