core.transformer.cuda_graphs#

Module Contents#

Classes#

ArgMetadata

Arg meta.

_CudagraphGlobalRecord

A global datastructure that records of the ordering of all _CudaGraphRunner’s first fwd or bwd passes. ‘create_cudagraphs’ will use this to create cudagraphs in execution order, which is required for cudagraphs sharing a mempool.

_GraphStatus

An Enum to track if a cudagraph is ready to perform a forward or backward pass.

_CudagraphRecordNode

Inserts a noop node into the autograd graph, used to record when a bwd graph needs to be created.

_CudagraphReplayNode

Replays the runner’s cudagraphs with autograd. Handles copying data into/out of the cudagraph io and fp8/fp4 if used.

_CudaGraphRunner

Represents the execution of a cudagraphed module for a single microbatch. If there are multiple outstanding microbatches per module, such as for pipeline parallelism, CudaGraphManager automatically creates multiple _CudaGraphRunners per module.

CudaGraphManager

Creates and runs cudagraphs for a megatron module

TECudaGraphHelper

Helper class to capture CUDA Graphs using TE make_graphed_callables(). It is used in the beginning of the training loop to capture per-layer CUDA Graphs. self.create_cudagraphs() should be called to capture the CUDA Graphs and self.cuda_graph_set_manual_hooks() should be called to set manual pre-forward hooks for the parameters that are covered by cudagraphs.

Functions#

is_graph_capturing

Query if currently capturing.

_set_capture_start

Set graph capture has started.

_set_capture_end

Set graph capture has ended.

_check_supported_type

Check if arg meta is a supported type for cudagraph input/outputs.

_determine_if_transformer_decoder_layer

Determine if the given module is a transformer decoder layer.

_determine_if_first_last_layer_of_this_vp_chunk

Determine if the given module is the first/last layer of the PP+VPP chunk it belongs to. Returns a tuple of two booleans indicating if the module is the first/last layer of the chunk.

create_cudagraphs

Should be called at the end of each schedule function, (e.g. forward_backward_pipelining_with_interleaving) in megatron.core.pipeline_parallel.schedules.py. During the first step, _CudaGraphRunners populate _CudagraphGlobalRecord with the global order in which cudagraphs should be created. At the end for the first step, this function calls each runner’s create_fwd_graph and create_bwd_graph in the order recorded in _CudagraphGlobalRecord, which allows cudagraphs to be created in execution order, which allows multiple cudagraphs to share a single memory pool, minimizing cudagraph memory usage.

delete_cuda_graphs

Delete all CUDA graphs.

_layer_is_graphable

Check if a layer is graphable.

Data#

API#

core.transformer.cuda_graphs._IS_GRAPH_CAPTURING#

False

core.transformer.cuda_graphs.logger#

‘getLogger(…)’

core.transformer.cuda_graphs.FREEZE_GC#

None

core.transformer.cuda_graphs.is_graph_capturing()#

Query if currently capturing.

core.transformer.cuda_graphs._set_capture_start()#

Set graph capture has started.

core.transformer.cuda_graphs._set_capture_end()#

Set graph capture has ended.

class core.transformer.cuda_graphs.ArgMetadata(arg)#

Arg meta.

Initialization

core.transformer.cuda_graphs._check_supported_type(meta)#

Check if arg meta is a supported type for cudagraph input/outputs.

core.transformer.cuda_graphs._determine_if_transformer_decoder_layer(base_module)#

Determine if the given module is a transformer decoder layer.

core.transformer.cuda_graphs._determine_if_first_last_layer_of_this_vp_chunk(base_module)#

Determine if the given module is the first/last layer of the PP+VPP chunk it belongs to. Returns a tuple of two booleans indicating if the module is the first/last layer of the chunk.

class core.transformer.cuda_graphs._CudagraphGlobalRecord#

A global datastructure that records of the ordering of all _CudaGraphRunner’s first fwd or bwd passes. ‘create_cudagraphs’ will use this to create cudagraphs in execution order, which is required for cudagraphs sharing a mempool.

cudagraph_created#

False

A record of fwd and bwd graph creation, populated with ‘record_fwd_graph’ and ‘record_bwd_graph.

cudagraph_record#

[]

cudagraph_inference_record#

[]

classmethod record_fwd_graph(runner, args, kwargs)#

Record a fwd graph to ‘cudagraph_record

classmethod record_bwd_graph(runner)#

Record a bwd graph to ‘cudagraph_record

classmethod create_cudagraphs()#

Iterate through ‘cudagraph_record’ creating graphs in the order in which they were recorded.

core.transformer.cuda_graphs.create_cudagraphs()#

Should be called at the end of each schedule function, (e.g. forward_backward_pipelining_with_interleaving) in megatron.core.pipeline_parallel.schedules.py. During the first step, _CudaGraphRunners populate _CudagraphGlobalRecord with the global order in which cudagraphs should be created. At the end for the first step, this function calls each runner’s create_fwd_graph and create_bwd_graph in the order recorded in _CudagraphGlobalRecord, which allows cudagraphs to be created in execution order, which allows multiple cudagraphs to share a single memory pool, minimizing cudagraph memory usage.

core.transformer.cuda_graphs.delete_cuda_graphs()#

Delete all CUDA graphs.

class core.transformer.cuda_graphs._GraphStatus(*args, **kwds)#

Bases: enum.Enum

An Enum to track if a cudagraph is ready to perform a forward or backward pass.

Initialization

FWD_READY#

0

BWD_READY#

1

class core.transformer.cuda_graphs._CudagraphRecordNode#

Bases: torch.autograd.Function

Inserts a noop node into the autograd graph, used to record when a bwd graph needs to be created.

static forward(ctx, runner, inputs)#

Forward pass, does nothing but registers an autograd node.

static backward(ctx, grads)#

If this is the first bwd pass of this runner, record that a bwd graph needs to be created.

class core.transformer.cuda_graphs._CudagraphReplayNode#

Bases: torch.autograd.Function

Replays the runner’s cudagraphs with autograd. Handles copying data into/out of the cudagraph io and fp8/fp4 if used.

static forward(ctx, runner, is_first_microbatch, *inputs)#

Replay the forward graph of the passed runner.

static backward(ctx, *grads)#

Replay the backward graph of the passed runner.

class core.transformer.cuda_graphs._CudaGraphRunner(
base_module: megatron.core.transformer.module.MegatronModule,
fwd_mempool: int,
bwd_mempool: int,
fwd_graph_input_args: List[Any],
fwd_graph_input_kwargs: Dict[str, Any],
share_cudagraph_io_buffers=None,
)#

Bases: torch.nn.Module

Represents the execution of a cudagraphed module for a single microbatch. If there are multiple outstanding microbatches per module, such as for pipeline parallelism, CudaGraphManager automatically creates multiple _CudaGraphRunners per module.

Initialization

Creates a _CudaGraphRunner, which holds a single pair of fwd and bwd cudagraphs, which are not created until this runner records its graph creation into ‘_CudagraphGlobalRecord’, and ‘create_cudagraphs()’ is called. share_cudagraph_io_buffers is a boolean flag to indicate whether to reuse the cudagraph input and output buffers for transformer layer specific optimizations that reduce memory usage and tensor copies.

__str__()#
get_fp8_context()#

Return a new fp8 context in cudagraph mode.

get_fp4_context()#

Return a new fp4 context in cudagraph mode.

get_quantization_context()#

Return appropriate quantization context (FP8 or FP4) in cudagraph mode.

create_fwd_graph(args, kwargs, clone_inputs=True)#

Create a fwd cudagraph for this runner. Should be called inside ‘create_cudagraphs()’.

create_bwd_graph(static_grad_outputs=None)#

Create a bwd cudagraph for this runner. Should be called inside ‘create_cudagraphs()’.

get_input_grads_with_dummy_flags()#

Get the inputs grads that are returned by the bwd cudagraph call. If using grad accum fusion, wgrads have already been accumulated, so return dummy wgrads.

record_graph_capture(args, kwargs)#

Records the data needed to create this runner’s forward cudagraph. The first pass records a graph and appends the runner to _CudagraphGlobalRecord. The actual cudagraph will be created when ‘create_cudagraphs()` is called. Subsequent passes should replay the graph.

replay_graph_capture(is_first_microbatch, args, kwargs)#

Replay the fwd cuda graph with autograd.

get_mismatch_errors(args, kwargs)#

Return list of detailed errors for mismatched cudagraph args.

zero_out_tensors(args, kwargs=None)#

Replace all tensors inside arg, kwargs with zeroed copies.

classmethod get_tensors(args, kwargs=None)#

Filter and flatten all tensors from args and kwargs.

class core.transformer.cuda_graphs.CudaGraphManager(
config: megatron.core.transformer.transformer_config.TransformerConfig,
share_cudagraph_io_buffers: bool = True,
vp_stage: Optional[int] = None,
)#

Bases: torch.nn.Module

Creates and runs cudagraphs for a megatron module

Initialization

global_mempool#

None

Forward pass mempools, used with cudagraph reuse mode.

fwd_mempools#

None

Backward pass mempool, used with cudagraph reuse mode.

bwd_mempool#

None

set_is_first_microbatch(is_first_microbatch: bool)#

Update the is_first_microbatch flag for weight caching.

Parameters:

is_first_microbatch (bool) – Whether this is the first microbatch in the step.

call_ddp_preforward_hook(module)#

Call any DDP pre-forward hooks which are used to launch async data parallel param gather. Any other pre-forward hooks are not allowed.

get_cudagraph_runner(megatron_module, args, kwargs)#

Returns a valid cudagraph runner for the current forward call. For single mempool mode, we create a cudagraph for each call, if the module is called multiple times per step, for instance in the case of pipeline parallelism. The cudagraph corresponding to this call is the first element of ‘self.cudagraph_runners’. We iterate through the list by 1 for each call, and the number of calls is equal to the length of ‘self.cudagraph_runners’. Otherwise, we assign a mempool per microbatch, which allows cudagraphs to be reused over different microbatches by tracking their respective fwd and bwd passes.

__call__(megatron_module, args, kwargs)#

Calls the forward pass of the cudagraphed module.

Parameters:
  • megatron_module (torch.nn.module) – The megatron module to be graphed and run

  • args (tuple) – The positional args to be passed to the module.

  • kwargs (dict) – The keyword args to be passed to the module.

core.transformer.cuda_graphs._layer_is_graphable(layer, config)#

Check if a layer is graphable.

class core.transformer.cuda_graphs.TECudaGraphHelper(
model,
config,
seq_length,
micro_batch_size,
optimizers=[],
)#

Helper class to capture CUDA Graphs using TE make_graphed_callables(). It is used in the beginning of the training loop to capture per-layer CUDA Graphs. self.create_cudagraphs() should be called to capture the CUDA Graphs and self.cuda_graph_set_manual_hooks() should be called to set manual pre-forward hooks for the parameters that are covered by cudagraphs.

Initialization

_get_cuda_graph_input_data()#

Create the CUDA Graph capturing input data. The data is organized per-chunk per-microbatch per-layer.

_start_capturing()#

Start capturing CUDA Graphs.

_finish_capturing(start_time)#

Finish capturing CUDA Graphs and clean up the related state.

create_cudagraphs()#

Capture CUDA Graphs per TransformerLayer per microbatch.

cuda_graph_set_manual_hooks()#

Set CUDA Graph manual hooks for the modules that contain direct parameters and are covered by cudagraphs.