CUDA Graph#
CUDA Graphs reduce kernel-launch overhead by recording GPU operations once and replaying the recording on subsequent iterations. Megatron-LM provides three CUDA graph implementations controlled by --cuda-graph-impl.
For implementation background and design details, see NVIDIA’s
Transformer Engine and Megatron-LM CUDA Graph Support.
That article is a useful conceptual reference, but some examples there still use older flags such as
--enable-cuda-graph or --cuda-graph-scope full_iteration; in this repository, prefer
--cuda-graph-impl local|transformer_engine|full_iteration as documented below.
Overview#
CUDA graph behavior is set by three orthogonal flags:
Flag |
Values |
Purpose |
|---|---|---|
|
|
Which capture backend or strategy to use |
|
|
Per-layer training capture coverage; multi-valued and only meaningful for |
|
|
Granularity of CUDA graphs during inference; only |
Supported combinations:
|
Backend |
Training capture |
Inference capture |
|---|---|---|---|
|
— |
off |
off |
|
MCore |
per-layer, controlled by |
|
|
TE |
per-layer, controlled by |
not supported ( |
|
MCore |
one graph per training iteration; |
not supported ( |
CUDA Graph — Local Implementation (--cuda-graph-impl local)#
Uses MCore’s built-in CudaGraphManager. During training, this is a per-layer mode:
leaving --cuda-graph-modules unset captures the whole Transformer layer, while specifying
modules restricts capture to selected sub-regions. During inference, local can instead attach
graphs at either the layer boundary or the enclosing block boundary, as controlled by
--inference-cuda-graph-scope.
Operationally, this path is tightly integrated into MCore training and inference:
graphable modules create and own their
CudaGraphManagerinstances automaticallythe existing training schedules drive warmup/capture/replay automatically
users select the mode through config flags only; there is no separate helper API to wire into a custom training loop or a separate need to handle static input buffers
Usage#
--cuda-graph-impl local
--cuda-graph-modules options#
Module |
What is captured |
|---|---|
(empty / not set) |
Entire Transformer layer (default) |
|
|
|
|
|
|
|
MoE router + shared experts (if not EP-comm-overlapped) |
|
|
|
Mamba SSM layer |
Example — MoE model, capture attention and router:
--cuda-graph-impl local \
# Optionally restrict captured modules (default: capture whole layer, but not working for MoE dynamic shapes)
--cuda-graph-modules attn moe_router moe_preprocess
CUDA Graph — Transformer Engine Implementation (--cuda-graph-impl transformer_engine)#
Uses Transformer Engine’s make_graphed_callables() path. In Megatron-LM’s CLI, this has the
same training granularity as local: leaving --cuda-graph-modules unset captures the whole
Transformer layer, while specifying modules restricts capture to selected sub-regions. The main difference from
local is the backend implementation and feature compatibility. Unlike local, this path does
not support inference CUDA graphs.
Compared to local, this path exposes a more general and self-contained API via TE’s
make_graphed_callables(), giving users greater flexibility and control over how CUDA graphs are
wired into custom training loops. The trade-off is that it requires more manual setup:
the training loop must instantiate
TECudaGraphHelperthe training loop must call helper methods such as
create_cudagraphs()andcuda_graph_set_manual_hooks()at the correct points
Megatron-LM’s stock training loop already wires these calls in megatron/training/training.py,
but custom training scripts must do the same work themselves.
Usage#
--cuda-graph-impl transformer_engine \
--cuda-graph-modules attn moe_router moe_preprocess
The same training --cuda-graph-modules options apply as for local, and the default is likewise
whole-layer training capture when the flag is omitted.
Full-Iteration Training CUDA Graph (--cuda-graph-impl full_iteration)#
Captures the entire training iteration (excluding optimizer) as a single CUDA graph. The same wrapper is also used for training-loop validation/eval in forward-only mode. This provides the largest training/validation latency reduction.
This implementation does not create inference CUDA graphs. For inference, use
--cuda-graph-impl local --inference-cuda-graph-scope layer|block.
Requirements#
--no-check-for-nan-in-loss-and-gradis required: NaN checks involve CPU-GPU synchronization which cannot run inside a CUDA graph.--cuda-graph-modulesmust be omitted (or left empty): per-module selection has no meaning when the entire iteration is captured as a single graph.
Example#
--cuda-graph-impl full_iteration \
--no-check-for-nan-in-loss-and-grad
Common Configuration Examples#
Dense Model Training#
All three implementations work for dense models:
# Per-layer (local)
--cuda-graph-impl local
# equivalent: --cuda-graph-impl local --cuda-graph-modules attn mlp
# Per-layer (TE)
--cuda-graph-impl transformer_engine
# equivalent: --cuda-graph-impl transformer_engine --cuda-graph-modules attn mlp
# Full-iteration
--cuda-graph-impl full_iteration \
--no-check-for-nan-in-loss-and-grad
MoE Model Training#
MoE expert dispatch involves dynamic shapes and cannot be captured. --cuda-graph-modules is used
to capture only the static parts (attention, router, preprocess) while leaving expert compute in
eager mode. Example using transformer_engine (local works the same way):
--cuda-graph-impl transformer_engine \
--cuda-graph-modules attn moe_router moe_preprocess
With paged stash (currently available only on dev; see
docs/user-guide/features/paged_stash.md on the dev branch), expert dispatch shapes become
static (pre-sized via --moe-expert-rank-capacity-factor), which allows full-iteration CUDA
graphs to be used on MoE models as well:
--cuda-graph-impl full_iteration \
--no-check-for-nan-in-loss-and-grad \
--moe-flex-dispatcher-backend hybridep \
--use-transformer-engine-op-fuser \
--moe-expert-rank-capacity-factor <float> \
--moe-paged-stash
Additional Notes#
--cuda-graph-warmup-steps(default: 3) controls how many warmup steps run before CUDA graph capture. Setting it to 0 is not recommended: some operations rely on the first few iterations for lazy initialization or autotuning, and capturing too early may produce incorrect or suboptimal graphs.Inference CUDA graphs (serving or RL rollout) currently require
--cuda-graph-impl local. Use--inference-cuda-graph-scope layer|blockwithlocal; all other implementations must set--inference-cuda-graph-scope none, meaning inference runs in eager mode.Background reference: Transformer Engine and Megatron-LM CUDA Graph Support, which also covers PyTorch CUDA Graph best practices and lessons learned.
Migration Guide#
Legacy configurations (including --enable-cuda-graph, --external-cuda-graph, the renamed
--cuda-graph-scope flag (now --cuda-graph-modules), and deprecated module values such as
full_iteration and full_iteration_inference) are still accepted and automatically migrated
at runtime, but we encourage updating your configs to the new forms:
Old command |
New command |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|