Source code for nemo.collections.common.parts.adapter_modules

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, is_dataclass
from typing import Any, Optional

from hydra.utils import instantiate
from omegaconf import OmegaConf
from torch import nn as nn

from nemo.collections.common.parts.utils import activation_registry
from nemo.core.classes.mixins import access_mixins, adapter_mixin_strategies


class AdapterModuleUtil(access_mixins.AccessMixin):
    """
    Base class of Adapter Modules, providing common functionality to all Adapter Modules.
    """

    def setup_adapter_strategy(self, adapter_strategy: Optional[adapter_mixin_strategies.AbstractAdapterStrategy]):
        """
        Setup adapter strategy of this class, enabling dynamic change in the way the adapter output is
        merged with the input.

        When called successfully, will assign the variable `adapter_strategy` to the module.

        Args:
            adapter_strategy: Can be a None or an implementation of AbstractAdapterStrategy.
        """
        # set default adapter strategy
        if adapter_strategy is None:
            adapter_strategy = self.get_default_strategy_config()

        if is_dataclass(adapter_strategy):
            adapter_strategy = OmegaConf.structured(adapter_strategy)
            OmegaConf.set_struct(adapter_strategy, False)

        # The config must have the `_target_` field pointing to the actual adapter strategy class
        # which will load that strategy dynamically to this module.
        if isinstance(adapter_strategy, dict) or OmegaConf.is_config(adapter_strategy):
            self.adapter_strategy = instantiate(adapter_strategy)
        elif isinstance(adapter_strategy, adapter_mixin_strategies.AbstractAdapterStrategy):
            self.adapter_strategy = adapter_strategy
        else:
            raise AttributeError(f'`adapter_strategy` provided is invalid : {adapter_strategy}')

    def get_default_strategy_config(self) -> 'dataclass':
        """
        Returns a default adapter module strategy.
        """
        return adapter_mixin_strategies.ResidualAddAdapterStrategyConfig()


[docs]class LinearAdapter(nn.Module, AdapterModuleUtil): """ Simple Linear Feedforward Adapter module with LayerNorm and singe hidden layer with activation function. Note: The adapter explicitly initializes its final layer with all zeros in order to avoid affecting the original model when all adapters are disabled. Args: in_features: Input dimension of the module. Note that for adapters, input_dim == output_dim. dim: Hidden dimension of the feed forward network. activation: Str name for an activation function. norm_position: Str, can be `pre` or `post`. Defaults to `pre`. Determines whether the normalization will occur in the first layer or the last layer. Certain architectures may prefer one over the other. dropout: float value, whether to perform dropout on the output of the last layer of the adapter. adapter_strategy: By default, ResidualAddAdapterStrategyConfig. An adapter composition function object. """ def __init__( self, in_features: int, dim: int, activation: str = 'swish', norm_position: str = 'pre', dropout: float = 0.0, adapter_strategy: adapter_mixin_strategies.ResidualAddAdapterStrategyConfig = None, ): super().__init__() activation = activation_registry[activation]() # If the activation can be executed in place, do so. if hasattr(activation, 'inplace'): activation.inplace = True assert norm_position in ['pre', 'post'] self.norm_position = norm_position if norm_position == 'pre': self.module = nn.Sequential( nn.LayerNorm(in_features), nn.Linear(in_features, dim, bias=False), activation, nn.Linear(dim, in_features, bias=False), ) elif norm_position == 'post': self.module = nn.Sequential( nn.Linear(in_features, dim, bias=False), activation, nn.Linear(dim, in_features, bias=False), nn.LayerNorm(in_features), ) if dropout > 0.0: self.dropout = nn.Dropout(dropout) else: self.dropout = None # Setup adapter strategy self.setup_adapter_strategy(adapter_strategy) # reset parameters self.reset_parameters()
[docs] def reset_parameters(self): # Final layer initializations must be 0 if self.norm_position == 'pre': self.module[-1].weight.data *= 0 elif self.norm_position == 'post': self.module[-1].weight.data *= 0 self.module[-1].bias.data *= 0
[docs] def forward(self, x): x = self.module(x) # Add dropout if available if self.dropout is not None: x = self.dropout(x) return x
@dataclass class LinearAdapterConfig: in_features: int dim: int activation: str = 'swish' norm_position: str = 'pre' dropout: float = 0.0 adapter_strategy: Optional[Any] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig() _target_: str = "{0}.{1}".format(LinearAdapter.__module__, LinearAdapter.__name__)