core.models.common.model_chunk_schedule_plan#

Module Contents#

Classes#

ModelChunkState

State shared across a model chunk.

TransformerLayerSchedulePlan

Schedule the executing plan of the nodes in a transformer/mtp layer.

TransformerModelChunkSchedulePlan

Schedule the executing plan of the sub-modules in a model chunk sub-modules.

API#

class core.models.common.model_chunk_schedule_plan.ModelChunkState#

State shared across a model chunk.

This class holds state that is shared between different components of a model chunk, such as input tensors, parameters, and configuration.

class core.models.common.model_chunk_schedule_plan.TransformerLayerSchedulePlan(
layer,
event,
chunk_state,
comp_stream,
comm_stream,
extra_args={},
)#

Schedule the executing plan of the nodes in a transformer/mtp layer.

This class organizes the sub-modules of a transformer/mtp layer, including attention, post attention, MLP, dispatch, combine and mtp post process nodes.

layer (TransformerLayerSchedulePlan) ├── attn (TransformerLayerNode): attention module ├── post_attn (TransformerLayerNode): layernorm -> router -> dispatch preprocess ├── moe_dispatch (TransformerLayerNode): dispatch All2All ├── mlp (TransformerLayerNode): mlp module ├── moe_combine (TransformerLayerNode): combine All2All └── mtp_post_process (PostProcessNode): mtp post process

Note that MTP layer has the same operation and execution order with TransformerLayer regarding post_attn, moe_dispatch, mlp, moe_combine, but contains extra operations in attn and mtp_post_process:

  • mtp.attn wraps around transformer_layer.attn with extra norm, proj and embedding operations.

  • mtp.mtp_post_process contains output_layer, mtp loss operations, whereas transformer_layer.mtp_post_process is empty.

Initialization

Initializes a transformer layer schedule plan.

Parameters:
  • layer (TransformerLayer) – split a transformer layer into multiple nodes for fine-grained scheduling.

  • event (torch.cuda.Event) – record CUDA event across multiple nodes on different streams for synchronization.

  • chunk_state (ModelChunkState) – model state shared in the model chunk.

  • comp_stream (torch.cuda.Stream) – CUDA stream for computation.

  • comm_stream (torch.cuda.Stream) – CUDA stream for communication.

  • extra_args (dict) – extra arguments for the layer.

The event and chunk_state are binded to the TransformerModelChunkSchedulePlan and shared across all layers in the model chunk.

attn#

None

post_attn#

None

moe_dispatch#

None

mlp#

None

moe_combine#

None

mtp_post_process#

None

_build_callable_nodes(event, comp_stream, comm_stream, extra_args)#

Builds the callable nodes for the transformer/mtp layer: attn, post_attn, mlp, moe_dispatch and moe_combine, and mtp_post_process.

get_fp8_context()#

Get the fp8 context for the transformer layer.

static run(
f_layer,
b_layer,
f_input=None,
b_grad=None,
is_last_layer_in_bwd=False,
)#

Schedule one-forward-one-backward operations for a single transformer layer.

This function interleaves forward and backward operations, overlapping the communications (dispatch or combine) of one with the computations (att or mlp) of the other to maximize parallelism and efficiency.

When f_layer and b_layer are not None, forward and backward pass are overlapped as follows: comm_stream: combine_bwd | dispatch_fwd->dispatch_bwd | combine_fwd comp_stream: attn_fwd->post_attn_fwd| mlp_bwd->mlp_bwd_dw->mlp_fwd| post_attn_bwd->attn_bwd For MTP, mtp_post_process_fwd is executed after the combine_fwd in the comp_stream, and mtp_post_process_bwd is executed before the combine_bwd in the comp_stream.

Parameters:
  • f_layer (TransformerLayerSchedulePlan) – Forward layer (for current microbatch)

  • b_layer (TransformerLayerSchedulePlan) – Backward layer (for previous microbatch)

  • f_input (Tensor) – Input for forward computation

  • b_grad (Tensor) – Gradient for backward computation

  • is_last_layer_in_bwd (bool) – Whether the current layer is the last layer in the backward pass.

Returns:

Functions or values for next iteration’s computation

class core.models.common.model_chunk_schedule_plan.TransformerModelChunkSchedulePlan(
model,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
decoder_input: torch.Tensor = None,
labels: torch.Tensor = None,
packed_seq_params=None,
extra_block_kwargs=None,
runtime_gather_output: Optional[bool] = None,
loss_mask: Optional[torch.Tensor] = None,
)#

Bases: megatron.core.pipeline_parallel.utils.AbstractSchedulePlan

Schedule the executing plan of the sub-modules in a model chunk sub-modules.

This class organizes the computation nodes for a model chunk, including preprocessing, transformer layers, and postprocessing.

TransformerModelChunkSchedulePlan ├── pre_process: PreProcessNode ├── layers: List[TransformerLayerSchedulePlan] │ ├── layer[0]: TransformerLayerSchedulePlan │ ├── layer[1]: TransformerLayerSchedulePlan │ └── … └── post_process: PostProcessNode

Initialization

Initialize the schedule plan of all Transformer layers’ sub-modules.

This function creates a schedule plan for a model chunk, including preprocessing, transformer layers, and postprocessing.

Parameters:
  • model – The model to build a schedule plan for.

  • input_ids – Input token IDs.

  • position_ids – Position IDs.

  • attention_mask – Attention mask.

  • decoder_input – Decoder input tensor.

  • labels – Labels for loss computation.

  • packed_seq_params – Parameters for packed sequences.

  • extra_block_kwargs – Additional keyword arguments for blocks.

  • runtime_gather_output – Whether to gather output at runtime.

  • loss_mask (torch.Tensor) – Used to mask out some portions of the loss

Returns:

The model chunk schedule plan.

property event#

Gets the CUDA event for synchronization.

record_current_stream()#

Records the current CUDA stream in the event.

wait_current_stream()#

Waits for the event to complete on the current CUDA stream.

get_layer(i)#

Gets the transformer layer at the specified index.

num_layers()#

Gets the number of transformer layers.

property state#

Gets the model chunk state.

release_state()#

Release reference, this helps avoid memory leak.

static run(
f_schedule_plan,
b_schedule_plan,
b_grad=None,
pre_forward=None,
pre_backward=None,
post_forward=None,
post_backward=None,
)#

Model Chunk level 1f1b fine-grained scheduler.

This function schedules the forward and backward passes for a model chunk, which interleaves forward and backward function of multiple Transformer layers within a model chunk, and this is needed to overlap the submodules between the individual forward and backward functions.

Assume there are 4 layers in the given model chunk: Phase 0: p2p_comm_sync -> forward_preprocess -> p2p_comm_sync -> backward_postprocess Phase 1: forward_layer[0] + backward_layer[3], overlapped execution by schedule_layer_1f1b Phase 2: forward_layer[1] + backward_layer[2], overlapped execution by schedule_layer_1f1b Phase 3: forward_layer[2] + backward_layer[1], overlapped execution by schedule_layer_1f1b Phase 4: forward_layer[3] + backward_layer[0], overlapped execution by schedule_layer_1f1b Phase 5: send_forward_recv_backward -> send_backward_recv_forward Phase 6: backward_dw of the first layer -> forward_postprocess -> backward_preprocess

Parameters:
  • f_schedule_plan (TransformerModelChunkSchedulePlan) – The forward schedule plan

  • b_schedule_plan (TransformerModelChunkSchedulePlan) – The backward schedule plan

  • b_grad (Tensor or None) – The gradient of the loss function

  • pre_forward (callable or None) – The function to call before the forward pass

  • pre_backward (callable or None) – The function to call before the backward pass

  • post_forward (callable or None) – The function to call after the forward pass

  • post_backward (callable or None) – The function to call after the backward pass

Returns:

The output of the forward pass.