Activation Recompute#
Stable docs: docs/training/activation-recomputation.md
Card: card.yaml (co-located)
What It Is#
Activation recompute trades GPU compute for memory by discarding intermediate activations during the forward pass and recomputing them during backward. Megatron Bridge supports two granularities:
Granularity |
What you specify |
What gets recomputed |
Memory savings |
Compute cost |
|---|---|---|---|---|
|
|
specific submodules within each layer |
moderate (module-dependent) |
low to high |
|
|
entire transformer layers (N layers) |
strongest |
highest |
Note: MCore names these “selective” (submodule-level) vs “full” (layer-level).
“Full” means recomputing full layers, not the full model — you still choose
how many layers via recompute_num_layers.
Quick Decision#
Set
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:Truefirst — most borderline OOMs are caused by memory fragmentation, not capacity. This fixes it at zero cost. Seeskills/perf-techniques/memory-tuning/SKILL.md.Start with
recompute_granularity=selective,recompute_modules=[core_attn](often already the default in recipes).Add
layernormto recompute modules — nearly free compute-wise but saves negligible memory. Only helps in extremely borderline cases.Add
mlpas a last resort — saves ~3 GB but costs ~16% GPU utilization on large dense models (Llama3 70B).Use
recompute_granularity=fullonly when selective recompute still does not fit.
CPU offloading (cpu_offloading=True) is an alternative that avoids recompute
cost entirely, but it is incompatible with PP > 1.
Enablement#
Selective recompute (default for most recipes)#
cfg.model.recompute_granularity = "selective"
cfg.model.recompute_modules = ["core_attn"]
Selective recompute with additional modules#
cfg.model.recompute_granularity = "selective"
cfg.model.recompute_modules = ["core_attn", "layernorm"] # or ["mlp"] or ["mlp", "core_attn"]
Full-layer recompute#
cfg.model.recompute_granularity = "full"
cfg.model.recompute_method = "uniform"
cfg.model.recompute_num_layers = 4
Available recompute_modules#
Module |
What it recomputes |
Compute cost |
Memory savings |
|---|---|---|---|
|
attention softmax/dropout/QKV dot product |
low (Flash Attention already recomputes internally) |
moderate |
|
layer normalization |
negligible (~0%) |
negligible |
|
full FFN block |
high (~16% on Llama3 70B, hidden=28672) |
~3 GB |
|
MoE expert dispatch |
varies |
varies |
|
MoE activation functions |
low |
small |
|
shared expert layers |
moderate |
moderate |
|
Multi-Latent Attention up projection |
moderate |
moderate |
Performance harness CLI#
python scripts/performance/run_performance_workload.py \
--recompute_granularity selective \
--recompute_modules core_attn layernorm \
...
Compatibility and Constraints#
recompute_granularity=selectiverequires a non-emptyrecompute_moduleslistrecompute_granularity=fullrequiresrecompute_methodandrecompute_num_layersLayer-level recompute (
recompute_granularity="full"+recompute_num_layers) is incompatible with TE-scoped CUDA graphs. MCore calls this “full” granularity — the name refers to recomputing full transformer layers, not the full model. Even though you’re selecting how many layers to recompute, MCore treats it differently from submodule recompute. Any TE-scoped scope (attn,mlp,moe_router, etc.) will assert. This commonly hits FP8 configs that enable TE-scoped graphs by default (e.g.LLAMA3_70B_SFT_CONFIG_H100_FP8_CS_V1setscuda_graph_impl="transformer_engine",cuda_graph_scope="mlp"). Options:use submodule recompute (
recompute_granularity="selective"+recompute_modules) — compatible with TE-scoped graphsdisable CUDA graphs (
cuda_graph_impl="none") and use layer-level recomputeswitch to
cuda_graph_impl="local",cuda_graph_scope="full_iteration"
distribute_saved_activations=Truecannot be combined withsequence_parallel=TrueCombining
mlp+core_attnrecompute is slightly worse thanmlpalone due to double recompute overhead
Measured Results#
Llama3 70B SFT on 32x H100 80GB, FP8 (Current Scaling):
Baseline: TP=4, PP=4, VPP=5, DP=2, MBS=1, GBS=32, seq_len=4096
Golden GPU utilization: 709.93 TFLOP/s/GPU
Regression threshold: 5%
Experiment |
recompute_modules |
TFLOP/s/GPU |
vs Golden |
Peak Mem (GB) |
Result |
|---|---|---|---|---|---|
Baseline |
[core_attn] |
~704 |
-0.8% |
58.8 (OOM rank0) |
OOM |
Exp 1 |
[mlp] |
593.6 |
-16.4% |
55.6 |
Perf regression |
Exp 2 |
[mlp, core_attn] |
586.8 |
-17.3% |
55.6 |
Perf regression |
Exp 3 |
[core_attn, layernorm] |
~702 |
-1.1% |
59.6 (OOM rank0) |
OOM |
Key takeaways:
layernormrecompute is nearly free compute-wise but saves negligible memorymlprecompute saves ~3 GB peak but costs ~16% because the Llama3 70B FFN (hidden=28672) is expensive to recomputeCombining
mlp+core_attnis slightly worse thanmlpaloneFor this workload, the actual OOM fix was
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True(memory fragmentation, not capacity). Seeskills/perf-techniques/memory-tuning/SKILL.md.
Code Anchors#
Recompute modules enum and selective checkpoint logic#
# 3rdparty/Megatron-LM/megatron/core/transformer/transformer_block.py
# _checkpointed_forward() applies selective recompute based on recompute_modules
Recompute config validation#
# 3rdparty/Megatron-LM/megatron/core/transformer/transformer_config.py
# Validates recompute_granularity, recompute_method, recompute_num_layers
Llama3 recipe defaults#
# Memory saving (recompute & offloading)
cfg.model.recompute_granularity = None
cfg.model.recompute_modules = None
cfg.model.fine_grained_activation_offloading = False
cfg.model.offload_modules = None
Full recompute + CUDA graph assertion (MCore)#
if self.recompute_granularity:
if self.recompute_granularity != "selective":
assert self.cuda_graph_scope == [
CudaGraphScope.full_iteration
], "full recompute is only supported with full iteration CUDA graph."
CPU offloading PP incompatibility (MCore)#
if self.cpu_offloading and self.pipeline_model_parallel_size > 1:
raise ValueError(
"Currently there is no support for Pipeline parallelism with CPU offloading"
)
Failure Diagnosis#
Symptom |
Cause |
Confirm |
Fix |
|---|---|---|---|
>15% GPU utilization drop |
mlp recompute on large FFN |
check |
check |
Still OOM after adding layernorm |
layernorm activations are too small |
compare peak memory before/after |
add mlp recompute or check |
|
layer-level recompute ( |
check |
use submodule recompute ( |
ValueError: PP + CPU offloading |
|
check PP config |
disable CPU offloading or set PP=1 |
mlp+core_attn worse than mlp alone |
double recompute overhead |
compare Exp 1 vs Exp 2 |
use mlp alone |
Known Limitations#
Per-module memory savings vary significantly by model architecture and hidden dimension
No automatic module selection — users must choose which modules to recompute
layernormrecompute is almost never worth it as a standalone fixCPU offloading (the zero-compute-cost alternative) is blocked when PP > 1
Verification#
uv run python -m pytest \
tests/unit_tests/training/test_config.py -k "recompute" -q
Success criteria:
Unit tests pass for recompute config validation
No assertion errors from config validation