bridge.peft.lora_merge#

Module Contents#

Classes#

LoRAMerge

Tensor helper for merging LoRA adapter weights into base weights.

API#

class bridge.peft.lora_merge.LoRAMerge#

Tensor helper for merging LoRA adapter weights into base weights.

merge(
base_weight: torch.Tensor,
linear_out: torch.Tensor,
linear_in: torch.Tensor,
alpha: int,
dim: int,
*,
tp_group: torch.distributed.ProcessGroup | None,
scale: float | None = None,
) torch.Tensor#

Merges the LoRA adapter weights with the base model weights. Handles tensor parallelism by gathering sharded dimensions.

For ColumnParallelLinear (e.g., linear_qkv, linear_fc1): - base_weight: (out_features/TP, in_features) - linear_in: (dim/TP, in_features) <- Need to gather this - linear_out: (out_features/TP, dim) - Target: (out_features/TP, dim) @ (dim, in_features) = (out_features/TP, in_features)

For RowParallelLinear (e.g., linear_proj, linear_fc2): - base_weight: (out_features, in_features/TP) - linear_in: (dim, in_features/TP) - linear_out: (out_features/TP, dim) <- Need to gather this - Target: (out_features, dim) @ (dim, in_features/TP) = (out_features, in_features/TP)

Parameters:
  • base_weight (torch.Tensor) – The base model weights.

  • linear_out (torch.Tensor) – LoRA’s B matrix.

  • linear_in (torch.Tensor) – LoRA’s A matrix.

  • alpha (int) – Weighting factor for the low-rank projection.

  • dim (int) – Dimension of the low-rank projection space.

  • tp_group – Tensor-parallel process group for the adapter shard.

  • scale – Optional precomputed LoRA scale. Defaults to alpha / dim.

Returns:

The merged weights.

Return type:

torch.Tensor