CUDA Graphs#
Stable docs: docs/training/cuda-graphs.md
Card: card.yaml (co-located)
What It Is#
CUDA graphs capture GPU operations once and replay them with minimal host-driver overhead. Bridge supports two implementations:
|
Mechanism |
Scope support |
|---|---|---|
|
MCore |
|
|
TE |
|
Enablement#
Local full-iteration graph#
cfg.model.cuda_graph_impl = "local"
cfg.model.cuda_graph_scope = ["full_iteration"]
cfg.model.cuda_graph_warmup_steps = 3
cfg.model.use_te_rng_tracker = True
cfg.rng.te_rng_tracker = True
cfg.rerun_state_machine.check_for_nan_in_loss = False
cfg.ddp.check_for_nan_in_grad = False
TE scoped graph (dense model)#
cfg.model.cuda_graph_impl = "transformer_engine"
cfg.model.cuda_graph_scope = ["attn"] # or ["attn", "mlp"]
cfg.model.cuda_graph_warmup_steps = 3
cfg.model.use_te_rng_tracker = True
cfg.rng.te_rng_tracker = True
TE scoped graph (MoE model)#
cfg.model.cuda_graph_impl = "transformer_engine"
cfg.model.cuda_graph_scope = ["attn", "moe_router", "moe_preprocess"]
cfg.model.cuda_graph_warmup_steps = 3
cfg.model.use_te_rng_tracker = True
cfg.rng.te_rng_tracker = True
Performance harness CLI#
python scripts/performance/run_performance_workload.py \
--cuda_graph_impl transformer_engine \
--cuda_graph_scope attn moe_router moe_preprocess \
...
Valid CLI values live in scripts/performance/argument_parser.py:
VALID_CUDA_GRAPH_IMPLS:["none", "local", "transformer_engine"]VALID_CUDA_GRAPH_SCOPES:["full_iteration", "attn", "mlp", "moe", "moe_router", "moe_preprocess", "mamba"]
Required constraints#
use_te_rng_tracker = True(enforced ingpt_provider.py)full_iterationscope only withcuda_graph_impl = "local"full_iterationscope requirescheck_for_nan_in_loss = FalseDo not combine
moescope andmoe_routerscopeTensor shapes must be static (fixed seq_length, fixed micro_batch_size)
MoE token-dropless routing limits graphable scope to dense modules
With
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True, setNCCL_GRAPH_REGISTER=0(MCore enforces for local impl on arch < sm_100; TE impl asserts unconditionally)CPU offloading is incompatible with CUDA graphs
moe_preprocessscope requiresmoe_routerscope to also be set
Code Anchors#
Bridge config and validation#
# CUDA graph scope validation: check_for_nan_in_loss must be disabled with full_iteration graph
if self.model.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in self.model.cuda_graph_scope:
assert not self.rerun_state_machine.check_for_nan_in_loss, (
"check_for_nan_in_loss must be disabled when using full_iteration CUDA graph. "
"Set rerun_state_machine.check_for_nan_in_loss=False."
)
if self.model.cuda_graph_impl == "none":
self.model.cuda_graph_scope = []
TE RNG tracker requirement#
if self.cuda_graph_impl != "none":
assert getattr(self, "use_te_rng_tracker", False), (
"Transformer engine's RNG tracker is required for cudagraphs, it can be "
"enabled with use_te_rng_tracker=True'."
Graph creation and capture in training loop#
# Capture CUDA Graphs.
cuda_graph_helper = None
if model_config.cuda_graph_impl == "transformer_engine":
cuda_graph_helper = TECudaGraphHelper(...)
# ...
if config.model.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in config.model.cuda_graph_scope:
forward_backward_func = FullCudaGraphWrapper(
forward_backward_func, cuda_graph_warmup_steps=config.model.cuda_graph_warmup_steps
)
TE graph capture after warmup#
# Capture CUDA Graphs after warmup.
if (
model_config.cuda_graph_impl == "transformer_engine"
and cuda_graph_helper is not None
and not cuda_graph_helper.graphs_created()
and global_state.train_state.step - start_iteration == model_config.cuda_graph_warmup_steps
):
if model_config.cuda_graph_warmup_steps > 0 and should_toggle_forward_pre_hook:
disable_forward_pre_hook(model, param_sync=False)
cuda_graph_helper.create_cudagraphs()
if model_config.cuda_graph_warmup_steps > 0 and should_toggle_forward_pre_hook:
enable_forward_pre_hook(model)
cuda_graph_helper.cuda_graph_set_manual_hooks()
RNG initialization#
_set_random_seed(
rng_config.seed,
rng_config.data_parallel_random_init,
rng_config.te_rng_tracker,
rng_config.inference_rng_tracker,
use_cudagraphable_rng=(model_config.cuda_graph_impl != "none"),
pg_collection=pg_collection,
)
Delayed wgrad + CUDA graph interaction#
cuda_graph_scope = getattr(model_cfg, "cuda_graph_scope", []) or []
# ... scope parsing ...
if wgrad_in_graph_scope:
assert is_te_min_version("2.12.0"), ...
assert model_cfg.gradient_accumulation_fusion, ...
if attn_scope_enabled:
assert not model_cfg.add_bias_linear and not model_cfg.add_qkv_bias, ...
Perf harness override helper#
def _set_cuda_graph_overrides(
recipe, cuda_graph_impl=None, cuda_graph_scope=None
):
# Sets impl, scope, and auto-enables te_rng_tracker
Graph cleanup#
def _delete_cuda_graphs(cuda_graph_helper):
# Deletes FullCudaGraphWrapper and TE graph objects to free NCCL buffers
MCore classes (in 3rdparty/Megatron-LM)#
CudaGraphManager:megatron/core/transformer/cuda_graphs.pyTECudaGraphHelper:megatron/core/transformer/cuda_graphs.pyFullCudaGraphWrapper:megatron/core/full_cuda_graph.pyCudaGraphScopeenum:megatron/core/transformer/enums.py
Positive recipe anchors#
scripts/performance/configs/deepseek/deepseek_workload_base_configs.pyscripts/performance/configs/qwen/qwen3_workload_base_configs.pyscripts/performance/configs/gpt_oss/gpt_oss_workload_base_configs.py
Tests#
File |
Coverage |
|---|---|
|
|
|
|
|
TE autocast with CUDA graphs |
|
End-to-end local and TE graph smoke tests |
|
TE + CUDA graph recipe config |
|
TE + CUDA graph recipe config |
|
VLM CUDA graph settings |
Pitfalls#
TE RNG tracker is mandatory: Setting
cuda_graph_implwithoutuse_te_rng_tracker=Trueandrng.te_rng_tracker=Truewill assert in the provider.full_iterationrequires NaN checks disabled: The entire fwd+bwd is captured, so loss-NaN checking cannot inspect intermediate values.MoE scope restrictions:
moescope andmoe_routerscope are mutually exclusive. Token-dropless MoE can only graphmoe_routerandmoe_preprocess, not the full expert dispatch.Memory overhead: CUDA graphs pin all intermediate buffers for the graph’s lifetime (no memory reuse). TE scoped graphs add a few GB; full-iteration graphs can increase peak memory by 1.5–2×.
PP > 1compounds overhead since each stage holds its own graph.Delayed wgrad interaction: When
delay_wgrad_compute=Trueand attention or MoE router is incuda_graph_scope, additional constraints apply: TE >= 2.12.0,gradient_accumulation_fusion=True, and no attention bias.Variable-length sequences break graphs: Sequence lengths must be constant across steps. Use padded packed sequences if packing is needed.
Graph cleanup is required: CUDA graph objects hold NCCL buffer references. Bridge handles this in
_delete_cuda_graphs()at the end of training, but early exits must call it explicitly.Older GPU architectures: On GPUs with compute capability < 10.0 (pre-Blackwell), set
NCCL_GRAPH_REGISTER=0when usingPYTORCH_CUDA_ALLOC_CONF=expandable_segments:True. Enforced in MCoreCudaGraphManager(cuda_graphs.py:1428) andTECudaGraphHelper(cuda_graphs.py:1697). The TE impl asserts unconditionally regardless of arch.CPU offloading incompatible: CUDA graphs cannot be used with CPU offloading. Enforced in MCore
transformer_config.py:1907.MoE recompute + moe_router scope: MoE recompute is not supported with
moe_routerCUDA graph scope when usingcuda_graph_impl = "transformer_engine". Enforced in MCoretransformer_config.py:1977.
Verification#
Unit tests#
uv run python -m pytest \
tests/unit_tests/training/test_config.py -k "cuda_graph" \
tests/unit_tests/training/test_comm_overlap.py -k "cuda_graph" \
tests/unit_tests/models/test_gpt_full_te_layer_autocast_spec.py -k "cuda_graph" -q
Functional smoke test (requires GPU)#
uv run python -m pytest \
tests/functional_tests/recipes/test_llama_recipes_pretrain_cuda_graphs.py -q
Success criteria#
Unit tests pass, covering config validation for both
localandtransformer_engineimplementations.Functional test completes training steps with both CUDA graph implementations.
No NCCL errors or illegal memory access in logs.