Source code for nemo_automodel.components._peft.module_matcher
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass, field
from typing import List
import torch.nn as nn
[docs]
def wildcard_match(pattern, key):
"""
Return whether the pattern (target module to add LoRA) matches the key (model weight name).
Example:
--------
>>> wildcard_match("*.layers.0.*.linear_qkv", "decoder.layers.0.self_attention.linear_qkv")
True
>>> wildcard_match("*.layers.0.*.linear_qkv", "decoder.layers.1.self_attention.linear_qkv")
False
"""
if key is None:
return None
regex_pattern = re.compile("^" + pattern.replace("*", "(.*)") + "$")
match = regex_pattern.match(key)
return match is not None
[docs]
@dataclass
class ModuleMatcher:
"""
Matches Modules to apply PEFT adapters on.
Args:
target_modules (List[str], optional): A list of module names to apply LoRA to.
Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2'].
- 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections
in self-attention.
- 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention.
- 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP.
- 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP.
Target modules can also contain wildcards. For example, you can specify
target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv
on the first two layers.
exclude_modules (List[str], optional): A list of module names to exclude from applying LoRA to.
match_all_linear (bool, optional): Whether to match all linear layers.
is_causal_lm (bool, optional): Whether the model is a causal language model.
"""
target_modules: List[str] = field(default_factory=lambda: ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"])
exclude_modules: List[str] = field(default_factory=list)
match_all_linear: bool = field(default=False)
is_causal_lm: bool = field(default=False)
[docs]
def __post_init__(self):
"""
Input validation.
"""
if isinstance(self.target_modules, str):
self.target_modules = [self.target_modules]
if isinstance(self.exclude_modules, str):
self.exclude_modules = [self.exclude_modules]
if (
self.match_all_linear is False
and (not isinstance(self.target_modules, list) or len(self.target_modules) == 0)
and (not isinstance(self.exclude_modules, list) or len(self.exclude_modules) == 0)
):
raise ValueError("Expected match_all_linear to be true or target_modules/exclude_modules to be non-empty")
# --------------------------------------------------------------------- #
# Public API #
# --------------------------------------------------------------------- #
[docs]
def match(self, m: nn.Module, name: str = None, prefix: str = None):
"""
Return (pattern, full_name) if the module matches; otherwise None.
"""
full_name = f"{prefix}.{name}" if prefix else name
if self.is_causal_lm:
if "lm_head" in full_name:
return False
# 1. matching by layer type takes absolute precedence
if self.match_all_linear and isinstance(m, nn.Linear):
return True
# 2. target_modules is the next most-specific rule set
elif self.target_modules:
assert not self.exclude_modules, "`exclude_modules` must be empty when `target_modules` is used."
for pattern in self.target_modules:
if name == pattern or wildcard_match(pattern, full_name):
return True
# 3. Fallback: “all linear layers except those explicitly excluded”
else:
return (
name not in self.exclude_modules
and not any(wildcard_match(pattern, full_name) for pattern in self.exclude_modules)
and isinstance(m, nn.Linear)
)