core.transformer.cuda_graphs#
Module Contents#
Classes#
Arg meta. |
|
A global datastructure that records of the ordering of all _CudaGraphRunner’s first fwd or bwd passes. ‘create_cudagraphs’ will use this to create cudagraphs in execution order, which is required for cudagraphs sharing a mempool. |
|
An Enum to track if a cudagraph is ready to perform a forward or backward pass. |
|
Inserts a noop node into the autograd graph, used to record when a bwd graph needs to be created. |
|
Replays the runner’s cudagraphs with autograd. Handles copying data into/out of the cudagraph io and fp8/fp4 if used. |
|
Represents the execution of a cudagraphed module for a single microbatch. If there are multiple outstanding microbatches per module, such as for pipeline parallelism, CudaGraphManager automatically creates multiple _CudaGraphRunners per module. |
|
Creates and runs cudagraphs for a megatron module |
|
Helper class to capture CUDA Graphs using TE make_graphed_callables().
It is used in the beginning of the training loop to capture per-layer CUDA Graphs.
|
Functions#
Query if currently capturing. |
|
Set graph capture has started. |
|
Set graph capture has ended. |
|
Check if arg meta is a supported type for cudagraph input/outputs. |
|
Determine if the given module is a transformer decoder layer. |
|
Determine if the given module is the first/last layer of the PP+VPP chunk it belongs to. Returns a tuple of two booleans indicating if the module is the first/last layer of the chunk. |
|
Should be called at the end of each schedule function,
(e.g. forward_backward_pipelining_with_interleaving) in
|
|
Delete all CUDA graphs. |
|
Check if a layer is graphable. |
Data#
API#
- core.transformer.cuda_graphs._IS_GRAPH_CAPTURING#
False
- core.transformer.cuda_graphs.logger#
‘getLogger(…)’
- core.transformer.cuda_graphs.FREEZE_GC#
None
- core.transformer.cuda_graphs.is_graph_capturing()#
Query if currently capturing.
- core.transformer.cuda_graphs._set_capture_start()#
Set graph capture has started.
- core.transformer.cuda_graphs._set_capture_end()#
Set graph capture has ended.
- class core.transformer.cuda_graphs.ArgMetadata(arg)#
Arg meta.
Initialization
- core.transformer.cuda_graphs._check_supported_type(meta)#
Check if arg meta is a supported type for cudagraph input/outputs.
- core.transformer.cuda_graphs._determine_if_transformer_decoder_layer(base_module)#
Determine if the given module is a transformer decoder layer.
- core.transformer.cuda_graphs._determine_if_first_last_layer_of_this_vp_chunk(base_module)#
Determine if the given module is the first/last layer of the PP+VPP chunk it belongs to. Returns a tuple of two booleans indicating if the module is the first/last layer of the chunk.
- class core.transformer.cuda_graphs._CudagraphGlobalRecord#
A global datastructure that records of the ordering of all _CudaGraphRunner’s first fwd or bwd passes. ‘create_cudagraphs’ will use this to create cudagraphs in execution order, which is required for cudagraphs sharing a mempool.
- cudagraph_created#
False
A record of fwd and bwd graph creation, populated with ‘record_fwd_graph’ and ‘record_bwd_graph.
- cudagraph_record#
[]
- cudagraph_inference_record#
[]
- classmethod record_fwd_graph(runner, args, kwargs)#
Record a fwd graph to ‘cudagraph_record
- classmethod record_bwd_graph(runner)#
Record a bwd graph to ‘cudagraph_record
- classmethod create_cudagraphs()#
Iterate through ‘cudagraph_record’ creating graphs in the order in which they were recorded.
- core.transformer.cuda_graphs.create_cudagraphs()#
Should be called at the end of each schedule function, (e.g. forward_backward_pipelining_with_interleaving) in
megatron.core.pipeline_parallel.schedules.py. During the first step, _CudaGraphRunners populate _CudagraphGlobalRecord with the global order in which cudagraphs should be created. At the end for the first step, this function calls each runner’screate_fwd_graphandcreate_bwd_graphin the order recorded in _CudagraphGlobalRecord, which allows cudagraphs to be created in execution order, which allows multiple cudagraphs to share a single memory pool, minimizing cudagraph memory usage.
- core.transformer.cuda_graphs.delete_cuda_graphs()#
Delete all CUDA graphs.
- class core.transformer.cuda_graphs._GraphStatus(*args, **kwds)#
Bases:
enum.EnumAn Enum to track if a cudagraph is ready to perform a forward or backward pass.
Initialization
- FWD_READY#
0
- BWD_READY#
1
- class core.transformer.cuda_graphs._CudagraphRecordNode#
Bases:
torch.autograd.FunctionInserts a noop node into the autograd graph, used to record when a bwd graph needs to be created.
- static forward(ctx, runner, inputs)#
Forward pass, does nothing but registers an autograd node.
- static backward(ctx, grads)#
If this is the first bwd pass of this runner, record that a bwd graph needs to be created.
- class core.transformer.cuda_graphs._CudagraphReplayNode#
Bases:
torch.autograd.FunctionReplays the runner’s cudagraphs with autograd. Handles copying data into/out of the cudagraph io and fp8/fp4 if used.
- static forward(ctx, runner, is_first_microbatch, *inputs)#
Replay the forward graph of the passed runner.
- static backward(ctx, *grads)#
Replay the backward graph of the passed runner.
- class core.transformer.cuda_graphs._CudaGraphRunner(
- base_module: megatron.core.transformer.module.MegatronModule,
- fwd_mempool: int,
- bwd_mempool: int,
- fwd_graph_input_args: List[Any],
- fwd_graph_input_kwargs: Dict[str, Any],
- share_cudagraph_io_buffers=None,
Bases:
torch.nn.ModuleRepresents the execution of a cudagraphed module for a single microbatch. If there are multiple outstanding microbatches per module, such as for pipeline parallelism, CudaGraphManager automatically creates multiple _CudaGraphRunners per module.
Initialization
Creates a _CudaGraphRunner, which holds a single pair of fwd and bwd cudagraphs, which are not created until this runner records its graph creation into ‘_CudagraphGlobalRecord’, and ‘create_cudagraphs()’ is called. share_cudagraph_io_buffers is a boolean flag to indicate whether to reuse the cudagraph input and output buffers for transformer layer specific optimizations that reduce memory usage and tensor copies.
- __str__()#
- get_fp8_context()#
Return a new fp8 context in cudagraph mode.
- get_fp4_context()#
Return a new fp4 context in cudagraph mode.
- get_quantization_context()#
Return appropriate quantization context (FP8 or FP4) in cudagraph mode.
- create_fwd_graph(args, kwargs, clone_inputs=True)#
Create a fwd cudagraph for this runner. Should be called inside ‘create_cudagraphs()’.
- create_bwd_graph(static_grad_outputs=None)#
Create a bwd cudagraph for this runner. Should be called inside ‘create_cudagraphs()’.
- get_input_grads_with_dummy_flags()#
Get the inputs grads that are returned by the bwd cudagraph call. If using grad accum fusion, wgrads have already been accumulated, so return dummy wgrads.
- record_graph_capture(args, kwargs)#
Records the data needed to create this runner’s forward cudagraph. The first pass records a graph and appends the runner to _CudagraphGlobalRecord. The actual cudagraph will be created when ‘create_cudagraphs()` is called. Subsequent passes should replay the graph.
- replay_graph_capture(is_first_microbatch, args, kwargs)#
Replay the fwd cuda graph with autograd.
- get_mismatch_errors(args, kwargs)#
Return list of detailed errors for mismatched cudagraph args.
- zero_out_tensors(args, kwargs=None)#
Replace all tensors inside arg, kwargs with zeroed copies.
- classmethod get_tensors(args, kwargs=None)#
Filter and flatten all tensors from args and kwargs.
- class core.transformer.cuda_graphs.CudaGraphManager(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- share_cudagraph_io_buffers: bool = True,
- vp_stage: Optional[int] = None,
Bases:
torch.nn.ModuleCreates and runs cudagraphs for a megatron module
Initialization
- global_mempool#
None
Forward pass mempools, used with cudagraph reuse mode.
- fwd_mempools#
None
Backward pass mempool, used with cudagraph reuse mode.
- bwd_mempool#
None
- set_is_first_microbatch(is_first_microbatch: bool)#
Update the is_first_microbatch flag for weight caching.
- Parameters:
is_first_microbatch (bool) – Whether this is the first microbatch in the step.
- call_ddp_preforward_hook(module)#
Call any DDP pre-forward hooks which are used to launch async data parallel param gather. Any other pre-forward hooks are not allowed.
- get_cudagraph_runner(megatron_module, args, kwargs)#
Returns a valid cudagraph runner for the current forward call. For single mempool mode, we create a cudagraph for each call, if the module is called multiple times per step, for instance in the case of pipeline parallelism. The cudagraph corresponding to this call is the first element of ‘self.cudagraph_runners’. We iterate through the list by 1 for each call, and the number of calls is equal to the length of ‘self.cudagraph_runners’. Otherwise, we assign a mempool per microbatch, which allows cudagraphs to be reused over different microbatches by tracking their respective fwd and bwd passes.
- __call__(megatron_module, args, kwargs)#
Calls the forward pass of the cudagraphed module.
- Parameters:
megatron_module (torch.nn.module) – The megatron module to be graphed and run
args (tuple) – The positional args to be passed to the module.
kwargs (dict) – The keyword args to be passed to the module.
- core.transformer.cuda_graphs._layer_is_graphable(layer, config)#
Check if a layer is graphable.
- class core.transformer.cuda_graphs.TECudaGraphHelper(
- model,
- config,
- seq_length,
- micro_batch_size,
- optimizers=[],
Helper class to capture CUDA Graphs using TE make_graphed_callables(). It is used in the beginning of the training loop to capture per-layer CUDA Graphs.
self.create_cudagraphs()should be called to capture the CUDA Graphs andself.cuda_graph_set_manual_hooks()should be called to set manual pre-forward hooks for the parameters that are covered by cudagraphs.Initialization
- _get_cuda_graph_input_data()#
Create the CUDA Graph capturing input data. The data is organized per-chunk per-microbatch per-layer.
- _start_capturing()#
Start capturing CUDA Graphs.
- _finish_capturing(start_time)#
Finish capturing CUDA Graphs and clean up the related state.
- create_cudagraphs()#
Capture CUDA Graphs per TransformerLayer per microbatch.
- cuda_graph_set_manual_hooks()#
Set CUDA Graph manual hooks for the modules that contain direct parameters and are covered by cudagraphs.