core.transformer.cuda_graphs#

Module Contents#

Classes#

CudagraphBufferMetadata

Metadata saved to tensors during cudagraph capture. This data will be used to determine during graph captue when a cudagraph can reuse a buffer or directly write its output into a subsequent’s graph’s input.

ArgMetadata

Arg meta.

TensorReusePool

A pool-like list of tensors that can be reused as input and output buffers during graph capture. Also maintains strong references to all tensors created by this pool, so that they will never be freed by the memory allocator.

_CudagraphGlobalRecord

A global datastructure that records of the ordering of all _CudaGraphRunner’s first fwd or bwd passes. ‘create_cudagraphs’ will use this to create cudagraphs in execution order, which is required for cudagraphs sharing a mempool.

_GraphStatus

An Enum to track if a cudagraph is ready to perform a forward or backward pass.

_CudagraphRecordNode

Inserts a noop node into the autograd graph, used to record when a bwd graph needs to be created.

_CudagraphReplayNode

Replays the runner’s cudagraphs with autograd. Handles copying data into/out of the cudagraph io and fp8/fp4 if used.

_CudaGraphRunner

Represents the execution of a cudagraphed module for a single microbatch. If there are multiple outstanding microbatches per module, such as for pipeline parallelism, CudaGraphManager automatically creates multiple _CudaGraphRunners per module.

CudaGraphManager

Creates and runs cudagraphs for a megatron module

TECudaGraphHelper

Helper class to capture CUDA Graphs using TE make_graphed_callables(). It is used in the beginning of the training loop to capture per-layer CUDA Graphs. self.create_cudagraphs() should be called to capture the CUDA Graphs and self.cuda_graph_set_manual_hooks() should be called to set manual pre-forward hooks for the parameters that are covered by cudagraphs.

VisionTECudaGraphHelper

Helper to capture CUDA Graphs for vision encoder layers using TE.

Functions#

is_graph_capturing

Query if currently capturing.

_set_capture_start

Set graph capture has started.

_set_capture_end

Set graph capture has ended.

is_graph_warmup

Query if currently warming up for graph capture.

_set_warmup_start

Set graph warmup has started.

_set_warmup_end

Set graph warmup has ended.

tree_map

Wrapper around pytorch’s tree_map, but also recurses into dataclasses.

_check_supported_type

Check if arg meta is a supported type for cudagraph input/outputs.

_determine_if_first_last_layer_of_this_vp_chunk

Determine if the given module is the first/last layer of the PP+VPP chunk it belongs to. Returns a tuple of two booleans indicating if the module is the first/last layer of the chunk.

_clone_nested_tensors

Recursively clone tensors inside nested containers.

_ensure_generator_state_is_cudagraph_safe

Make generator state safe for CUDA graph capture/replay.

create_cudagraphs

Should be called at the end of each schedule function, (e.g. forward_backward_pipelining_with_interleaving) in megatron.core.pipeline_parallel.schedules.py. During the first step, _CudaGraphRunners populate _CudagraphGlobalRecord with the global order in which cudagraphs should be created. At the end for the first step, this function calls each runner’s create_fwd_graph and create_bwd_graph in the order recorded in _CudagraphGlobalRecord, which allows cudagraphs to be created in execution order, which allows multiple cudagraphs to share a single memory pool, minimizing cudagraph memory usage.

delete_cuda_graphs

Delete all CUDA graphs.

_layer_is_graphable

Check if a layer is graphable.

convert_schedule_table_to_order

Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below: virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9 microbatch_id | 0 1 2 0 1 2 3 4 3 4 model_chunk_id | 0 0 0 1 1 1 0 0 1 1

get_overlap_moe_expert_parallel_comm_order

This functions gets the order for overlap_moe_expert_parallel_comm schedule for the original chunk-wise order list. Each chunk is transformered to chunks with only 1 layer so that layers between 2 chunks can now overlap with each other while following the graph order. If capture_wgrad_graph is True, the wgrad backward graph is also added to the order by decreasing the layer id by 0.5.

set_current_microbatch

Set the current microbatch on all layers that use TE CUDA graph replay.

_wrap_graph_for_vision

Wrap a graphed callable to filter out None outputs.

get_vision_cuda_graph_seq_length

Calculate the sequence length for vision encoder CUDA graphs.

Data#

API#

core.transformer.cuda_graphs._IS_GRAPH_CAPTURING#

False

core.transformer.cuda_graphs._IS_GRAPH_WARMUP#

False

core.transformer.cuda_graphs.logger#

‘getLogger(…)’

core.transformer.cuda_graphs.FREEZE_GC#

None

core.transformer.cuda_graphs.is_graph_capturing()#

Query if currently capturing.

core.transformer.cuda_graphs._set_capture_start()#

Set graph capture has started.

core.transformer.cuda_graphs._set_capture_end()#

Set graph capture has ended.

core.transformer.cuda_graphs.is_graph_warmup()#

Query if currently warming up for graph capture.

core.transformer.cuda_graphs._set_warmup_start()#

Set graph warmup has started.

core.transformer.cuda_graphs._set_warmup_end()#

Set graph warmup has ended.

class core.transformer.cuda_graphs.CudagraphBufferMetadata#

Metadata saved to tensors during cudagraph capture. This data will be used to determine during graph captue when a cudagraph can reuse a buffer or directly write its output into a subsequent’s graph’s input.

is_cudagraph_input: bool#

False

is_cudagraph_output: bool#

False

input_use_count: int#

0

cudagraph_reuse_ref_count: int#

0

capture_reuse_count: int#

0

fwd_cudagraph_buffer: torch.Tensor#

None

bwd_cudagraph_buffer: torch.Tensor#

None

class core.transformer.cuda_graphs.ArgMetadata(arg)#

Arg meta.

Initialization

zeros_like()#

Reconstruct a tensor with the properties as the meta arg.

class core.transformer.cuda_graphs.TensorReusePool#

A pool-like list of tensors that can be reused as input and output buffers during graph capture. Also maintains strong references to all tensors created by this pool, so that they will never be freed by the memory allocator.

tensor_strong_refs: list#

[]

Record the data_ptrs of buffers created by the pool to check when a tensor came was allocated from this pool.

tensor_strong_refs_dataptrs: set#

‘set(…)’

Buffers that have been returned to the pool and are available for reuse.

pool: list[torch.Tensor]#

[]

insert(tensor: torch.Tensor)#

Return a tensor to the pool reuse.

owns(tensor: torch.Tensor)#

Check if a tensor was created from this pool.

get(meta: core.transformer.cuda_graphs.ArgMetadata)#

Try to get a buffer from the pool. If a matching tensor is already in the pool, its assumed to be available and returned. Otherwise, allocate a new buffer.

core.transformer.cuda_graphs.tree_map(func, tree)#

Wrapper around pytorch’s tree_map, but also recurses into dataclasses.

core.transformer.cuda_graphs._check_supported_type(meta)#

Check if arg meta is a supported type for cudagraph input/outputs.

core.transformer.cuda_graphs._determine_if_first_last_layer_of_this_vp_chunk(base_module)#

Determine if the given module is the first/last layer of the PP+VPP chunk it belongs to. Returns a tuple of two booleans indicating if the module is the first/last layer of the chunk.

core.transformer.cuda_graphs._clone_nested_tensors(value: Any) Any#

Recursively clone tensors inside nested containers.

core.transformer.cuda_graphs._ensure_generator_state_is_cudagraph_safe(
gen: torch.Generator,
) torch.Generator#

Make generator state safe for CUDA graph capture/replay.

Generator state tensors can become inference tensors if created under torch.inference_mode(). CUDA graph capture may later attempt in-place updates on that state; this fails for inference tensors. Fix the generator in-place (preserving identity) by cloning its state outside inference mode and setting it back.

core.transformer.cuda_graphs.fwd_buffer_reuse_ref_count#

0

core.transformer.cuda_graphs.bwd_buffer_reuse_ref_count#

0

class core.transformer.cuda_graphs._CudagraphGlobalRecord#

A global datastructure that records of the ordering of all _CudaGraphRunner’s first fwd or bwd passes. ‘create_cudagraphs’ will use this to create cudagraphs in execution order, which is required for cudagraphs sharing a mempool.

cudagraph_created#

False

A record of fwd and bwd graph creation, populated with ‘record_fwd_graph’ and ‘record_bwd_graph.

cudagraph_record: list[tuple]#

[]

cudagraph_inference_record: list[tuple]#

[]

A pool-like data structure to reuse input and output buffers across cudagraph.

tensor_reuse_pool#

‘TensorReusePool(…)’

classmethod record_fwd_graph(runner, args, kwargs, out)#

Record a fwd graph to ‘cudagraph_record

classmethod record_bwd_graph(runner)#

Record a bwd graph to ‘cudagraph_record

classmethod create_cudagraphs()#

Iterate through ‘cudagraph_record’ creating graphs in the order in which they were recorded.

core.transformer.cuda_graphs.create_cudagraphs()#

Should be called at the end of each schedule function, (e.g. forward_backward_pipelining_with_interleaving) in megatron.core.pipeline_parallel.schedules.py. During the first step, _CudaGraphRunners populate _CudagraphGlobalRecord with the global order in which cudagraphs should be created. At the end for the first step, this function calls each runner’s create_fwd_graph and create_bwd_graph in the order recorded in _CudagraphGlobalRecord, which allows cudagraphs to be created in execution order, which allows multiple cudagraphs to share a single memory pool, minimizing cudagraph memory usage.

core.transformer.cuda_graphs.delete_cuda_graphs()#

Delete all CUDA graphs.

class core.transformer.cuda_graphs._GraphStatus(*args, **kwds)#

Bases: enum.Enum

An Enum to track if a cudagraph is ready to perform a forward or backward pass.

Initialization

FWD_READY#

0

BWD_READY#

1

class core.transformer.cuda_graphs._CudagraphRecordNode#

Bases: torch.autograd.Function

Inserts a noop node into the autograd graph, used to record when a bwd graph needs to be created.

static forward(ctx, runner, inputs)#

Forward pass, does nothing but registers an autograd node.

static backward(ctx, grads)#

If this is the first bwd pass of this runner, record that a bwd graph needs to be created.

class core.transformer.cuda_graphs._CudagraphReplayNode#

Bases: torch.autograd.Function

Replays the runner’s cudagraphs with autograd. Handles copying data into/out of the cudagraph io and fp8/fp4 if used.

static forward(ctx, runner, is_first_microbatch, *inputs)#

Replay the forward graph of the passed runner.

static backward(ctx, *grads)#

Replay the backward graph of the passed runner.

class core.transformer.cuda_graphs._CudaGraphRunner(
base_module: megatron.core.transformer.module.MegatronModule,
mempool: int,
fwd_graph_input_args: List[Any],
fwd_graph_input_kwargs: Dict[str, Any],
func,
need_backward,
)#

Bases: torch.nn.Module

Represents the execution of a cudagraphed module for a single microbatch. If there are multiple outstanding microbatches per module, such as for pipeline parallelism, CudaGraphManager automatically creates multiple _CudaGraphRunners per module.

Initialization

Creates a _CudaGraphRunner, which holds a single pair of fwd and bwd cudagraphs, which are not created until this runner records its graph creation into ‘_CudagraphGlobalRecord’, and ‘create_cudagraphs()’ is called.

__str__()#
get_quantization_context()#

Return appropriate quantization context (FP8 or FP4) in cudagraph mode.

get_connected_params(outputs)#

Iterate through the autograd graph of ‘outputs’ and returns all parameters connected. In theory this should return all parameters that return a nonzero wgrad when computing the backward pass of ‘outputs’.

create_fwd_graph(args, kwargs, outputs=None, clone_inputs=True)#

Create a fwd cudagraph for this runner. Should be called inside ‘create_cudagraphs()’.

create_bwd_graph()#

Create a bwd cudagraph for this runner. Should be called inside ‘create_cudagraphs()’.

apply_cudagraph_record_metadata(args, kwargs, outputs)#

Attaches graph capture metadata to all passed in tensors.

record_graph_capture(args, kwargs)#

Records the data needed to create this runner’s forward cudagraph. The first pass records a graph and appends the runner to _CudagraphGlobalRecord. The actual cudagraph will be created when ‘create_cudagraphs()` is called. Subsequent passes should replay the graph.

replay_graph_capture(is_first_microbatch, args, kwargs)#

Replay the fwd cuda graph with autograd.

get_mismatch_errors(args, kwargs)#

Return list of detailed errors for mismatched cudagraph args.

get_arg_metas(args, kwargs=None)#

Replaces all passed in tensors with ‘ArgMetadata’ and returns them as a list.

get_tensors(args, kwargs=None, check_types=True)#

Filter and flatten all tensors from args and kwargs using list comprehensions and itertools.chain for faster flattening.

to_list(x)#

Helper function to wrap an input into a list

class core.transformer.cuda_graphs.CudaGraphManager(
config: megatron.core.transformer.transformer_config.TransformerConfig,
base_module=None,
function_name=None,
need_backward=True,
)#

Bases: torch.nn.Module

Creates and runs cudagraphs for a megatron module

Initialization

global_mempool#

None

call_ddp_preforward_hook(module)#

Call any DDP pre-forward hooks which are used to launch async data parallel param gather. Any other pre-forward hooks are not allowed.

get_cudagraph_runner(megatron_module, args, kwargs, reuse_cudagraphs)#

Returns a valid cudagraph runner for the current forward call. The cudagraph corresponding to this call is the first element of ‘self.cudagraph_runners’. We iterate through the list by 1 for each call, and the number of calls is equal to the length of ‘self.cudagraph_runners’. Otherwise, we assign a mempool per microbatch, which allows cudagraphs to be reused over different microbatches by tracking their respective fwd and bwd passes.

__call__(megatron_module, args, kwargs)#

Calls the forward pass of the cudagraphed module.

Parameters:
  • megatron_module (torch.nn.module) – The megatron module to be graphed and run

  • args (tuple) – The positional args to be passed to the module.

  • kwargs (dict) – The keyword args to be passed to the module.

core.transformer.cuda_graphs._layer_is_graphable(layer, config)#

Check if a layer is graphable.

class core.transformer.cuda_graphs.TECudaGraphHelper(
model,
config,
seq_length,
micro_batch_size,
optimizers=[],
)#

Helper class to capture CUDA Graphs using TE make_graphed_callables(). It is used in the beginning of the training loop to capture per-layer CUDA Graphs. self.create_cudagraphs() should be called to capture the CUDA Graphs and self.cuda_graph_set_manual_hooks() should be called to set manual pre-forward hooks for the parameters that are covered by cudagraphs.

Initialization

_discover_layers()#

Discover captureable layers from the model and populate internal data structures.

graphs_created()#

Returns whether the CUDA Graphs have been created.

_get_sample_arguments(order, chunk_id_list=None)#

Generate sample arguments and keyword arguments for CUDA Graph capturing with memory-optimized buffer reuse.

This method creates static input tensors for each (layer, microbatch) pair needed by TE’s make_graphed_callables(). It optimizes memory usage by reusing input buffers across non-overlapping forward passes based on the pipeline parallel schedule. This optimization is essential for reducing peak memory during CUDA Graph capturing with many microbatches, as it allows buffers to be reused instead of allocating new ones for later microbatches.

Memory Optimization Strategy: The 1F1B (one-forward-one-backward) interleaved schedule in pipeline parallelism means that once a microbatch’s backward pass completes, its input buffers are no longer needed. This method tracks buffer lifecycle and reuses “consumed” buffers (those whose backward has completed) for new forward passes with matching tensor signatures (shape, dtype, layout).

Example schedule: [1, 1, 1, 2, 2, 2, -2, 1, -2, 1, -2, 2, -1, 2, -1, -1, -2, -2, -1, -1]
- Positive values indicate forward passes (chunk_id = value)
- Negative values indicate backward passes (chunk_id = -value)
- When processing -2 (backward of chunk 2), its buffers become available for reuse
- The next forward with matching signature can reuse those buffers
Parameters:
  • order (List[int]) – The forward/backward execution order from convert_schedule_table_to_order(). Positive integers represent forward passes (1-indexed chunk ID), negative integers represent backward passes.

  • chunk_id_list (List[Tuple[int, int]]) – The list of chunk IDs and layer IDs in the order. This is useful only when overlap_moe_expert_parallel_comm is enabled, the order maps each layers’ idx to their original chunk id.

Returns:

A tuple containing: - sample_args: List of positional argument tuples for each (layer, microbatch). Length = num_layers * num_microbatches. Elements with the same tensor signature may share references to reduce memory allocation. - sample_kwargs: List of keyword argument dicts for each (layer, microbatch). Length = num_layers * num_microbatches. Elements with the same tensor signature may share references to reduce memory allocation.

Return type:

Tuple[List[Tuple], List[Dict]]

Data Structures: - fwd_sample_queues: Dict[chunk_id, List[Tuple[sample_keys, fwd_idx]]] Queue of forward samples per chunk awaiting their backward pass. - consumed_sample_queue: Dict[sample_keys, List[fwd_idx]] Pool of buffer indices whose backward is complete, keyed by tensor signature. - sample_keys: Tuple of (shape, dtype, layout) for args + (key, shape, dtype, layout) for kwargs, used to match compatible buffers for reuse.

_get_cuda_graph_input_data()#

Create the CUDA Graph capturing input data. The data is organized per-chunk per-microbatch per-layer.

_start_capturing()#

Start capturing CUDA Graphs.

_finish_capturing(start_time)#

Finish capturing CUDA Graphs and clean up the related state.

create_cudagraphs()#

Capture CUDA Graphs per TransformerLayer per microbatch.

cuda_graph_set_manual_hooks()#

Set CUDA Graph manual hooks for the modules that contain direct parameters and are covered by cudagraphs.

delete_cuda_graphs()#

Delete all CUDA graphs.

core.transformer.cuda_graphs.convert_schedule_table_to_order(
num_warmup_microbatches,
num_model_chunks,
schedule_table,
)#

Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below: virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9 microbatch_id | 0 1 2 0 1 2 3 4 3 4 model_chunk_id | 0 0 0 1 1 1 0 0 1 1

Then the forward backward separated order is: forward | 1 1 1 2 2 2 1 1 2 2 backward | -2 -2 -2 -1 -1 -1 -2 -2 -1 -1

If num_warmup_microbatches is 5, the output order is: 1 1 1 2 2 2 -2 1 -2 1 -2 2 -1 2 -1 -1 -2 -2 -1 -1

core.transformer.cuda_graphs.get_overlap_moe_expert_parallel_comm_order(
order,
num_layers_per_chunk,
capture_wgrad_graph,
)#

This functions gets the order for overlap_moe_expert_parallel_comm schedule for the original chunk-wise order list. Each chunk is transformered to chunks with only 1 layer so that layers between 2 chunks can now overlap with each other while following the graph order. If capture_wgrad_graph is True, the wgrad backward graph is also added to the order by decreasing the layer id by 0.5.

Parameters:
  • order (List[int]) – The original chunk-wise order list. Positive values represent forward passes for chunks, negative values represent backward passes. The absolute value indicates the chunk ID (1-indexed).

  • num_layers_per_chunk (List[int]) – Number of graphable layers in each chunk. The length of this list equals the number of chunks.

  • capture_wgrad_graph (bool) – If True, weight gradient computation graphs are added to the order by appending entries with layer_id - 0.5.

Returns:

A tuple containing: - new_order: The layer-wise order list where each chunk is expanded to individual layers. Positive values are forward passes, negative values are backward passes. Values with .5 suffix indicate weight gradient computations. - chunk_id_list: A list parallel to new_order. For forward passes, contains [chunk_id, layer_index_within_chunk]. For backward passes, contains None.

Return type:

Tuple[List[float], List[Optional[List[int]]]]

.. rubric:: Example

original_order: [1, 2, -2, 1, -1, -1] num_layers_per_chunk: [1, 2] capture_wgrad_graph=True: new_order: [1, 2, 3, 1, -3, -3.5, -2, -2.5, -1, -1.5, -1, -1.5] chunk_id_list: [[0, 0], [1, 0], [1, 1], [0, 0], None, None, None, None, None, None, None, None] capture_wgrad_graph=False: new_order: [1, 2, 3, 1, -3, -2, -1, -1] chunk_id_list: [[0, 0], [1, 0], [1, 1], [0, 0], None, None, None, None]

core.transformer.cuda_graphs.set_current_microbatch(model, microbatch_id)#

Set the current microbatch on all layers that use TE CUDA graph replay.

current_microbatch is read by _te_cuda_graph_replay to select the correct graph index. This helper is called from the pipeline-parallel schedule before each forward step.

core.transformer.cuda_graphs._wrap_graph_for_vision(graph_fn)#

Wrap a graphed callable to filter out None outputs.

During make_graphed_callables warmup, vision encoder layers go through their normal forward() path which returns (output, context=None). _te_cuda_graph_replay asserts len(output) == 1 but gets 2 elements. This wrapper filters out None values so replay sees (output,) instead of (output, None).

core.transformer.cuda_graphs.get_vision_cuda_graph_seq_length(
vision_config,
default_seq_length: int = 4096,
) int#

Calculate the sequence length for vision encoder CUDA graphs.

For vision encoders, the sequence length depends on:

  • max_vision_cuda_graph_seq_length: explicit maximum (if set)

  • num_position_embeddings: maximum number of patches

  • spatial_merge_size: pooling factor that reduces sequence length

Parameters:
  • vision_config – The TransformerConfig for vision encoder

  • default_seq_length – Default sequence length if cannot be calculated

Returns:

The sequence length to use for CUDA graph capture

class core.transformer.cuda_graphs.VisionTECudaGraphHelper(
model,
vision_config,
vision_seq_length: int,
micro_batch_size: int,
num_microbatches: int = 1,
)#

Bases: core.transformer.cuda_graphs.TECudaGraphHelper

Helper to capture CUDA Graphs for vision encoder layers using TE.

Inherits from TECudaGraphHelper and overrides only the vision-specific behaviour:

  • Layer discovery finds vision_model.decoder.layers instead of the language decoder layers.

  • num_model_chunks is always 1 (vision has no virtual pipeline stages).

  • Batch dimension is always 1 (images are concatenated along the sequence dimension).

  • Sample argument generation uses a simple loop (no rotary embeddings or buffer-reuse optimization).

  • Captured graph outputs are wrapped to filter None values that arise from vision encoder layers returning (output, None).

Parameters:
  • model – The full model (list of model chunks) containing vision_model.

  • vision_config – TransformerConfig for the vision encoder.

  • vision_seq_length – Sequence length for vision (max vision tokens).

  • micro_batch_size – Micro-batch size (unused for sample-arg generation since the vision encoder always uses batch-dim = 1).

  • num_microbatches – Number of microbatches per step.

Initialization

_discover_layers()#

Discover captureable layers from the vision encoder.

_start_capturing()#

Start capturing for vision encoder.

Unlike the parent, this skips torch.distributed.barrier() because with PP > 1 only the first pipeline stage has vision layers — other ranks return early from create_cudagraphs and never reach this point, so a barrier would deadlock.

_finish_capturing(start_time)#

Finish capturing for vision encoder.

Unlike the parent, this skips:

  • torch.distributed.barrier() (asymmetric: only first PP stage captures).

  • model_chunk.zero_grad_buffer() / optimizer.zero_grad() (handled by the LM decoder helper’s _finish_capturing which runs on all ranks).

  • clear_aux_losses_tracker / reset_model_temporary_tensors (LM-specific cleanup already handled by the LM helper).

_get_sample_arguments(order, chunk_id_list=None)#

Generate sample arguments for vision encoder CUDA Graph capturing.

Vision uses a simple per-layer-per-microbatch loop with batch_dim=1 and no rotary embeddings (unlike the parent’s buffer-reuse optimization). The order and chunk_id_list arguments are unused because vision has num_model_chunks=1 and does not need the pipeline-schedule-aware buffer lifecycle tracking.

Returns:

Tuple of (sample_args, sample_kwargs) lists for each (layer, microbatch) pair.

create_cudagraphs()#

Capture CUDA Graphs for vision encoder layers per microbatch.

Delegates to the parent’s capture workflow, then wraps the captured graphs with _wrap_graph_for_vision to filter None from (output, None) tuples so that _te_cuda_graph_replay’s len == 1 assertion passes.

cuda_graph_set_manual_hooks()#

No-op: vision encoder layers do not use DDP parameter-gather hooks.

The parent derives hooks from model_chunk._make_forward_pre_hook which requires overlap_param_gather=True. Vision encoder parameters are not distributed with the same overlap strategy, so we skip hook setup.