core.models.gpt.fine_grained_callables#

Module Contents#

Classes#

TransformerLayerState

State shared within a transformer layer.

PreProcessNode

Node responsible for preprocessing operations in the model.

PostProcessNode

Node responsible for postprocessing operations in the model.

TransformerLayerNode

Base class for transformer layer computation nodes.

Functions#

weak_method

Creates a weak reference to a method to prevent circular references.

should_free_input

Determine if the node should free its input memory.

build_transformer_layer_callables

Create callables for transformer layer nodes. Divides the transformer layer’s operations into a sequence of smaller, independent functions. This decomposition separates computation-heavy tasks (e.g., self-attention, MLP) from communication-heavy tasks (e.g., MoE’s All-to-All).

build_mtp_layer_callables

Callables for multi-token prediction layer nodes.

build_layer_callables

Builds the callable functions(forward and dw) for the given layer. For now, 1f1b overlap only support TransformerLayer and MultiTokenPredictionLayer.

API#

core.models.gpt.fine_grained_callables.weak_method(method)#

Creates a weak reference to a method to prevent circular references.

This function creates a weak reference to a method and returns a wrapper function that calls the method when invoked. This helps prevent memory leaks from circular references.

core.models.gpt.fine_grained_callables.should_free_input(name, is_moe, is_deepep)#

Determine if the node should free its input memory.

Parameters:
  • name – Node name

  • is_moe – Whether it’s a MoE model

  • is_deepep – Whether it’s a DeepEP model

Returns:

Whether to free input memory

Return type:

bool

class core.models.gpt.fine_grained_callables.TransformerLayerState#

State shared within a transformer layer.

This class holds state that is shared between different nodes within a transformer layer.

class core.models.gpt.fine_grained_callables.PreProcessNode(gpt_model, chunk_state, event, stream)#

Bases: megatron.core.pipeline_parallel.utils.ScheduleNode

Node responsible for preprocessing operations in the model.

This node handles embedding and rotary positional embedding computations before the main transformer layers.

Initialization

Initializes a preprocessing node.

Parameters:
  • gpt_model – The GPT model instance.

  • chunk_state (TransformerChunkState) – State shared within a chunk

  • event – CUDA event for synchronization.

  • stream – CUDA stream for execution.

forward_impl()#

forward pass for pre-processing.

This method handles:

  1. Decoder embedding computation

  2. Rotary positional embedding computation

  3. Sequence length offset computation for flash decoding

Returns:

The processed decoder input tensor.

class core.models.gpt.fine_grained_callables.PostProcessNode(gpt_model, chunk_state, event, stream)#

Bases: megatron.core.pipeline_parallel.utils.ScheduleNode

Node responsible for postprocessing operations in the model.

This node handles final layer normalization and output layer computation after the main transformer layers.

Initialization

Initializes a postprocessing node.

Parameters:
  • gpt_model – The GPT model instance.

  • chunk_state (TransformerChunkState) – State shared within a chunk

  • event – CUDA event for synchronization.

  • stream – CUDA stream for execution.

forward_impl(hidden_states)#

Implements the forward pass for postprocessing.

This method handles:

  1. Final layer normalization

  2. Output layer computation

  3. Loss computation if labels are provided

Parameters:

hidden_states – The hidden states from the transformer layers.

Returns:

The logits or loss depending on whether labels are provided.

class core.models.gpt.fine_grained_callables.TransformerLayerNode(
stream,
event,
layer_state,
chunk_state,
submodule,
name='default',
bwd_dw_callables=None,
extra_args={},
)#

Bases: megatron.core.pipeline_parallel.utils.ScheduleNode

Base class for transformer layer computation nodes.

This class provides common functionality for different types of transformer layer nodes (attention, MLP, etc.)

Initialization

Initialize a transformer layer node.

Parameters:
  • stream (torch.cuda.Stream) – CUDA stream for execution

  • event (torch.cuda.Event) – Synchronization event

  • layer_state (TransformerLayerState) – State shared within a layer

  • chunk_state (TransformerChunkState) – State shared within a chunk

  • submodule (function) – The submodule contain forward and dw function

  • per_batch_state_context (it's the)

  • nullcontext (o.w.)

  • name (str) – Node name, also used to determine memory strategy

  • bwd_dw_callables (list) – List of weight gradient functions for the layer.

  • extra_args (dict) – Extra arguments for the node: is_moe, enable_deepep.

detach(t)#

Detaches a tensor and stores it for backward computation.

forward_impl(*args)#

Calls the submodule as the forward pass.

backward_impl(outputs, output_grad)#

Implements the backward pass for the transformer layer node.

backward_dw()#

Computes the weight gradients for the transformer layer node.

_release_state()#
core.models.gpt.fine_grained_callables.build_transformer_layer_callables(
layer: megatron.core.transformer.transformer_layer.TransformerLayer,
)#

Create callables for transformer layer nodes. Divides the transformer layer’s operations into a sequence of smaller, independent functions. This decomposition separates computation-heavy tasks (e.g., self-attention, MLP) from communication-heavy tasks (e.g., MoE’s All-to-All).

The five callables are:

  1. Attention (computation)

  2. Post-Attention (computation)

  3. MoE Dispatch (communication)

  4. MLP / MoE Experts (computation)

  5. MoE Combine (communication)

By assigning these functions to different CUDA streams (e.g., a compute stream and a communication stream), the scheduler can overlap their execution, preventing tasks from competing for resources and hiding communication latency by running them in parallel with functions from other micro-batches.

Parameters:

layer – The transformer layer to build callables for.

Returns:

  • forward_funcs: List of callable functions for the layer

  • backward_dw: Dict of weight gradient functions for the layer

Return type:

A tuple containing

core.models.gpt.fine_grained_callables.build_mtp_layer_callables(layer)#

Callables for multi-token prediction layer nodes.

This class contains the callable functions for different types of multi-token prediction layer nodes (attention, MLP, etc.)

core.models.gpt.fine_grained_callables.build_layer_callables(layer)#

Builds the callable functions(forward and dw) for the given layer. For now, 1f1b overlap only support TransformerLayer and MultiTokenPredictionLayer.

Parameters:

layer – The layer to build callables for.

Returns:

list of callable functions for the layer. backward_dw: dict of weight gradient functions for the layer.

Return type:

forward_funcs