Important
NeMo 2.0 is an experimental feature and currently released in the dev container only: nvcr.io/nvidia/nemo:dev. Please refer to NeMo 2.0 overview for information on getting started.
Adapters API
Core
- class nemo.core.adapter_mixins.AdapterModuleMixin
Bases:
abc.ABC
Generic Adapter Mixin that can augment any torch.nn.Module with Adapter module support.
This mixin class adds a hierarchical way to add any type of Adapter modules to a pre-existing module. Since Models are inherently also nn.Module, this mixin can be attached to any Model or Module. This mixin class adds several utility methods which are utilized or overridden as necessary.
An Adapter module is any Pytorch nn.Module that possess a few properties :
It’s input and output dimension are the same, while the hidden dimension need not be the same.
- The final layer of the Adapter module is zero-initialized, so that the residual connection to the adapter
yields the original output.
This mixin adds the following instance variables to the class this inherits it:
- adapter_layer: A torch.nn.ModuleDict(), whose keys are the names of the adapter (globally unique),
and values are the Adapter nn.Module().
adapter_cfg: A OmegaConf DictConfig object that holds the config of the adapters that are initialized.
- adapter_name: A str resolved name which is unique key globally, but more than one modules may share
this name.
- adapter_global_cfg_key: A str representing a key in the model config that can be provided by the user.
The value resolves to global_cfg, and can be overridden via model.cfg.adapters.global_cfg.*.
- adapter_metadata_cfg_key: A str representing a key in the model config that is used to preserve the
metadata of the adapter config.
Note
This module is not responsible for maintaining its config. Subclasses must ensure config is updated or preserved as needed. It is the responsibility of the subclasses to propagate the most up to date config to lower layers.
- adapter_global_cfg_key = 'global_cfg'
- adapter_metadata_cfg_key = 'adapter_meta_cfg'
- add_adapter(name: str, cfg: Union[omegaconf.DictConfig, nemo.core.classes.mixins.adapter_mixins.AdapterConfig], **kwargs)
Add an Adapter module to this module.
- Parameters
name – A globally unique name for the adapter. Will be used to access, enable and disable adapters.
cfg – A DictConfig or Dataclass that contains at the bare minimum __target__ to instantiate a new Adapter module.
- is_adapter_available() bool
Checks if any Adapter module has been instantiated.
- Returns
bool, determining if any Adapter module has been instantiated. Returns true even if the adapters are enabled or disabled, false only if no adapters exist.
- set_enabled_adapters(name: Optional[str] = None, enabled: bool = True)
Updated the internal adapter config, determining if an adapter (or all adapters) are either enabled or disabled.
A common user pattern would be to disable all adapters (either after adding them, or restoring a model with pre-existing adapters) and then simply enable one of the adapters.
module.set_enabled_adapters(enabled=False) module.set_enabled_adapters(name=<some adapter name>, enabled=True)
- Parameters
name – Optional str. If a str name is given, the config will be updated to the value of enabled. If no name is given, then all adapters will be enabled/disabled.
enabled – Bool, determines if the adapter(s) will be enabled/disabled.
- get_enabled_adapters() List[str]
Returns a list of all enabled adapters names. The names will always be the resolved names, without module info.
- Returns
A list of str names of each enabled adapter names(s).
- get_adapter_module(name: str)
Gets an adapter module by name if possible, otherwise returns None.
- Parameters
name – A str name (resolved or not) corresponding to an Adapter.
- Returns
An nn.Module if the name could be resolved and matched, otherwise None/
- get_adapter_cfg(name: str)
Same logic as get_adapter_module but to get the config
- set_accepted_adapter_types(adapter_types: List[Union[type, str]]) None
The module with this mixin can define a list of adapter names that it will accept. This method should be called in the modules init method and set the adapter names the module will expect to be added.
- Parameters
adapter_types – A list of str paths that correspond to classes. The class paths will be instantiated to ensure that the class path is correct.
- get_accepted_adapter_types() Set[type]
Utility function to get the set of all classes that are accepted by the module.
- Returns
Returns the set of accepted adapter types as classes, otherwise an empty set.
- unfreeze_enabled_adapters(freeze_batchnorm: bool = True) None
Utility method to unfreeze only the enabled Adapter module(s).
A common user pattern is to freeze all the modules (including all the adapters), and then unfreeze just the required adapters.
module.freeze() # only available to nemo.core.NeuralModule ! module.unfreeze_enabled_adapters()
- Parameters
freeze_batchnorm – An optional (and recommended) practice of freezing the updates to the moving average buffers of any and all BatchNorm*D layers. This is necessary to ensure that disabling all adapters will precisely yield the original (base) model’s outputs.
- forward_enabled_adapters(input: torch.Tensor)
Forward’s all active adapters one by one with the provided input, and chaining the outputs of each adapter layer to the next.
Utilizes the implicit merge strategy of each adapter when computing the adapter’s output, and how that output will be merged back with the original input.
- Parameters
input – The output tensor of the calling module is the input to the first adapter, whose output is then chained to the next adapter until all adapters are consumed.
- Returns
The result tensor, after all active adapters have finished their forward passes.
- resolve_adapter_module_name_(name: str) Tuple[str, str]
Utility method to resolve a given global/module adapter name to its components. Always returns a tuple representing (module_name, adapter_name). “:” is used as the delimiter for denoting the module name vs the adapter name.
Will attempt to also resolve a given adapter_name alone back to (module_name, adapter_name) if the metadata config exists for access.
- Parameters
name – A global adapter, or a module adapter name (with structure module_name:adapter_name).
- Returns
A tuple representing (module_name, adapter_name). If a global adapter is provided, module_name is set to ‘’.
- forward_single_enabled_adapter_(input: torch.Tensor, adapter_module: torch.nn.Module, *, adapter_name: str, adapter_strategy: nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy)
Perform the forward step of a single adapter module on some input data.
Note
Subclasses can override this method to accommodate more complicate adapter forward steps.
- Parameters
input – input: The output tensor of the calling module is the input to the first adapter, whose output is then chained to the next adapter until all adapters are consumed.
adapter_module – The adapter module that is currently required to perform the forward pass.
adapter_name – The resolved name of the adapter that is undergoing the current forward pass.
adapter_strategy – A subclass of AbstractAdapterStrategy, that determines how the output of the adapter should be merged with the input, or if it should be merged at all.
- Returns
The result tensor, after the current active adapter has finished its forward pass.
- check_supported_adapter_type_(adapter_cfg: omegaconf.DictConfig, supported_adapter_types: Optional[Iterable[type]] = None)
Utility method to check if the adapter module is a supported type by the module.
This method should be called by the subclass to ensure that the adapter module is a supported type.
- class nemo.core.adapter_mixins.AdapterModelPTMixin
Bases:
nemo.core.classes.mixins.adapter_mixins.AdapterModuleMixin
Adapter Mixin that can augment a ModelPT subclass with Adapter support.
This mixin class should be used only with a top level ModelPT subclass. This mixin class adds several utility methods which should be subclassed and overriden to propagated to the submodules as necessary.
An Adapter module is any Pytorch nn.Module that possess a few properties :
It’s input and output dimension are the same, while the hidden dimension need not be the same.
- The final layer of the Adapter module is zero-initialized, so that the residual connection to the adapter
yields the original output.
This mixin adds the following instance variables to the class this inherits it:
- adapter_layer: A torch.nn.ModuleDict(), whose keys are the names of the adapter (globally unique),
and values are the Adapter nn.Module().
adapter_cfg: A OmegaConf DictConfig object that holds the config of the adapters that are initialized.
adapter_global_cfg_key: A str representing a key in the model config that can be provided by the user. The value resolves to global_cfg, and can be overridden via model.cfg.adapters.global_cfg.*.
Note
This module is responsible for maintaining its config. At the ModelPT level, it will access and write Adapter config information to self.cfg.adapters.
- setup_adapters()
Utility method that is called in the ASR ModelPT-implementation constructor, so as to restore any adapters that were previously added.
Should be overriden by the subclass for additional setup steps as required.
This method should be called just once at constructor time.
- add_adapter(name: str, cfg: Union[omegaconf.DictConfig, nemo.core.classes.mixins.adapter_mixins.AdapterConfig])
Add an Adapter module to this model.
Should be overridden by subclass and super() call must be used - this will setup the config. After calling super(), forward this call to modules that implement the mixin.
- Parameters
name – A globally unique name for the adapter. Will be used to access, enable and disable adapters.
cfg – A DictConfig that contains at the bare minimum __target__ to instantiate a new Adapter module.
- is_adapter_available() bool
Checks if any Adapter module has been instantiated.
Should be overridden by the subclass.
- Returns
bool, determining if any Adapter module has been instantiated. Returns true even if the adapters are enabled or disabled, false only if no adapters exist.
- set_enabled_adapters(name: Optional[str] = None, enabled: bool = True)
Updated the internal adapter config, determining if an adapter (or all adapters) are either enabled or disabled.
A common user pattern would be to disable all adapters (either after adding them, or restoring a model with pre-existing adapters) and then simply enable one of the adapters.
Should be overridden by subclass and super() call must be used - this will setup the config. After calling super(), forward this call to modules that implement the mixin.
model.set_enabled_adapters(enabled=False) model.set_enabled_adapters(name=<some adapter name>, enabled=True)
- Parameters
name – Optional str. If a str name is given, the config will be updated to the value of enabled. If no name is given, then all adapters will be enabled/disabled.
enabled – Bool, determines if the adapter(s) will be enabled/disabled.
- get_enabled_adapters() List[str]
Returns a list of all enabled adapters.
Should be implemented by the subclass.
- Returns
A list of str names of each enabled adapter(s).
- check_valid_model_with_adapter_support_()
Utility method to test if the subclass of this mixin is an appropriate subclass of ModelPT itself.
Should be implemented by the subclass.
- save_adapters(filepath: str, name: Optional[str] = None)
Utility method that saves only the adapter module(s), and not the entire model itself. This allows the sharing of adapters which are often just a fraction of the size of the full model, enabling easier deliver.
Note
The saved file is a pytorch compatible pickle file, containing the state dicts of the adapter(s), as well as a binary representation of the adapter config.
- Parameters
filepath – A str filepath where the .pt file that will contain the adapter state dict.
name – Optional name of the adapter that will be saved to this file. If None is passed, all adapters will be saved to the file. The name can be either the global name (adapter_name), or the module level name (module:adapter_name).
- load_adapters(filepath: str, name: Optional[str] = None, map_location: Optional[str] = None, strict: bool = True)
Utility method that restores only the adapter module(s), and not the entire model itself. This allows the sharing of adapters which are often just a fraction of the size of the full model, enabling easier deliver.
Note
During restoration, assumes that the model does not currently already have an adapter with the name (if provided), or any adapter that shares a name with the state dict’s modules (if name is not provided). This is to ensure that each adapter name is globally unique in a model.
- Parameters
filepath – Filepath of the .pt file.
name – Optional name of the adapter that will be saved to this file. If None is passed, all adapters will be saved to the file. The name must be either the global name (adapter_name), or the module level name (module:adapter_name), whichever exactly matches the state dict.
map_location – Pytorch flag, where to place the adapter(s) state dict(s).
strict – Pytorch flag, whether to load the weights of the adapter(s) strictly or not.
- update_adapter_cfg(cfg: omegaconf.DictConfig)
Utility method to recursively update all of the Adapter module configs with the provided config.
Note
It is not a (deep)copy, but a reference copy. Changes made to the config will be reflected to adapter submodules, but it is still encouraged to explicitly update the adapter_cfg using this method.
- Parameters
cfg – DictConfig containing the value of model.cfg.adapters.
- replace_adapter_compatible_modules(update_config: bool = True, verbose: bool = True)
Utility method to replace all child modules with Adapter variants, if they exist. Does NOT recurse through children of children modules (only immediate children).
- Parameters
update_config – A flag that determines if the config should be updated or not.
verbose – A flag that determines if the method should log the changes made or not.
- property adapter_module_names: List[str]
List of valid adapter modules that are supported by the model.
Note
Subclasses should override this property and return a list of str names, of all the modules that they support, which will enable users to determine where to place the adapter modules.
- Returns
A list of str, one for each of the adapter modules that are supported. By default, the subclass should support the “default adapter” (‘’).
- property default_adapter_module_name: Optional[str]
Name of the adapter module that is used as “default” if a name of ‘’ is provided.
Note
Subclasses should override this property and return a str name of the module that they wish to denote as the default.
- Returns
A str name of a module, which is denoted as ‘default’ adapter or None. If None, then no default adapter is supported.
Adapter Networks
- class nemo.collections.common.parts.adapter_modules.AdapterModuleUtil
Bases:
nemo.core.classes.mixins.access_mixins.AccessMixin
Base class of Adapter Modules, providing common functionality to all Adapter Modules.
- setup_adapter_strategy(adapter_strategy: Optional[nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy])
Setup adapter strategy of this class, enabling dynamic change in the way the adapter output is merged with the input.
When called successfully, will assign the variable adapter_strategy to the module.
- Parameters
adapter_strategy – Can be a None or an implementation of AbstractAdapterStrategy.
- get_default_strategy_config() dataclasses.dataclass
Returns a default adapter module strategy.
- adapter_unfreeze()
Sets the requires grad for all parameters in the adapter to True. This method should be overridden for any custom unfreeze behavior that is required. For example, if not all params of the adapter should be unfrozen.
- class nemo.collections.common.parts.adapter_modules.LinearAdapter(*args: Any, **kwargs: Any)
Bases:
torch.nn.Module
,nemo.collections.common.parts.adapter_modules.AdapterModuleUtil
Simple Linear Feedforward Adapter module with LayerNorm and singe hidden layer with activation function. Note: The adapter explicitly initializes its final layer with all zeros in order to avoid affecting the original model when all adapters are disabled.
- Parameters
in_features – Input dimension of the module. Note that for adapters, input_dim == output_dim.
dim – Hidden dimension of the feed forward network.
activation – Str name for an activation function.
norm_position – Str, can be pre or post. Defaults to pre. Determines whether the normalization will occur in the first layer or the last layer. Certain architectures may prefer one over the other.
dropout – float value, whether to perform dropout on the output of the last layer of the adapter.
adapter_strategy – By default, ResidualAddAdapterStrategyConfig. An adapter composition function object.
Adapter Strategies
- class nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy
Bases:
abc.ABC
- forward(input: torch.Tensor, adapter: torch.nn.Module, *, module: AdapterModuleMixin)
Forward method that defines how the output of the adapter should be merged with the input, or if it should be merged at all.
Also provides the module that called this strategy - thereby allowing access to all other adapters in the calling module. This can be useful if one adapter is a meta adapter, that combines the outputs of various adapters. In such a case, the input can be forwarded across all other adapters, collecting their outputs, and those outputs can then be merged via some strategy. For example, refer to :
[AdapterFusion: Non-Destructive Task Composition for Transfer Learning](https://arxiv.org/abs/2005.00247)
[Exploiting Adapters for Cross-lingual Low-resource Speech Recognition](https://arxiv.org/abs/2105.11905)
- Parameters
input – Original output tensor of the module, or the output of the previous adapter (if more than one adapters are enabled).
adapter – The adapter module that is currently required to perform the forward pass.
module – The calling module, in its entirety. It is a module that implements AdapterModuleMixin, therefore the strategy can access all other adapters in this module via module.adapter_layer.
- Returns
The result tensor, after one of the active adapters has finished its forward passes.
- class nemo.core.classes.mixins.adapter_mixin_strategies.ReturnResultAdapterStrategy
Bases:
nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy
An implementation of an adapter strategy that simply returns the result of the adapter. Supports stochastic
- forward(input: torch.Tensor, adapter: torch.nn.Module, *, module: AdapterModuleMixin)
A basic strategy, which simply returns the result of the adapter’s calculation as the output.
- Parameters
input – Original output tensor of the module, or the output of the previous adapter (if more than one adapters are enabled).
adapter – The adapter module that is currently required to perform the forward pass.
module – The calling module, in its entirety. It is a module that implements AdapterModuleMixin, therefore the strategy can access all other adapters in this module via module.adapter_layer.
- Returns
The result tensor, after one of the active adapters has finished its forward passes.
- compute_output(input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], Dict[str, Any]], adapter: torch.nn.Module, *, module: AdapterModuleMixin) torch.Tensor
Compute the output of a single adapter to some input.
- Parameters
input – Original output tensor of the module, or the output of the previous adapter (if more than one adapters are enabled).
adapter – The adapter module that is currently required to perform the forward pass.
module – The calling module, in its entirety. It is a module that implements AdapterModuleMixin, therefore the strategy can access all other adapters in this module via module.adapter_layer.
- Returns
The result tensor, after one of the active adapters has finished its forward passes.
- class nemo.core.classes.mixins.adapter_mixin_strategies.ResidualAddAdapterStrategy(stochastic_depth: float = 0.0, l2_lambda: float = 0.0)
Bases:
nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy
An implementation of residual addition of an adapter module with its input. Supports stochastic depth regularization.
- forward(input: torch.Tensor, adapter: torch.nn.Module, *, module: AdapterModuleMixin)
A basic strategy, comprising of a residual connection over the input, after forward pass by the underlying adapter.
- Parameters
input – Original output tensor of the module, or the output of the previous adapter (if more than one adapters are enabled).
adapter – The adapter module that is currently required to perform the forward pass.
module – The calling module, in its entirety. It is a module that implements AdapterModuleMixin, therefore the strategy can access all other adapters in this module via module.adapter_layer.
- Returns
The result tensor, after one of the active adapters has finished its forward passes.
- compute_output(input: torch.Tensor, adapter: torch.nn.Module, *, module: AdapterModuleMixin) torch.Tensor
Compute the output of a single adapter to some input.
- Parameters
input – Original output tensor of the module, or the output of the previous adapter (if more than one adapters are enabled).
adapter – The adapter module that is currently required to perform the forward pass.
module – The calling module, in its entirety. It is a module that implements AdapterModuleMixin, therefore the strategy can access all other adapters in this module via module.adapter_layer.
- Returns
The result tensor, after one of the active adapters has finished its forward passes.
- apply_stochastic_depth(output: torch.Tensor, input: torch.Tensor, adapter: torch.nn.Module, *, module: AdapterModuleMixin)
Compute and apply stochastic depth if probability is greater than 0.
- Parameters
output – The result tensor, after one of the active adapters has finished its forward passes.
input – Original output tensor of the module, or the output of the previous adapter (if more than one adapters are enabled).
adapter – The adapter module that is currently required to perform the forward pass.
module – The calling module, in its entirety. It is a module that implements AdapterModuleMixin, therefore the strategy can access all other adapters in this module via module.adapter_layer.
- Returns
The result tensor, after stochastic depth has been potentially applied to it.
- compute_auxiliary_losses(output: torch.Tensor, input: torch.Tensor, adapter: torch.nn.Module, *, module: AdapterModuleMixin)
Compute any auxiliary losses and preserve it in the tensor registry.
- Parameters
output – The result tensor, after one of the active adapters has finished its forward passes.
input – Original output tensor of the module, or the output of the previous adapter (if more than one adapters are enabled).
adapter – The adapter module that is currently required to perform the forward pass.
module – The calling module, in its entirety. It is a module that implements AdapterModuleMixin, therefore the strategy can access all other adapters in this module via module.adapter_layer.