bridge.peft.dora_layers
#
Module Contents#
Classes#
Adapter class for DoRA to handle the additional weight_magnitude parameter. |
|
An adapter wrapper that is designed to be used with DoRA. |
API#
- class bridge.peft.dora_layers.ParallelLinearDoRAAdapter#
Bases:
megatron.bridge.peft.utils.ParallelLinearAdapter
Adapter class for DoRA to handle the additional weight_magnitude parameter.
This class extends ParallelLinearAdapter to add DoRA-specific functionality, including weight magnitude tracking and sharded state dict support for distributed training.
- init_weight_magnitude(value: torch.Tensor) None #
Initialize weight_magnitude with shape (d,), where d is the output dim of the linear layer.
- Parameters:
value (torch.Tensor) β Initial values for the weight magnitude parameter.
- get_weight_magnitude() torch.Tensor #
Public function to get the weight magnitude parameter.
- Returns:
The weight magnitude parameter.
- Return type:
torch.Tensor
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: tuple = (),
- metadata: Optional[dict] = None,
Sharded state dict implementation for DoRA adapter. Weight magnitude is TP sharded for linear_qkv and linear_fc1 only.
- Parameters:
prefix (str) β Prefix for parameter names. Defaults to ββ.
sharded_offsets (tuple) β Offsets for sharded parameters. Defaults to ().
metadata (Optional[dict]) β Additional metadata. Defaults to None.
- Returns:
The sharded state dictionary.
- Return type:
ShardedStateDict
- class bridge.peft.dora_layers.DoRALinear(
- to_wrap: torch.nn.Module,
- adapter: bridge.peft.dora_layers.ParallelLinearDoRAAdapter,
Bases:
megatron.bridge.peft.adapter_wrapper.AdapterWrapper
An adapter wrapper that is designed to be used with DoRA.
DoRA (Weight-Decomposed Low-Rank Adaptation) extends LoRA by decomposing the pre-trained weight into magnitude and direction components. This class implements the DoRA forward pass that applies magnitude scaling to the combined base and adapter outputs.
It extends the AdapterWrapper class to provide DoRA-specific implementation of the forward method.
Initialization
Initialize the DoRALinear wrapper.
- Parameters:
to_wrap (nn.Module) β The base linear module to wrap.
adapter (ParallelLinearDoRAAdapter) β The DoRA adapter instance.
- _get_weight_norm() torch.Tensor #
Calculate the norm of the combined weight matrix (W_0 + B A).
This method handles tensor parallel communication to gather weights when needed and computes the L2 norm along the appropriate dimension.
- Returns:
The L2 norm of the combined weight matrix.
- Return type:
torch.Tensor
- forward(x: torch.Tensor) tuple[torch.Tensor, torch.Tensor] #
Forward method for DoRA.
The DoRA forward pass implements: mag_norm_scale * (linear_output + adapter_output) = ||W_0 + B_0 A_0|| / ||W_0 + B A|| * (W_0 x + B A x) = ||W_0 + B_0 A_0|| ((W_0 + B A) / ||W_0 + B A||) x = m ((W_0 + B A) / ||W_0 + B A||) x = equation 5 in DoRA paper
When dropout is used, equation becomes: W_0 x + (m /||W_0 + B A|| - 1) W_0 dropout(x) + m /||W_0 + B A|| B A dropout(x) = β¦ = m /||W_0 + B A|| (W_0 x + B A dropout(x)) + (m /||W_0 + B A|| - 1) W_0 (dropout(x) - x)
- Parameters:
x (torch.Tensor) β Input tensor.
- Returns:
A tuple containing the DoRA output and bias term.
- Return type:
tuple