core.models.common.model_chunk_schedule_plan#
Module Contents#
Classes#
State shared across a model chunk. |
|
Schedule the executing plan of the nodes in a transformer/mtp layer. |
|
Schedule the executing plan of the sub-modules in a model chunk sub-modules. |
API#
- class core.models.common.model_chunk_schedule_plan.ModelChunkState#
State shared across a model chunk.
This class holds state that is shared between different components of a model chunk, such as input tensors, parameters, and configuration.
- class core.models.common.model_chunk_schedule_plan.TransformerLayerSchedulePlan(
- layer,
- event,
- chunk_state,
- comp_stream,
- comm_stream,
- extra_args={},
Schedule the executing plan of the nodes in a transformer/mtp layer.
This class organizes the sub-modules of a transformer/mtp layer, including attention, post attention, MLP, dispatch, combine and mtp post process nodes.
layer (TransformerLayerSchedulePlan) ├── attn (TransformerLayerNode): attention module ├── post_attn (TransformerLayerNode): layernorm -> router -> dispatch preprocess ├── moe_dispatch (TransformerLayerNode): dispatch All2All ├── mlp (TransformerLayerNode): mlp module ├── moe_combine (TransformerLayerNode): combine All2All └── mtp_post_process (PostProcessNode): mtp post process
Note that MTP layer has the same operation and execution order with TransformerLayer regarding post_attn, moe_dispatch, mlp, moe_combine, but contains extra operations in attn and mtp_post_process:
mtp.attn wraps around transformer_layer.attn with extra norm, proj and embedding operations.
mtp.mtp_post_process contains output_layer, mtp loss operations, whereas transformer_layer.mtp_post_process is empty.
Initialization
Initializes a transformer layer schedule plan.
- Parameters:
layer (TransformerLayer) – split a transformer layer into multiple nodes for fine-grained scheduling.
event (torch.cuda.Event) – record CUDA event across multiple nodes on different streams for synchronization.
chunk_state (ModelChunkState) – model state shared in the model chunk.
comp_stream (torch.cuda.Stream) – CUDA stream for computation.
comm_stream (torch.cuda.Stream) – CUDA stream for communication.
extra_args (dict) – extra arguments for the layer.
The event and chunk_state are binded to the TransformerModelChunkSchedulePlan and shared across all layers in the model chunk.
- attn#
None
- post_attn#
None
- moe_dispatch#
None
- mlp#
None
- moe_combine#
None
- mtp_post_process#
None
- _build_callable_nodes(event, comp_stream, comm_stream, extra_args)#
Builds the callable nodes for the transformer/mtp layer: attn, post_attn, mlp, moe_dispatch and moe_combine, and mtp_post_process.
- get_fp8_context()#
Get the fp8 context for the transformer layer.
- static run(
- f_layer,
- b_layer,
- f_input=None,
- b_grad=None,
- is_last_layer_in_bwd=False,
Schedule one-forward-one-backward operations for a single transformer layer.
This function interleaves forward and backward operations, overlapping the communications (dispatch or combine) of one with the computations (att or mlp) of the other to maximize parallelism and efficiency.
When f_layer and b_layer are not None, forward and backward pass are overlapped as follows: comm_stream: combine_bwd | dispatch_fwd->dispatch_bwd | combine_fwd comp_stream: attn_fwd->post_attn_fwd| mlp_bwd->mlp_bwd_dw->mlp_fwd| post_attn_bwd->attn_bwd For MTP, mtp_post_process_fwd is executed after the combine_fwd in the comp_stream, and mtp_post_process_bwd is executed before the combine_bwd in the comp_stream.
- Parameters:
f_layer (TransformerLayerSchedulePlan) – Forward layer (for current microbatch)
b_layer (TransformerLayerSchedulePlan) – Backward layer (for previous microbatch)
f_input (Tensor) – Input for forward computation
b_grad (Tensor) – Gradient for backward computation
is_last_layer_in_bwd (bool) – Whether the current layer is the last layer in the backward pass.
- Returns:
Functions or values for next iteration’s computation
- class core.models.common.model_chunk_schedule_plan.TransformerModelChunkSchedulePlan(
- model,
- input_ids: torch.Tensor,
- position_ids: torch.Tensor,
- attention_mask: torch.Tensor,
- decoder_input: torch.Tensor = None,
- labels: torch.Tensor = None,
- packed_seq_params=None,
- extra_block_kwargs=None,
- runtime_gather_output: Optional[bool] = None,
- loss_mask: Optional[torch.Tensor] = None,
Bases:
megatron.core.pipeline_parallel.utils.AbstractSchedulePlanSchedule the executing plan of the sub-modules in a model chunk sub-modules.
This class organizes the computation nodes for a model chunk, including preprocessing, transformer layers, and postprocessing.
TransformerModelChunkSchedulePlan ├── pre_process: PreProcessNode ├── layers: List[TransformerLayerSchedulePlan] │ ├── layer[0]: TransformerLayerSchedulePlan │ ├── layer[1]: TransformerLayerSchedulePlan │ └── … └── post_process: PostProcessNode
Initialization
Initialize the schedule plan of all Transformer layers’ sub-modules.
This function creates a schedule plan for a model chunk, including preprocessing, transformer layers, and postprocessing.
- Parameters:
model – The model to build a schedule plan for.
input_ids – Input token IDs.
position_ids – Position IDs.
attention_mask – Attention mask.
decoder_input – Decoder input tensor.
labels – Labels for loss computation.
packed_seq_params – Parameters for packed sequences.
extra_block_kwargs – Additional keyword arguments for blocks.
runtime_gather_output – Whether to gather output at runtime.
loss_mask (torch.Tensor) – Used to mask out some portions of the loss
- Returns:
The model chunk schedule plan.
- property event#
Gets the CUDA event for synchronization.
- record_current_stream()#
Records the current CUDA stream in the event.
- wait_current_stream()#
Waits for the event to complete on the current CUDA stream.
- get_layer(i)#
Gets the transformer layer at the specified index.
- num_layers()#
Gets the number of transformer layers.
- property state#
Gets the model chunk state.
- release_state()#
Release reference, this helps avoid memory leak.
- static run(
- f_schedule_plan,
- b_schedule_plan,
- b_grad=None,
- pre_forward=None,
- pre_backward=None,
- post_forward=None,
- post_backward=None,
Model Chunk level 1f1b fine-grained scheduler.
This function schedules the forward and backward passes for a model chunk, which interleaves forward and backward function of multiple Transformer layers within a model chunk, and this is needed to overlap the submodules between the individual forward and backward functions.
Assume there are 4 layers in the given model chunk: Phase 0: p2p_comm_sync -> forward_preprocess -> p2p_comm_sync -> backward_postprocess Phase 1: forward_layer[0] + backward_layer[3], overlapped execution by schedule_layer_1f1b Phase 2: forward_layer[1] + backward_layer[2], overlapped execution by schedule_layer_1f1b Phase 3: forward_layer[2] + backward_layer[1], overlapped execution by schedule_layer_1f1b Phase 4: forward_layer[3] + backward_layer[0], overlapped execution by schedule_layer_1f1b Phase 5: send_forward_recv_backward -> send_backward_recv_forward Phase 6: backward_dw of the first layer -> forward_postprocess -> backward_preprocess
- Parameters:
f_schedule_plan (TransformerModelChunkSchedulePlan) – The forward schedule plan
b_schedule_plan (TransformerModelChunkSchedulePlan) – The backward schedule plan
b_grad (Tensor or None) – The gradient of the loss function
pre_forward (callable or None) – The function to call before the forward pass
pre_backward (callable or None) – The function to call before the backward pass
post_forward (callable or None) – The function to call after the forward pass
post_backward (callable or None) – The function to call after the backward pass
- Returns:
The output of the forward pass.