bridge.peft.lora_layers
#
Module Contents#
Classes#
An adapter wrapper that adds the output of the adapter to the output of the wrapped module. |
|
TELinear + LoRA, maintains ckpts structure (i.e. Linear’s weight/bias remain at the same FQN) |
|
Linear + LoRA, maintains ckpts structure (i.e. Linear’s weight/bias remain at the same FQN) |
Functions#
Monkey-patch a nn.Linear or te.Linear to be a LinearAdapter. |
API#
- class bridge.peft.lora_layers.LoRALinear#
Bases:
megatron.bridge.peft.adapter_wrapper.AdapterWrapper
An adapter wrapper that adds the output of the adapter to the output of the wrapped module.
This class is designed to be used with LoRA (Low-Rank Adaptation) and similar techniques where the adapter’s output is added to the main module’s output. It extends the AdapterWrapper class to provide a specific implementation of the forward method.
- forward(
- x: torch.Tensor,
- *args: Any,
- **kwargs: Any,
Forward pass that combines the wrapped module output with the adapter output.
- Parameters:
x – Input tensor.
*args – Additional positional arguments for the wrapped module.
**kwargs – Additional keyword arguments for the wrapped module.
- Returns:
Combined output (linear_output + adapter_output)
Bias term (if present, otherwise None)
- Return type:
A tuple containing
- class bridge.peft.lora_layers.TELinearAdapter(
- orig_linear: transformer_engine.pytorch.Linear,
- dim: int = 8,
- alpha: int = 32,
- dropout: float = 0.0,
- dropout_position: Literal[pre, post] = 'post',
- lora_A_init_method: Literal[xavier, uniform] = 'xavier',
- lora_dtype: Optional[torch.dtype] = None,
Bases:
transformer_engine.pytorch.Linear
TELinear + LoRA, maintains ckpts structure (i.e. Linear’s weight/bias remain at the same FQN)
The _init_adapter and forward methods provide the LoRA functionality. We want to be able to use those inside LinearAdapter but also for monkey-patching modules, without repeating the same code -> therefore those are decorated with @staticmethod.
- Parameters:
orig_linear – The linear module to augment.
dim – LoRA’s dimension (in_features -> dim -> out_features).
alpha – LoRA’s scaling alpha.
dropout – Dropout probability (default: 0.0).
dropout_position – Where to apply dropout relative to LoRA (choices: [‘pre’, ‘post’], default=’post’).
lora_A_init_method – Initialization method for lora_A (choices: [‘xavier’, ‘uniform’]).
lora_dtype – Weight’s dtype, by default will use orig_linear’s but if they are quantized weights (e.g. 4bit) needs to be specified explicitly.
Initialization
Initialize TELinearAdapter by copying from original TELinear and adding LoRA components.
- Parameters:
orig_linear – The original TELinear module to adapt.
dim – LoRA rank dimension.
alpha – LoRA scaling factor.
dropout – Dropout probability.
dropout_position – When to apply dropout (‘pre’ or ‘post’ LoRA computation).
lora_A_init_method – Initialization method for LoRA matrix A.
lora_dtype – Data type for LoRA weights.
- static _init_adapter(
- obj: Union[bridge.peft.lora_layers.TELinearAdapter, torch.nn.Module],
- dim: int = 8,
- alpha: int = 32,
- dropout: float = 0.0,
- dropout_position: Literal[pre, post] = 'post',
- lora_A_init_method: Literal[xavier, uniform] = 'xavier',
- lora_dtype: Optional[torch.dtype] = None,
Add LoRA weights to obj. The obj is either a LinearAdapter or an nn.Module (when monkey-patching).
- Parameters:
obj – Input module to adapt (LinearAdapter or nn.Module).
dim – LoRA’s dimension (in_features -> dim -> out_features).
alpha – LoRA’s scaling alpha.
dropout – Dropout probability (default: 0.0).
dropout_position – Where to apply dropout relative to LoRA (choices: [‘pre’, ‘post’], default=’post’).
lora_A_init_method – Initialization method for lora_A (choices: [‘xavier’, ‘uniform’]).
lora_dtype – Weight’s dtype, by default will use orig_linear’s but if they are quantized weights (e.g. 4bit) needs to be specified explicitly.
- forward(x: torch.Tensor) torch.Tensor #
Forward pass combining TELinear output with LoRA adaptation.
- Parameters:
x – Input tensor.
- Returns:
Combined output from original linear layer and LoRA adaptation.
- class bridge.peft.lora_layers.LinearAdapter(
- orig_linear: torch.nn.Linear,
- dim: int = 8,
- alpha: int = 32,
- dropout: float = 0.0,
- dropout_position: Literal[pre, post] = 'post',
- lora_A_init_method: Literal[xavier, uniform] = 'xavier',
- lora_dtype: Optional[torch.dtype] = None,
Bases:
torch.nn.Linear
Linear + LoRA, maintains ckpts structure (i.e. Linear’s weight/bias remain at the same FQN)
The _init_adapter and forward methods provide the LoRA functionality. We want to be able to use those inside LinearAdapter but also for monkey-patching modules, without repeating the same code -> therefore those are decorated with @staticmethod.
- Parameters:
orig_linear – The linear module to augment.
dim – LoRA’s dimension (in_features -> dim -> out_features).
alpha – LoRA’s scaling alpha.
dropout – Dropout probability (default: 0.0).
dropout_position – Where to apply dropout relative to LoRA (choices: [‘pre’, ‘post’], default=’post’).
lora_A_init_method – Initialization method for lora_A (choices: [‘xavier’, ‘uniform’]).
lora_dtype – Weight’s dtype, by default will use orig_linear’s but if they are quantized weights (e.g. 4bit) needs to be specified explicitly.
Initialization
Initialize LinearAdapter by copying from original Linear and adding LoRA components.
- Parameters:
orig_linear – The original Linear module to adapt.
dim – LoRA rank dimension.
alpha – LoRA scaling factor.
dropout – Dropout probability.
dropout_position – When to apply dropout (‘pre’ or ‘post’ LoRA computation).
lora_A_init_method – Initialization method for LoRA matrix A.
lora_dtype – Data type for LoRA weights.
- static _init_adapter(
- obj: Union[bridge.peft.lora_layers.LinearAdapter, torch.nn.Module],
- dim: int = 8,
- alpha: int = 32,
- dropout: float = 0.0,
- dropout_position: Literal[pre, post] = 'post',
- lora_A_init_method: Literal[xavier, uniform] = 'xavier',
- lora_dtype: Optional[torch.dtype] = None,
Add LoRA weights to obj. The obj is either a LinearAdapter or an nn.Module (when monkey-patching).
- Parameters:
obj – Input module to adapt (LinearAdapter or nn.Module).
dim – LoRA’s dimension (in_features -> dim -> out_features).
alpha – LoRA’s scaling alpha.
dropout – Dropout probability (default: 0.0).
dropout_position – Where to apply dropout relative to LoRA (choices: [‘pre’, ‘post’], default=’post’).
lora_A_init_method – Initialization method for lora_A (choices: [‘xavier’, ‘uniform’]).
lora_dtype – Weight’s dtype, by default will use orig_linear’s but if they are quantized weights (e.g. 4bit) needs to be specified explicitly.
- forward(x: torch.Tensor) torch.Tensor #
Forward pass combining Linear output with LoRA adaptation.
- Parameters:
x – Input tensor.
- Returns:
Combined output from original linear layer and LoRA adaptation.
- bridge.peft.lora_layers.patch_linear_module(
- orig_linear: Union[torch.nn.Linear, transformer_engine.pytorch.Linear],
- dim: int = 8,
- alpha: int = 32,
- dropout: float = 0.0,
- dropout_position: Literal[pre, post] = 'post',
- lora_A_init_method: Literal[xavier, uniform] = 'xavier',
- lora_dtype: Optional[torch.dtype] = None,
Monkey-patch a nn.Linear or te.Linear to be a LinearAdapter.
This function replaces a nn.Linear with a LinearAdapter without copying weights, making it suitable for cases where the original module was initialized with meta device.
The orig_linear might not contain valid weights, for example, the given orig_linear was initialized within a context-manager that uses a “meta” device. Therefore, we cannot copy the weight/bias from the orig_linear to the LinearAdapter, since those have not been allocated.
To circumvent this scenario, LinearAdapter’s additional functionality (_init_adapter, _forward) is based on static functions, so that we can use them for patching or when allocating a new LinearAdapter object.
- Parameters:
orig_linear – The module we add adapter to.
dim – LoRA dimension. Defaults to 8.
alpha – LoRA alpha scale. Defaults to 32.
dropout – Dropout probability. Defaults to 0.0.
dropout_position – Location to apply dropout wrt LoRA. Defaults to ‘post’ (choices: ‘pre’, ‘post’).
lora_A_init_method – LoRA_A initialization method. Defaults to ‘xavier’.
lora_dtype – LoRA weights’ dtype. By default will use orig_linear’s dtype but orig_linear might use non-trainable dtype (e.g., 4bit), in which case the user must specify the dtype manually. Defaults to None.
- Returns:
The monkey-patched (nn.Linear + LoRA) nn.Module.
- Raises:
NotImplementedError – If orig_linear is not nn.Linear or te.Linear.
AssertionError – If orig_linear already has super_fwd attribute.