Source code for nemo_automodel.components._peft.lora

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from dataclasses import dataclass, field
from typing import Any, Literal, Optional

import torch
import torch.nn.functional as F
from torch import nn

from nemo_automodel.components._peft.lora_kernel import (
    lora_da_dx_update_wrapper,
    lora_db_update_wrapper,
    lora_forward_wrapper,
)
from nemo_automodel.components._peft.module_matcher import ModuleMatcher
from nemo_automodel.shared.import_utils import safe_import
from nemo_automodel.shared.utils import dtype_from_str

HAS_BNB, bitsandbytes = safe_import("bitsandbytes")


[docs] @dataclass class PeftConfig: target_modules: list = field(default_factory=list) exclude_modules: list = field(default_factory=list) match_all_linear: bool = False dim: int = 8 alpha: int = 32 dropout: float = 0.0 dropout_position: Literal["pre", "post"] = "post" lora_A_init: str = "xavier" lora_dtype: Optional[torch.dtype] = None use_triton: bool = False
[docs] def to_dict(self): return self.__dict__.copy()
[docs] @classmethod def from_dict(cls, d: dict[str, Any]): return cls( target_modules=d.get("target_modules", []), exclude_modules=d.get("exclude_modules", []), match_all_linear=d.get("match_all_linear", False), dim=d.get("dim", 8), alpha=d.get("alpha", 32), dropout=d.get("dropout", 0.0), dropout_position=d.get("dropout_position", "post"), lora_A_init=d.get("lora_A_init", "xavier"), lora_dtype=d.get("lora_dtype", None), use_triton=d.get("use_triton", False), )
[docs] class LinearLoRA(nn.Linear): """ Linear + LoRA, maintains ckpts structure (i.e. Linear's weight/bias remain at the same FQN). The _init_wrapper and _forward methods provide the LoRA functionality. We want to be able to use those inside LinearLoRA but also for monkey-patching modules, without repeating the same code -> therefore those are decorated with @staticmethod. """ def __init__( self, orig_linear, dim=8, alpha=32, dropout=0.0, dropout_position="post", lora_A_init_method="xavier", lora_dtype=None, ): """ LinearLora constructor. Args: orig_linear (nn.Module): the linear module to augment. dim (int): lora's dim in_features -> dim -> out_features. alpha (int): lora's scaling alpha. dropout (float): dropout prob (default: 0.0). dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post) lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform']) lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they are quantized weights (e.g. 4bit) needs to be specified explicitly. """ assert isinstance(orig_linear, nn.Linear) super(LinearLoRA, self).__init__( in_features=orig_linear.in_features, out_features=orig_linear.out_features, bias=orig_linear.bias is not None, device=orig_linear.weight.device, dtype=orig_linear.weight.dtype, ) # copy weights self.weight.data.copy_(orig_linear.weight.data) if orig_linear.bias is not None: self.bias.data.copy_(orig_linear.bias.data) # initialize the adapte LinearLoRA._init_adapter( self, dim=dim, alpha=alpha, dropout=dropout, dropout_position=dropout_position, lora_A_init_method=lora_A_init_method, lora_dtype=lora_dtype, )
[docs] @torch.no_grad def init_lora_weights(self, init_method: str): """ Initialize the LoRA weights. Args: init_method (str): Method to initialize the LoRA weights. """ if init_method == "xavier": torch.nn.init.uniform_(self.lora_A.weight.data) else: nn.init.kaiming_uniform_(self.lora_A.weight.data, a=math.sqrt(5)) self.lora_B.weight.data.fill_(0)
[docs] @torch.no_grad @staticmethod def _init_adapter( obj, dim=8, alpha=32, dropout=0.0, dropout_position="post", lora_A_init_method="xavier", lora_dtype=None, ): """ Adds LoRA weights to obj. Obj is either a LinearLoRA or an nn.Module (when monkey-patching). Args: obj (LinearLoRA | nn.Module): input module to adapt. dim (int): lora's dim in_features -> dim -> out_features. alpha (int): lora's scaling alpha. dropout (float): dropout prob (default: 0.0). dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post) lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform']) lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they are quantized weights (e.g. 4bit) needs to be specified explicitly. """ obj.dim = dim obj.scale = alpha / dim # Freezer device = obj.weight.device obj.weight.requires_grad = False if obj.bias is not None: obj.bias.requires_grad = False in_features = obj.in_features out_features = obj.out_features if isinstance(lora_dtype, str): lora_dtype = dtype_from_str(lora_dtype) assert lora_dtype is None or isinstance(lora_dtype, torch.dtype) dtype = lora_dtype or obj.weight.dtype obj.lora_A = nn.Linear(in_features, dim, bias=False, dtype=dtype, device=device) obj.lora_B = nn.Linear(dim, out_features, bias=False, dtype=dtype, device=device) LinearLoRA.init_lora_weights(obj, lora_A_init_method) obj.dropout = nn.Dropout(p=dropout) assert dropout_position in ["pre", "post"], ("dropout position can only be pre/post", dropout_position) obj.dropout_position = dropout_position
[docs] def forward(self, x): """ Forward pass through the original linear layer augmented with the LoRA pathway. Applies LoRA either before or after the dropout, depending on the configuration. The result of the original linear transformation is combined with the LoRA output. Args: x (Tensor): Input tensor of shape (batch_size, in_features). Returns: Tensor: Output tensor of shape (batch_size, out_features). """ # pylint: disable=C0115,C0116 # If LinearLoRA is used to monkey-patch a nn.Linear module, we want to use nn.Linear's # forward in the case where it uses quantized weights. We store a reference to nn.Linear's # forward in `super_fwd` attribute. If the attribute does not exist we do the usual linear. if (fwd := getattr(self, "super_fwd", None)) is not None: assert fwd != self.forward res = fwd(x) else: res = F.linear(x, self.weight, self.bias) if self.dropout_position == "pre": x = self.dropout(x) lora_res = self.lora_B(self.lora_A(x)) lora_res = lora_res * self.scale if self.dropout_position == "post": lora_res = self.dropout(lora_res) return res + lora_res
[docs] class TritonLinearLoRA(LinearLoRA): """ Subclass of LinearLoRA that uses triton kernels for forward and backward passes. Args: orig_linear (nn.Module): the linear module to augment. dim (int): lora's dim in_features -> dim -> out_features. alpha (int): lora's scaling alpha. dropout (float): dropout prob (default: 0.0). dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post) lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform']) lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they are quantized weights (e.g. 4bit) needs to be specified explicitly. """
[docs] def forward(self, x): """ Forward function for LoRA with triton kernels. Args: x (torch.Tensor): the input tensor. Returns: torch.Tensor: the output tensor. """ # If LinearLoRA is used to monkey-patch a nn.Linear module, we want to use nn.Linear's # forward in the case where it uses quantized weights. We store a reference to nn.Linear's # forward in `super_fwd` attribute. If the attribute does not exist we do the usual linear. if (fwd := getattr(self, "super_fwd", None)) is not None: assert fwd != self.forward res = fwd(x) else: res = F.linear(x, self.weight, self.bias) if self.dropout_position == "pre": x = self.dropout(x) lora_res = LoRATritonFunction.apply(x, self.lora_A.weight, self.lora_B.weight, self.scale, x.dtype) if self.dropout_position == "post": lora_res = self.dropout(lora_res) return res + lora_res
[docs] def patch_linear_module( orig_linear, dim=8, alpha=32, dropout=0.0, dropout_position="post", lora_A_init_method="xavier", lora_dtype=None, use_triton=True, ): """ Monkey-patches a nn.Linear (orig_linear param) to be a LinearLoRA. The orig_linear might not contain valid weights, for example, the given orig_linear was initialized within a context-manager that uses a "meta" device. Therefore, we cannot copy the weight/bias from the orig_linear to the LinearLoRA, since those have not been allocated, To circumvent this scenario, LinearLoRA's additional functionality (_init_adapter, _forward) is based on static functions, so that we can use them for patching or when allocating a new LinearLoRA object. Args: orig_linear (nn.Linear): the module we add adapter to. dim (int, optional): Lora dim. Defaults to 8. alpha (int, optional): Lora alpha scale. Defaults to 32. dropout (float, optional): dropout prob. Defaults to 0.0. dropout_position (str, optional): location to apply dropout wrt lora. Defaults to 'post' (choices: 'pre', 'post'). lora_A_init_method (str, optional): lora_a init method. Defaults to 'xavier'. lora_dtype (_type_, optional): Lora weights' dtype. By default will use orig_linear's dtype but orig_linear might use non-trainable dtype (e.g., 4bit), in which case the user must specify the dtype manually. Defaults to None. use_triton (bool, optional): By default we use the triton kernel LoRA implementation. Returns: (nn.Module): the monkey-patched (nn.Linear + LoRA) nn.Module """ assert isinstance(orig_linear, nn.Linear), type(orig_linear) assert not hasattr(orig_linear, "super_fwd"), orig_linear.super_fwd if isinstance(orig_linear, nn.Linear): linear_lora_cls = TritonLinearLoRA if use_triton else LinearLoRA linear_lora_cls._init_adapter( orig_linear, dim, alpha, dropout, dropout_position, lora_A_init_method, lora_dtype ) cls = orig_linear.__class__ new_cls = type("PatchedLinearLoRA", (linear_lora_cls, cls), {}) else: raise NotImplementedError("Expected isinstance(orig_linear, nn.Linear)") # If the model uses quantized weights, we want to use orig_linear's forward if ( getattr(orig_linear, "quant_state", None) is not None and orig_linear.quant_state.__class__ == bitsandbytes.functional.QuantState ): orig_linear.super_fwd = orig_linear.forward orig_linear.__class__ = new_cls return orig_linear
# -----------------------------------------------------------------------------# # 2. Convenience: patch a model in-place # # -----------------------------------------------------------------------------#
[docs] def apply_lora_to_linear_modules( model: nn.Module, peft_config: PeftConfig, ) -> int: """ Replace selected nn.Linear layers with LinearLoRA layers (in-place). target_modules accepts wildcard fragments, e.g. ["q_proj", "k_proj", ".*fc.*"]. """ # Freeze base model parameters for w in model.parameters(): w.requires_grad_(False) is_causal_lm = False try: if hasattr(model, "config") and "CausalLM" in model.config.architectures[0]: # for example, LlamaForCausalLM is_causal_lm = True except AttributeError: is_causal_lm = False matcher = ModuleMatcher( peft_config.target_modules, peft_config.exclude_modules, peft_config.match_all_linear, is_causal_lm ) num_modules_matched = 0 for name, module in list(model.named_modules()): if matcher.match(module, name): num_modules_matched += 1 patch_linear_module( module, dim=peft_config.dim, alpha=peft_config.alpha, dropout=peft_config.dropout, dropout_position=peft_config.dropout_position, lora_A_init_method=peft_config.lora_A_init, lora_dtype=peft_config.lora_dtype, use_triton=peft_config.use_triton, ) return num_modules_matched
[docs] class LoRATritonFunction(torch.autograd.Function): """ Autograd function that calls the triton kernel wrappers for the LoRA forward and backward passes. """
[docs] @staticmethod def setup_context(ctx, inputs, output): """ Stores context for LoRA backward pass. """ x, lora_A, lora_B, scale, _ = inputs ctx.save_for_backward(x, lora_A, lora_B) ctx.scale = scale
[docs] @staticmethod def forward(x, lora_A, lora_B, scale, dtype): """ Forward method for LoRATriton. Reshapes 3D tensors into 2D and then calls the triton kernel. """ reshape = x.dim() == 3 if reshape: bs, seq_len, d = x.shape x = x.reshape(-1, d) lora_res = lora_forward_wrapper(x, lora_A.t(), lora_B.t(), res=None, scale=scale, dtype=dtype) if reshape: return lora_res.view(bs, seq_len, -1) else: return lora_res
[docs] @staticmethod def backward(ctx, d_y): """ Backward method for LoRATriton. Reshapes 3D tensors into 2D and then calls the kernels to update d_lora_a, d_lora_b, and dx. """ x, lora_A, lora_B = ctx.saved_tensors scale = ctx.scale dtype = x.dtype reshape = x.dim() == 3 if reshape: bs, seq_len, d = x.shape d_y = d_y.reshape(-1, d_y.shape[-1]) x = x.reshape(-1, d) d_lora_A, d_x = lora_da_dx_update_wrapper(x.t(), d_y, lora_B, lora_A, scale, dtype=dtype) d_lora_B = lora_db_update_wrapper(lora_A, x.t(), d_y, scale, dtype) if reshape: d_x = d_x.view(bs, seq_len, d) return d_x, d_lora_A.t(), d_lora_B, None, None, None