bridge.peft.lora_merge#
Module Contents#
Classes#
Tensor helper for merging LoRA adapter weights into base weights. |
API#
- class bridge.peft.lora_merge.LoRAMerge#
Tensor helper for merging LoRA adapter weights into base weights.
- merge(
- base_weight: torch.Tensor,
- linear_out: torch.Tensor,
- linear_in: torch.Tensor,
- alpha: int,
- dim: int,
- *,
- tp_group: torch.distributed.ProcessGroup | None,
- scale: float | None = None,
Merges the LoRA adapter weights with the base model weights. Handles tensor parallelism by gathering sharded dimensions.
For ColumnParallelLinear (e.g., linear_qkv, linear_fc1): - base_weight: (out_features/TP, in_features) - linear_in: (dim/TP, in_features) <- Need to gather this - linear_out: (out_features/TP, dim) - Target: (out_features/TP, dim) @ (dim, in_features) = (out_features/TP, in_features)
For RowParallelLinear (e.g., linear_proj, linear_fc2): - base_weight: (out_features, in_features/TP) - linear_in: (dim, in_features/TP) - linear_out: (out_features/TP, dim) <- Need to gather this - Target: (out_features, dim) @ (dim, in_features/TP) = (out_features, in_features/TP)
- Parameters:
base_weight (torch.Tensor) β The base model weights.
linear_out (torch.Tensor) β LoRAβs B matrix.
linear_in (torch.Tensor) β LoRAβs A matrix.
alpha (int) β Weighting factor for the low-rank projection.
dim (int) β Dimension of the low-rank projection space.
tp_group β Tensor-parallel process group for the adapter shard.
scale β Optional precomputed LoRA scale. Defaults to alpha / dim.
- Returns:
The merged weights.
- Return type:
torch.Tensor