Handling Dynamic Patterns#
Note
This section provides practical patterns for handling common dynamic behaviors when adopting CUDA Graphs in PyTorch.
CUDA Graphs traditionally require static execution patterns—the same sequence of operations with the same memory addresses every replay. Recent CUDA versions (12.3+) introduced conditional nodes that support limited dynamic control flow at the graph level, but as of PyTorch 2.9, these features are not yet exposed through PyTorch’s Python APIs. This document focuses on practical patterns for adapting real-world code with dynamic behavior to work with current PyTorch CUDA graph implementations.
Types of Dynamic Patterns#
Dynamic patterns fall into four categories:
Dynamic control flow: Conditional execution based on runtime GPU values that requires checking values on the CPU.
Dynamic tensors: Tensors whose memory addresses change across iterations, requiring new allocations or reallocations.
Dynamic scalars: Scalar values that change each iteration and must be correctly passed to GPU operations.
Dynamic shapes: Input dimensions that vary across iterations, involving both different kernel launches and dynamic memory allocation.
For each pattern, we’ll explore solutions ranging from simple workarounds to advanced techniques. When these patterns cannot be easily resolved, the last resort is to move them outside the CUDA graph capture range and execute them in eager mode.
Dynamic Control Flow#
Dynamic control flow refers to conditional execution based on runtime GPU values that requires checking values on the CPU. These patterns are problematic because checking GPU values in Python requires CPU-GPU synchronization (usually through .item() or implicit .is_nonzero()), which breaks CUDA graph capture.
Common Patterns and Why They Break#
Dynamic control flow appears in many real-world scenarios. The following code snippets illustrate common patterns that break graph capture:
Training stability checks: Detecting NaN/Inf loss values to skip corrupted iterations
Gradient management: Clipping gradients when norms exceed thresholds to prevent exploding gradients
Mixed precision training: Conditionally skipping optimizer steps when gradients contain invalid values
Adaptive inference: Early exit mechanisms that route simple inputs through cheaper paths while complex inputs use the full network
All these patterns share a common problem: Python needs to inspect GPU-computed values to make branching decisions, forcing CPU-GPU synchronization.
Example 1: NaN/Inf Loss Detection
Training stability often requires checking for invalid loss values. However, checking whether a GPU tensor contains NaN or Inf requires moving that information to the CPU, causing synchronization.
# ❌ Breaks graph capture - synchronization on loss.isnan()
loss = criterion(output, target)
if loss.isnan() or loss.isinf():
print("Invalid loss, skipping iteration")
continue
loss.backward()
Example 2: Gradient Clipping
Gradient clipping prevents exploding gradients by computing the global norm and scaling gradients when the norm exceeds a threshold. The clip_grad_norm_() function returns the computed norm to the CPU, triggering synchronization.
# ❌ Breaks graph capture - synchronization in clip_grad_norm_
loss.backward()
# What clip_grad_norm_() does internally (simplified):
# 1. Compute total gradient norm across all parameters
total_norm = torch.sqrt(sum(p.grad.norm()**2 for p in model.parameters()))
# 2. Check if clipping is needed
# This comparison implicitly calls .is_nonzero() - synchronization point!
if total_norm > max_norm:
clip_coef = max_norm / total_norm
for p in model.parameters():
p.grad.mul_(clip_coef)
optimizer.step()
Example 3: Conditional Optimizer Step (AMP)
Automatic Mixed Precision (AMP) training uses gradient scaling to prevent underflow in fp16. The gradient scaler checks for NaN/Inf gradients and conditionally skips the optimizer step, which involves checking GPU values on the CPU.
# ❌ Breaks graph capture - scaler.step() checks for NaN/Inf internally
scaler.scale(loss).backward()
# What scaler.step() does internally (simplified):
# See: https://github.com/pytorch/pytorch/blob/v2.4.0/torch/amp/grad_scaler.py#L343
# 1. Unscale gradients and check for infs/nans
optimizer_state = scaler._per_optimizer_states[id(optimizer)]
inf_grad_tensor = optimizer_state["found_inf_per_device"]
# 2. Check if any inf/nan was found - this calls .item()!
if inf_grad_tensor.item() == 0: # Synchronization point
# No inf/nan found, safe to step
optimizer.step()
# Otherwise skip the optimizer step
scaler.update()
Example 4: Early Exit and Adaptive Inference
Some models use early exit mechanisms where simpler inputs can exit through earlier, cheaper layers while complex inputs go through the full network. The decision to exit early requires checking confidence scores on the CPU, causing synchronization.
# ❌ Breaks graph capture - early exit based on confidence threshold
# Example: Early exit neural network (like BERTExit, DeeBERT)
hidden = early_layers(input)
exit_logits = early_exit_classifier(hidden)
confidence = exit_logits.softmax(dim=-1).max().item() # Synchronization point
if confidence > early_exit_threshold:
# High confidence - exit early (cheaper)
output = exit_logits
else:
# Low confidence - continue through full model (expensive)
hidden = remaining_layers(hidden)
output = final_classifier(hidden)
loss = criterion(output, target)
General Solutions#
The following subsections demonstrate specific solutions for each dynamic control flow pattern above. However, most solutions follow three general strategies that can be applied across different scenarios:
Use GPU-native alternatives: Replace Python conditionals with GPU operations like
torch.where(),torch.clamp(), ortorch.masked_fill_()that keep computation on GPU without synchronization.Implement custom CUDA kernels: For non-trivial logic, write CUDA kernels that perform conditional operations entirely on GPU. This is worthwhile when the control flow is moderately complex but can be expressed in CUDA.
Move outside graph with partial capture: For complex dynamic control flow (e.g., dispatching to different models, multi-path execution), keep the decision-making outside the graph. Capture individual paths as separate graphs and select which to replay. This approach benefits from graph trees or checkpointing techniques to manage multiple graph instances efficiently.
Note
Future: CUDA Conditional Nodes
CUDA 12.3+ introduced conditional nodes (IF, WHILE, and SWITCH in CUDA 12.8) that allow limited branching within graphs without CPU synchronization. These enable GPU-side conditionals based on device memory values. However, PyTorch does not yet expose these features through its Python APIs. For now, use the workarounds described below. Future PyTorch versions may support conditional nodes, enabling more flexible dynamic control flow within graphs.
For more information about CUDA conditional nodes, see Dynamic Control Flow in CUDA Graphs with Conditional Nodes on the NVIDIA Developer Blog.
Example 1: Gradient Clipping#
PyTorch’s clip_grad_norm_ implementation evolution provides an excellent example of how to remove synchronization and dynamic control flow to make code CUDA graph compatible.
Problem: PyTorch 1.0 implementation with synchronization
The PyTorch 1.0 clip_grad_norm_ implementation contained explicit synchronization:
# Simplified from PyTorch 1.0 clip_grad_norm_ implementation
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type # Synchronization here!
total_norm = total_norm ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1: # Dynamic control flow - different code paths
for p in parameters:
p.grad.data.mul_(clip_coef)
This version had a critical synchronization point: calling .item() during norm computation, which moves the GPU value to CPU. Once total_norm is a CPU scalar, the subsequent if clip_coef < 1: conditional executes on CPU without additional sync, but it creates dynamic control flow—different code paths execute depending on runtime values.
Solution: PyTorch 1.13+ sync-free implementation
The PyTorch 1.13 clip_grad_norm_ implementation eliminates synchronization and converts dynamic control flow into pure GPU operations:
# Simplified from PyTorch 1.13+ clip_grad_norm_ implementation
total_norm = torch.norm(
torch.stack([torch.norm(g.detach(), norm_type) for g in grads]),
norm_type
)
clip_coef = max_norm / (total_norm + 1e-6)
# Clamp avoids the conditional - always multiply, no sync needed
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for g in grads:
g.detach().mul_(clip_coef_clamped)
Key techniques PyTorch used:
Keep computation on GPU: Compute
total_normentirely on GPU as a tensor, avoiding.item()that would sync.Replace conditionals with GPU operations: Use
torch.clamp(clip_coef, max=1.0)instead ofif clip_coef < 1:. This eliminates the conditional branch—the code always executes the same path (multiplying gradients), making the control flow static.Accept redundant work: Always multiply gradients by the clamped coefficient, even when it equals 1.0. This redundant multiplication is cheap and maintains static execution.
Applying this pattern to your code:
When you encounter similar synchronization or dynamic control flow issues, apply the same principles: keep values on GPU, replace Python conditionals with GPU operations like torch.where() or torch.clamp(), and accept redundant computation to maintain static control flow. For complex logic that can’t be expressed with PyTorch operations, consider implementing a custom CUDA kernel that performs the entire operation on GPU. If the logic is too complex even for a custom kernel, use partial capture—keep the dynamic control flow outside the graph and capture only the static computation regions.
Example 2: Early Exit and Adaptive Inference#
Recall Example 4 above where early exit models check confidence and conditionally route through different paths. This creates dynamic control flow that breaks graph capture. Here are two approaches to make this pattern CUDA graph compatible.
Note
Simplified Code Examples
The code examples below are simplified to illustrate the core logic. Production implementations require additional warmup iterations, proper error handling, and memory management.
Solution 1 - Capture shared layers once, replay remaining layers conditionally:
This approach captures the early layers and remaining layers as separate graphs, then replays them based on confidence.
# Capture early layers + early exit classifier (always executed)
graph_early = torch.cuda.CUDAGraph()
static_input = torch.zeros_like(input)
static_hidden = None
static_early_logits = None
with torch.cuda.graph(graph_early):
static_hidden = early_layers(static_input)
static_early_logits = early_exit_classifier(static_hidden)
# Capture remaining layers (conditional)
graph_remaining = torch.cuda.CUDAGraph()
static_hidden_remaining = torch.zeros_like(static_hidden)
static_full_output = None
with torch.cuda.graph(graph_remaining):
static_full_output = remaining_layers(static_hidden_remaining)
# Inference loop
for input in dataloader:
# Always execute early layers + early exit classifier
static_input.copy_(input)
graph_early.replay()
# Check confidence outside graph - .item() call here
confidence = static_early_logits.max().item()
if confidence > threshold:
# High confidence - use early exit
output = static_early_logits
else:
# Low confidence - continue with remaining layers
static_hidden_remaining.copy_(static_hidden)
graph_remaining.replay()
output = static_full_output
Pros: Memory efficient (shared early layers), flexible routing decision
Cons: One sync per iteration for confidence check
Solution 2 - Capture complete paths separately:
This approach captures the full early exit path and full model path as separate end-to-end graphs.
# Capture early exit path
graph_early_exit = torch.cuda.CUDAGraph()
static_input_early = torch.zeros_like(input)
static_early_output = None
with torch.cuda.graph(graph_early_exit):
hidden = early_layers(static_input_early)
static_early_output = early_exit_classifier(hidden)
# Capture full model path
graph_full = torch.cuda.CUDAGraph()
static_input_full = torch.zeros_like(input)
static_full_output = None
with torch.cuda.graph(graph_full):
hidden = early_layers(static_input_full)
hidden = remaining_layers(hidden)
static_full_output = final_classifier(hidden)
# Inference loop with heuristic-based routing (no sync)
for input in dataloader:
# Decide which path based on cheap heuristic
use_early_exit = (input.norm() < threshold)
if use_early_exit:
static_input_early.copy_(input)
graph_early_exit.replay()
output = static_early_output
else:
static_input_full.copy_(input)
graph_full.replay()
output = static_full_output
Pros: Fully captured computation, routing can be sync-free with cheap heuristics
Cons: Higher memory (duplicate early layers), less flexible (can’t check actual confidence without sync)
Which solution to choose:
Solution 1 if you need confidence-based routing and early layers are expensive
Solution 2 if you can use heuristic routing or want maximum speedup from full graph capture
Both patterns apply to hierarchical classification, cascaded models, and adaptive computation architectures.
Dynamic Tensors#
Dynamic tensors refer to graph input tensors whose memory addresses change across different executions of the graphed region. This primarily concerns graph input tensors—tensors created and used entirely within the graph maintain stable addresses thanks to PyTorch’s caching allocator. Unlike tensors with changing values (which can be updated in-place with .copy_()), changing input tensor addresses break CUDA graphs because graphs record specific memory locations during capture.
In this context, “static” (or “persistent”) means a tensor remains allocated with the same memory address for the entire CUDA graph lifecycle (from capture through all replays). Dynamic tensors, by contrast, are recreated or reallocated, causing their addresses to change between iterations.
Understanding which tensors can be dynamic:
Not all tensors in your PyTorch program are graph inputs—only tensors that cross the graph boundary (coming from outside the graphed region) matter. The table below shows how different tensor types typically behave without and with CUDA graphs:
Note
Typical Cases
The behaviors shown are typical patterns. Actual behavior may vary depending on your specific implementation, model architecture, and training setup.
Tensor Type |
Without CUDA Graph |
With CUDA Graph |
Notes |
|---|---|---|---|
Model parameters |
Static (same address) |
Static (same address) |
Allocated once at initialization, reused across iterations |
Optimizer states |
Static (same address) |
Static (same address) |
Momentum buffers, Adam moments—created once, updated in-place |
Dataloader outputs |
Dynamic (new tensors) |
Graph input—must be static |
Most common issue: new objects each iteration. Solution: preallocate static buffers, use |
Type-casted parameters |
May change |
Usually inside graph |
E.g., FP32→FP16 in AMP—typically handled within graph automatically. Only manual casting outside graph requires stable addresses |
Internal activations |
Dynamic (address may change) |
Static (same address) |
Without graphs: caching allocator may return different addresses. With graphs: capture fixes addresses, caching allocator ensures reuse |
External tensors |
May change |
Graph input—must be static |
Preprocessing results, global state—preallocate and use |
Key insight: Internal tensors (created and consumed within the graphed region) are automatically handled by PyTorch’s caching allocator during graph capture—their addresses become fixed. The challenge is graph input tensors (coming from outside): dataloader outputs, external preprocessing, or any tensor passing into the graph. These must be explicitly managed to maintain stable addresses across replays.
Why this problem occurs:
Dynamic input tensors are problematic because CUDA graphs record specific memory addresses during capture and expect those addresses to remain unchanged during replay. If tensor addresses change between replays, the graph continues to access the old (stale) memory addresses, leading to incorrect results or memory errors.
Graph input tensors can become dynamic in several scenarios:
Unnoticed input tensors: Users don’t realize an input-changing tensor is being used inside the graphed region.
Assignment instead of
copy_(): Usingtensor = new_tensorinstead oftensor.copy_(new_tensor)creates a new tensor object with a different address.Tensor address collections: Tensors that store pointers to other tensors (e.g., grouped GEMM operations with lists of tensor addresses) can become invalid when the referenced tensors change addresses.
Consequences:
Silent correctness issues: Model diverges or produces wrong results without obvious errors—the graph reads stale data or data from unrelated tensors at the old addresses
Illegal memory access: If the old tensor is deallocated and its memory is recycled by PyTorch’s caching allocator, the graph accesses freed memory, causing CUDA errors or crashes
General Solutions#
The key principle is to maintain stable memory addresses for all graph input tensors across replays:
Identify all graph input tensors: Carefully audit the graphed region to find every tensor that comes from outside. This includes obvious inputs (data, targets) and subtle ones (learning rates, counters, hyperparameters stored as tensors).
Preallocate input tensors before capture: Create all graph input tensors before capture and reuse them across all replays. Never create new input tensors during replay.
Use
.copy_()for updates: Always update graph input tensors withstatic_tensor.copy_(new_data)instead of reassignment (static_tensor = new_data). The.copy_()method preserves the tensor’s memory address while updating its contents.Wrap dataloader with static buffer (optional): For cleaner code, wrap your dataloader to automatically copy data into static buffers. This encapsulates the
.copy_()logic and makes graph replay code simpler. See NeMo’s StaticBufferLoader implementation for a reference implementation.Be cautious with tensor address collections: Some operations (e.g., grouped GEMM) use tensors that store pointers to other tensors. For these cases, ensure both the pointer-holding tensor and all referenced tensors maintain stable addresses across replays. The pointer-holding tensor itself must also be preallocated and reused.
The following subsections demonstrate these solutions for specific common patterns.
Example 1: Global Statistics Tensor for Normalization#
Some models maintain running statistics (e.g., mean/std for normalization, EMA of activations) as global tensors that are updated each iteration. These are easy to overlook as graph inputs.
Problem - Unnoticed global tensor as graph input:
# Global running statistics for custom normalization layer
# Common in research code, video models, or online learning scenarios
running_mean = torch.zeros(num_features, device='cuda')
running_std = torch.ones(num_features, device='cuda')
class CustomNormLayer(nn.Module):
def forward(self, x):
# Uses global running statistics
normalized = (x - running_mean) / (running_std + 1e-5)
return normalized
model = nn.Sequential(CustomNormLayer(), ...)
# Capture the graph
g = torch.cuda.CUDAGraph()
static_input = torch.zeros(batch_size, num_features, device='cuda')
with torch.cuda.graph(g):
output = model(static_input)
# During training - this breaks the graph!
for batch in dataloader:
# Update statistics (common pattern: compute from current batch)
running_mean = batch.mean(dim=0) # ❌ New tensor with different address!
running_std = batch.std(dim=0) # ❌ New tensor with different address!
static_input.copy_(batch)
g.replay() # Graph still uses OLD running_mean/running_std - incorrect normalization!
Why it fails: The graph captured the original running_mean and running_std addresses. When you reassign these variables to new tensors (from .mean() and .std()), the graph continues reading from the old addresses, which may contain stale data or be deallocated.
Solution - Use .copy_() to update in-place:
# Preallocate statistics tensors before capture
running_mean = torch.zeros(num_features, device='cuda')
running_std = torch.ones(num_features, device='cuda')
# Capture graph (same as before)
with torch.cuda.graph(g):
output = model(static_input)
# During training - update statistics in-place
for batch in dataloader:
# Compute new statistics
new_mean = batch.mean(dim=0)
new_std = batch.std(dim=0)
# Update at the SAME addresses
running_mean.copy_(new_mean) # ✅ Same address, updated value
running_std.copy_(new_std) # ✅ Same address, updated value
static_input.copy_(batch)
g.replay() # Graph reads updated values at same addresses
Key lesson: Global tensors referenced inside the graphed region are graph inputs, even if they seem “internal” to the model. Always identify them before capture and update using .copy_() to preserve memory addresses.
Example 2: Grouped GEMM in MOE Models#
Mixture of Experts (MOE) models use a gating network to route each input token to a subset of “expert” sub-networks. Different tokens route to different experts, resulting in variable-sized inputs per expert. Grouped GEMM efficiently processes these variable-sized batches in one kernel call, performing multiple independent matrix multiplications where each pair can have different dimensions.
# Pseudo code for grouped GEMM
def grouped_gemm(A_list, B_list):
"""
Computes: [A_0 @ B_0, A_1 @ B_1, ..., A_n @ B_n]
where A_i has shape [M_i, K_i] and B_i has shape [K_i, N_i]
Implementation (simplified):
1. Extract pointers to each A_i and B_i tensor
2. Store pointers in host arrays (CPU memory)
3. Copy pointer arrays from host to device
4. Call cuBLAS grouped GEMM with device pointer arrays
"""
n_groups = len(A_list)
# Step 1: Extract tensor addresses into Python lists (host memory)
A_ptrs = [A.data_ptr() for A in A_list] # Python list of integers (CPU)
B_ptrs = [B.data_ptr() for B in B_list] # Python list of integers (CPU)
# Step 2: Create host tensors with correct data type (int64 for pointers)
A_ptrs_host = torch.tensor(A_ptrs, dtype=torch.int64, device='cpu')
B_ptrs_host = torch.tensor(B_ptrs, dtype=torch.int64, device='cpu')
# Step 3: Copy pointer arrays to device (non-blocking H2D copy - this gets captured!)
A_ptrs_device = A_ptrs_host.to('cuda', non_blocking=True) # Non-blocking H2D
B_ptrs_device = B_ptrs_host.to('cuda', non_blocking=True) # Non-blocking H2D
# Step 4: Call CUDA kernel with device pointer arrays
outputs = _cuda_grouped_gemm_kernel(A_ptrs_device, B_ptrs_device, ...)
return outputs
# Problem: A_ptrs_host and B_ptrs_host are DYNAMIC tensors - they are recreated
# and recycled on every call. But the H2D copy makes them CUDA graph inputs!
# Graph inputs must be STATIC (persistent) tensors with stable addresses.
If user tries to capture this Grouped GEMM function into a CUDA graph like this:
# MOE training setup
num_experts = 8
hidden_dim = 1024
expert_sizes = [256, 512, 128, 1024, 256, 512, 256, 128] # Variable sizes per expert
# Create expert input tensors
expert_inputs = [torch.randn(size, hidden_dim, device='cuda') for size in expert_sizes]
expert_weights = [torch.randn(hidden_dim, hidden_dim, device='cuda') for _ in range(num_experts)]
# Capture graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
# Grouped GEMM captures addresses of expert_inputs
outputs = grouped_gemm(expert_inputs, expert_weights)
# During training - update expert inputs and replay
for batch in dataloader:
# Route batch to experts (common in MOE models)
routed_data = [route_to_expert(batch, i) for i in range(num_experts)]
# Update expert inputs in-place
for i in range(num_experts):
expert_inputs[i][:routed_data[i].shape[0]].copy_(routed_data[i]) # ✅ Update device tensors
g.replay() # This still fails! Why?
Why it fails:
In normal (eager) execution, A_ptrs_host and B_ptrs_host are dynamic tensors—created fresh on every call and immediately recycled. When CUDA graph captures the non-blocking H2D copy, it records the source host memory addresses, making these host tensors CUDA graph inputs. All graph inputs must be static (persistent) with stable addresses, but these are dynamic (recreated each call).
During capture:
grouped_gemmcreates temporary host tensorsA_ptrs_hostat host address0x5000containing device pointers[0x7f3a2000, 0x7f3a4000, ...]. The graph captures: “copy from host address0x5000to device”.After capture: When the function returns,
A_ptrs_hostis deallocated and host address0x5000is recycled by Python’s memory allocator.During replay: The captured graph executes: “copy from host address
0x5000to device”. But0x5000now contains unrelated data, garbage, or is unmapped → segmentation fault or incorrect pointer values.
Solution - Keep host pointer arrays alive:
# Preallocate expert input tensors before capture
max_expert_size = max(expert_sizes)
expert_inputs = [
torch.zeros(max_expert_size, hidden_dim, device='cuda')
for _ in range(num_experts)
]
expert_weights = [torch.randn(hidden_dim, hidden_dim, device='cuda') for _ in range(num_experts)]
# Preallocate host pointer arrays (keep alive for graph lifecycle!)
# These must persist across all graph replays
A_ptrs_host_list = [] # Store host tensors for each grouped_gemm call
B_ptrs_host_list = []
def grouped_gemm_static(A_list, B_list, call_index):
"""Modified grouped_gemm that reuses preallocated host pointer arrays"""
# Extract current device pointers
A_ptrs = [A.data_ptr() for A in A_list]
B_ptrs = [B.data_ptr() for B in B_list]
# Update preallocated host pointer arrays (same host addresses!)
A_ptrs_host_list[call_index].copy_(torch.tensor(A_ptrs, dtype=torch.int64))
B_ptrs_host_list[call_index].copy_(torch.tensor(B_ptrs, dtype=torch.int64))
# Copy to device from persistent host addresses (non-blocking)
A_ptrs_device = A_ptrs_host_list[call_index].to('cuda', non_blocking=True)
B_ptrs_device = B_ptrs_host_list[call_index].to('cuda', non_blocking=True)
outputs = _cuda_grouped_gemm_kernel(A_ptrs_device, B_ptrs_device, ...)
return outputs
# Preallocate host pointer arrays before capture (assume 2 grouped_gemm calls)
num_gemm_calls = 2
for _ in range(num_gemm_calls):
A_ptrs_host_list.append(torch.zeros(num_experts, dtype=torch.int64, device='cpu'))
B_ptrs_host_list.append(torch.zeros(num_experts, dtype=torch.int64, device='cpu'))
# Capture graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
outputs1 = grouped_gemm_static(expert_inputs, expert_weights, call_index=0)
outputs2 = grouped_gemm_static(expert_inputs, expert_weights, call_index=1)
# During training - update both expert inputs and host pointer arrays
for batch in dataloader:
routed_data = [route_to_expert(batch, i) for i in range(num_experts)]
# Copy routed data into preallocated tensors (maintains device pointer stability)
for i in range(num_experts):
expert_inputs[i][:routed_data[i].shape[0]].copy_(routed_data[i]) # ✅ Same addresses
# Host pointer arrays are updated via .copy_() inside grouped_gemm_static
g.replay() # Graph reads from persistent host addresses
Key lesson: When operations store tensor addresses in pointer arrays and copy them via H2D transfers, you must ensure:
Device tensors (the actual data) maintain stable addresses across replays (use
.copy_()for updates)Host pointer arrays remain allocated for the entire graph lifecycle (not just during capture)
For multiple grouped GEMM calls, preallocate one host pointer array per call. The non-blocking nature of H2D copies means the host memory must stay valid throughout all graph replays.
Tip
Performance Optimization
The H2D copy in the captured graph is redundant during replay—it copies the same pointer values from the same host address every time. You can optimize by moving the H2D copy outside the capture region: perform it once before graph capturing instead of replaying it every iteration. This way, you only need to ensure A_ptrs_device (the device pointer array) is static/persistent for the graph lifecycle, not the host arrays. This eliminates the redundant H2D copy on every replay, improving performance.
Dynamic Scalars#
Dynamic scalars refer to CPU-side scalar values that change across iterations and are used in GPU operations. Unlike dynamic tensors (which concern memory addresses), dynamic scalars are about changing values that originate on the CPU. Common examples include random number generator seeds, learning rates, global step counters, and gradient scaling factors in mixed precision training.
# ❌ Breaks graph - exponent changes but graph uses captured constant
exponent = 2.0
# Capture graph with exponent=2.0
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
result = torch.pow(x, exponent) # Captures exponent=2.0 as constant
# Try to use different exponents during replay
for iteration in range(num_iters):
exponent = compute_exponent(iteration) # Returns 2.5, then 3.0, then 3.5...
g.replay() # Always computes x^2.0, ignoring current exponent value!
Why this is problematic:
When you pass a Python scalar (like exponent=2.0) to a GPU operation, the CUDA kernel is launched with that scalar value as a kernel parameter. During graph capture, CUDA records this kernel launch with the specific scalar value—essentially turning the scalar into a constant in the captured graph. On replay, the graph launches the same kernel with the same constant value (2.0), completely ignoring any changes you made to the Python variable exponent (now 2.5, 3.0, etc.). The scalar is “baked into” the captured graph as a constant.
General Solutions#
Identify all changing scalars: Before attempting to graph your code, audit all scalar values passed to GPU operations within the graphed region. Look for learning rates, step counters, scale factors, exponents, and any other CPU scalars that change across iterations. These must be converted to GPU tensors if they are captured into CUDA graph.
To make dynamic scalars work with CUDA graphs, follow these steps:
Step 1: Store scalars as GPU tensors and update in-place:
Convert all identified CPU scalars to GPU tensors before capture, then update them in-place between replays:
# Preallocate GPU tensor for scalar
exponent_tensor = torch.tensor(2.0, device='cuda')
# Capture graph using GPU tensor
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
result = torch.pow(x, exponent_tensor) # ✅ Uses GPU tensor instead of Python scalar
# Update scalar tensor before each replay
for iteration in range(num_iters):
new_exponent = compute_exponent(iteration) # Returns 2.5, then 3.0, then 3.5...
exponent_tensor.fill_(new_exponent)
g.replay() # Uses updated exponent value!
Step 2: Modify custom CUDA kernels to accept device pointers (only for custom kernels):
For custom CUDA kernels (not PyTorch native operators), you must also modify the kernel implementation to accept device pointers instead of scalar values as kernel parameters. For example, if you wrote a custom pow kernel:
// ❌ Original custom pow kernel - exponent captured as constant
__global__ void custom_pow_kernel(float* output, float* input, float exponent, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = powf(input[idx], exponent); // exponent captured as constant
}
}
// Launch: custom_pow_kernel<<<...>>>(output, input, exponent, n); // exponent captured as constant
// ✅ Modified custom pow kernel - exponent read from device memory
__global__ void custom_pow_kernel(float* output, float* input, float* exponent_ptr, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float exponent = *exponent_ptr; // Dereference pointer to read current value
output[idx] = powf(input[idx], exponent);
}
}
// Launch: custom_pow_kernel<<<...>>>(output, input, exponent_tensor.data_ptr(), n);
When this is needed: Only for custom CUDA kernels (via PyTorch C++ extensions or third-party libraries). PyTorch native operators like torch.pow() don’t require kernel modification—they already handle GPU tensors correctly.
The following examples demonstrate these steps in real-world scenarios.
Example 1: Random Number Generator State#
Random operations like dropout require different random values each iteration, but naive graph capture would replay the same random sequence.
Background - How CUDA RNG works:
CUDA random number generation uses the Philox4_32_10 algorithm, which requires three parameters to initialize the random state in each CUDA thread:
Seed (uint64): Determines which random sequence to use—same seed produces the same sequence
Subsequence ID (typically thread index): Each CUDA thread uses its own subsequence to generate independent random numbers in parallel
Offset (uint64): Position within the random sequence—the offset advances by the number of random values consumed (e.g., dropout over 1024 elements advances offset by 1024)
Each random operation in PyTorch follows this pattern:
Host prepares
PhiloxCudaStatecontaining current seed and offset from the global RNG stateCUDA kernel receives this state as a parameter
Each thread initializes its own
curandStatePhilox4_32_10_twith:curand_init(seed, thread_idx, offset, &state)Threads generate random numbers from their independent subsequences
Host increments the global offset for the next operation
The CUDA graph problem: During graph capture, if the seed and offset values are captured as scalar constants in the kernel launch parameters, every replay would use the exact same random sequence—your dropout mask would be identical across all iterations of graph replays!
PyTorch’s solution: Register the RNG generator state with the graph using register_generator_state(). This changes how PhiloxCudaState is constructed during capture:
Without registration:
PhiloxCudaStatecaptures the seed and offset as constant scalar values—replays always use the same valuesWith registration:
PhiloxCudaStatecaptures pointers to GPU tensors holding the seed and offset—replays read the current values from GPU memory
# Example with custom generator (default generator is auto-registered)
custom_gen = torch.Generator(device='cuda')
g = torch.cuda.CUDAGraph()
# Register custom generator if using one (default generator auto-registers)
g.register_generator_state(custom_gen)
# Capture graph with dropout using custom generator
with torch.cuda.graph(g):
output = F.dropout(input, p=0.5, training=True, generator=custom_gen)
# Each replay uses different random values
for _ in range(10):
g.replay() # PyTorch automatically advances RNG offset after each replay
How it works internally (see PyTorch v2.4.0 source code):
During capture (
capture_prologue): Instead of passingPhiloxCudaState(seed_value, offset_value)to kernels, PyTorch passesPhiloxCudaState(seed_ptr, offset_ptr, intra_graph_offset)whereseed_ptrandoffset_ptrare device pointers to 1-element GPU tensors. At capture end (capture_epilogue), PyTorch records the total offset consumed by the graph (wholegraph_increment)Before each replay (
replay_prologue): PyTorch increments the CPU-side offset counter bywholegraph_increment, then copies the updated seed and offset values to GPU tensors using.fill_():seed_extragraph_.fill_(seed)andoffset_extragraph_.fill_(offset)During replay (kernel execution): Kernels dereference the device pointers to read the current seed and offset from GPU memory:
seed = *seed_ptr; offset = *offset_ptr + intra_graph_offset
This follows the same pattern as our “General Solutions”—store changing scalars in GPU tensors, update them on the CPU side, copy to GPU before each replay, and pass device pointers to kernels.
Key lesson: For RNG-based operations, use PyTorch’s built-in register_generator_state() API. This is handled automatically by torch.cuda.graph() and make_graphed_callables()—you rarely need to call it manually.
Example 2: Learning Rate and Global Step in APEX FusedAdam#
Optimizers like Adam use learning rate and global step (for bias correction) that change each iteration. APEX’s FusedAdam provides a CUDA graph-compatible implementation by moving these scalars to GPU.
Problem with standard Adam:
# Standard Adam - not CUDA graph compatible
# Capture graph at iteration 0 with lr=0.001
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
# Adam internally uses:
# - lr (learning rate) as Python float (0.001)
# - step (global step counter) as Python int (0)
optimizer.step() # ❌ Captures lr=0.001 and step=0 as constants!
# Try to use different learning rates during replay
for iteration in range(num_iters):
lr = scheduler.get_lr() # Returns 0.0005, 0.0001, etc.
g.replay() # Always uses lr=0.001 and step=0 from capture!
APEX FusedAdam’s solution (see APEX FusedAdam implementation):
Store
lrandstepas GPU tensors instead of CPU scalars:lrinitialized as CPU tensor, then moved to device ifcapturable=True(lines 78, 105)stepas GPU tensor whencapturable=True(line 151-154)
Use capturable CUDA kernel variant that reads from device pointers (lines 219-232, calls
multi_tensor_adam_capturablekernel)
from apex.optimizers import FusedAdam
# APEX FusedAdam stores lr and step on GPU
optimizer = FusedAdam(model.parameters(), lr=0.001, set_grad_none=True)
# Capture graph - FusedAdam kernel reads from device pointers
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
optimizer.step()
# Access GPU tensors for lr and step (after capture)
lr_tensor = optimizer.param_groups[0]['lr'] # GPU tensor, not Python float
step_tensor = optimizer.state[p]['step'] # GPU tensor, not Python int
# Update lr before each replay
for iteration in range(num_iters):
lr_tensor.fill_(scheduler.get_lr()[0]) # Update GPU tensor in-place
# step is automatically incremented by FusedAdam's CUDA kernel
g.replay() # Uses current lr value from GPU memory
How it works: APEX modified the Adam CUDA kernel signature to accept GPU tensors instead of scalars (see multi_tensor_adam.cu):
// Standard Adam - takes scalars as kernel parameters (captured as constants)
adam_kernel<<<...>>>(param_ptr, grad_ptr, lr_value, step_value, ...)
// FusedAdam capturable - takes GPU tensors (reads dynamic values from device memory)
fused_adam_capturable<<<...>>>(param_ptr, grad_ptr, lr_tensor, step_tensor, ...)
The kernel dereferences lr_ptr and step_ptr to read the current values from GPU memory on each replay, allowing them to change between iterations.
Key lesson: To make optimizers CUDA graph-compatible with dynamic scalars:
Store scalars as GPU tensors (not CPU Python values)
Modify CUDA kernels to accept device pointers and dereference them, rather than embedding scalar values as kernel parameters
Update GPU tensors in-place before each replay
This pattern applies to any custom CUDA kernel that needs dynamic scalar parameters.
Dynamic Shapes#
Dynamic shapes occur when tensor dimensions vary across iterations. Common examples include variable batch sizes, variable sequence lengths in NLP, and different image resolutions. This is the most challenging dynamic pattern because changing shapes means different tensor sizes, different memory allocations, and often different kernels, all of which break CUDA graph capture.
Why this is problematic: CUDA graphs capture not just the sequence of kernel launches, but also the specific kernel configurations (grid/block dimensions) and memory allocation patterns. When shapes change:
Different kernel configurations: A matmul with shape
[64, 512] @ [512, 1024]launches different grid/block sizes than[128, 512] @ [512, 1024]Different memory allocations: Changing shapes trigger dynamic memory allocation during replay, which is forbidden in CUDA graphs
Different execution paths: Some kernels (e.g., cuDNN convolutions) select different algorithms based on input shapes
During graph replay, the captured kernel configurations and memory addresses are reused regardless of actual input shapes, leading to incorrect results or crashes.
General Solutions#
There are two main approaches to handle dynamic shapes:
Padding to fixed size: Pad all inputs to a fixed size, capture a single graph. Can pad during data loading or at runtime
Bucketing: Capture multiple graphs for common shape configurations (e.g., sequence lengths 128, 256, 512), select the appropriate graph at runtime. Usually works together with padding—pad inputs to the nearest bucket size
Start with padding (simplest), then move to bucketing only if profiling shows significant wasted computation.
When shapes are too dynamic: If shapes vary unpredictably with no pattern (e.g., MoE token routing, variable image sizes, irregular scientific data), neither padding nor bucketing works well. In these cases:
Graph only the static parts (see Example 2 below), or
Keep the model in eager mode entirely
CUDA graphs are not a one-size-fits-all solution—sometimes eager mode is the right answer.
Example 1: Variable Sequence Lengths (NLP)#
Problem: NLP models often have variable sequence lengths (and sometimes variable batch sizes):
# Sequence lengths vary: 128, 256, 512, 1024, ...
# Batch sizes may also vary: 16, 32, 64, ...
for input_ids in dataloader: # Shape: [batch, seq_len]
output = model(input_ids)
Solution 1 - Padding to fixed size:
Pad all inputs to a fixed size and capture a single graph. This trades computational efficiency for implementation simplicity.
max_seq_len = 2048
batch_size = 32
static_input = torch.zeros(batch_size, max_seq_len, dtype=torch.long, device='cuda')
static_output = None
# Capture with max size
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
static_output = model(static_input)
# Pad variable-length sequences at runtime
def forward(input_ids):
actual_seq_len = input_ids.shape[1]
# Pad to max length
if actual_seq_len < max_seq_len:
input_ids = F.pad(input_ids, (0, max_seq_len - actual_seq_len), value=pad_token_id)
# Copy to static buffer and replay
static_input.copy_(input_ids)
g.replay()
# Extract actual results (trim padding)
return static_output[:, :actual_seq_len].clone()
How it works: The model always processes max_seq_len tokens, even if the actual input is shorter. For transformers with attention masking, padding tokens are masked out so they don’t affect results—but the model still computes attention scores for them. A 128-token sequence padded to 2048 wastes ~93% of computation in attention layers.
Alternative: You can also pad during data loading (e.g., in DataLoader’s collate_fn) to eliminate runtime padding overhead, at the cost of less flexibility in handling variable lengths.
When this is acceptable: If sequence lengths cluster near the maximum (e.g., most sequences are 1800-2048 tokens out of 2048 max), wasted computation is minimal. This approach makes sense when simplicity and memory constraints (only one graph) outweigh efficiency concerns.
Pros: Single graph, simple implementation, handles any size ≤ max, minimal memory overhead
Cons: Wastes computation on padding (especially for short sequences), can significantly reduce throughput
Solution 2 - Bucketing:
Capture multiple graphs, one for each common shape configuration. At runtime, select the graph that best matches the input shape (usually rounding up to the nearest bucket). This provides the best efficiency for diverse shape distributions at the cost of increased complexity and memory usage.
class MultiGraphModel:
def __init__(self, model, seq_lengths=[128, 256, 512, 1024, 2048], batch_size=32):
self.model = model
self.graphs = {}
# Capture one graph per sequence length bucket
for seq_len in seq_lengths:
static_input = torch.zeros(batch_size, seq_len, dtype=torch.long, device='cuda')
static_output = None
# Warmup
for _ in range(3):
with torch.no_grad():
_ = model(static_input)
# Capture graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
static_output = model(static_input)
self.graphs[seq_len] = {
'graph': g,
'input': static_input,
'output': static_output
}
def __call__(self, input_ids):
seq_len = input_ids.shape[1]
# Round up to nearest bucket
bucket = min([s for s in self.graphs.keys() if s >= seq_len],
default=max(self.graphs.keys()))
# Pad to bucket size if needed
if seq_len < bucket:
input_ids = F.pad(input_ids, (0, bucket - seq_len), value=pad_token_id)
# Copy to static buffer and replay
self.graphs[bucket]['input'].copy_(input_ids)
self.graphs[bucket]['graph'].replay()
return self.graphs[bucket]['output'][:, :seq_len].clone() # Trim padding
How it works: Each bucket graph is optimized for its specific shape—a 512-token sequence runs different kernel configurations than a 1024-token sequence. By capturing multiple buckets, you avoid the wasted computation of always padding to maximum length. The runtime overhead is minimal: selecting a bucket (dictionary lookup), padding if needed (cheap for small gaps like 480→512), and trimming output (simple slicing).
Memory considerations: Each graph stores the entire computation graph including all intermediate activations at that specific shape. For a transformer model with 5 buckets, expect 5× the memory overhead compared to a single graph. This is acceptable if you have enough GPU memory to accommodate multiple graph instances—the performance gain justifies the cost when memory permits.
Pros: Optimal performance for each bucket, handles diverse lengths efficiently, minimal wasted computation
Cons: Memory overhead (one graph per bucket), complex management, need to choose buckets wisely
When to use each:
Start with padding for simplicity. Move to bucketing only if profiling shows significant wasted computation (e.g., many short sequences padded to a very long max) and you have sufficient GPU memory for multiple graphs.
Key decision factors:
Shape distribution: If 90%+ of sequences cluster in a narrow range (e.g., 1800-2048 tokens), padding to max is simplest and wastes minimal computation. If shapes spread widely (e.g., distributed across 128, 256, 512, 1024), bucketing avoids wasted work
Memory constraints: Bucketing requires more memory for multiple graph instances. If memory-limited, padding is the only option
Hybrid approach: Graph only the most frequent shapes (e.g., top 3 buckets covering 95% of data) and fall back to eager mode for rare shapes. This balances performance, memory, and complexity
Example 2: MoE Models with Token Dropless#
Problem: Mixture-of-Experts (MoE) models with token dropless mechanisms route different numbers of tokens to each expert based on the gating network’s output. This creates highly dynamic shapes that change every forward pass:
class MoELayer(nn.Module):
def __init__(self, num_experts=8, expert_capacity=None):
self.experts = nn.ModuleList([Expert() for _ in range(num_experts)])
self.gate = nn.Linear(hidden_dim, num_experts)
self.expert_capacity = expert_capacity # tokens per expert
def forward(self, tokens): # Shape: [batch * seq_len, hidden_dim]
# Gate decides which tokens go to which experts
routing_weights = self.gate(tokens) # [num_tokens, num_experts]
expert_assignments = routing_weights.argmax(dim=-1) # Dynamic routing
# Group tokens by expert - HIGHLY DYNAMIC SHAPES!
expert_inputs = []
for expert_id in range(len(self.experts)):
# Number of tokens varies per expert, per iteration
mask = (expert_assignments == expert_id)
expert_tokens = tokens[mask] # Shape: [dynamic_count, hidden_dim]
expert_inputs.append(expert_tokens)
# Process each expert - each has different input shape
expert_outputs = []
for expert, inputs in zip(self.experts, expert_inputs):
if len(inputs) > 0:
outputs = expert(inputs) # Shape varies!
expert_outputs.append(outputs)
# Combine outputs (also dynamic)
return combine_expert_outputs(expert_outputs, expert_assignments)
Why this is extremely dynamic:
Token distribution changes every forward pass based on input data
Each expert processes a different number of tokens (could be 0 to all tokens)
No predictable pattern—can’t use bucketing effectively
Even with capacity limits, shapes vary significantly
Solution - Graph only static parts:
Since the expert routing is inherently dynamic, graph the non-MoE layers and keep MoE layers in eager mode:
class MoEModel(nn.Module):
def __init__(self):
self.pre_layers = nn.Sequential(...) # Static computation
self.moe_layer = MoELayer(...) # Dynamic computation
self.post_layers = nn.Sequential(...) # Static computation
def forward(self, x):
x = self.pre_layers(x) # Can be graphed
x = self.moe_layer(x) # Must stay in eager mode
x = self.post_layers(x) # Can be graphed
return x
# Graph only the static parts
pre_layers_graphed = torch.cuda.make_graphed_callables(
model.pre_layers,
sample_args=(torch.zeros(batch_size, seq_len, hidden_dim, device='cuda'),)
)
post_layers_graphed = torch.cuda.make_graphed_callables(
model.post_layers,
sample_args=(torch.zeros(batch_size, seq_len, hidden_dim, device='cuda'),)
)
# Training loop with partial graphing
for batch in dataloader:
# Pre-processing: graphed (fixed shape)
pre_output = pre_layers_graphed(batch)
# MoE routing: eager mode (dynamic shapes)
moe_output = model.moe_layer(pre_output)
# Post-processing: graphed (fixed shape)
final_output = post_layers_graphed(moe_output)
loss = criterion(final_output, targets)
loss.backward()
optimizer.step()
Performance considerations:
Even with partial graphing, you can still achieve significant speedup if the non-MoE layers dominate compute time
For MoE-heavy models (many MoE layers), CUDA graphs may provide limited benefit
Alternative solutions for fully graphing MoE:
Token dropping: Set a fixed expert capacity (e.g., max 128 tokens per expert). If routing assigns more tokens, drop the excess. This makes shapes static and fully graphable, but may degrade model quality if too many tokens are dropped
Static expert assignment: Assign tokens to experts in a round-robin or predetermined pattern instead of learned routing. This makes shapes predictable and graphable, but sacrifices the adaptive routing that makes MoE effective
Key lesson: Don’t force CUDA graphs where they don’t fit. MoE models with dynamic routing are a prime example where partial graphing (static layers only) is the pragmatic solution. The dynamic parts stay in eager mode where PyTorch handles variable shapes naturally.
What’s Next?#
Quick Checklist: Complete checklist to verify your code is ready for capture
Examples: Real-world implementations in RNN-T, Stable Diffusion, and Llama 3.1 405B
Troubleshooting: Comprehensive guide for debugging capture failures, silent failures, memory errors, and performance issues
Constraints: Full list of CUDA Graph restrictions
PyTorch-Specific Constraints: PyTorch requirements