Quick Checklist#

Note

A complete checklist to verify your PyTorch code is ready for CUDA Graph capture. Check each item before attempting to graph your workload.

Asynchronous Execution#

  • No host-device synchronization (sync-free code, capture failures)

    • No explicit sync: torch.cuda.synchronize(), stream.synchronize(), event.synchronize() (details)

    • No blocking GPU→CPU transfers: .item(), .cpu(), .numpy(), print(tensor) (details)

    • No direct CUDA tensor creation from Python objects (details)

    • No data-dependent control flow: if tensor:, loss.item() (details)

    • No GPU tensor indexing with CPU tensors or Python lists (details)

    • No slicing with CUDA tensor bounds: x[i:j] where i, j are CUDA tensors (details)

  • No default stream usage

    • Tensors with gradient tape execute on side stream before capture (details)

    • DDP/FSDP initialized on side stream (details)

    • Ensure extensions/libraries use PyTorch’s current stream, not default stream (details)

  • No event/stream query

    • No stream.query(), event.query() during capture (details)

    • No background thread queries (details)

    • No pinned memory allocation during capture (triggers hidden event query) (details)

    • DataLoader with pin_memory=True: use thread_local mode or disable pin_memory (details)

    • NCCL watchdog handled (auto in PyTorch 2.2+) (details)

Static Graph#

  • Static graph topology (details)

    • No dynamic control flow (if/else based on tensor values); use torch.where() or capture multiple graphs (details)

    • Gradient clipping: use sync-free clip_grad_norm_ (PyTorch 1.13+) (details)

    • Early exit and adaptive inference: capture separate graphs per path (details)

    • Capture-aware code (is_current_stream_capturing()) doesn’t change computation (details)

  • Static memory addresses (details)

    • Static input tensors allocated before capture, updated via .copy_() (details)

    • Global tensors used within graph are persistent (details)

    • Grouped GEMM / pointer arrays: keep host pointer tensors alive (details)

    • AMP autocast cache disabled (cache_enabled=False) or capture autocast inside graph (details)

  • Static scalars (details)

    • CPU variable scalars converted to GPU tensors, update via .fill_() (details)

    • Learning rate / global step: use capturable optimizer (e.g., APEX FusedAdam) (details)

    • Handling RNG state correctly

      • Custom generators registered with graph.register_generator_state() (details)

      • Use graph-safe APIs: graphsafe_get_state(), graphsafe_set_state() (details)

      • Activation checkpointing uses preserve_rng_state=False (details)

      • Partial graphing uses use_reentrant=False (details)

      • torch.compile functions warmed up before capture (details)

  • Static shapes (details)

    • Tensor shapes fixed across replays; use padding or bucketing (details)

    • MoE with dynamic routing: graph only static parts (details)

Self-Contained Stream Capture#

  • Side streams fork from capture stream via side_stream.wait_stream(capture_stream) (details)

  • Side streams join back to capture stream via capture_stream.wait_stream(side_stream) (details)

  • No dependency on external work, or use external events (details)

CPU Code Is Not Captured#

  • No host state mutation inside graph that affects code outside (details)

  • CPU code requiring execution on every replay moved outside graph (details)

  • Use cudaLaunchHostFunc() for necessary CPU code

Memory Requirements#

  • No pinned memory alloc/free in global mode (details)

  • Persistent graph input tensors

    • CUDA input tensors not freed before graph replay (details)

    • CPU input tensors (for H2D copy) kept alive for graph lifetime (details)

  • No cross-iteration reuse of output tensor without cloning (details)

  • Memory pool sharing (if using shared pools):

    • Intermediate tensors exposed as output handled carefully (details)

    • Replay order matches capture order (details)

    • No parallel replay of graphs sharing pools (details)

  • Memory usage awareness (for OOM prevention):

    • Reuse static input tensors across graphs when possible (details)

    • Chain graph outputs as inputs to next graph (details)

    • Be aware: intermediate tensors can’t be reused across different pools (details)

    • Be aware: operations after capture can’t reuse graph pool memory (details)

    • Be aware: memory fragmentation across pools (details)

    • Be aware: deferred memory recycling with multi-streams during capture (details)

    • Be aware: gradient accumulator cross-stream growth (details)

    • Be aware: cudaFree is suppressed during capture (details)

Other Considerations#

  • Warmup iterations before capture on the same side stream (details)

  • Capture mode: Use global mode unless specific multi-threading needs require thread_local or relaxed (details)

  • Module hooks: Only top-level module hooks fire with make_graphed_callables (details)

  • Deferred gradient hooks: make_graphed_callables defers gradient accumulation and DDP hooks (details)

  • NCCL communicator lifecycle: Destroy CUDA graphs before NCCL communicators (details)

  • Pinned memory race condition: Synchronize before CPU writes to pinned memory (details)

  • Stream count: Avoid too many streams to prevent channel serialization (details)

  • NCCL buffer registration: Set NCCL_GRAPH_REGISTER=0 if using expandable segments with older NCCL (details)

What’s Next?#