bridge.peft.adapter_wrapper
#
Module Contents#
Classes#
Abstract base class for wrapping modules with adapters in Parameter-Efficient Fine-Tuning (PEFT). |
API#
- class bridge.peft.adapter_wrapper.AdapterWrapper(to_wrap: torch.nn.Module, adapter: torch.nn.Module)#
Bases:
torch.nn.Module
Abstract base class for wrapping modules with adapters in Parameter-Efficient Fine-Tuning (PEFT).
This class wraps a module and its associated adapter, providing methods for managing the state dictionaries of both the main module and the adapter. It does not implement the forward method, which must be implemented by concrete subclasses.
.. attribute:: to_wrap
The main module to be wrapped.
- Type:
nn.Module
.. attribute:: adapter
The adapter module to be applied.
- Type:
nn.Module
.. note::
This class is abstract and cannot be instantiated directly. Subclasses must implement the forward method.
.. rubric:: Example
class LoRALinear(AdapterWrapper): def init(self, to_wrap, adapter): super().init(to_wrap, adapter)
def forward(self, x): return self.to_wrap(x) + self.adapter(x)
main_module = nn.Linear(100, 100) adapter = nn.Linear(100, 100) parallel_adapter = LoRALinear(main_module, adapter)
Initialization
Initialize the AdapterWrapper with a main module and adapter.
- Parameters:
to_wrap – The main module to be wrapped.
adapter – The adapter module to be applied.
- base_linear_forward(
- x: torch.Tensor,
- *args: Any,
- **kwargs: Any,
Run the forward method of the linear module
to_wrap
.This method handles the complex return patterns of Megatron’s linear layers, which can return different combinations of outputs, biases, and layernorm outputs.
The flow is: x -> [layernorm/identity] -> layernorm_output -> [linear] -> linear_output, bias
- Parameters:
x – Input tensor.
*args – Additional positional arguments for the wrapped module.
**kwargs – Additional keyword arguments for the wrapped module.
- Returns:
linear_output: The output from the linear layer
bias: The bias term (if present, otherwise None)
layernorm_output: The output from layernorm (differs from x only for LayerNormColumnParallelLinear, otherwise equals x)
- Return type:
A tuple containing
.. note::
The wrapped module can return values in four different patterns:
nothing: (out, None)
return_bias: (out, bias)
return_layernorm_output: ((out, ln_out), None)
both: (out, bias, ln_out)
- state_dict(
- destination: Optional[Dict[str, Any]] = None,
- prefix: str = '',
- keep_vars: bool = False,
Retrieve the state dictionary of the wrapped module and adapter.
This method overrides the default state_dict behavior to include both the main module’s state and the adapter’s state under a special ‘adapter’ prefix.
- Parameters:
destination – A dictionary to store the state. If None, a new dictionary is created. Defaults to None.
prefix – A prefix added to parameter and buffer names. Defaults to ‘’.
keep_vars – If True, returns variables instead of tensor values. Defaults to False.
- Returns:
The state dictionary containing both the main module and adapter states.
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: Tuple[Tuple[int, int, int]] = (),
- metadata: Optional[Dict[str, Any]] = None,
Retrieve the sharded state dictionary of the wrapped module and adapter.
This method is used for distributed checkpointing, combining the sharded states of both the main module and the adapter.
- Parameters:
prefix – A prefix added to parameter and buffer names. Defaults to ‘’.
sharded_offsets – Offsets for sharded parameters. Defaults to an empty tuple.
metadata – Additional metadata for the sharded state. Defaults to None.
- Returns:
The combined sharded state dictionary.