nemo_automodel.components.distributed.activation_checkpointing

View as Markdown

Selective activation checkpointing core.

TorchTitan-style selective activation checkpointing: the policy decides, per op, whether to save or recompute an activation, saving the expensive ops (attention, half of the matmuls, comm collectives) while recomputing the cheap ones.

This module holds the parts of the AC implementation that do not depend on the rest of parallelizer.py (notably the heavy, transformers-aware _extract_model_layers). parallelizer.py imports from here — never the other way around — so the dependency stays one-directional and the central parallelizer file stays small.

Module Contents

Functions

NameDescription
_build_selective_ac_save_opsBuild the set of ops whose activations are always saved under selective AC.
_default_compute_intensive_opsCompute-intensive aten ops from PyTorch’s partitioner, or () if unavailable.
_disable_dynamo_lru_cacheBest-effort disable of TorchDynamo’s LRU cache for selective AC + compile.
_existing_ops-
_is_cuda_to_cpu_copy-
_maybe_trace_selective_ac_decisionLog a selective-AC decision once per op (no-op unless tracing is enabled).
_replace_child_moduleReplace target with replacement in root’s module tree.
_resolve_op_attrResolve a dotted attribute path from root, or None if any part is absent.
_resolve_torch_opResolve torch.ops.<namespace>.<name>.<overload>, or None if absent.
apply_selective_checkpointing_to_layersWrap whole transformer blocks with the selective-AC policy.
apply_submodule_checkpointingWrap a transformer block’s sub-modules with checkpoint_wrapper.
detect_kv_sharing_and_maybe_disable_cacheDetect KV-sharing and disable use_cache for non-KV-shared models.
is_selective_activation_checkpointingReturn whether the config value selects selective activation checkpointing.
make_selective_checkpoint_context_fnBuild a TorchTitan-style selective activation checkpointing context.

Data

SELECTIVE_AC_WRAPPER_FLAG

_SELECTIVE_AC_MATMUL_OPS

_SELECTIVE_AC_MUST_SAVE_OPS

_SELECTIVE_AC_TO_COPY_OP

_SELECTIVE_AC_TRACE

_SELECTIVE_AC_TRACE_SEEN

logger

API

nemo_automodel.components.distributed.activation_checkpointing._build_selective_ac_save_ops() -> frozenset

Build the set of ops whose activations are always saved under selective AC.

The set is seeded from PyTorch’s compute-intensive op list and supplemented with attention variants, low-precision/reduction ops, the compiled HOP, and communication collectives whose outputs are expensive to recompute.

nemo_automodel.components.distributed.activation_checkpointing._default_compute_intensive_ops() -> tuple

Compute-intensive aten ops from PyTorch’s partitioner, or () if unavailable.

Mirrors TorchTitan: seeding from PyTorch’s own compute_intensive_ops list keeps the save-set in sync with upstream rather than relying on a frozen, hand-maintained list. torch._functorch.partitioners is a private API, so any failure falls back to the curated supplement in :func:_build_selective_ac_save_ops.

nemo_automodel.components.distributed.activation_checkpointing._disable_dynamo_lru_cache() -> None

Best-effort disable of TorchDynamo’s LRU cache for selective AC + compile.

With multiple pipeline microbatches, dynamo may compile a second graph with dynamic shapes and then select it over the static graph whose compiled-HOP output SAC cached for microbatch 0, tripping a missing-symint assertion. Selecting graphs in insertion order avoids this. Mirrors TorchTitan. The underlying API is private, so failures are swallowed.

nemo_automodel.components.distributed.activation_checkpointing._existing_ops(
ops = ()
)
nemo_automodel.components.distributed.activation_checkpointing._is_cuda_to_cpu_copy(
func,
args,
kwargs
) -> bool
nemo_automodel.components.distributed.activation_checkpointing._maybe_trace_selective_ac_decision(
func,
decision,
is_alternating: bool,
is_recompute: bool
) -> None

Log a selective-AC decision once per op (no-op unless tracing is enabled).

Parameters:

func

The op the policy was queried about.

decision

The CheckpointPolicy the policy returned for func.

is_alternating
bool

Whether func is an alternating-save matmul op.

is_recompute
bool

Whether the policy was queried during the recompute pass; decisions are only logged on the forward pass to avoid duplicates.

nemo_automodel.components.distributed.activation_checkpointing._replace_child_module(
root: torch.nn.Module,
target: torch.nn.Module,
replacement: torch.nn.Module
) -> bool

Replace target with replacement in root’s module tree.

nemo_automodel.components.distributed.activation_checkpointing._resolve_op_attr(
root: object,
dotted_path: str
)

Resolve a dotted attribute path from root, or None if any part is absent.

Used for ops that live outside torch.ops (higher-order ops, optional custom backends such as DeepEP/HybridEP). Missing namespaces/ops raise AttributeError on access, so they are swallowed and reported as None.

nemo_automodel.components.distributed.activation_checkpointing._resolve_torch_op(
namespace: str,
name: str,
overload: str = 'default'
)

Resolve torch.ops.<namespace>.<name>.<overload>, or None if absent.

nemo_automodel.components.distributed.activation_checkpointing.apply_selective_checkpointing_to_layers(
model: torch.nn.Module,
layers: typing.List[torch.nn.Module],
has_kv_sharing: bool,
enable_compile: bool = False
) -> None

Wrap whole transformer blocks with the selective-AC policy.

KV-shared models cannot checkpoint attention through the DynamicCache, so they fall back to sub-module checkpointing. layers is mutated in place so callers that retain the list (e.g. for subsequent FSDP sharding) see the wrapped modules. Works without FSDP/distributed, so it is shared by the FSDP2 strategy and the single-GPU path.

nemo_automodel.components.distributed.activation_checkpointing.apply_submodule_checkpointing(
layers: typing.List[torch.nn.Module],
has_kv_sharing: bool
) -> None

Wrap a transformer block’s sub-modules with checkpoint_wrapper.

This is the sub-module granularity path used both as the default (non-compile) behavior and as the fallback for selective activation checkpointing on KV-shared models, which cannot checkpoint the whole block.

self_attn is skipped for KV-shared models: recomputing attention during backward would double-write to the DynamicCache, corrupting the K/V entries that later shared layers depend on.

Parameters:

layers
List[nn.Module]

Transformer decoder layers to wrap (mutated in place).

has_kv_sharing
bool

Whether the model reuses K/V across layers via the cache.

nemo_automodel.components.distributed.activation_checkpointing.detect_kv_sharing_and_maybe_disable_cache(
model: torch.nn.Module
) -> bool

Detect KV-sharing and disable use_cache for non-KV-shared models.

Models with KV-shared layers (e.g. Gemma4 2B/4B) pass K/V from earlier layers to later layers through the DynamicCache; disabling the cache breaks that dependency, so use_cache is left untouched for them.

Returns: bool

Whether the model uses KV-sharing.

nemo_automodel.components.distributed.activation_checkpointing.is_selective_activation_checkpointing(
activation_checkpointing: object
) -> bool

Return whether the config value selects selective activation checkpointing.

Parameters:

activation_checkpointing
object

The configured value (bool or string such as "selective"/"full").

Returns: bool

True only for the string "selective" (case- and

nemo_automodel.components.distributed.activation_checkpointing.make_selective_checkpoint_context_fn()

Build a TorchTitan-style selective activation checkpointing context.

nemo_automodel.components.distributed.activation_checkpointing.SELECTIVE_AC_WRAPPER_FLAG = '_nemo_selective_ac'
nemo_automodel.components.distributed.activation_checkpointing._SELECTIVE_AC_MATMUL_OPS = _existing_ops(_resolve_torch_op('aten', 'mm'), _resolve_torch_op('aten', 'linear...
nemo_automodel.components.distributed.activation_checkpointing._SELECTIVE_AC_MUST_SAVE_OPS = _build_selective_ac_save_ops()
nemo_automodel.components.distributed.activation_checkpointing._SELECTIVE_AC_TO_COPY_OP = _resolve_torch_op('aten', '_to_copy')
nemo_automodel.components.distributed.activation_checkpointing._SELECTIVE_AC_TRACE = os.environ.get('NEMO_SELECTIVE_AC_TRACE', '0').lower() not in ('0', '', 'false',...
nemo_automodel.components.distributed.activation_checkpointing._SELECTIVE_AC_TRACE_SEEN: set[str] = set()
nemo_automodel.components.distributed.activation_checkpointing.logger = logging.getLogger(__name__)