bridge.peft.recompute#
Helpers for PEFT-specific activation recompute fixes.
Module Contents#
Functions#
Yield unwrapped Megatron modules regardless of list/list-like inputs. |
|
Enable grad on TransformerBlock inputs when only adapters are trainable. |
Data#
API#
- bridge.peft.recompute.PEFT_RECOMPUTE_PATCHED: Set[int]#
‘set(…)’
- bridge.peft.recompute._iter_unwrapped_models(model) Iterable[torch.nn.Module]#
Yield unwrapped Megatron modules regardless of list/list-like inputs.
- bridge.peft.recompute.maybe_enable_recompute_inputs_grad(
- model,
- peft_recompute_patched: Set[int] | None = None,
Enable grad on TransformerBlock inputs when only adapters are trainable.
Root cause analysis:
Megatron’s CheckpointFunction.backward() is only invoked by PyTorch autograd when at least one input tensor requires grad.
With PP>1, received tensors from other stages have requires_grad=True, so checkpoint backward is always called.
With PP=1 and frozen base model, embedding outputs have requires_grad=False. This means CheckpointFunction.backward() is never called, and LoRA gradients inside the checkpoint are never computed.
Solution: Hook TransformerBlock.forward to ensure hidden_states.requires_grad=True before it enters checkpointed computation. This doesn’t unfreeze any parameters; it just ensures the autograd machinery calls checkpoint’s backward.
Borrowed (with modifications) from https://github.com/HollowMan6/verl/blob/4285f0601028aee7ddcb9ec5a15198ebfc69bba3/verl/utils/megatron_peft_utils.py
- bridge.peft.recompute.__all__#
[‘maybe_enable_recompute_inputs_grad’, ‘PEFT_RECOMPUTE_PATCHED’]