Adapters API#

Core#

class nemo.core.adapter_mixins.AdapterModuleMixin#

Bases: abc.ABC

Generic Adapter Mixin that can augment any torch.nn.Module with Adapter module support.

This mixin class adds a hierarchical way to add any type of Adapter modules to a pre-existing module. Since Models are inherently also nn.Module, this mixin can be attached to any Model or Module. This mixin class adds several utility methods which are utilized or overridden as necessary.

An Adapter module is any Pytorch nn.Module that possess a few properties :

  • It’s input and output dimension are the same, while the hidden dimension need not be the same.

  • The final layer of the Adapter module is zero-initialized, so that the residual connection to the adapter

    yields the original output.

This mixin adds the following instance variables to the class this inherits it:

  • adapter_layer: A torch.nn.ModuleDict(), whose keys are the names of the adapter (globally unique),

    and values are the Adapter nn.Module().

  • adapter_cfg: A OmegaConf DictConfig object that holds the config of the adapters that are initialized.

  • adapter_name: A str resolved name which is unique key globally, but more than one modules may share

    this name.

  • adapter_global_cfg_key: A str representing a key in the model config that can be provided by the user.

    The value resolves to global_cfg, and can be overridden via model.cfg.adapters.global_cfg.*.

  • adapter_metadata_cfg_key: A str representing a key in the model config that is used to preserve the

    metadata of the adapter config.

Note: This module is not responsible for maintaining its config. Subclasses must ensure config is updated

or preserved as needed. It is the responsibility of the subclasses to propagate the most up to date config to lower layers.

adapter_global_cfg_key = 'global_cfg'#
adapter_metadata_cfg_key = 'adapter_meta_cfg'#
add_adapter(name: str, cfg: omegaconf.DictConfig)#

Add an Adapter module to this module.

Parameters
  • name – A globally unique name for the adapter. Will be used to access, enable and disable adapters.

  • cfg – A DictConfig or Dataclass that contains at the bare minimum __target__ to instantiate a new Adapter module.

is_adapter_available() bool#

Checks if any Adapter module has been instantiated.

Returns

bool, determining if any Adapter module has been instantiated. Returns true even if the adapters are enabled or disabled, false only if no adapters exist.

set_enabled_adapters(name: Optional[str] = None, enabled: bool = True)#

Updated the internal adapter config, determining if an adapter (or all adapters) are either enabled or disabled.

A common user pattern would be to disable all adapters (either after adding them, or restoring a model with pre-existing adapters) and then simply enable one of the adapters.

module.set_enabled_adapters(enabled=False)
module.set_enabled_adapters(name=<some adapter name>, enabled=True)
Parameters
  • name – Optional str. If a str name is given, the config will be updated to the value of enabled. If no name is given, then all adapters will be enabled/disabled.

  • enabled – Bool, determines if the adapter(s) will be enabled/disabled.

get_enabled_adapters() List[str]#

Returns a list of all enabled adapters names. The names will always be the resolved names, without module info.

Returns

A list of str names of each enabled adapter names(s).

unfreeze_enabled_adapters(freeze_batchnorm: bool = True) None#

Utility method to unfreeze only the enabled Adapter module(s).

A common user pattern is to freeze all the modules (including all the adapters), and then unfreeze just the required adapters.

module.freeze()  # only available to nemo.core.NeuralModule !
module.unfreeze_enabled_adapters()
Parameters

freeze_batchnorm – An optional (and recommended) practice of freezing the updates to the moving average buffers of any and all BatchNorm*D layers. This is necessary to ensure that disabling all adapters will precisely yield the original (base) model’s outputs.

forward_enabled_adapters(input: torch.Tensor)#

Forward’s all active adapters one by one with the provided input, and chaining the outputs of each adapter layer to the next.

Utilizes the implicit merge strategy of each adapter when computing the adapter’s output, and how that output will be merged back with the original input.

Note:

Parameters

input – The output tensor of the calling module is the input to the first adapter, whose output is then chained to the next adapter until all adapters are consumed.

Returns

The result tensor, after all active adapters have finished their forward passes.

resolve_adapter_module_name_(name: str) -> (<class 'str'>, <class 'str'>)#

Utility method to resolve a given global/module adapter name to its components. Always returns a tuple representing (module_name, adapter_name). “:” is used as the delimiter for denoting the module name vs the adapter name.

Will attempt to also resolve a given adapter_name alone back to (module_name, adapter_name) if the metadata config exists for access.

Parameters

name – A global adapter, or a module adapter name (with structure module_name:adapter_name).

Returns

A tuple representing (module_name, adapter_name). If a global adapter is provided, module_name is set to ‘’.

forward_single_enabled_adapter_(input: torch.Tensor, adapter_module: torch.nn.Module, *, adapter_name: str, adapter_strategy: nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy)#

Perform the forward step of a single adapter module on some input data.

Note: Subclasses can override this method to accommodate more complicate adapter forward steps.

Parameters
  • input – input: The output tensor of the calling module is the input to the first adapter, whose output is then chained to the next adapter until all adapters are consumed.

  • adapter_module – The adapter module that is currently required to perform the forward pass.

  • adapter_name – The resolved name of the adapter that is undergoing the current forward pass.

  • adapter_strategy – A subclass of AbstractAdapterStrategy, that determines how the output of the adapter should be merged with the input, or if it should be merged at all.

Returns

The result tensor, after the current active adapter has finished its forward pass.


class nemo.core.adapter_mixins.AdapterModelPTMixin#

Bases: nemo.core.classes.mixins.adapter_mixins.AdapterModuleMixin

Adapter Mixin that can augment a ModelPT subclass with Adapter support.

This mixin class should be used only with a top level ModelPT subclass. This mixin class adds several utility methods which should be subclassed and overriden to propagated to the submodules as necessary.

An Adapter module is any Pytorch nn.Module that possess a few properties :

  • It’s input and output dimension are the same, while the hidden dimension need not be the same.

  • The final layer of the Adapter module is zero-initialized, so that the residual connection to the adapter

    yields the original output.

This mixin adds the following instance variables to the class this inherits it:

  • adapter_layer: A torch.nn.ModuleDict(), whose keys are the names of the adapter (globally unique),

    and values are the Adapter nn.Module().

  • adapter_cfg: A OmegaConf DictConfig object that holds the config of the adapters that are initialized.

  • adapter_global_cfg_key: A str representing a key in the model config that can be provided by the user. The value resolves to global_cfg, and can be overridden via model.cfg.adapters.global_cfg.*.

Note

This module is responsible for maintaining its config. At the ModelPT level, it will access and write Adapter config information to self.cfg.adapters.

setup_adapters()#

Utility method that is called in the ASR ModelPT-implementation constructor, so as to restore any adapters that were previously added.

Should be overriden by the subclass for additional setup steps as required.

This method should be called just once at constructor time.

add_adapter(name: str, cfg: omegaconf.DictConfig)#

Add an Adapter module to this model.

Should be overridden by subclass and super() call must be used - this will setup the config. After calling super(), forward this call to modules that implement the mixin.

Parameters
  • name – A globally unique name for the adapter. Will be used to access, enable and disable adapters.

  • cfg – A DictConfig that contains at the bare minimum __target__ to instantiate a new Adapter module.

is_adapter_available() bool#

Checks if any Adapter module has been instantiated.

Should be overridden by the subclass.

Returns

bool, determining if any Adapter module has been instantiated. Returns true even if the adapters are enabled or disabled, false only if no adapters exist.

set_enabled_adapters(name: Optional[str] = None, enabled: bool = True)#

Updated the internal adapter config, determining if an adapter (or all adapters) are either enabled or disabled.

A common user pattern would be to disable all adapters (either after adding them, or restoring a model with pre-existing adapters) and then simply enable one of the adapters.

Should be overridden by subclass and super() call must be used - this will setup the config. After calling super(), forward this call to modules that implement the mixin.

model.set_enabled_adapters(enabled=False)
model.set_enabled_adapters(name=<some adapter name>, enabled=True)
Parameters
  • name – Optional str. If a str name is given, the config will be updated to the value of enabled. If no name is given, then all adapters will be enabled/disabled.

  • enabled – Bool, determines if the adapter(s) will be enabled/disabled.

get_enabled_adapters() List[str]#

Returns a list of all enabled adapters.

Should be implemented by the subclass.

Returns

A list of str names of each enabled adapter(s).

check_valid_model_with_adapter_support_()#

Utility method to test if the subclass of this mixin is an appropriate subclass of ModelPT itself.

Should be implemented by the subclass.

save_adapters(filepath: str, name: Optional[str] = None)#

Utility method that saves only the adapter module(s), and not the entire model itself. This allows the sharing of adapters which are often just a fraction of the size of the full model, enabling easier deliver.

Note: The saved file is a pytorch compatible pickle file, containing the state dicts of the adapter(s),

as well as a binary representation of the adapter config.

Parameters
  • filepath – A str filepath where the .pt file that will contain the adapter state dict.

  • name – Optional name of the adapter that will be saved to this file. If None is passed, all adapters will be saved to the file. The name can be either the global name (adapter_name), or the module level name (module:adapter_name).

load_adapters(filepath: str, name: Optional[str] = None, map_location: Optional[str] = None, strict: bool = True)#

Utility method that restores only the adapter module(s), and not the entire model itself. This allows the sharing of adapters which are often just a fraction of the size of the full model, enabling easier deliver.

Note: During restoration, assumes that the model does not currently already have an adapter with

the name (if provided), or any adapter that shares a name with the state dict’s modules (if name is not provided). This is to ensure that each adapter name is globally unique in a model.

Parameters
  • filepath – Filepath of the .pt file.

  • name – Optional name of the adapter that will be saved to this file. If None is passed, all adapters will be saved to the file. The name must be either the global name (adapter_name), or the module level name (module:adapter_name), whichever exactly matches the state dict.

  • map_location – Pytorch flag, where to place the adapter(s) state dict(s).

  • strict – Pytorch flag, whether to load the weights of the adapter(s) strictly or not.

update_adapter_cfg(cfg: omegaconf.DictConfig)#

Utility method to recursively update all of the Adapter module configs with the provided config.

Note

It is not a (deep)copy, but a reference copy. Changes made to the config will be reflected to adapter submodules, but it is still encouraged to explicitly update the adapter_cfg using this method.

Parameters

cfg – DictConfig containing the value of model.cfg.adapters.

property adapter_module_names: List[str]#

List of valid adapter modules that are supported by the model.

Note: Subclasses should override this property and return a list of str names, of all the modules

that they support, which will enable users to determine where to place the adapter modules.

Returns

A list of str, one for each of the adapter modules that are supported. By default, the subclass should support the “global adapter” (‘’).


Adapter Networks#

class nemo.collections.common.parts.adapter_modules.AbstractAdapterModule(*args: Any, **kwargs: Any)[source]#

Bases: torch.nn.Module, nemo.core.classes.mixins.access_mixins.AccessMixin

Base class of Adapter Modules, providing common functionality to all Adapter Modules.

setup_adapter_strategy(adapter_strategy: Optional[nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy])[source]#

Setup adapter strategy of this class, enabling dynamic change in the way the adapter output is merged with the input.

When called successfully, will assign the variable adapter_strategy to the module.

Parameters

adapter_strategy – Can be a None or an implementation of AbstractAdapterStrategy.


class nemo.collections.common.parts.adapter_modules.LinearAdapter(*args: Any, **kwargs: Any)[source]#

Bases: nemo.collections.common.parts.adapter_modules.AbstractAdapterModule

Simple Linear Feedforward Adapter module with LayerNorm and singe hidden layer with activation function. Note: The adapter explicitly initializes its final layer with all zeros in order to avoid affecting the original model when all adapters are disabled.

Parameters
  • in_features – Input dimension of the module. Note that for adapters, input_dim == output_dim.

  • dim – Hidden dimension of the feed forward network.

  • activation – Str name for an activation function.

  • norm_position – Str, can be pre or post. Defaults to post. Determines whether the normalization will occur in the first layer or the last layer. Certain architectures may prefer one over the other.

  • dropout – float value, whether to perform dropout on the output of the last layer of the adapter.

  • adapter_strategy – By default, ResidualAddAdapterStrategyConfig. An adapter composition function object.

reset_parameters()[source]#
forward(x)[source]#

Adapter Strategies#

class nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy[source]#

Bases: abc.ABC

forward(input: torch.Tensor, adapter: torch.nn.Module, *, module: AdapterModuleMixin)[source]#

Forward method that defines how the output of the adapter should be merged with the input, or if it should be merged at all.

Also provides the module that called this strategy - thereby allowing access to all other adapters in the calling module. This can be useful if one adapter is a meta adapter, that combines the outputs of various adapters. In such a case, the input can be forwarded across all other adapters, collecting their outputs, and those outputs can then be merged via some strategy. For example, refer to :

Parameters
  • input – Original output tensor of the module, or the output of the previous adapter (if more than one adapters are enabled).

  • adapter – The adapter module that is currently required to perform the forward pass.

  • module – The calling module, in its entirety. It is a module that implements AdapterModuleMixin, therefore the strategy can access all other adapters in this module via module.adapter_layer.

Returns

The result tensor, after one of the active adapters has finished its forward passes.


class nemo.core.classes.mixins.adapter_mixin_strategies.ResidualAddAdapterStrategy(stochastic_depth: float = 0.0, l2_lambda: float = 0.0)[source]#

Bases: nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy

An implementation of residual addition of an adapter module with its input. Supports stochastic depth regularization.

forward(input: torch.Tensor, adapter: torch.nn.Module, *, module: AdapterModuleMixin)[source]#

A basic strategy, comprising of a residual connection over the input, after forward pass by the underlying adapter.

Parameters
  • input – Original output tensor of the module, or the output of the previous adapter (if more than one adapters are enabled).

  • adapter – The adapter module that is currently required to perform the forward pass.

  • module – The calling module, in its entirety. It is a module that implements AdapterModuleMixin, therefore the strategy can access all other adapters in this module via module.adapter_layer.

Returns

The result tensor, after one of the active adapters has finished its forward passes.