Source code for nemo.collections.common.parts.adapter_modules

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, is_dataclass
from typing import Optional

from hydra.utils import instantiate
from omegaconf import OmegaConf
from torch import nn as nn

from nemo.collections.common.parts.utils import activation_registry
from nemo.core.classes.mixins import access_mixins, adapter_mixin_strategies


[docs]class AbstractAdapterModule(nn.Module, access_mixins.AccessMixin): """ Base class of Adapter Modules, providing common functionality to all Adapter Modules. """
[docs] def setup_adapter_strategy(self, adapter_strategy: Optional[adapter_mixin_strategies.AbstractAdapterStrategy]): """ Setup adapter strategy of this class, enabling dynamic change in the way the adapter output is merged with the input. When called successfully, will assign the variable `adapter_strategy` to the module. Args: adapter_strategy: Can be a None or an implementation of AbstractAdapterStrategy. """ # set default adapter strategy if adapter_strategy is None: adapter_strategy = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig() if is_dataclass(adapter_strategy): adapter_strategy = OmegaConf.structured(adapter_strategy) OmegaConf.set_struct(adapter_strategy, False) # The config must have the `_target_` field pointing to the actual adapter strategy class # which will load that strategy dynamically to this module. if isinstance(adapter_strategy, dict) or OmegaConf.is_config(adapter_strategy): self.adapter_strategy = instantiate(adapter_strategy) elif isinstance(adapter_strategy, adapter_mixin_strategies.AbstractAdapterStrategy): self.adapter_strategy = adapter_strategy else: raise AttributeError(f'`adapter_strategy` provided is invalid : {adapter_strategy}')
[docs]class LinearAdapter(AbstractAdapterModule): """ Simple Linear Feedforward Adapter module with LayerNorm and singe hidden layer with activation function. Note: The adapter explicitly initializes its final layer with all zeros in order to avoid affecting the original model when all adapters are disabled. Args: in_features: Input dimension of the module. Note that for adapters, input_dim == output_dim. dim: Hidden dimension of the feed forward network. activation: Str name for an activation function. norm_position: Str, can be `pre` or `post`. Defaults to `post`. Determines whether the normalization will occur in the first layer or the last layer. Certain architectures may prefer one over the other. dropout: float value, whether to perform dropout on the output of the last layer of the adapter. adapter_strategy: By default, ResidualAddAdapterStrategyConfig. An adapter composition function object. """ def __init__( self, in_features: int, dim: int, activation: str = 'swish', norm_position: str = "post", dropout: float = 0.0, adapter_strategy: adapter_mixin_strategies.ResidualAddAdapterStrategyConfig = None, ): super().__init__() activation = activation_registry[activation]() # If the activation can be executed in place, do so. if hasattr(activation, 'inplace'): activation.inplace = True assert norm_position in ['pre', 'post'] self.norm_position = norm_position if norm_position == 'pre': self.module = nn.Sequential( nn.LayerNorm(in_features), nn.Linear(in_features, dim, bias=False), activation, nn.Linear(dim, in_features, bias=False), ) elif norm_position == 'post': self.module = nn.Sequential( nn.Linear(in_features, dim, bias=False), activation, nn.Linear(dim, in_features, bias=False), nn.LayerNorm(in_features), ) if dropout > 0.0: self.dropout = nn.Dropout(dropout) else: self.dropout = None # Setup adapter strategy self.setup_adapter_strategy(adapter_strategy) # reset parameters self.reset_parameters()
[docs] def reset_parameters(self): # Final layer initializations must be 0 if self.norm_position == 'pre': self.module[-1].weight.data *= 0 elif self.norm_position == 'post': self.module[-1].weight.data *= 0 self.module[-1].bias.data *= 0
[docs] def forward(self, x): x = self.module(x) # Add dropout if available if self.dropout is not None: x = self.dropout(x) return x
@dataclass class LinearAdapterConfig: in_features: int dim: int activation: str = 'swish' norm_position: str = 'post' dropout: float = 0.0 adapter_strategy: Optional[dict] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig() _target_: str = "{0}.{1}".format(LinearAdapter.__module__, LinearAdapter.__name__)