core.models.gpt.fine_grained_callables#

Module Contents#

Classes#

TransformerLayerState

State shared within a transformer layer.

PreProcessNode

Node responsible for preprocessing operations in the model.

PostProcessNode

Node responsible for postprocessing operations in the model.

TransformerLayerNode

Base class for transformer layer computation nodes.

_BackwardDWWrapper

Wrapper for managing backward weight gradient computation of attn module.

Functions#

weak_method

Creates a weak reference to a method to prevent circular references.

should_free_input

Determine if the node should free its input memory.

build_transformer_layer_callables

Create callables for transformer layer nodes. Divides the transformer layer’s operations into a sequence of smaller, independent functions. This decomposition separates computation-heavy tasks (e.g., self-attention, MLP) from communication-heavy tasks (e.g., MoE’s All-to-All).

build_mtp_layer_callables

Callables for multi-token prediction layer nodes.

build_layer_callables

Builds the callable functions(forward and dw) for the given layer. For now, 1f1b overlap only support TransformerLayer and MultiTokenPredictionLayer.

API#

core.models.gpt.fine_grained_callables.weak_method(method)#

Creates a weak reference to a method to prevent circular references.

This function creates a weak reference to a method and returns a wrapper function that calls the method when invoked. This helps prevent memory leaks from circular references.

core.models.gpt.fine_grained_callables.should_free_input(name, is_moe, config)#

Determine if the node should free its input memory.

Parameters:
  • name – Node name

  • is_moe – Whether it’s a MoE model

  • config – TransformerConfig object

Returns:

Whether to free input memory

Return type:

bool

class core.models.gpt.fine_grained_callables.TransformerLayerState#

State shared within a transformer layer.

This class holds state that is shared between different nodes within a transformer layer.

class core.models.gpt.fine_grained_callables.PreProcessNode(gpt_model, chunk_state, event, stream)#

Bases: megatron.core.pipeline_parallel.utils.ScheduleNode

Node responsible for preprocessing operations in the model.

This node handles embedding and rotary positional embedding computations before the main transformer layers.

Initialization

Initializes a preprocessing node.

Parameters:
  • gpt_model – The GPT model instance.

  • chunk_state (TransformerChunkState) – State shared within a chunk

  • event – CUDA event for synchronization.

  • stream – CUDA stream for execution.

forward_impl()#

forward pass for pre-processing.

This method handles:

  1. Decoder embedding computation

  2. Rotary positional embedding computation

  3. Sequence length offset computation for flash decoding

Returns:

The processed decoder input tensor.

class core.models.gpt.fine_grained_callables.PostProcessNode(gpt_model, chunk_state, event, stream)#

Bases: megatron.core.pipeline_parallel.utils.ScheduleNode

Node responsible for postprocessing operations in the model.

This node handles final layer normalization and output layer computation after the main transformer layers.

Initialization

Initializes a postprocessing node.

Parameters:
  • gpt_model – The GPT model instance.

  • chunk_state (TransformerChunkState) – State shared within a chunk

  • event – CUDA event for synchronization.

  • stream – CUDA stream for execution.

forward_impl(hidden_states)#

Implements the forward pass for postprocessing.

This method handles:

  1. Output layer computation

  2. Loss computation if labels are provided

Parameters:

hidden_states – The hidden states from the transformer layers.

Returns:

The logits or loss depending on whether labels are provided.

class core.models.gpt.fine_grained_callables.TransformerLayerNode(
stream,
event,
layer_state,
chunk_state,
submodule,
name='default',
bwd_dw_callables=None,
extra_args={},
)#

Bases: megatron.core.pipeline_parallel.utils.ScheduleNode

Base class for transformer layer computation nodes.

This class provides common functionality for different types of transformer layer nodes (attention, MLP, etc.)

Initialization

Initialize a transformer layer node.

Parameters:
  • stream (torch.cuda.Stream) – CUDA stream for execution

  • event (torch.cuda.Event) – Synchronization event

  • layer_state (TransformerLayerState) – State shared within a layer

  • chunk_state (TransformerChunkState) – State shared within a chunk

  • submodule (function) – The submodule contain forward and dw function

  • per_batch_state_context (it's the)

  • nullcontext (o.w.)

  • name (str) – Node name, also used to determine memory strategy

  • bwd_dw_callables (list) – List of weight gradient functions for the layer.

  • extra_args (dict) – Extra arguments for the node: is_moe, config.

detach(t)#

Detaches a tensor and stores it for backward computation.

forward_impl(*args)#

Calls the submodule as the forward pass.

backward_impl(outputs, output_grad)#

Implements the backward pass for the transformer layer node.

backward_dw()#

Computes the weight gradients for the transformer layer node.

__del__()#
class core.models.gpt.fine_grained_callables._BackwardDWWrapper(layer)#

Wrapper for managing backward weight gradient computation of attn module.

This class handles the execution of weight gradient computations for transformer layers, coordinating between CUDA graphed and non-graphed components. It is used when overlap_moe_expert_parallel_comm and delay_wgrad_compute are enabled to manage the delayed weight gradient computation in MoE models.

The wrapper stores references to the attention and shared expert backward weight gradient callables, and determines which components should be executed based on whether CUDA graphs are being replayed and which scopes are covered by the graphs.

Initialization

backward_dw()#

Execute weight gradients, skipping CUDA graphed components during replay.

set_graphed_backward_dw_callable(graphed_backward_dw_callable)#

Store the CUDA graphed backward weight gradient callable.

core.models.gpt.fine_grained_callables.build_transformer_layer_callables(
layer: megatron.core.transformer.transformer_layer.TransformerLayer,
)#

Create callables for transformer layer nodes. Divides the transformer layer’s operations into a sequence of smaller, independent functions. This decomposition separates computation-heavy tasks (e.g., self-attention, MLP) from communication-heavy tasks (e.g., MoE’s All-to-All).

The five callables are:

  1. Attention (computation)

  2. Post-Attention (computation)

  3. MoE Dispatch (communication)

  4. MLP / MoE Experts (computation)

  5. MoE Combine (communication)

By assigning these functions to different CUDA streams (e.g., a compute stream and a communication stream), the scheduler can overlap their execution, preventing tasks from competing for resources and hiding communication latency by running them in parallel with functions from other micro-batches.

Parameters:

layer – The transformer layer to build callables for.

Returns:

  • forward_funcs: List of callable functions for the layer

  • backward_dw: Dict of weight gradient functions for the layer

Return type:

A tuple containing

core.models.gpt.fine_grained_callables.build_mtp_layer_callables(layer)#

Callables for multi-token prediction layer nodes.

This class contains the callable functions for different types of multi-token prediction layer nodes (attention, MLP, etc.)

core.models.gpt.fine_grained_callables.build_layer_callables(layer)#

Builds the callable functions(forward and dw) for the given layer. For now, 1f1b overlap only support TransformerLayer and MultiTokenPredictionLayer.

Parameters:

layer – The layer to build callables for.

Returns:

list of callable functions for the layer. backward_dw: dict of weight gradient functions for the layer.

Return type:

forward_funcs