nemo_automodel.components.distributed.activation_checkpointing
nemo_automodel.components.distributed.activation_checkpointing
Selective activation checkpointing core.
TorchTitan-style selective activation checkpointing: the policy decides, per op, whether to save or recompute an activation, saving the expensive ops (attention, half of the matmuls, comm collectives) while recomputing the cheap ones.
This module holds the parts of the AC implementation that do not depend on the
rest of parallelizer.py (notably the heavy, transformers-aware
_extract_model_layers). parallelizer.py imports from here — never the
other way around — so the dependency stays one-directional and the central
parallelizer file stays small.
Module Contents
Functions
Data
API
Build the set of ops whose activations are always saved under selective AC.
The set is seeded from PyTorch’s compute-intensive op list and supplemented with attention variants, low-precision/reduction ops, the compiled HOP, and communication collectives whose outputs are expensive to recompute.
Compute-intensive aten ops from PyTorch’s partitioner, or () if unavailable.
Mirrors TorchTitan: seeding from PyTorch’s own compute_intensive_ops list
keeps the save-set in sync with upstream rather than relying on a frozen,
hand-maintained list. torch._functorch.partitioners is a private API, so
any failure falls back to the curated supplement in
:func:_build_selective_ac_save_ops.
Best-effort disable of TorchDynamo’s LRU cache for selective AC + compile.
With multiple pipeline microbatches, dynamo may compile a second graph with dynamic shapes and then select it over the static graph whose compiled-HOP output SAC cached for microbatch 0, tripping a missing-symint assertion. Selecting graphs in insertion order avoids this. Mirrors TorchTitan. The underlying API is private, so failures are swallowed.
Log a selective-AC decision once per op (no-op unless tracing is enabled).
Parameters:
The op the policy was queried about.
The CheckpointPolicy the policy returned for func.
Whether func is an alternating-save matmul op.
Whether the policy was queried during the recompute pass; decisions are only logged on the forward pass to avoid duplicates.
Replace target with replacement in root’s module tree.
Resolve a dotted attribute path from root, or None if any part is absent.
Used for ops that live outside torch.ops (higher-order ops, optional
custom backends such as DeepEP/HybridEP). Missing namespaces/ops raise
AttributeError on access, so they are swallowed and reported as None.
Resolve torch.ops.<namespace>.<name>.<overload>, or None if absent.
Wrap whole transformer blocks with the selective-AC policy.
KV-shared models cannot checkpoint attention through the DynamicCache,
so they fall back to sub-module checkpointing. layers is mutated in
place so callers that retain the list (e.g. for subsequent FSDP sharding)
see the wrapped modules. Works without FSDP/distributed, so it is shared by
the FSDP2 strategy and the single-GPU path.
Wrap a transformer block’s sub-modules with checkpoint_wrapper.
This is the sub-module granularity path used both as the default (non-compile) behavior and as the fallback for selective activation checkpointing on KV-shared models, which cannot checkpoint the whole block.
self_attn is skipped for KV-shared models: recomputing attention during
backward would double-write to the DynamicCache, corrupting the K/V
entries that later shared layers depend on.
Parameters:
Transformer decoder layers to wrap (mutated in place).
Whether the model reuses K/V across layers via the cache.
Detect KV-sharing and disable use_cache for non-KV-shared models.
Models with KV-shared layers (e.g. Gemma4 2B/4B) pass K/V from earlier
layers to later layers through the DynamicCache; disabling the cache
breaks that dependency, so use_cache is left untouched for them.
Returns: bool
Whether the model uses KV-sharing.
Return whether the config value selects selective activation checkpointing.
Parameters:
The configured value (bool or string such as
"selective"/"full").
Returns: bool
True only for the string "selective" (case- and
Build a TorchTitan-style selective activation checkpointing context.