core.models.gpt.fine_grained_callables#
Module Contents#
Classes#
State shared within a transformer layer. |
|
Node responsible for preprocessing operations in the model. |
|
Node responsible for postprocessing operations in the model. |
|
Base class for transformer layer computation nodes. |
Functions#
Creates a weak reference to a method to prevent circular references. |
|
Determine if the node should free its input memory. |
|
Create callables for transformer layer nodes. Divides the transformer layer’s operations into a sequence of smaller, independent functions. This decomposition separates computation-heavy tasks (e.g., self-attention, MLP) from communication-heavy tasks (e.g., MoE’s All-to-All). |
|
Callables for multi-token prediction layer nodes. |
|
Builds the callable functions(forward and dw) for the given layer. For now, 1f1b overlap only support TransformerLayer and MultiTokenPredictionLayer. |
API#
- core.models.gpt.fine_grained_callables.weak_method(method)#
Creates a weak reference to a method to prevent circular references.
This function creates a weak reference to a method and returns a wrapper function that calls the method when invoked. This helps prevent memory leaks from circular references.
- core.models.gpt.fine_grained_callables.should_free_input(name, is_moe, is_deepep)#
Determine if the node should free its input memory.
- Parameters:
name – Node name
is_moe – Whether it’s a MoE model
is_deepep – Whether it’s a DeepEP model
- Returns:
Whether to free input memory
- Return type:
bool
- class core.models.gpt.fine_grained_callables.TransformerLayerState#
State shared within a transformer layer.
This class holds state that is shared between different nodes within a transformer layer.
- class core.models.gpt.fine_grained_callables.PreProcessNode(gpt_model, chunk_state, event, stream)#
Bases:
megatron.core.pipeline_parallel.utils.ScheduleNodeNode responsible for preprocessing operations in the model.
This node handles embedding and rotary positional embedding computations before the main transformer layers.
Initialization
Initializes a preprocessing node.
- Parameters:
gpt_model – The GPT model instance.
chunk_state (TransformerChunkState) – State shared within a chunk
event – CUDA event for synchronization.
stream – CUDA stream for execution.
- forward_impl()#
forward pass for pre-processing.
This method handles:
Decoder embedding computation
Rotary positional embedding computation
Sequence length offset computation for flash decoding
- Returns:
The processed decoder input tensor.
- class core.models.gpt.fine_grained_callables.PostProcessNode(gpt_model, chunk_state, event, stream)#
Bases:
megatron.core.pipeline_parallel.utils.ScheduleNodeNode responsible for postprocessing operations in the model.
This node handles final layer normalization and output layer computation after the main transformer layers.
Initialization
Initializes a postprocessing node.
- Parameters:
gpt_model – The GPT model instance.
chunk_state (TransformerChunkState) – State shared within a chunk
event – CUDA event for synchronization.
stream – CUDA stream for execution.
- forward_impl(hidden_states)#
Implements the forward pass for postprocessing.
This method handles:
Final layer normalization
Output layer computation
Loss computation if labels are provided
- Parameters:
hidden_states – The hidden states from the transformer layers.
- Returns:
The logits or loss depending on whether labels are provided.
- class core.models.gpt.fine_grained_callables.TransformerLayerNode(
- stream,
- event,
- layer_state,
- chunk_state,
- submodule,
- name='default',
- bwd_dw_callables=None,
- extra_args={},
Bases:
megatron.core.pipeline_parallel.utils.ScheduleNodeBase class for transformer layer computation nodes.
This class provides common functionality for different types of transformer layer nodes (attention, MLP, etc.)
Initialization
Initialize a transformer layer node.
- Parameters:
stream (torch.cuda.Stream) – CUDA stream for execution
event (torch.cuda.Event) – Synchronization event
layer_state (TransformerLayerState) – State shared within a layer
chunk_state (TransformerChunkState) – State shared within a chunk
submodule (function) – The submodule contain forward and dw function
per_batch_state_context (it's the)
nullcontext (o.w.)
name (str) – Node name, also used to determine memory strategy
bwd_dw_callables (list) – List of weight gradient functions for the layer.
extra_args (dict) – Extra arguments for the node: is_moe, enable_deepep.
- detach(t)#
Detaches a tensor and stores it for backward computation.
- forward_impl(*args)#
Calls the submodule as the forward pass.
- backward_impl(outputs, output_grad)#
Implements the backward pass for the transformer layer node.
- backward_dw()#
Computes the weight gradients for the transformer layer node.
- _release_state()#
- core.models.gpt.fine_grained_callables.build_transformer_layer_callables(
- layer: megatron.core.transformer.transformer_layer.TransformerLayer,
Create callables for transformer layer nodes. Divides the transformer layer’s operations into a sequence of smaller, independent functions. This decomposition separates computation-heavy tasks (e.g., self-attention, MLP) from communication-heavy tasks (e.g., MoE’s All-to-All).
The five callables are:
Attention (computation)
Post-Attention (computation)
MoE Dispatch (communication)
MLP / MoE Experts (computation)
MoE Combine (communication)
By assigning these functions to different CUDA streams (e.g., a compute stream and a communication stream), the scheduler can overlap their execution, preventing tasks from competing for resources and hiding communication latency by running them in parallel with functions from other micro-batches.
- Parameters:
layer – The transformer layer to build callables for.
- Returns:
forward_funcs: List of callable functions for the layer
backward_dw: Dict of weight gradient functions for the layer
- Return type:
A tuple containing
- core.models.gpt.fine_grained_callables.build_mtp_layer_callables(layer)#
Callables for multi-token prediction layer nodes.
This class contains the callable functions for different types of multi-token prediction layer nodes (attention, MLP, etc.)
- core.models.gpt.fine_grained_callables.build_layer_callables(layer)#
Builds the callable functions(forward and dw) for the given layer. For now, 1f1b overlap only support TransformerLayer and MultiTokenPredictionLayer.
- Parameters:
layer – The layer to build callables for.
- Returns:
list of callable functions for the layer. backward_dw: dict of weight gradient functions for the layer.
- Return type:
forward_funcs