bridge.peft.lora_layers#

Module Contents#

Classes#

LoRALinear

An adapter wrapper that adds the output of the adapter to the output of the wrapped module.

TELinearAdapter

TELinear + LoRA, maintains ckpts structure (i.e. Linear’s weight/bias remain at the same FQN)

LinearAdapter

Linear + LoRA, maintains ckpts structure (i.e. Linear’s weight/bias remain at the same FQN)

Functions#

patch_linear_module

Monkey-patch a nn.Linear or te.Linear to be a LinearAdapter.

API#

class bridge.peft.lora_layers.LoRALinear#

Bases: megatron.bridge.peft.adapter_wrapper.AdapterWrapper

An adapter wrapper that adds the output of the adapter to the output of the wrapped module.

This class is designed to be used with LoRA (Low-Rank Adaptation) and similar techniques where the adapter’s output is added to the main module’s output. It extends the AdapterWrapper class to provide a specific implementation of the forward method.

forward(
x: torch.Tensor,
*args: Any,
**kwargs: Any,
) Tuple[torch.Tensor, Optional[torch.Tensor]]#

Forward pass that combines the wrapped module output with the adapter output.

Parameters:
  • x – Input tensor.

  • *args – Additional positional arguments for the wrapped module.

  • **kwargs – Additional keyword arguments for the wrapped module.

Returns:

  • Combined output (linear_output + adapter_output)

  • Bias term (if present, otherwise None)

Return type:

A tuple containing

class bridge.peft.lora_layers.TELinearAdapter(
orig_linear: transformer_engine.pytorch.Linear,
dim: int = 8,
alpha: int = 32,
dropout: float = 0.0,
dropout_position: Literal[pre, post] = 'post',
lora_A_init_method: Literal[xavier, uniform] = 'xavier',
lora_dtype: Optional[torch.dtype] = None,
)#

Bases: transformer_engine.pytorch.Linear

TELinear + LoRA, maintains ckpts structure (i.e. Linear’s weight/bias remain at the same FQN)

The _init_adapter and forward methods provide the LoRA functionality. We want to be able to use those inside LinearAdapter but also for monkey-patching modules, without repeating the same code -> therefore those are decorated with @staticmethod.

Parameters:
  • orig_linear – The linear module to augment.

  • dim – LoRA’s dimension (in_features -> dim -> out_features).

  • alpha – LoRA’s scaling alpha.

  • dropout – Dropout probability (default: 0.0).

  • dropout_position – Where to apply dropout relative to LoRA (choices: [‘pre’, ‘post’], default=’post’).

  • lora_A_init_method – Initialization method for lora_A (choices: [‘xavier’, ‘uniform’]).

  • lora_dtype – Weight’s dtype, by default will use orig_linear’s but if they are quantized weights (e.g. 4bit) needs to be specified explicitly.

Initialization

Initialize TELinearAdapter by copying from original TELinear and adding LoRA components.

Parameters:
  • orig_linear – The original TELinear module to adapt.

  • dim – LoRA rank dimension.

  • alpha – LoRA scaling factor.

  • dropout – Dropout probability.

  • dropout_position – When to apply dropout (‘pre’ or ‘post’ LoRA computation).

  • lora_A_init_method – Initialization method for LoRA matrix A.

  • lora_dtype – Data type for LoRA weights.

static _init_adapter(
obj: Union[bridge.peft.lora_layers.TELinearAdapter, torch.nn.Module],
dim: int = 8,
alpha: int = 32,
dropout: float = 0.0,
dropout_position: Literal[pre, post] = 'post',
lora_A_init_method: Literal[xavier, uniform] = 'xavier',
lora_dtype: Optional[torch.dtype] = None,
) None#

Add LoRA weights to obj. The obj is either a LinearAdapter or an nn.Module (when monkey-patching).

Parameters:
  • obj – Input module to adapt (LinearAdapter or nn.Module).

  • dim – LoRA’s dimension (in_features -> dim -> out_features).

  • alpha – LoRA’s scaling alpha.

  • dropout – Dropout probability (default: 0.0).

  • dropout_position – Where to apply dropout relative to LoRA (choices: [‘pre’, ‘post’], default=’post’).

  • lora_A_init_method – Initialization method for lora_A (choices: [‘xavier’, ‘uniform’]).

  • lora_dtype – Weight’s dtype, by default will use orig_linear’s but if they are quantized weights (e.g. 4bit) needs to be specified explicitly.

forward(x: torch.Tensor) torch.Tensor#

Forward pass combining TELinear output with LoRA adaptation.

Parameters:

x – Input tensor.

Returns:

Combined output from original linear layer and LoRA adaptation.

class bridge.peft.lora_layers.LinearAdapter(
orig_linear: torch.nn.Linear,
dim: int = 8,
alpha: int = 32,
dropout: float = 0.0,
dropout_position: Literal[pre, post] = 'post',
lora_A_init_method: Literal[xavier, uniform] = 'xavier',
lora_dtype: Optional[torch.dtype] = None,
)#

Bases: torch.nn.Linear

Linear + LoRA, maintains ckpts structure (i.e. Linear’s weight/bias remain at the same FQN)

The _init_adapter and forward methods provide the LoRA functionality. We want to be able to use those inside LinearAdapter but also for monkey-patching modules, without repeating the same code -> therefore those are decorated with @staticmethod.

Parameters:
  • orig_linear – The linear module to augment.

  • dim – LoRA’s dimension (in_features -> dim -> out_features).

  • alpha – LoRA’s scaling alpha.

  • dropout – Dropout probability (default: 0.0).

  • dropout_position – Where to apply dropout relative to LoRA (choices: [‘pre’, ‘post’], default=’post’).

  • lora_A_init_method – Initialization method for lora_A (choices: [‘xavier’, ‘uniform’]).

  • lora_dtype – Weight’s dtype, by default will use orig_linear’s but if they are quantized weights (e.g. 4bit) needs to be specified explicitly.

Initialization

Initialize LinearAdapter by copying from original Linear and adding LoRA components.

Parameters:
  • orig_linear – The original Linear module to adapt.

  • dim – LoRA rank dimension.

  • alpha – LoRA scaling factor.

  • dropout – Dropout probability.

  • dropout_position – When to apply dropout (‘pre’ or ‘post’ LoRA computation).

  • lora_A_init_method – Initialization method for LoRA matrix A.

  • lora_dtype – Data type for LoRA weights.

static _init_adapter(
obj: Union[bridge.peft.lora_layers.LinearAdapter, torch.nn.Module],
dim: int = 8,
alpha: int = 32,
dropout: float = 0.0,
dropout_position: Literal[pre, post] = 'post',
lora_A_init_method: Literal[xavier, uniform] = 'xavier',
lora_dtype: Optional[torch.dtype] = None,
) None#

Add LoRA weights to obj. The obj is either a LinearAdapter or an nn.Module (when monkey-patching).

Parameters:
  • obj – Input module to adapt (LinearAdapter or nn.Module).

  • dim – LoRA’s dimension (in_features -> dim -> out_features).

  • alpha – LoRA’s scaling alpha.

  • dropout – Dropout probability (default: 0.0).

  • dropout_position – Where to apply dropout relative to LoRA (choices: [‘pre’, ‘post’], default=’post’).

  • lora_A_init_method – Initialization method for lora_A (choices: [‘xavier’, ‘uniform’]).

  • lora_dtype – Weight’s dtype, by default will use orig_linear’s but if they are quantized weights (e.g. 4bit) needs to be specified explicitly.

forward(x: torch.Tensor) torch.Tensor#

Forward pass combining Linear output with LoRA adaptation.

Parameters:

x – Input tensor.

Returns:

Combined output from original linear layer and LoRA adaptation.

bridge.peft.lora_layers.patch_linear_module(
orig_linear: Union[torch.nn.Linear, transformer_engine.pytorch.Linear],
dim: int = 8,
alpha: int = 32,
dropout: float = 0.0,
dropout_position: Literal[pre, post] = 'post',
lora_A_init_method: Literal[xavier, uniform] = 'xavier',
lora_dtype: Optional[torch.dtype] = None,
) Union[torch.nn.Linear, transformer_engine.pytorch.Linear]#

Monkey-patch a nn.Linear or te.Linear to be a LinearAdapter.

This function replaces a nn.Linear with a LinearAdapter without copying weights, making it suitable for cases where the original module was initialized with meta device.

The orig_linear might not contain valid weights, for example, the given orig_linear was initialized within a context-manager that uses a “meta” device. Therefore, we cannot copy the weight/bias from the orig_linear to the LinearAdapter, since those have not been allocated.

To circumvent this scenario, LinearAdapter’s additional functionality (_init_adapter, _forward) is based on static functions, so that we can use them for patching or when allocating a new LinearAdapter object.

Parameters:
  • orig_linear – The module we add adapter to.

  • dim – LoRA dimension. Defaults to 8.

  • alpha – LoRA alpha scale. Defaults to 32.

  • dropout – Dropout probability. Defaults to 0.0.

  • dropout_position – Location to apply dropout wrt LoRA. Defaults to ‘post’ (choices: ‘pre’, ‘post’).

  • lora_A_init_method – LoRA_A initialization method. Defaults to ‘xavier’.

  • lora_dtype – LoRA weights’ dtype. By default will use orig_linear’s dtype but orig_linear might use non-trainable dtype (e.g., 4bit), in which case the user must specify the dtype manually. Defaults to None.

Returns:

The monkey-patched (nn.Linear + LoRA) nn.Module.

Raises:
  • NotImplementedError – If orig_linear is not nn.Linear or te.Linear.

  • AssertionError – If orig_linear already has super_fwd attribute.