Important
NeMo 2.0 is an experimental feature and currently released in the dev container only: nvcr.io/nvidia/nemo:dev. Please refer to the Migration Guide for information on getting started.
Adapters
In NeMo, we often train models and fine-tune them for a specific task. This is a reasonable approach when the models are just a few million parameters. However, this approach quickly becomes infeasible when approaching hundreds of millions or even billions of parameters. As a potential solution to such a scenario, where fine-tuning a massive model is no longer feasible, we look to Adapters [adapters1] to specialize our model on a specific domain or task. Adapters require a fraction of the total number of parameters as the original model and are much more efficient to fine-tune.
Note
For a detailed tutorial on adding Adapter
support to any PyTorch module, please refer to the Tutorials for NeMo Adapters.
What are Adapters?
Adapters are a straightforward concept - one formulation can be shown by the diagram below. At their simplest, they are residual Feedforward layers that compress the input dimension (\(D\)) to a small bottleneck dimension (\(H\)), such that \(R^D \text{->} R^H\), compute an activation (such as ReLU), finally mapping \(R^H \text{->} R^D\) with another Feedforward layer. This output is then added to the input via a simple residual connection.
Adapter modules such as this are usually initialized such that the initial output of the adapter will always be zeros so as to prevent degradation of the original model’s performance due to addition of such modules.
torch.nn.Module
with Adapters
In NeMo, Adapters are supported via a Mixin
class that can be attached to any torch.nn.Module
. Such a module will
have multiple additional methods which will enable adapter capabilities in that module.
# Import the adapter mixin from NeMo
from nemo.core import adapter_mixins
# NOTE: See the *two* classes being inherited here !
class MyModule(torch.nn.Module, adapter_mixins.AdapterModuleMixin):
pass
AdapterModuleMixin
Let’s look into what AdapterModuleMixin
adds to the general PyTorch module. Some of the most important methods that are required are listed below :
add_adapter
: Used to add an adapter with a unique name to the module.get_enabled_adapters
: Returns a list of names of all enabled adapter modules.set_enabled_adapters
: Sets whether a single (or all) adapters are enabled or disabled.is_adapter_available
: Check if any adapter is available and enabled or not.
Modules that extend this mixin usually can directly use these methods without extending them, but we will cover a case below where you may wish to extend these methods.
- class nemo.core.adapter_mixins.AdapterModuleMixin
Bases:
abc.ABC
Generic Adapter Mixin that can augment any torch.nn.Module with Adapter module support.
This mixin class adds a hierarchical way to add any type of Adapter modules to a pre-existing module. Since Models are inherently also nn.Module, this mixin can be attached to any Model or Module. This mixin class adds several utility methods which are utilized or overridden as necessary.
An Adapter module is any Pytorch nn.Module that possess a few properties :
It’s input and output dimension are the same, while the hidden dimension need not be the same.
- The final layer of the Adapter module is zero-initialized, so that the residual connection to the adapter
yields the original output.
This mixin adds the following instance variables to the class this inherits it:
- adapter_layer: A torch.nn.ModuleDict(), whose keys are the names of the adapter (globally unique),
and values are the Adapter nn.Module().
adapter_cfg: A OmegaConf DictConfig object that holds the config of the adapters that are initialized.
- adapter_name: A str resolved name which is unique key globally, but more than one modules may share
this name.
- adapter_global_cfg_key: A str representing a key in the model config that can be provided by the user.
The value resolves to global_cfg, and can be overridden via model.cfg.adapters.global_cfg.*.
- adapter_metadata_cfg_key: A str representing a key in the model config that is used to preserve the
metadata of the adapter config.
Note
This module is not responsible for maintaining its config. Subclasses must ensure config is updated or preserved as needed. It is the responsibility of the subclasses to propagate the most up to date config to lower layers.
- adapter_global_cfg_key = 'global_cfg'
- adapter_metadata_cfg_key = 'adapter_meta_cfg'
- add_adapter(name: str, cfg: Union[omegaconf.DictConfig, nemo.core.classes.mixins.adapter_mixins.AdapterConfig], **kwargs)
Add an Adapter module to this module.
- Parameters
name – A globally unique name for the adapter. Will be used to access, enable and disable adapters.
cfg – A DictConfig or Dataclass that contains at the bare minimum __target__ to instantiate a new Adapter module.
- is_adapter_available() bool
Checks if any Adapter module has been instantiated.
- Returns
bool, determining if any Adapter module has been instantiated. Returns true even if the adapters are enabled or disabled, false only if no adapters exist.
- set_enabled_adapters(name: Optional[str] = None, enabled: bool = True)
Updated the internal adapter config, determining if an adapter (or all adapters) are either enabled or disabled.
A common user pattern would be to disable all adapters (either after adding them, or restoring a model with pre-existing adapters) and then simply enable one of the adapters.
module.set_enabled_adapters(enabled=False) module.set_enabled_adapters(name=<some adapter name>, enabled=True)
- Parameters
name – Optional str. If a str name is given, the config will be updated to the value of enabled. If no name is given, then all adapters will be enabled/disabled.
enabled – Bool, determines if the adapter(s) will be enabled/disabled.
- get_enabled_adapters() List[str]
Returns a list of all enabled adapters names. The names will always be the resolved names, without module info.
- Returns
A list of str names of each enabled adapter names(s).
- get_adapter_module(name: str)
Gets an adapter module by name if possible, otherwise returns None.
- Parameters
name – A str name (resolved or not) corresponding to an Adapter.
- Returns
An nn.Module if the name could be resolved and matched, otherwise None/
- get_adapter_cfg(name: str)
Same logic as get_adapter_module but to get the config
- set_accepted_adapter_types(adapter_types: List[Union[type, str]]) None
The module with this mixin can define a list of adapter names that it will accept. This method should be called in the modules init method and set the adapter names the module will expect to be added.
- Parameters
adapter_types – A list of str paths that correspond to classes. The class paths will be instantiated to ensure that the class path is correct.
- get_accepted_adapter_types() Set[type]
Utility function to get the set of all classes that are accepted by the module.
- Returns
Returns the set of accepted adapter types as classes, otherwise an empty set.
- unfreeze_enabled_adapters(freeze_batchnorm: bool = True) None
Utility method to unfreeze only the enabled Adapter module(s).
A common user pattern is to freeze all the modules (including all the adapters), and then unfreeze just the required adapters.
module.freeze() # only available to nemo.core.NeuralModule ! module.unfreeze_enabled_adapters()
- Parameters
freeze_batchnorm – An optional (and recommended) practice of freezing the updates to the moving average buffers of any and all BatchNorm*D layers. This is necessary to ensure that disabling all adapters will precisely yield the original (base) model’s outputs.
- forward_enabled_adapters(input: torch.Tensor)
Forward’s all active adapters one by one with the provided input, and chaining the outputs of each adapter layer to the next.
Utilizes the implicit merge strategy of each adapter when computing the adapter’s output, and how that output will be merged back with the original input.
- Parameters
input – The output tensor of the calling module is the input to the first adapter, whose output is then chained to the next adapter until all adapters are consumed.
- Returns
The result tensor, after all active adapters have finished their forward passes.
- resolve_adapter_module_name_(name: str) Tuple[str, str]
Utility method to resolve a given global/module adapter name to its components. Always returns a tuple representing (module_name, adapter_name). “:” is used as the delimiter for denoting the module name vs the adapter name.
Will attempt to also resolve a given adapter_name alone back to (module_name, adapter_name) if the metadata config exists for access.
- Parameters
name – A global adapter, or a module adapter name (with structure module_name:adapter_name).
- Returns
A tuple representing (module_name, adapter_name). If a global adapter is provided, module_name is set to ‘’.
- forward_single_enabled_adapter_(input: torch.Tensor, adapter_module: torch.nn.Module, *, adapter_name: str, adapter_strategy: nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy)
Perform the forward step of a single adapter module on some input data.
Note
Subclasses can override this method to accommodate more complicate adapter forward steps.
- Parameters
input – input: The output tensor of the calling module is the input to the first adapter, whose output is then chained to the next adapter until all adapters are consumed.
adapter_module – The adapter module that is currently required to perform the forward pass.
adapter_name – The resolved name of the adapter that is undergoing the current forward pass.
adapter_strategy – A subclass of AbstractAdapterStrategy, that determines how the output of the adapter should be merged with the input, or if it should be merged at all.
- Returns
The result tensor, after the current active adapter has finished its forward pass.
- check_supported_adapter_type_(adapter_cfg: omegaconf.DictConfig, supported_adapter_types: Optional[Iterable[type]] = None)
Utility method to check if the adapter module is a supported type by the module.
This method should be called by the subclass to ensure that the adapter module is a supported type.
Using the Adapter Module
Now that MyModule
supports adapters, we can easily add adapters, set their state, check if they are available and
perform their forward pass. Note that if multiple adapters are enabled, they are called in a chain, the output of the previous adapter is passed as input to the next adapter and so on.
# Import the adapter mixin and modules from NeMo
import torch
from nemo.core import adapter_mixins
from nemo.collections.common.parts import adapter_modules
class MyModule(torch.nn.Module, adapter_mixins.AdapterModuleMixin):
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self.layers(x) # assume self.layers is some Sequential() module
if self.is_adapter_available(): # check if adapters were added or not
output = self.forward_enabled_adapters() # perform the forward of all enabled adapters in a chain
return output
# Now let us create a module, add an adapter and do a forward pass with some random inputs
module = MyModule(dim) # assume dim is some input and output dimension of the module.
# Add an adapter
module.add_adapter("first_adapter", cfg=adapter_modules.LinearAdapter(in_features=dim, dim=5))
# Check if adapter is available
module.is_adapter_available() # returns True
# Check the name(s) of the enabled adapters
module.get_enabled_adapters() # returns ['first_adapter']
# Set the state of the adapter (by name)
module.set_enabled_adapters(name="first_adapter", enabled=True)
# Freeze all the weights of the original module (equivalent to calling module.freeze() for a NeuralModule)
for param in module.parameters():
param.requires_grad = False
# Unfreeze only the adapter weights (so that we finetune only the adapters and not the original weights !)
module.unfreeze_enabled_adapters()
# Now you can train this model's adapters !
input_data = torch.randn(4, dim) # assume dim is the input-output dim of the module
outputs_with_adapter = module(input_data)
# Compute loss and backward ...
Adapter Compatible Models
If the goal was to support adapters in a single module, then the goal has been accomplished. In the real world however, we
build large composite models out of multiple modules and combine them to build a final model that we then train. We do this using the
AdapterModelPTMixin
.
Note
For an in-depth guide to supporting hierarchical adapter modules, please refer to the Tutorials for NeMo Adapters.
- class nemo.core.adapter_mixins.AdapterModelPTMixin
Bases:
nemo.core.classes.mixins.adapter_mixins.AdapterModuleMixin
Adapter Mixin that can augment a ModelPT subclass with Adapter support.
This mixin class should be used only with a top level ModelPT subclass. This mixin class adds several utility methods which should be subclassed and overriden to propagated to the submodules as necessary.
An Adapter module is any Pytorch nn.Module that possess a few properties :
It’s input and output dimension are the same, while the hidden dimension need not be the same.
- The final layer of the Adapter module is zero-initialized, so that the residual connection to the adapter
yields the original output.
This mixin adds the following instance variables to the class this inherits it:
- adapter_layer: A torch.nn.ModuleDict(), whose keys are the names of the adapter (globally unique),
and values are the Adapter nn.Module().
adapter_cfg: A OmegaConf DictConfig object that holds the config of the adapters that are initialized.
adapter_global_cfg_key: A str representing a key in the model config that can be provided by the user. The value resolves to global_cfg, and can be overridden via model.cfg.adapters.global_cfg.*.
Note
This module is responsible for maintaining its config. At the ModelPT level, it will access and write Adapter config information to self.cfg.adapters.
- setup_adapters()
Utility method that is called in the ASR ModelPT-implementation constructor, so as to restore any adapters that were previously added.
Should be overriden by the subclass for additional setup steps as required.
This method should be called just once at constructor time.
- add_adapter(name: str, cfg: Union[omegaconf.DictConfig, nemo.core.classes.mixins.adapter_mixins.AdapterConfig])
Add an Adapter module to this model.
Should be overridden by subclass and super() call must be used - this will setup the config. After calling super(), forward this call to modules that implement the mixin.
- Parameters
name – A globally unique name for the adapter. Will be used to access, enable and disable adapters.
cfg – A DictConfig that contains at the bare minimum __target__ to instantiate a new Adapter module.
- is_adapter_available() bool
Checks if any Adapter module has been instantiated.
Should be overridden by the subclass.
- Returns
bool, determining if any Adapter module has been instantiated. Returns true even if the adapters are enabled or disabled, false only if no adapters exist.
- set_enabled_adapters(name: Optional[str] = None, enabled: bool = True)
Updated the internal adapter config, determining if an adapter (or all adapters) are either enabled or disabled.
A common user pattern would be to disable all adapters (either after adding them, or restoring a model with pre-existing adapters) and then simply enable one of the adapters.
Should be overridden by subclass and super() call must be used - this will setup the config. After calling super(), forward this call to modules that implement the mixin.
model.set_enabled_adapters(enabled=False) model.set_enabled_adapters(name=<some adapter name>, enabled=True)
- Parameters
name – Optional str. If a str name is given, the config will be updated to the value of enabled. If no name is given, then all adapters will be enabled/disabled.
enabled – Bool, determines if the adapter(s) will be enabled/disabled.
- get_enabled_adapters() List[str]
Returns a list of all enabled adapters.
Should be implemented by the subclass.
- Returns
A list of str names of each enabled adapter(s).
- check_valid_model_with_adapter_support_()
Utility method to test if the subclass of this mixin is an appropriate subclass of ModelPT itself.
Should be implemented by the subclass.
- save_adapters(filepath: str, name: Optional[str] = None)
Utility method that saves only the adapter module(s), and not the entire model itself. This allows the sharing of adapters which are often just a fraction of the size of the full model, enabling easier deliver.
Note
The saved file is a pytorch compatible pickle file, containing the state dicts of the adapter(s), as well as a binary representation of the adapter config.
- Parameters
filepath – A str filepath where the .pt file that will contain the adapter state dict.
name – Optional name of the adapter that will be saved to this file. If None is passed, all adapters will be saved to the file. The name can be either the global name (adapter_name), or the module level name (module:adapter_name).
- load_adapters(filepath: str, name: Optional[str] = None, map_location: Optional[str] = None, strict: bool = True)
Utility method that restores only the adapter module(s), and not the entire model itself. This allows the sharing of adapters which are often just a fraction of the size of the full model, enabling easier deliver.
Note
During restoration, assumes that the model does not currently already have an adapter with the name (if provided), or any adapter that shares a name with the state dict’s modules (if name is not provided). This is to ensure that each adapter name is globally unique in a model.
- Parameters
filepath – Filepath of the .pt file.
name – Optional name of the adapter that will be saved to this file. If None is passed, all adapters will be saved to the file. The name must be either the global name (adapter_name), or the module level name (module:adapter_name), whichever exactly matches the state dict.
map_location – Pytorch flag, where to place the adapter(s) state dict(s).
strict – Pytorch flag, whether to load the weights of the adapter(s) strictly or not.
- update_adapter_cfg(cfg: omegaconf.DictConfig)
Utility method to recursively update all of the Adapter module configs with the provided config.
Note
It is not a (deep)copy, but a reference copy. Changes made to the config will be reflected to adapter submodules, but it is still encouraged to explicitly update the adapter_cfg using this method.
- Parameters
cfg – DictConfig containing the value of model.cfg.adapters.
- replace_adapter_compatible_modules(update_config: bool = True, verbose: bool = True)
Utility method to replace all child modules with Adapter variants, if they exist. Does NOT recurse through children of children modules (only immediate children).
- Parameters
update_config – A flag that determines if the config should be updated or not.
verbose – A flag that determines if the method should log the changes made or not.
- property adapter_module_names: List[str]
List of valid adapter modules that are supported by the model.
Note
Subclasses should override this property and return a list of str names, of all the modules that they support, which will enable users to determine where to place the adapter modules.
- Returns
A list of str, one for each of the adapter modules that are supported. By default, the subclass should support the “default adapter” (‘’).
- property default_adapter_module_name: Optional[str]
Name of the adapter module that is used as “default” if a name of ‘’ is provided.
Note
Subclasses should override this property and return a str name of the module that they wish to denote as the default.
- Returns
A str name of a module, which is denoted as ‘default’ adapter or None. If None, then no default adapter is supported.
Below, we will discuss some useful functionality of Adapter Compatible Models.
Save and restore a Model with adapter capability
: Any NeMo model that implements this class correctly can save and restore NeMo models with adapter capabilities, thereby allowing sharing of adapters.save_adapters
andload_adapters
: Adapters are usually a very small number of parameters, there is no need for the entire model to be duplicated for each adapter. This method allows storing just the adapter module(s) separately from the Model, so that you can use the same “base” model, and share just the Adapter modules.
References
- adapters1
Neil Houlsby, Andrei Giurgiu, Stanislaw Jastrzebski, Bruna Morrone, Quentin De Laroussilhe, Andrea Gesmundo, Mona Attariyan, and Sylvain Gelly. Parameter-efficient transfer learning for nlp. In International Conference on Machine Learning, 2790–2799. PMLR, 2019.