Source code for nemo.collections.asr.modules.conformer_encoder

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from collections import OrderedDict

import torch
import torch.nn as nn

from nemo.collections.asr.parts.conformer_modules import ConformerLayer
from nemo.collections.asr.parts.multi_head_attention import PositionalEncoding, RelPositionalEncoding
from nemo.collections.asr.parts.subsampling import ConvSubsampling
from nemo.core.classes.common import typecheck
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types import AcousticEncodedRepresentation, LengthsType, NeuralType, SpectrogramType

__all__ = ['ConformerEncoder']


[docs]class ConformerEncoder(NeuralModule, Exportable): """ The encoder for ASR model of Conformer. Based on this paper: 'Conformer: Convolution-augmented Transformer for Speech Recognition' by Anmol Gulati et al. https://arxiv.org/abs/2005.08100 Args: feat_in (int): the size of feature channels n_layers (int): number of layers of ConformerBlock d_model (int): the hidden size of the model feat_out (int): the size of the output features Defaults to -1 (means feat_out is d_model) subsampling (str): the method of subsampling, choices=['vggnet', 'striding'] Defaults to striding. subsampling_factor (int): the subsampling factor which should be power of 2 Defaults to 4. subsampling_conv_channels (int): the size of the convolutions in the subsampling module Defaults to -1 which would set it to d_model. ff_expansion_factor (int): the expansion factor in feed forward layers Defaults to 4. self_attention_model (str): type of the attention layer and positional encoding 'rel_pos': relative positional embedding and Transformer-XL 'abs_pos': absolute positional embedding and Transformer default is rel_pos. pos_emb_max_len (int): the maximum length of positional embeddings Defaulst to 5000 n_heads (int): number of heads in multi-headed attention layers Defaults to 4. xscaling (bool): enables scaling the inputs to the multi-headed attention layers by sqrt(d_model) Defaults to True. untie_biases (bool): whether to not share (untie) the bias weights between layers of Transformer-XL Defaults to True. conv_kernel_size (int): the size of the convolutions in the convolutional modules Defaults to 31. dropout (float): the dropout rate used in all layers except the attention layers Defaults to 0.1. dropout_emb (float): the dropout rate used for the positional embeddings Defaults to 0.1. dropout_att (float): the dropout rate used for the attention layer Defaults to 0.0. """ def _prepare_for_export(self): Exportable._prepare_for_export(self)
[docs] def input_example(self): """ Generates input examples for tracing etc. Returns: A tuple of input examples. """ input_example = torch.randn(16, self._feat_in, 256).to(next(self.parameters()).device) input_example_length = torch.randint(0, 256, (16,)).to(next(self.parameters()).device) return tuple([input_example, input_example_length])
@property def input_types(self): """Returns definitions of module input ports. """ return OrderedDict( { "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), "length": NeuralType(tuple('B'), LengthsType()), } ) @property def output_types(self): """Returns definitions of module output ports. """ return OrderedDict( { "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "encoded_lengths": NeuralType(tuple('B'), LengthsType()), } ) def __init__( self, feat_in, n_layers, d_model, feat_out=-1, subsampling='striding', subsampling_factor=4, subsampling_conv_channels=-1, ff_expansion_factor=4, self_attention_model='rel_pos', n_heads=4, att_context_size=None, xscaling=True, untie_biases=True, pos_emb_max_len=5000, conv_kernel_size=31, dropout=0.1, dropout_emb=0.1, dropout_att=0.0, ): super().__init__() d_ff = d_model * ff_expansion_factor self.d_model = d_model self._feat_in = feat_in self.scale = math.sqrt(self.d_model) if att_context_size: self.att_context_size = att_context_size else: self.att_context_size = [-1, -1] if xscaling: self.xscale = math.sqrt(d_model) else: self.xscale = None if subsampling_conv_channels == -1: subsampling_conv_channels = d_model if subsampling: self.pre_encode = ConvSubsampling( subsampling=subsampling, subsampling_factor=subsampling_factor, feat_in=feat_in, feat_out=d_model, conv_channels=subsampling_conv_channels, activation=nn.ReLU(), ) self._feat_out = d_model else: self._feat_out = d_model self.pre_encode = nn.Linear(feat_in, d_model) if not untie_biases and self_attention_model == "rel_pos": d_head = d_model // n_heads pos_bias_u = nn.Parameter(torch.Tensor(n_heads, d_head)) pos_bias_v = nn.Parameter(torch.Tensor(n_heads, d_head)) nn.init.zeros_(pos_bias_u) nn.init.zeros_(pos_bias_v) else: pos_bias_u = None pos_bias_v = None if self_attention_model == "rel_pos": self.pos_enc = RelPositionalEncoding( d_model=d_model, dropout_rate=dropout, max_len=pos_emb_max_len, xscale=self.xscale, dropout_rate_emb=dropout_emb, ) elif self_attention_model == "abs_pos": pos_bias_u = None pos_bias_v = None self.pos_enc = PositionalEncoding( d_model=d_model, dropout_rate=dropout, max_len=pos_emb_max_len, xscale=self.xscale ) else: raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!") self.layers = nn.ModuleList() for i in range(n_layers): layer = ConformerLayer( d_model=d_model, d_ff=d_ff, self_attention_model=self_attention_model, n_heads=n_heads, conv_kernel_size=conv_kernel_size, dropout=dropout, dropout_att=dropout_att, pos_bias_u=pos_bias_u, pos_bias_v=pos_bias_v, ) self.layers.append(layer) if feat_out > 0 and feat_out != self.output_dim: self.out_proj = nn.Linear(self.feat_out, feat_out) self._feat_out = feat_out else: self.out_proj = None self._feat_out = d_model
[docs] @typecheck() def forward(self, audio_signal, length=None): if length is None: length = torch.tensor(audio_signal.size(-1)).repeat(audio_signal.size(0)).to(audio_signal) audio_signal = torch.transpose(audio_signal, 1, 2) if isinstance(self.pre_encode, ConvSubsampling): audio_signal, length = self.pre_encode(audio_signal, length) else: audio_signal = self.embed(audio_signal) audio_signal, pos_emb = self.pos_enc(audio_signal) bs, xmax, idim = audio_signal.size() # Create the self-attention and padding masks pad_mask = self.make_pad_mask(length, max_time=xmax, device=audio_signal.device) att_mask = pad_mask.unsqueeze(1).repeat([1, xmax, 1]) att_mask = att_mask & att_mask.transpose(1, 2) if self.att_context_size[0] >= 0: att_mask = att_mask.triu(diagonal=-self.att_context_size[0]) if self.att_context_size[1] >= 0: att_mask = att_mask.tril(diagonal=self.att_context_size[1]) att_mask = ~att_mask pad_mask = ~pad_mask for lth, layer in enumerate(self.layers): audio_signal = layer(x=audio_signal, att_mask=att_mask, pos_emb=pos_emb, pad_mask=pad_mask) if self.out_proj is not None: audio_signal = self.out_proj(audio_signal) audio_signal = torch.transpose(audio_signal, 1, 2) return audio_signal, length
[docs] @staticmethod def make_pad_mask(seq_lens, max_time, device=None): """Make masking for padding.""" bs = seq_lens.size(0) seq_range = torch.arange(0, max_time, dtype=torch.int32) seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_time) seq_lens = seq_lens.type(seq_range_expand.dtype).to(seq_range_expand.device) seq_length_expand = seq_lens.unsqueeze(-1) mask = seq_range_expand < seq_length_expand if device: mask = mask.to(device) return mask