core.transformer.cuda_graphs#
Module Contents#
Classes#
Metadata saved to tensors during cudagraph capture. This data will be used to determine during graph captue when a cudagraph can reuse a buffer or directly write its output into a subsequent’s graph’s input. |
|
Arg meta. |
|
A pool-like list of tensors that can be reused as input and output buffers during graph capture. Also maintains strong references to all tensors created by this pool, so that they will never be freed by the memory allocator. |
|
A global datastructure that records of the ordering of all _CudaGraphRunner’s first fwd or bwd passes. ‘create_cudagraphs’ will use this to create cudagraphs in execution order, which is required for cudagraphs sharing a mempool. |
|
An Enum to track if a cudagraph is ready to perform a forward or backward pass. |
|
Inserts a noop node into the autograd graph, used to record when a bwd graph needs to be created. |
|
Replays the runner’s cudagraphs with autograd. Handles copying data into/out of the cudagraph io and fp8/fp4 if used. |
|
Represents the execution of a cudagraphed module for a single microbatch. If there are multiple outstanding microbatches per module, such as for pipeline parallelism, CudaGraphManager automatically creates multiple _CudaGraphRunners per module. |
|
Creates and runs cudagraphs for a megatron module |
|
Helper class to capture CUDA Graphs using TE make_graphed_callables().
It is used in the beginning of the training loop to capture per-layer CUDA Graphs.
|
Functions#
Query if currently capturing. |
|
Set graph capture has started. |
|
Set graph capture has ended. |
|
Query if currently warming up for graph capture. |
|
Set graph warmup has started. |
|
Set graph warmup has ended. |
|
Wrapper around pytorch’s tree_map, but also recurses into dataclasses. |
|
Check if arg meta is a supported type for cudagraph input/outputs. |
|
Determine if the given module is the first/last layer of the PP+VPP chunk it belongs to. Returns a tuple of two booleans indicating if the module is the first/last layer of the chunk. |
|
Recursively clone tensors inside nested containers. |
|
Make generator state safe for CUDA graph capture/replay. |
|
Should be called at the end of each schedule function,
(e.g. forward_backward_pipelining_with_interleaving) in
|
|
Delete all CUDA graphs. |
|
Check if a layer is graphable. |
|
Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below: virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9 microbatch_id | 0 1 2 0 1 2 3 4 3 4 model_chunk_id | 0 0 0 1 1 1 0 0 1 1 |
|
This functions gets the order for overlap_moe_expert_parallel_comm schedule for the original chunk-wise order list. Each chunk is transformered to chunks with only 1 layer so that layers between 2 chunks can now overlap with each other while following the graph order. If capture_wgrad_graph is True, the wgrad backward graph is also added to the order by decreasing the layer id by 0.5. |
Data#
API#
- core.transformer.cuda_graphs._IS_GRAPH_CAPTURING#
False
- core.transformer.cuda_graphs._IS_GRAPH_WARMUP#
False
- core.transformer.cuda_graphs.logger#
‘getLogger(…)’
- core.transformer.cuda_graphs.FREEZE_GC#
None
- core.transformer.cuda_graphs.is_graph_capturing()#
Query if currently capturing.
- core.transformer.cuda_graphs._set_capture_start()#
Set graph capture has started.
- core.transformer.cuda_graphs._set_capture_end()#
Set graph capture has ended.
- core.transformer.cuda_graphs.is_graph_warmup()#
Query if currently warming up for graph capture.
- core.transformer.cuda_graphs._set_warmup_start()#
Set graph warmup has started.
- core.transformer.cuda_graphs._set_warmup_end()#
Set graph warmup has ended.
- class core.transformer.cuda_graphs.CudagraphBufferMetadata#
Metadata saved to tensors during cudagraph capture. This data will be used to determine during graph captue when a cudagraph can reuse a buffer or directly write its output into a subsequent’s graph’s input.
- is_cudagraph_input: bool#
False
- is_cudagraph_output: bool#
False
- input_use_count: int#
0
- cudagraph_reuse_ref_count: int#
0
- capture_reuse_count: int#
0
- fwd_cudagraph_buffer: torch.Tensor#
None
- bwd_cudagraph_buffer: torch.Tensor#
None
- class core.transformer.cuda_graphs.ArgMetadata(arg)#
Arg meta.
Initialization
- zeros_like()#
Reconstruct a tensor with the properties as the meta arg.
- class core.transformer.cuda_graphs.TensorReusePool#
A pool-like list of tensors that can be reused as input and output buffers during graph capture. Also maintains strong references to all tensors created by this pool, so that they will never be freed by the memory allocator.
- tensor_strong_refs: list#
[]
Record the data_ptrs of buffers created by the pool to check when a tensor came was allocated from this pool.
- tensor_strong_refs_dataptrs: set#
‘set(…)’
Buffers that have been returned to the pool and are available for reuse.
- pool: list[torch.Tensor]#
[]
- insert(tensor: torch.Tensor)#
Return a tensor to the pool reuse.
- owns(tensor: torch.Tensor)#
Check if a tensor was created from this pool.
- get(meta: core.transformer.cuda_graphs.ArgMetadata)#
Try to get a buffer from the pool. If a matching tensor is already in the pool, its assumed to be available and returned. Otherwise, allocate a new buffer.
- core.transformer.cuda_graphs.tree_map(func, tree)#
Wrapper around pytorch’s tree_map, but also recurses into dataclasses.
- core.transformer.cuda_graphs._check_supported_type(meta)#
Check if arg meta is a supported type for cudagraph input/outputs.
- core.transformer.cuda_graphs._determine_if_first_last_layer_of_this_vp_chunk(base_module)#
Determine if the given module is the first/last layer of the PP+VPP chunk it belongs to. Returns a tuple of two booleans indicating if the module is the first/last layer of the chunk.
- core.transformer.cuda_graphs._clone_nested_tensors(value: Any) Any#
Recursively clone tensors inside nested containers.
- core.transformer.cuda_graphs._ensure_generator_state_is_cudagraph_safe(
- gen: torch.Generator,
Make generator state safe for CUDA graph capture/replay.
Generator state tensors can become inference tensors if created under
torch.inference_mode(). CUDA graph capture may later attempt in-place updates on that state; this fails for inference tensors. Fix the generator in-place (preserving identity) by cloning its state outside inference mode and setting it back.
- core.transformer.cuda_graphs.fwd_buffer_reuse_ref_count#
0
- core.transformer.cuda_graphs.bwd_buffer_reuse_ref_count#
0
- class core.transformer.cuda_graphs._CudagraphGlobalRecord#
A global datastructure that records of the ordering of all _CudaGraphRunner’s first fwd or bwd passes. ‘create_cudagraphs’ will use this to create cudagraphs in execution order, which is required for cudagraphs sharing a mempool.
- cudagraph_created#
False
A record of fwd and bwd graph creation, populated with ‘record_fwd_graph’ and ‘record_bwd_graph.
- cudagraph_record: list[tuple]#
[]
- cudagraph_inference_record: list[tuple]#
[]
A pool-like data structure to reuse input and output buffers across cudagraph.
- tensor_reuse_pool#
‘TensorReusePool(…)’
- classmethod record_fwd_graph(runner, args, kwargs, out)#
Record a fwd graph to ‘cudagraph_record
- classmethod record_bwd_graph(runner)#
Record a bwd graph to ‘cudagraph_record
- classmethod create_cudagraphs()#
Iterate through ‘cudagraph_record’ creating graphs in the order in which they were recorded.
- core.transformer.cuda_graphs.create_cudagraphs()#
Should be called at the end of each schedule function, (e.g. forward_backward_pipelining_with_interleaving) in
megatron.core.pipeline_parallel.schedules.py. During the first step, _CudaGraphRunners populate _CudagraphGlobalRecord with the global order in which cudagraphs should be created. At the end for the first step, this function calls each runner’screate_fwd_graphandcreate_bwd_graphin the order recorded in _CudagraphGlobalRecord, which allows cudagraphs to be created in execution order, which allows multiple cudagraphs to share a single memory pool, minimizing cudagraph memory usage.
- core.transformer.cuda_graphs.delete_cuda_graphs()#
Delete all CUDA graphs.
- class core.transformer.cuda_graphs._GraphStatus(*args, **kwds)#
Bases:
enum.EnumAn Enum to track if a cudagraph is ready to perform a forward or backward pass.
Initialization
- FWD_READY#
0
- BWD_READY#
1
- class core.transformer.cuda_graphs._CudagraphRecordNode#
Bases:
torch.autograd.FunctionInserts a noop node into the autograd graph, used to record when a bwd graph needs to be created.
- static forward(ctx, runner, inputs)#
Forward pass, does nothing but registers an autograd node.
- static backward(ctx, grads)#
If this is the first bwd pass of this runner, record that a bwd graph needs to be created.
- class core.transformer.cuda_graphs._CudagraphReplayNode#
Bases:
torch.autograd.FunctionReplays the runner’s cudagraphs with autograd. Handles copying data into/out of the cudagraph io and fp8/fp4 if used.
- static forward(ctx, runner, is_first_microbatch, *inputs)#
Replay the forward graph of the passed runner.
- static backward(ctx, *grads)#
Replay the backward graph of the passed runner.
- class core.transformer.cuda_graphs._CudaGraphRunner(
- base_module: megatron.core.transformer.module.MegatronModule,
- mempool: int,
- fwd_graph_input_args: List[Any],
- fwd_graph_input_kwargs: Dict[str, Any],
- func,
- need_backward,
Bases:
torch.nn.ModuleRepresents the execution of a cudagraphed module for a single microbatch. If there are multiple outstanding microbatches per module, such as for pipeline parallelism, CudaGraphManager automatically creates multiple _CudaGraphRunners per module.
Initialization
Creates a _CudaGraphRunner, which holds a single pair of fwd and bwd cudagraphs, which are not created until this runner records its graph creation into ‘_CudagraphGlobalRecord’, and ‘create_cudagraphs()’ is called.
- __str__()#
- get_quantization_context()#
Return appropriate quantization context (FP8 or FP4) in cudagraph mode.
- get_connected_params(outputs)#
Iterate through the autograd graph of ‘outputs’ and returns all parameters connected. In theory this should return all parameters that return a nonzero wgrad when computing the backward pass of ‘outputs’.
- create_fwd_graph(args, kwargs, outputs=None, clone_inputs=True)#
Create a fwd cudagraph for this runner. Should be called inside ‘create_cudagraphs()’.
- create_bwd_graph()#
Create a bwd cudagraph for this runner. Should be called inside ‘create_cudagraphs()’.
- apply_cudagraph_record_metadata(args, kwargs, outputs)#
Attaches graph capture metadata to all passed in tensors.
- record_graph_capture(args, kwargs)#
Records the data needed to create this runner’s forward cudagraph. The first pass records a graph and appends the runner to _CudagraphGlobalRecord. The actual cudagraph will be created when ‘create_cudagraphs()` is called. Subsequent passes should replay the graph.
- replay_graph_capture(is_first_microbatch, args, kwargs)#
Replay the fwd cuda graph with autograd.
- get_mismatch_errors(args, kwargs)#
Return list of detailed errors for mismatched cudagraph args.
- get_arg_metas(args, kwargs=None)#
Replaces all passed in tensors with ‘ArgMetadata’ and returns them as a list.
- get_tensors(args, kwargs=None, check_types=True)#
Filter and flatten all tensors from args and kwargs using list comprehensions and itertools.chain for faster flattening.
- to_list(x)#
Helper function to wrap an input into a list
- class core.transformer.cuda_graphs.CudaGraphManager(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- base_module=None,
- function_name=None,
- need_backward=True,
Bases:
torch.nn.ModuleCreates and runs cudagraphs for a megatron module
Initialization
- global_mempool#
None
- call_ddp_preforward_hook(module)#
Call any DDP pre-forward hooks which are used to launch async data parallel param gather. Any other pre-forward hooks are not allowed.
- get_cudagraph_runner(megatron_module, args, kwargs, reuse_cudagraphs)#
Returns a valid cudagraph runner for the current forward call. The cudagraph corresponding to this call is the first element of ‘self.cudagraph_runners’. We iterate through the list by 1 for each call, and the number of calls is equal to the length of ‘self.cudagraph_runners’. Otherwise, we assign a mempool per microbatch, which allows cudagraphs to be reused over different microbatches by tracking their respective fwd and bwd passes.
- __call__(megatron_module, args, kwargs)#
Calls the forward pass of the cudagraphed module.
- Parameters:
megatron_module (torch.nn.module) – The megatron module to be graphed and run
args (tuple) – The positional args to be passed to the module.
kwargs (dict) – The keyword args to be passed to the module.
- core.transformer.cuda_graphs._layer_is_graphable(layer, config)#
Check if a layer is graphable.
- class core.transformer.cuda_graphs.TECudaGraphHelper(
- model,
- config,
- seq_length,
- micro_batch_size,
- optimizers=[],
Helper class to capture CUDA Graphs using TE make_graphed_callables(). It is used in the beginning of the training loop to capture per-layer CUDA Graphs.
self.create_cudagraphs()should be called to capture the CUDA Graphs andself.cuda_graph_set_manual_hooks()should be called to set manual pre-forward hooks for the parameters that are covered by cudagraphs.Initialization
- graphs_created()#
Returns whether the CUDA Graphs have been created.
- _get_sample_arguments(order, chunk_id_list=None)#
Generate sample arguments and keyword arguments for CUDA Graph capturing with memory-optimized buffer reuse.
This method creates static input tensors for each (layer, microbatch) pair needed by TE’s make_graphed_callables(). It optimizes memory usage by reusing input buffers across non-overlapping forward passes based on the pipeline parallel schedule. This optimization is essential for reducing peak memory during CUDA Graph capturing with many microbatches, as it allows buffers to be reused instead of allocating new ones for later microbatches.
Memory Optimization Strategy: The 1F1B (one-forward-one-backward) interleaved schedule in pipeline parallelism means that once a microbatch’s backward pass completes, its input buffers are no longer needed. This method tracks buffer lifecycle and reuses “consumed” buffers (those whose backward has completed) for new forward passes with matching tensor signatures (shape, dtype, layout).
Example schedule: [1, 1, 1, 2, 2, 2, -2, 1, -2, 1, -2, 2, -1, 2, -1, -1, -2, -2, -1, -1] - Positive values indicate forward passes (chunk_id = value) - Negative values indicate backward passes (chunk_id = -value) - When processing -2 (backward of chunk 2), its buffers become available for reuse - The next forward with matching signature can reuse those buffers
- Parameters:
order (List[int]) – The forward/backward execution order from convert_schedule_table_to_order(). Positive integers represent forward passes (1-indexed chunk ID), negative integers represent backward passes.
chunk_id_list (List[Tuple[int, int]]) – The list of chunk IDs and layer IDs in the order. This is useful only when overlap_moe_expert_parallel_comm is enabled, the order maps each layers’ idx to their original chunk id.
- Returns:
A tuple containing: - sample_args: List of positional argument tuples for each (layer, microbatch). Length = num_layers * num_microbatches. Elements with the same tensor signature may share references to reduce memory allocation. - sample_kwargs: List of keyword argument dicts for each (layer, microbatch). Length = num_layers * num_microbatches. Elements with the same tensor signature may share references to reduce memory allocation.
- Return type:
Tuple[List[Tuple], List[Dict]]
Data Structures: - fwd_sample_queues: Dict[chunk_id, List[Tuple[sample_keys, fwd_idx]]] Queue of forward samples per chunk awaiting their backward pass. - consumed_sample_queue: Dict[sample_keys, List[fwd_idx]] Pool of buffer indices whose backward is complete, keyed by tensor signature. - sample_keys: Tuple of (shape, dtype, layout) for args + (key, shape, dtype, layout) for kwargs, used to match compatible buffers for reuse.
- _get_cuda_graph_input_data()#
Create the CUDA Graph capturing input data. The data is organized per-chunk per-microbatch per-layer.
- _start_capturing()#
Start capturing CUDA Graphs.
- _finish_capturing(start_time)#
Finish capturing CUDA Graphs and clean up the related state.
- create_cudagraphs()#
Capture CUDA Graphs per TransformerLayer per microbatch.
- cuda_graph_set_manual_hooks()#
Set CUDA Graph manual hooks for the modules that contain direct parameters and are covered by cudagraphs.
- delete_cuda_graphs()#
Delete all CUDA graphs.
- core.transformer.cuda_graphs.convert_schedule_table_to_order(
- num_warmup_microbatches,
- num_model_chunks,
- schedule_table,
Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below: virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9 microbatch_id | 0 1 2 0 1 2 3 4 3 4 model_chunk_id | 0 0 0 1 1 1 0 0 1 1
Then the forward backward separated order is: forward | 1 1 1 2 2 2 1 1 2 2 backward | -2 -2 -2 -1 -1 -1 -2 -2 -1 -1
If num_warmup_microbatches is 5, the output order is: 1 1 1 2 2 2 -2 1 -2 1 -2 2 -1 2 -1 -1 -2 -2 -1 -1
- core.transformer.cuda_graphs.get_overlap_moe_expert_parallel_comm_order(
- order,
- num_layers_per_chunk,
- capture_wgrad_graph,
This functions gets the order for overlap_moe_expert_parallel_comm schedule for the original chunk-wise order list. Each chunk is transformered to chunks with only 1 layer so that layers between 2 chunks can now overlap with each other while following the graph order. If capture_wgrad_graph is True, the wgrad backward graph is also added to the order by decreasing the layer id by 0.5.
- Parameters:
order (List[int]) – The original chunk-wise order list. Positive values represent forward passes for chunks, negative values represent backward passes. The absolute value indicates the chunk ID (1-indexed).
num_layers_per_chunk (List[int]) – Number of graphable layers in each chunk. The length of this list equals the number of chunks.
capture_wgrad_graph (bool) – If True, weight gradient computation graphs are added to the order by appending entries with layer_id - 0.5.
- Returns:
A tuple containing: - new_order: The layer-wise order list where each chunk is expanded to individual layers. Positive values are forward passes, negative values are backward passes. Values with .5 suffix indicate weight gradient computations. - chunk_id_list: A list parallel to new_order. For forward passes, contains [chunk_id, layer_index_within_chunk]. For backward passes, contains None.
- Return type:
Tuple[List[float], List[Optional[List[int]]]]
.. rubric:: Example
original_order: [1, 2, -2, 1, -1, -1] num_layers_per_chunk: [1, 2] capture_wgrad_graph=True: new_order: [1, 2, 3, 1, -3, -3.5, -2, -2.5, -1, -1.5, -1, -1.5] chunk_id_list: [[0, 0], [1, 0], [1, 1], [0, 0], None, None, None, None, None, None, None, None] capture_wgrad_graph=False: new_order: [1, 2, 3, 1, -3, -2, -1, -1] chunk_id_list: [[0, 0], [1, 0], [1, 1], [0, 0], None, None, None, None]