Adapters#

In NeMo, we often train models and fine-tune them for a specific task. This is a reasonable approach when the models are just a few million parameters. However, this approach quickly becomes infeasible when approaching hundreds of millions or even billions of parameters. As a potential solution to such a scenario, where fine-tuning a massive model is no longer feasible, we look to Adapters [2] to specialize our model on a specific domain or task. Adapters require a fraction of the total number of parameters as the original model and are much more efficient to fine-tune.

Note

For a detailed tutorial on adding Adapter support to any PyTorch module, please refer to the Tutorials for NeMo Adapters.

What are Adapters?#

Adapters are a straightforward concept - one formulation can be shown by the diagram below. At their simplest, they are residual Feedforward layers that compress the input dimension (\(D\)) to a small bottleneck dimension (\(H\)), such that \(R^D \text{->} R^H\), compute an activation (such as ReLU), finally mapping \(R^H \text{->} R^D\) with another Feedforward layer. This output is then added to the input via a simple residual connection.

Adapter modules such as this are usually initialized such that the initial output of the adapter will always be zeros so as to prevent degradation of the original model’s performance due to addition of such modules.

torch.nn.Module with Adapters#

In NeMo, Adapters are supported via a Mixin class that can be attached to any torch.nn.Module. Such a module will have multiple additional methods which will enable adapter capabilities in that module.

# Import the adapter mixin from NeMo
from nemo.core import adapter_mixins

# NOTE: See the *two* classes being inherited here !
class MyModule(torch.nn.Module, adapter_mixins.AdapterModuleMixin):
    pass

AdapterModuleMixin#

Let’s look into what AdapterModuleMixin adds to the general PyTorch module. Some of the most important methods that are required are listed below :

  1. add_adapter: Used to add an adapter with a unique name to the module.

  2. get_enabled_adapters: Returns a list of names of all enabled adapter modules.

  3. set_enabled_adapters: Sets whether a single (or all) adapters are enabled or disabled.

  4. is_adapter_available: Check if any adapter is available and enabled or not.

Modules that extend this mixin usually can directly use these methods without extending them, but we will cover a case below where you may wish to extend these methods.

class nemo.core.adapter_mixins.AdapterModuleMixin#

Bases: abc.ABC

Generic Adapter Mixin that can augment any torch.nn.Module with Adapter module support.

This mixin class adds a hierarchical way to add any type of Adapter modules to a pre-existing module. Since Models are inherently also nn.Module, this mixin can be attached to any Model or Module. This mixin class adds several utility methods which are utilized or overridden as necessary.

An Adapter module is any Pytorch nn.Module that possess a few properties :

  • It’s input and output dimension are the same, while the hidden dimension need not be the same.

  • The final layer of the Adapter module is zero-initialized, so that the residual connection to the adapter

    yields the original output.

This mixin adds the following instance variables to the class this inherits it:

  • adapter_layer: A torch.nn.ModuleDict(), whose keys are the names of the adapter (globally unique),

    and values are the Adapter nn.Module().

  • adapter_cfg: A OmegaConf DictConfig object that holds the config of the adapters that are initialized.

  • adapter_name: A str resolved name which is unique key globally, but more than one modules may share

    this name.

  • adapter_global_cfg_key: A str representing a key in the model config that can be provided by the user.

    The value resolves to global_cfg, and can be overridden via model.cfg.adapters.global_cfg.*.

  • adapter_metadata_cfg_key: A str representing a key in the model config that is used to preserve the

    metadata of the adapter config.

Note: This module is not responsible for maintaining its config. Subclasses must ensure config is updated

or preserved as needed. It is the responsibility of the subclasses to propagate the most up to date config to lower layers.

adapter_global_cfg_key = 'global_cfg'#
adapter_metadata_cfg_key = 'adapter_meta_cfg'#
set_accepted_adapter_types(adapter_types: List[str]) None#

The module with this mixin can define a list of adapter names that it will accept. This method should be called in the modules init method and set the adapter names the module will expect to be added.

get_accepted_adapter_types() List[str]#

Returns the list of accepted adapter types.

get_from_adapter_layer(name: str)#
add_adapter(name: str, cfg: omegaconf.DictConfig)#

Add an Adapter module to this module.

Parameters
  • name – A globally unique name for the adapter. Will be used to access, enable and disable adapters.

  • cfg – A DictConfig or Dataclass that contains at the bare minimum __target__ to instantiate a new Adapter module.

is_adapter_available() bool#

Checks if any Adapter module has been instantiated.

Returns

bool, determining if any Adapter module has been instantiated. Returns true even if the adapters are enabled or disabled, false only if no adapters exist.

set_enabled_adapters(name: Optional[str] = None, enabled: bool = True)#

Updated the internal adapter config, determining if an adapter (or all adapters) are either enabled or disabled.

A common user pattern would be to disable all adapters (either after adding them, or restoring a model with pre-existing adapters) and then simply enable one of the adapters.

module.set_enabled_adapters(enabled=False)
module.set_enabled_adapters(name=<some adapter name>, enabled=True)
Parameters
  • name – Optional str. If a str name is given, the config will be updated to the value of enabled. If no name is given, then all adapters will be enabled/disabled.

  • enabled – Bool, determines if the adapter(s) will be enabled/disabled.

get_enabled_adapters() List[str]#

Returns a list of all enabled adapters names. The names will always be the resolved names, without module info.

Returns

A list of str names of each enabled adapter names(s).

unfreeze_enabled_adapters(freeze_batchnorm: bool = True) None#

Utility method to unfreeze only the enabled Adapter module(s).

A common user pattern is to freeze all the modules (including all the adapters), and then unfreeze just the required adapters.

module.freeze()  # only available to nemo.core.NeuralModule !
module.unfreeze_enabled_adapters()
Parameters

freeze_batchnorm – An optional (and recommended) practice of freezing the updates to the moving average buffers of any and all BatchNorm*D layers. This is necessary to ensure that disabling all adapters will precisely yield the original (base) model’s outputs.

forward_enabled_adapters(input: torch.Tensor)#

Forward’s all active adapters one by one with the provided input, and chaining the outputs of each adapter layer to the next.

Utilizes the implicit merge strategy of each adapter when computing the adapter’s output, and how that output will be merged back with the original input.

Note:

Parameters

input – The output tensor of the calling module is the input to the first adapter, whose output is then chained to the next adapter until all adapters are consumed.

Returns

The result tensor, after all active adapters have finished their forward passes.

resolve_adapter_module_name_(name: str) Tuple[str, str]#

Utility method to resolve a given global/module adapter name to its components. Always returns a tuple representing (module_name, adapter_name). “:” is used as the delimiter for denoting the module name vs the adapter name.

Will attempt to also resolve a given adapter_name alone back to (module_name, adapter_name) if the metadata config exists for access.

Parameters

name – A global adapter, or a module adapter name (with structure module_name:adapter_name).

Returns

A tuple representing (module_name, adapter_name). If a global adapter is provided, module_name is set to ‘’.

forward_single_enabled_adapter_(input: torch.Tensor, adapter_module: torch.nn.Module, *, adapter_name: str, adapter_strategy: nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy)#

Perform the forward step of a single adapter module on some input data.

Note: Subclasses can override this method to accommodate more complicate adapter forward steps.

Parameters
  • input – input: The output tensor of the calling module is the input to the first adapter, whose output is then chained to the next adapter until all adapters are consumed.

  • adapter_module – The adapter module that is currently required to perform the forward pass.

  • adapter_name – The resolved name of the adapter that is undergoing the current forward pass.

  • adapter_strategy – A subclass of AbstractAdapterStrategy, that determines how the output of the adapter should be merged with the input, or if it should be merged at all.

Returns

The result tensor, after the current active adapter has finished its forward pass.

Using the Adapter Module#

Now that MyModule supports adapters, we can easily add adapters, set their state, check if they are available and perform their forward pass. Note that if multiple adapters are enabled, they are called in a chain, the output of the previous adapter is passed as input to the next adapter and so on.

# Import the adapter mixin and modules from NeMo
import torch
from nemo.core import adapter_mixins
from nemo.collections.common.parts import adapter_modules

class MyModule(torch.nn.Module, adapter_mixins.AdapterModuleMixin):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output = self.layers(x)  # assume self.layers is some Sequential() module

        if self.is_adapter_available():  # check if adapters were added or not
            output = self.forward_enabled_adapters()  # perform the forward of all enabled adapters in a chain

        return output

# Now let us create a module, add an adapter and do a forward pass with some random inputs
module = MyModule(dim)  # assume dim is some input and output dimension of the module.

# Add an adapter
module.add_adapter("first_adapter", cfg=adapter_modules.LinearAdapter(in_features=dim, dim=5))

# Check if adapter is available
module.is_adapter_available()  # returns True

# Check the name(s) of the enabled adapters
module.get_enabled_adapters()  # returns ['first_adapter']

# Set the state of the adapter (by name)
module.set_enabled_adapters(name="first_adapter", enabled=True)

# Freeze all the weights of the original module  (equivalent to calling module.freeze() for a NeuralModule)
for param in module.parameters():
    param.requires_grad = False

# Unfreeze only the adapter weights (so that we finetune only the adapters and not the original weights !)
module.unfreeze_enabled_adapters()

# Now you can train this model's adapters !
input_data = torch.randn(4, dim)  # assume dim is the input-output dim of the module
outputs_with_adapter = module(input_data)

# Compute loss and backward ...

Adapter Compatible Models#

If the goal was to support adapters in a single module, then the goal has been accomplished. In the real world however, we build large composite models out of multiple modules and combine them to build a final model that we then train. We do this using the AdapterModelPTMixin.

Note

For an in-depth guide to supporting hierarchical adapter modules, please refer to the Tutorials for NeMo Adapters.

class nemo.core.adapter_mixins.AdapterModelPTMixin#

Bases: nemo.core.classes.mixins.adapter_mixins.AdapterModuleMixin

Adapter Mixin that can augment a ModelPT subclass with Adapter support.

This mixin class should be used only with a top level ModelPT subclass. This mixin class adds several utility methods which should be subclassed and overriden to propagated to the submodules as necessary.

An Adapter module is any Pytorch nn.Module that possess a few properties :

  • It’s input and output dimension are the same, while the hidden dimension need not be the same.

  • The final layer of the Adapter module is zero-initialized, so that the residual connection to the adapter

    yields the original output.

This mixin adds the following instance variables to the class this inherits it:

  • adapter_layer: A torch.nn.ModuleDict(), whose keys are the names of the adapter (globally unique),

    and values are the Adapter nn.Module().

  • adapter_cfg: A OmegaConf DictConfig object that holds the config of the adapters that are initialized.

  • adapter_global_cfg_key: A str representing a key in the model config that can be provided by the user. The value resolves to global_cfg, and can be overridden via model.cfg.adapters.global_cfg.*.

Note

This module is responsible for maintaining its config. At the ModelPT level, it will access and write Adapter config information to self.cfg.adapters.

setup_adapters()#

Utility method that is called in the ASR ModelPT-implementation constructor, so as to restore any adapters that were previously added.

Should be overriden by the subclass for additional setup steps as required.

This method should be called just once at constructor time.

add_adapter(name: str, cfg: omegaconf.DictConfig)#

Add an Adapter module to this model.

Should be overridden by subclass and super() call must be used - this will setup the config. After calling super(), forward this call to modules that implement the mixin.

Parameters
  • name – A globally unique name for the adapter. Will be used to access, enable and disable adapters.

  • cfg – A DictConfig that contains at the bare minimum __target__ to instantiate a new Adapter module.

is_adapter_available() bool#

Checks if any Adapter module has been instantiated.

Should be overridden by the subclass.

Returns

bool, determining if any Adapter module has been instantiated. Returns true even if the adapters are enabled or disabled, false only if no adapters exist.

set_enabled_adapters(name: Optional[str] = None, enabled: bool = True)#

Updated the internal adapter config, determining if an adapter (or all adapters) are either enabled or disabled.

A common user pattern would be to disable all adapters (either after adding them, or restoring a model with pre-existing adapters) and then simply enable one of the adapters.

Should be overridden by subclass and super() call must be used - this will setup the config. After calling super(), forward this call to modules that implement the mixin.

model.set_enabled_adapters(enabled=False)
model.set_enabled_adapters(name=<some adapter name>, enabled=True)
Parameters
  • name – Optional str. If a str name is given, the config will be updated to the value of enabled. If no name is given, then all adapters will be enabled/disabled.

  • enabled – Bool, determines if the adapter(s) will be enabled/disabled.

get_enabled_adapters() List[str]#

Returns a list of all enabled adapters.

Should be implemented by the subclass.

Returns

A list of str names of each enabled adapter(s).

check_valid_model_with_adapter_support_()#

Utility method to test if the subclass of this mixin is an appropriate subclass of ModelPT itself.

Should be implemented by the subclass.

save_adapters(filepath: str, name: Optional[str] = None)#

Utility method that saves only the adapter module(s), and not the entire model itself. This allows the sharing of adapters which are often just a fraction of the size of the full model, enabling easier deliver.

Note: The saved file is a pytorch compatible pickle file, containing the state dicts of the adapter(s),

as well as a binary representation of the adapter config.

Parameters
  • filepath – A str filepath where the .pt file that will contain the adapter state dict.

  • name – Optional name of the adapter that will be saved to this file. If None is passed, all adapters will be saved to the file. The name can be either the global name (adapter_name), or the module level name (module:adapter_name).

load_adapters(filepath: str, name: Optional[str] = None, map_location: Optional[str] = None, strict: bool = True)#

Utility method that restores only the adapter module(s), and not the entire model itself. This allows the sharing of adapters which are often just a fraction of the size of the full model, enabling easier deliver.

Note: During restoration, assumes that the model does not currently already have an adapter with

the name (if provided), or any adapter that shares a name with the state dict’s modules (if name is not provided). This is to ensure that each adapter name is globally unique in a model.

Parameters
  • filepath – Filepath of the .pt file.

  • name – Optional name of the adapter that will be saved to this file. If None is passed, all adapters will be saved to the file. The name must be either the global name (adapter_name), or the module level name (module:adapter_name), whichever exactly matches the state dict.

  • map_location – Pytorch flag, where to place the adapter(s) state dict(s).

  • strict – Pytorch flag, whether to load the weights of the adapter(s) strictly or not.

update_adapter_cfg(cfg: omegaconf.DictConfig)#

Utility method to recursively update all of the Adapter module configs with the provided config.

Note

It is not a (deep)copy, but a reference copy. Changes made to the config will be reflected to adapter submodules, but it is still encouraged to explicitly update the adapter_cfg using this method.

Parameters

cfg – DictConfig containing the value of model.cfg.adapters.

property adapter_module_names: List[str]#

List of valid adapter modules that are supported by the model.

Note: Subclasses should override this property and return a list of str names, of all the modules

that they support, which will enable users to determine where to place the adapter modules.

Returns

A list of str, one for each of the adapter modules that are supported. By default, the subclass should support the “global adapter” (‘’).

Below, we will discuss some useful functionality of Adapter Compatible Models.

  1. Save and restore a Model with adapter capability: Any NeMo model that implements this class correctly can save and restore NeMo models with adapter capabilities, thereby allowing sharing of adapters.

  2. save_adapters and load_adapters: Adapters are usually a very small number of parameters, there is no need for the entire model to be duplicated for each adapter. This method allows storing just the adapter module(s) separately from the Model, so that you can use the same “base” model, and share just the Adapter modules.

References#

1

Junxian He, Chunting Zhou, Xuezhe Ma, Taylor Berg-Kirkpatrick, and Graham Neubig. Towards a unified view of parameter-efficient transfer learning. 2021. URL: https://arxiv.org/abs/2110.04366, doi:10.48550/ARXIV.2110.04366.

2

Neil Houlsby, Andrei Giurgiu, Stanislaw Jastrzebski, Bruna Morrone, Quentin De Laroussilhe, Andrea Gesmundo, Mona Attariyan, and Sylvain Gelly. Parameter-efficient transfer learning for nlp. In International Conference on Machine Learning, 2790–2799. PMLR, 2019.