Adapters

In NeMo, we often train models and fine-tune them for a specific task. This is a reasonable approach when the models are just a few million parameters. However, this approach quickly becomes infeasible when approaching hundreds of millions or even billions of parameters. As a potential solution to such a scenario, where fine-tuning a massive model is no longer feasible, we look to Adapters [2] to specialize our model on a specific domain or task. Adapters require a fraction of the total number of parameters as the original model and are much more efficient to fine-tune.

Note

For a detailed tutorial on adding Adapter support to any PyTorch module, please refer to the Tutorials for NeMo Adapters.

Adapters are a straightforward concept - one formulation can be shown by the diagram below. At their simplest, they are residual Feedforward layers that compress the input dimension (\(D\)) to a small bottleneck dimension (\(H\)), such that \(R^D \text{->} R^H\), compute an activation (such as ReLU), finally mapping \(R^H \text{->} R^D\) with another Feedforward layer. This output is then added to the input via a simple residual connection.


Adapter modules such as this are usually initialized such that the initial output of the adapter will always be zeros so as to prevent degradation of the original model’s performance due to addition of such modules.

In NeMo, Adapters are supported via a Mixin class that can be attached to any torch.nn.Module. Such a module will have multiple additional methods which will enable adapter capabilities in that module.

Copy
Copied!
            

# Import the adapter mixin from NeMo from nemo.core import adapter_mixins # NOTE: See the *two* classes being inherited here ! class MyModule(torch.nn.Module, adapter_mixins.AdapterModuleMixin): pass

Let’s look into what AdapterModuleMixin adds to the general PyTorch module. Some of the most important methods that are required are listed below :

  1. add_adapter: Used to add an adapter with a unique name to the module.

  2. get_enabled_adapters: Returns a list of names of all enabled adapter modules.

  3. set_enabled_adapters: Sets whether a single (or all) adapters are enabled or disabled.

  4. is_adapter_available: Check if any adapter is available and enabled or not.

Modules that extend this mixin usually can directly use these methods without extending them, but we will cover a case below where you may wish to extend these methods.

class nemo.core.adapter_mixins.AdapterModuleMixin

Bases: abc.ABC

Generic Adapter Mixin that can augment any torch.nn.Module with Adapter module support.

This mixin class adds a hierarchical way to add any type of Adapter modules to a pre-existing module. Since Models are inherently also nn.Module, this mixin can be attached to any Model or Module. This mixin class adds several utility methods which are utilized or overridden as necessary.

An Adapter module is any Pytorch nn.Module that possess a few properties :

  • It’s input and output dimension are the same, while the hidden dimension need not be the same.

  • The final layer of the Adapter module is zero-initialized, so that the residual connection to the adapter

    yields the original output.

This mixin adds the following instance variables to the class this inherits it:

  • adapter_layer: A torch.nn.ModuleDict(), whose keys are the names of the adapter (globally unique),

    and values are the Adapter nn.Module().

  • adapter_cfg: A OmegaConf DictConfig object that holds the config of the adapters that are initialized.

  • adapter_name: A str resolved name which is unique key globally, but more than one modules may share

    this name.

  • adapter_global_cfg_key: A str representing a key in the model config that can be provided by the user.

    The value resolves to global_cfg, and can be overridden via model.cfg.adapters.global_cfg.*.

  • adapter_metadata_cfg_key: A str representing a key in the model config that is used to preserve the

    metadata of the adapter config.

Note

This module is not responsible for maintaining its config. Subclasses must ensure config is updated or preserved as needed. It is the responsibility of the subclasses to propagate the most up to date config to lower layers.

adapter_global_cfg_key = 'global_cfg'

adapter_metadata_cfg_key = 'adapter_meta_cfg'

add_adapter(name: str, cfg: Union[omegaconf.DictConfig, nemo.core.classes.mixins.adapter_mixins.AdapterConfig], **kwargs)

Add an Adapter module to this module.

Parameters
  • name – A globally unique name for the adapter. Will be used to access, enable and disable adapters.

  • cfg – A DictConfig or Dataclass that contains at the bare minimum __target__ to instantiate a new Adapter module.

is_adapter_available() → bool

Checks if any Adapter module has been instantiated.

Returns

bool, determining if any Adapter module has been instantiated. Returns true even if the adapters are enabled or disabled, false only if no adapters exist.

set_enabled_adapters(name: Optional[str] = None, enabled: bool = True)

Updated the internal adapter config, determining if an adapter (or all adapters) are either enabled or disabled.

A common user pattern would be to disable all adapters (either after adding them, or restoring a model with pre-existing adapters) and then simply enable one of the adapters.

Copy
Copied!
            

module.set_enabled_adapters(enabled=False) module.set_enabled_adapters(name=<some adapter name>, enabled=True)

Parameters
  • name – Optional str. If a str name is given, the config will be updated to the value of enabled. If no name is given, then all adapters will be enabled/disabled.

  • enabled – Bool, determines if the adapter(s) will be enabled/disabled.

get_enabled_adapters() → List[str]

Returns a list of all enabled adapters names. The names will always be the resolved names, without module info.

Returns

A list of str names of each enabled adapter names(s).

get_adapter_module(name: str)

Gets an adapter module by name if possible, otherwise returns None.

Parameters

name – A str name (resolved or not) corresponding to an Adapter.

Returns

An nn.Module if the name could be resolved and matched, otherwise None/

set_accepted_adapter_types(adapter_types: List[Union[type, str]]) → None

The module with this mixin can define a list of adapter names that it will accept. This method should be called in the modules init method and set the adapter names the module will expect to be added.

Parameters

adapter_types – A list of str paths that correspond to classes. The class paths will be instantiated to ensure that the class path is correct.

get_accepted_adapter_types() → Set[type]

Utility function to get the set of all classes that are accepted by the module.

Returns

Returns the set of accepted adapter types as classes, otherwise an empty set.

unfreeze_enabled_adapters(freeze_batchnorm: bool = True) → None

Utility method to unfreeze only the enabled Adapter module(s).

A common user pattern is to freeze all the modules (including all the adapters), and then unfreeze just the required adapters.

Copy
Copied!
            

module.freeze() # only available to nemo.core.NeuralModule ! module.unfreeze_enabled_adapters()

Parameters

freeze_batchnorm – An optional (and recommended) practice of freezing the updates to the moving average buffers of any and all BatchNorm*D layers. This is necessary to ensure that disabling all adapters will precisely yield the original (base) model’s outputs.

forward_enabled_adapters(input: torch.Tensor)

Forward’s all active adapters one by one with the provided input, and chaining the outputs of each adapter layer to the next.

Utilizes the implicit merge strategy of each adapter when computing the adapter’s output, and how that output will be merged back with the original input.

Parameters

input – The output tensor of the calling module is the input to the first adapter, whose output is then chained to the next adapter until all adapters are consumed.

Returns

The result tensor, after all active adapters have finished their forward passes.

resolve_adapter_module_name_(name: str) → Tuple[str, str]

Utility method to resolve a given global/module adapter name to its components. Always returns a tuple representing (module_name, adapter_name). “:” is used as the delimiter for denoting the module name vs the adapter name.

Will attempt to also resolve a given adapter_name alone back to (module_name, adapter_name) if the metadata config exists for access.

Parameters

name – A global adapter, or a module adapter name (with structure module_name:adapter_name).

Returns

A tuple representing (module_name, adapter_name). If a global adapter is provided, module_name is set to ‘’.

forward_single_enabled_adapter_(input: torch.Tensor, adapter_module: torch.nn.Module, *, adapter_name: str, adapter_strategy: nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy)

Perform the forward step of a single adapter module on some input data.

Note

Subclasses can override this method to accommodate more complicate adapter forward steps.

Parameters
  • input – input: The output tensor of the calling module is the input to the first adapter, whose output is then chained to the next adapter until all adapters are consumed.

  • adapter_module – The adapter module that is currently required to perform the forward pass.

  • adapter_name – The resolved name of the adapter that is undergoing the current forward pass.

  • adapter_strategy – A subclass of AbstractAdapterStrategy, that determines how the output of the adapter should be merged with the input, or if it should be merged at all.

Returns

The result tensor, after the current active adapter has finished its forward pass.

Now that MyModule supports adapters, we can easily add adapters, set their state, check if they are available and perform their forward pass. Note that if multiple adapters are enabled, they are called in a chain, the output of the previous adapter is passed as input to the next adapter and so on.

Copy
Copied!
            

# Import the adapter mixin and modules from NeMo import torch from nemo.core import adapter_mixins from nemo.collections.common.parts import adapter_modules class MyModule(torch.nn.Module, adapter_mixins.AdapterModuleMixin): def forward(self, x: torch.Tensor) -> torch.Tensor: output = self.layers(x) # assume self.layers is some Sequential() module if self.is_adapter_available(): # check if adapters were added or not output = self.forward_enabled_adapters() # perform the forward of all enabled adapters in a chain return output # Now let us create a module, add an adapter and do a forward pass with some random inputs module = MyModule(dim) # assume dim is some input and output dimension of the module. # Add an adapter module.add_adapter("first_adapter", cfg=adapter_modules.LinearAdapter(in_features=dim, dim=5)) # Check if adapter is available module.is_adapter_available() # returns True # Check the name(s) of the enabled adapters module.get_enabled_adapters() # returns ['first_adapter'] # Set the state of the adapter (by name) module.set_enabled_adapters(name="first_adapter", enabled=True) # Freeze all the weights of the original module (equivalent to calling module.freeze() for a NeuralModule) for param in module.parameters(): param.requires_grad = False # Unfreeze only the adapter weights (so that we finetune only the adapters and not the original weights !) module.unfreeze_enabled_adapters() # Now you can train this model's adapters ! input_data = torch.randn(4, dim) # assume dim is the input-output dim of the module outputs_with_adapter = module(input_data) # Compute loss and backward ...

If the goal was to support adapters in a single module, then the goal has been accomplished. In the real world however, we build large composite models out of multiple modules and combine them to build a final model that we then train. We do this using the AdapterModelPTMixin.

Note

For an in-depth guide to supporting hierarchical adapter modules, please refer to the Tutorials for NeMo Adapters.

class nemo.core.adapter_mixins.AdapterModelPTMixin

Bases: nemo.core.classes.mixins.adapter_mixins.AdapterModuleMixin

Adapter Mixin that can augment a ModelPT subclass with Adapter support.

This mixin class should be used only with a top level ModelPT subclass. This mixin class adds several utility methods which should be subclassed and overriden to propagated to the submodules as necessary.

An Adapter module is any Pytorch nn.Module that possess a few properties :

  • It’s input and output dimension are the same, while the hidden dimension need not be the same.

  • The final layer of the Adapter module is zero-initialized, so that the residual connection to the adapter

    yields the original output.

This mixin adds the following instance variables to the class this inherits it:

  • adapter_layer: A torch.nn.ModuleDict(), whose keys are the names of the adapter (globally unique),

    and values are the Adapter nn.Module().

  • adapter_cfg: A OmegaConf DictConfig object that holds the config of the adapters that are initialized.

  • adapter_global_cfg_key: A str representing a key in the model config that can be provided by the user. The value resolves to global_cfg, and can be overridden via model.cfg.adapters.global_cfg.*.

Note

This module is responsible for maintaining its config. At the ModelPT level, it will access and write Adapter config information to self.cfg.adapters.

setup_adapters()

Utility method that is called in the ASR ModelPT-implementation constructor, so as to restore any adapters that were previously added.

Should be overriden by the subclass for additional setup steps as required.

This method should be called just once at constructor time.

add_adapter(name: str, cfg: Union[omegaconf.DictConfig, nemo.core.classes.mixins.adapter_mixins.AdapterConfig])

Add an Adapter module to this model.

Should be overridden by subclass and super() call must be used - this will setup the config. After calling super(), forward this call to modules that implement the mixin.

Parameters
  • name – A globally unique name for the adapter. Will be used to access, enable and disable adapters.

  • cfg – A DictConfig that contains at the bare minimum __target__ to instantiate a new Adapter module.

is_adapter_available() → bool

Checks if any Adapter module has been instantiated.

Should be overridden by the subclass.

Returns

bool, determining if any Adapter module has been instantiated. Returns true even if the adapters are enabled or disabled, false only if no adapters exist.

set_enabled_adapters(name: Optional[str] = None, enabled: bool = True)

Updated the internal adapter config, determining if an adapter (or all adapters) are either enabled or disabled.

A common user pattern would be to disable all adapters (either after adding them, or restoring a model with pre-existing adapters) and then simply enable one of the adapters.

Should be overridden by subclass and super() call must be used - this will setup the config. After calling super(), forward this call to modules that implement the mixin.

Copy
Copied!
            

model.set_enabled_adapters(enabled=False) model.set_enabled_adapters(name=<some adapter name>, enabled=True)

Parameters
  • name – Optional str. If a str name is given, the config will be updated to the value of enabled. If no name is given, then all adapters will be enabled/disabled.

  • enabled – Bool, determines if the adapter(s) will be enabled/disabled.

get_enabled_adapters() → List[str]

Returns a list of all enabled adapters.

Should be implemented by the subclass.

Returns

A list of str names of each enabled adapter(s).

check_valid_model_with_adapter_support_()

Utility method to test if the subclass of this mixin is an appropriate subclass of ModelPT itself.

Should be implemented by the subclass.

save_adapters(filepath: str, name: Optional[str] = None)

Utility method that saves only the adapter module(s), and not the entire model itself. This allows the sharing of adapters which are often just a fraction of the size of the full model, enabling easier deliver.

Note

The saved file is a pytorch compatible pickle file, containing the state dicts of the adapter(s), as well as a binary representation of the adapter config.

Parameters
  • filepath – A str filepath where the .pt file that will contain the adapter state dict.

  • name – Optional name of the adapter that will be saved to this file. If None is passed, all adapters will be saved to the file. The name can be either the global name (adapter_name), or the module level name (module:adapter_name).

load_adapters(filepath: str, name: Optional[str] = None, map_location: Optional[str] = None, strict: bool = True)

Utility method that restores only the adapter module(s), and not the entire model itself. This allows the sharing of adapters which are often just a fraction of the size of the full model, enabling easier deliver.

Note

During restoration, assumes that the model does not currently already have an adapter with the name (if provided), or any adapter that shares a name with the state dict’s modules (if name is not provided). This is to ensure that each adapter name is globally unique in a model.

Parameters
  • filepath – Filepath of the .pt file.

  • name – Optional name of the adapter that will be saved to this file. If None is passed, all adapters will be saved to the file. The name must be either the global name (adapter_name), or the module level name (module:adapter_name), whichever exactly matches the state dict.

  • map_location – Pytorch flag, where to place the adapter(s) state dict(s).

  • strict – Pytorch flag, whether to load the weights of the adapter(s) strictly or not.

update_adapter_cfg(cfg: omegaconf.DictConfig)

Utility method to recursively update all of the Adapter module configs with the provided config.

Note

It is not a (deep)copy, but a reference copy. Changes made to the config will be reflected to adapter submodules, but it is still encouraged to explicitly update the adapter_cfg using this method.

Parameters

cfg – DictConfig containing the value of model.cfg.adapters.

property adapter_module_names: List[str]

List of valid adapter modules that are supported by the model.

Note

Subclasses should override this property and return a list of str names, of all the modules that they support, which will enable users to determine where to place the adapter modules.

Returns

A list of str, one for each of the adapter modules that are supported. By default, the subclass should support the “global adapter” (‘’).

Below, we will discuss some useful functionality of Adapter Compatible Models.

  1. Save and restore a Model with adapter capability: Any NeMo model that implements this class correctly can save and restore NeMo models with adapter capabilities, thereby allowing sharing of adapters.

  2. save_adapters and load_adapters: Adapters are usually a very small number of parameters, there is no need for the entire model to be duplicated for each adapter. This method allows storing just the adapter module(s) separately from the Model, so that you can use the same “base” model, and share just the Adapter modules.

Adapters

[1]

Junxian He, Chunting Zhou, Xuezhe Ma, Taylor Berg-Kirkpatrick, and Graham Neubig. Towards a unified view of parameter-efficient transfer learning. 2021. URL: https://arxiv.org/abs/2110.04366, doi:10.48550/ARXIV.2110.04366.

[2]

Neil Houlsby, Andrei Giurgiu, Stanislaw Jastrzebski, Bruna Morrone, Quentin De Laroussilhe, Andrea Gesmundo, Mona Attariyan, and Sylvain Gelly. Parameter-efficient transfer learning for nlp. In International Conference on Machine Learning, 2790–2799. PMLR, 2019.

Previous Exporting NeMo Models
Next Adapter Components
© Copyright 2023-2024, NVIDIA. Last updated on Apr 25, 2024.