core.transformer.cuda_graphs#

Module Contents#

Classes#

CudagraphBufferMetadata

Metadata saved to tensors during cudagraph capture. This data will be used to determine during graph captue when a cudagraph can reuse a buffer or directly write its output into a subsequent’s graph’s input.

ArgMetadata

Arg meta.

TensorReusePool

A pool-like list of tensors that can be reused as input and output buffers during graph capture. Also maintains strong references to all tensors created by this pool, so that they will never be freed by the memory allocator.

_CudagraphGlobalRecord

A global datastructure that records of the ordering of all _CudaGraphRunner’s first fwd or bwd passes. ‘create_cudagraphs’ will use this to create cudagraphs in execution order, which is required for cudagraphs sharing a mempool.

_GraphStatus

An Enum to track if a cudagraph is ready to perform a forward or backward pass.

_CudagraphRecordNode

Inserts a noop node into the autograd graph, used to record when a bwd graph needs to be created.

_CudagraphReplayNode

Replays the runner’s cudagraphs with autograd. Handles copying data into/out of the cudagraph io and fp8/fp4 if used.

_CudaGraphRunner

Represents the execution of a cudagraphed module for a single microbatch. If there are multiple outstanding microbatches per module, such as for pipeline parallelism, CudaGraphManager automatically creates multiple _CudaGraphRunners per module.

CudaGraphManager

Creates and runs cudagraphs for a megatron module

TECudaGraphHelper

Helper class to capture CUDA Graphs using TE make_graphed_callables(). It is used in the beginning of the training loop to capture per-layer CUDA Graphs. self.create_cudagraphs() should be called to capture the CUDA Graphs and self.cuda_graph_set_manual_hooks() should be called to set manual pre-forward hooks for the parameters that are covered by cudagraphs.

Functions#

is_graph_capturing

Query if currently capturing.

_set_capture_start

Set graph capture has started.

_set_capture_end

Set graph capture has ended.

is_graph_warmup

Query if currently warming up for graph capture.

_set_warmup_start

Set graph warmup has started.

_set_warmup_end

Set graph warmup has ended.

tree_map

Wrapper around pytorch’s tree_map, but also recurses into dataclasses.

_check_supported_type

Check if arg meta is a supported type for cudagraph input/outputs.

_determine_if_first_last_layer_of_this_vp_chunk

Determine if the given module is the first/last layer of the PP+VPP chunk it belongs to. Returns a tuple of two booleans indicating if the module is the first/last layer of the chunk.

_clone_nested_tensors

Recursively clone tensors inside nested containers.

_ensure_generator_state_is_cudagraph_safe

Make generator state safe for CUDA graph capture/replay.

create_cudagraphs

Should be called at the end of each schedule function, (e.g. forward_backward_pipelining_with_interleaving) in megatron.core.pipeline_parallel.schedules.py. During the first step, _CudaGraphRunners populate _CudagraphGlobalRecord with the global order in which cudagraphs should be created. At the end for the first step, this function calls each runner’s create_fwd_graph and create_bwd_graph in the order recorded in _CudagraphGlobalRecord, which allows cudagraphs to be created in execution order, which allows multiple cudagraphs to share a single memory pool, minimizing cudagraph memory usage.

delete_cuda_graphs

Delete all CUDA graphs.

_layer_is_graphable

Check if a layer is graphable.

convert_schedule_table_to_order

Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below: virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9 microbatch_id | 0 1 2 0 1 2 3 4 3 4 model_chunk_id | 0 0 0 1 1 1 0 0 1 1

get_overlap_moe_expert_parallel_comm_order

This functions gets the order for overlap_moe_expert_parallel_comm schedule for the original chunk-wise order list. Each chunk is transformered to chunks with only 1 layer so that layers between 2 chunks can now overlap with each other while following the graph order. If capture_wgrad_graph is True, the wgrad backward graph is also added to the order by decreasing the layer id by 0.5.

Data#

API#

core.transformer.cuda_graphs._IS_GRAPH_CAPTURING#

False

core.transformer.cuda_graphs._IS_GRAPH_WARMUP#

False

core.transformer.cuda_graphs.logger#

‘getLogger(…)’

core.transformer.cuda_graphs.FREEZE_GC#

None

core.transformer.cuda_graphs.is_graph_capturing()#

Query if currently capturing.

core.transformer.cuda_graphs._set_capture_start()#

Set graph capture has started.

core.transformer.cuda_graphs._set_capture_end()#

Set graph capture has ended.

core.transformer.cuda_graphs.is_graph_warmup()#

Query if currently warming up for graph capture.

core.transformer.cuda_graphs._set_warmup_start()#

Set graph warmup has started.

core.transformer.cuda_graphs._set_warmup_end()#

Set graph warmup has ended.

class core.transformer.cuda_graphs.CudagraphBufferMetadata#

Metadata saved to tensors during cudagraph capture. This data will be used to determine during graph captue when a cudagraph can reuse a buffer or directly write its output into a subsequent’s graph’s input.

is_cudagraph_input: bool#

False

is_cudagraph_output: bool#

False

input_use_count: int#

0

cudagraph_reuse_ref_count: int#

0

capture_reuse_count: int#

0

fwd_cudagraph_buffer: torch.Tensor#

None

bwd_cudagraph_buffer: torch.Tensor#

None

class core.transformer.cuda_graphs.ArgMetadata(arg)#

Arg meta.

Initialization

zeros_like()#

Reconstruct a tensor with the properties as the meta arg.

class core.transformer.cuda_graphs.TensorReusePool#

A pool-like list of tensors that can be reused as input and output buffers during graph capture. Also maintains strong references to all tensors created by this pool, so that they will never be freed by the memory allocator.

tensor_strong_refs: list#

[]

Record the data_ptrs of buffers created by the pool to check when a tensor came was allocated from this pool.

tensor_strong_refs_dataptrs: set#

‘set(…)’

Buffers that have been returned to the pool and are available for reuse.

pool: list[torch.Tensor]#

[]

insert(tensor: torch.Tensor)#

Return a tensor to the pool reuse.

owns(tensor: torch.Tensor)#

Check if a tensor was created from this pool.

get(meta: core.transformer.cuda_graphs.ArgMetadata)#

Try to get a buffer from the pool. If a matching tensor is already in the pool, its assumed to be available and returned. Otherwise, allocate a new buffer.

core.transformer.cuda_graphs.tree_map(func, tree)#

Wrapper around pytorch’s tree_map, but also recurses into dataclasses.

core.transformer.cuda_graphs._check_supported_type(meta)#

Check if arg meta is a supported type for cudagraph input/outputs.

core.transformer.cuda_graphs._determine_if_first_last_layer_of_this_vp_chunk(base_module)#

Determine if the given module is the first/last layer of the PP+VPP chunk it belongs to. Returns a tuple of two booleans indicating if the module is the first/last layer of the chunk.

core.transformer.cuda_graphs._clone_nested_tensors(value: Any) Any#

Recursively clone tensors inside nested containers.

core.transformer.cuda_graphs._ensure_generator_state_is_cudagraph_safe(
gen: torch.Generator,
) torch.Generator#

Make generator state safe for CUDA graph capture/replay.

Generator state tensors can become inference tensors if created under torch.inference_mode(). CUDA graph capture may later attempt in-place updates on that state; this fails for inference tensors. Fix the generator in-place (preserving identity) by cloning its state outside inference mode and setting it back.

core.transformer.cuda_graphs.fwd_buffer_reuse_ref_count#

0

core.transformer.cuda_graphs.bwd_buffer_reuse_ref_count#

0

class core.transformer.cuda_graphs._CudagraphGlobalRecord#

A global datastructure that records of the ordering of all _CudaGraphRunner’s first fwd or bwd passes. ‘create_cudagraphs’ will use this to create cudagraphs in execution order, which is required for cudagraphs sharing a mempool.

cudagraph_created#

False

A record of fwd and bwd graph creation, populated with ‘record_fwd_graph’ and ‘record_bwd_graph.

cudagraph_record: list[tuple]#

[]

cudagraph_inference_record: list[tuple]#

[]

A pool-like data structure to reuse input and output buffers across cudagraph.

tensor_reuse_pool#

‘TensorReusePool(…)’

classmethod record_fwd_graph(runner, args, kwargs, out)#

Record a fwd graph to ‘cudagraph_record

classmethod record_bwd_graph(runner)#

Record a bwd graph to ‘cudagraph_record

classmethod create_cudagraphs()#

Iterate through ‘cudagraph_record’ creating graphs in the order in which they were recorded.

core.transformer.cuda_graphs.create_cudagraphs()#

Should be called at the end of each schedule function, (e.g. forward_backward_pipelining_with_interleaving) in megatron.core.pipeline_parallel.schedules.py. During the first step, _CudaGraphRunners populate _CudagraphGlobalRecord with the global order in which cudagraphs should be created. At the end for the first step, this function calls each runner’s create_fwd_graph and create_bwd_graph in the order recorded in _CudagraphGlobalRecord, which allows cudagraphs to be created in execution order, which allows multiple cudagraphs to share a single memory pool, minimizing cudagraph memory usage.

core.transformer.cuda_graphs.delete_cuda_graphs()#

Delete all CUDA graphs.

class core.transformer.cuda_graphs._GraphStatus(*args, **kwds)#

Bases: enum.Enum

An Enum to track if a cudagraph is ready to perform a forward or backward pass.

Initialization

FWD_READY#

0

BWD_READY#

1

class core.transformer.cuda_graphs._CudagraphRecordNode#

Bases: torch.autograd.Function

Inserts a noop node into the autograd graph, used to record when a bwd graph needs to be created.

static forward(ctx, runner, inputs)#

Forward pass, does nothing but registers an autograd node.

static backward(ctx, grads)#

If this is the first bwd pass of this runner, record that a bwd graph needs to be created.

class core.transformer.cuda_graphs._CudagraphReplayNode#

Bases: torch.autograd.Function

Replays the runner’s cudagraphs with autograd. Handles copying data into/out of the cudagraph io and fp8/fp4 if used.

static forward(ctx, runner, is_first_microbatch, *inputs)#

Replay the forward graph of the passed runner.

static backward(ctx, *grads)#

Replay the backward graph of the passed runner.

class core.transformer.cuda_graphs._CudaGraphRunner(
base_module: megatron.core.transformer.module.MegatronModule,
mempool: int,
fwd_graph_input_args: List[Any],
fwd_graph_input_kwargs: Dict[str, Any],
func,
need_backward,
)#

Bases: torch.nn.Module

Represents the execution of a cudagraphed module for a single microbatch. If there are multiple outstanding microbatches per module, such as for pipeline parallelism, CudaGraphManager automatically creates multiple _CudaGraphRunners per module.

Initialization

Creates a _CudaGraphRunner, which holds a single pair of fwd and bwd cudagraphs, which are not created until this runner records its graph creation into ‘_CudagraphGlobalRecord’, and ‘create_cudagraphs()’ is called.

__str__()#
get_quantization_context()#

Return appropriate quantization context (FP8 or FP4) in cudagraph mode.

get_connected_params(outputs)#

Iterate through the autograd graph of ‘outputs’ and returns all parameters connected. In theory this should return all parameters that return a nonzero wgrad when computing the backward pass of ‘outputs’.

create_fwd_graph(args, kwargs, outputs=None, clone_inputs=True)#

Create a fwd cudagraph for this runner. Should be called inside ‘create_cudagraphs()’.

create_bwd_graph()#

Create a bwd cudagraph for this runner. Should be called inside ‘create_cudagraphs()’.

apply_cudagraph_record_metadata(args, kwargs, outputs)#

Attaches graph capture metadata to all passed in tensors.

record_graph_capture(args, kwargs)#

Records the data needed to create this runner’s forward cudagraph. The first pass records a graph and appends the runner to _CudagraphGlobalRecord. The actual cudagraph will be created when ‘create_cudagraphs()` is called. Subsequent passes should replay the graph.

replay_graph_capture(is_first_microbatch, args, kwargs)#

Replay the fwd cuda graph with autograd.

get_mismatch_errors(args, kwargs)#

Return list of detailed errors for mismatched cudagraph args.

get_arg_metas(args, kwargs=None)#

Replaces all passed in tensors with ‘ArgMetadata’ and returns them as a list.

get_tensors(args, kwargs=None, check_types=True)#

Filter and flatten all tensors from args and kwargs using list comprehensions and itertools.chain for faster flattening.

to_list(x)#

Helper function to wrap an input into a list

class core.transformer.cuda_graphs.CudaGraphManager(
config: megatron.core.transformer.transformer_config.TransformerConfig,
base_module=None,
function_name=None,
need_backward=True,
)#

Bases: torch.nn.Module

Creates and runs cudagraphs for a megatron module

Initialization

global_mempool#

None

call_ddp_preforward_hook(module)#

Call any DDP pre-forward hooks which are used to launch async data parallel param gather. Any other pre-forward hooks are not allowed.

get_cudagraph_runner(megatron_module, args, kwargs, reuse_cudagraphs)#

Returns a valid cudagraph runner for the current forward call. The cudagraph corresponding to this call is the first element of ‘self.cudagraph_runners’. We iterate through the list by 1 for each call, and the number of calls is equal to the length of ‘self.cudagraph_runners’. Otherwise, we assign a mempool per microbatch, which allows cudagraphs to be reused over different microbatches by tracking their respective fwd and bwd passes.

__call__(megatron_module, args, kwargs)#

Calls the forward pass of the cudagraphed module.

Parameters:
  • megatron_module (torch.nn.module) – The megatron module to be graphed and run

  • args (tuple) – The positional args to be passed to the module.

  • kwargs (dict) – The keyword args to be passed to the module.

core.transformer.cuda_graphs._layer_is_graphable(layer, config)#

Check if a layer is graphable.

class core.transformer.cuda_graphs.TECudaGraphHelper(
model,
config,
seq_length,
micro_batch_size,
optimizers=[],
)#

Helper class to capture CUDA Graphs using TE make_graphed_callables(). It is used in the beginning of the training loop to capture per-layer CUDA Graphs. self.create_cudagraphs() should be called to capture the CUDA Graphs and self.cuda_graph_set_manual_hooks() should be called to set manual pre-forward hooks for the parameters that are covered by cudagraphs.

Initialization

graphs_created()#

Returns whether the CUDA Graphs have been created.

_get_sample_arguments(order, chunk_id_list=None)#

Generate sample arguments and keyword arguments for CUDA Graph capturing with memory-optimized buffer reuse.

This method creates static input tensors for each (layer, microbatch) pair needed by TE’s make_graphed_callables(). It optimizes memory usage by reusing input buffers across non-overlapping forward passes based on the pipeline parallel schedule. This optimization is essential for reducing peak memory during CUDA Graph capturing with many microbatches, as it allows buffers to be reused instead of allocating new ones for later microbatches.

Memory Optimization Strategy: The 1F1B (one-forward-one-backward) interleaved schedule in pipeline parallelism means that once a microbatch’s backward pass completes, its input buffers are no longer needed. This method tracks buffer lifecycle and reuses “consumed” buffers (those whose backward has completed) for new forward passes with matching tensor signatures (shape, dtype, layout).

Example schedule: [1, 1, 1, 2, 2, 2, -2, 1, -2, 1, -2, 2, -1, 2, -1, -1, -2, -2, -1, -1]
- Positive values indicate forward passes (chunk_id = value)
- Negative values indicate backward passes (chunk_id = -value)
- When processing -2 (backward of chunk 2), its buffers become available for reuse
- The next forward with matching signature can reuse those buffers
Parameters:
  • order (List[int]) – The forward/backward execution order from convert_schedule_table_to_order(). Positive integers represent forward passes (1-indexed chunk ID), negative integers represent backward passes.

  • chunk_id_list (List[Tuple[int, int]]) – The list of chunk IDs and layer IDs in the order. This is useful only when overlap_moe_expert_parallel_comm is enabled, the order maps each layers’ idx to their original chunk id.

Returns:

A tuple containing: - sample_args: List of positional argument tuples for each (layer, microbatch). Length = num_layers * num_microbatches. Elements with the same tensor signature may share references to reduce memory allocation. - sample_kwargs: List of keyword argument dicts for each (layer, microbatch). Length = num_layers * num_microbatches. Elements with the same tensor signature may share references to reduce memory allocation.

Return type:

Tuple[List[Tuple], List[Dict]]

Data Structures: - fwd_sample_queues: Dict[chunk_id, List[Tuple[sample_keys, fwd_idx]]] Queue of forward samples per chunk awaiting their backward pass. - consumed_sample_queue: Dict[sample_keys, List[fwd_idx]] Pool of buffer indices whose backward is complete, keyed by tensor signature. - sample_keys: Tuple of (shape, dtype, layout) for args + (key, shape, dtype, layout) for kwargs, used to match compatible buffers for reuse.

_get_cuda_graph_input_data()#

Create the CUDA Graph capturing input data. The data is organized per-chunk per-microbatch per-layer.

_start_capturing()#

Start capturing CUDA Graphs.

_finish_capturing(start_time)#

Finish capturing CUDA Graphs and clean up the related state.

create_cudagraphs()#

Capture CUDA Graphs per TransformerLayer per microbatch.

cuda_graph_set_manual_hooks()#

Set CUDA Graph manual hooks for the modules that contain direct parameters and are covered by cudagraphs.

delete_cuda_graphs()#

Delete all CUDA graphs.

core.transformer.cuda_graphs.convert_schedule_table_to_order(
num_warmup_microbatches,
num_model_chunks,
schedule_table,
)#

Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below: virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9 microbatch_id | 0 1 2 0 1 2 3 4 3 4 model_chunk_id | 0 0 0 1 1 1 0 0 1 1

Then the forward backward separated order is: forward | 1 1 1 2 2 2 1 1 2 2 backward | -2 -2 -2 -1 -1 -1 -2 -2 -1 -1

If num_warmup_microbatches is 5, the output order is: 1 1 1 2 2 2 -2 1 -2 1 -2 2 -1 2 -1 -1 -2 -2 -1 -1

core.transformer.cuda_graphs.get_overlap_moe_expert_parallel_comm_order(
order,
num_layers_per_chunk,
capture_wgrad_graph,
)#

This functions gets the order for overlap_moe_expert_parallel_comm schedule for the original chunk-wise order list. Each chunk is transformered to chunks with only 1 layer so that layers between 2 chunks can now overlap with each other while following the graph order. If capture_wgrad_graph is True, the wgrad backward graph is also added to the order by decreasing the layer id by 0.5.

Parameters:
  • order (List[int]) – The original chunk-wise order list. Positive values represent forward passes for chunks, negative values represent backward passes. The absolute value indicates the chunk ID (1-indexed).

  • num_layers_per_chunk (List[int]) – Number of graphable layers in each chunk. The length of this list equals the number of chunks.

  • capture_wgrad_graph (bool) – If True, weight gradient computation graphs are added to the order by appending entries with layer_id - 0.5.

Returns:

A tuple containing: - new_order: The layer-wise order list where each chunk is expanded to individual layers. Positive values are forward passes, negative values are backward passes. Values with .5 suffix indicate weight gradient computations. - chunk_id_list: A list parallel to new_order. For forward passes, contains [chunk_id, layer_index_within_chunk]. For backward passes, contains None.

Return type:

Tuple[List[float], List[Optional[List[int]]]]

.. rubric:: Example

original_order: [1, 2, -2, 1, -1, -1] num_layers_per_chunk: [1, 2] capture_wgrad_graph=True: new_order: [1, 2, 3, 1, -3, -3.5, -2, -2.5, -1, -1.5, -1, -1.5] chunk_id_list: [[0, 0], [1, 0], [1, 1], [0, 0], None, None, None, None, None, None, None, None] capture_wgrad_graph=False: new_order: [1, 2, 3, 1, -3, -2, -1, -1] chunk_id_list: [[0, 0], [1, 0], [1, 1], [0, 0], None, None, None, None]