bridge.peft.recompute#

Helpers for PEFT-specific activation recompute fixes.

Module Contents#

Functions#

_iter_unwrapped_models

Yield unwrapped Megatron modules regardless of list/list-like inputs.

maybe_enable_recompute_inputs_grad

Enable grad on TransformerBlock inputs when only adapters are trainable.

Data#

API#

bridge.peft.recompute.PEFT_RECOMPUTE_PATCHED: Set[int]#

‘set(…)’

bridge.peft.recompute._iter_unwrapped_models(model) Iterable[torch.nn.Module]#

Yield unwrapped Megatron modules regardless of list/list-like inputs.

bridge.peft.recompute.maybe_enable_recompute_inputs_grad(
model,
peft_recompute_patched: Set[int] | None = None,
) Set[int]#

Enable grad on TransformerBlock inputs when only adapters are trainable.

Root cause analysis:

  • Megatron’s CheckpointFunction.backward() is only invoked by PyTorch autograd when at least one input tensor requires grad.

  • With PP>1, received tensors from other stages have requires_grad=True, so checkpoint backward is always called.

  • With PP=1 and frozen base model, embedding outputs have requires_grad=False. This means CheckpointFunction.backward() is never called, and LoRA gradients inside the checkpoint are never computed.

Solution: Hook TransformerBlock.forward to ensure hidden_states.requires_grad=True before it enters checkpointed computation. This doesn’t unfreeze any parameters; it just ensures the autograd machinery calls checkpoint’s backward.

Borrowed (with modifications) from https://github.com/HollowMan6/verl/blob/4285f0601028aee7ddcb9ec5a15198ebfc69bba3/verl/utils/megatron_peft_utils.py

bridge.peft.recompute.__all__#

[‘maybe_enable_recompute_inputs_grad’, ‘PEFT_RECOMPUTE_PATCHED’]