Source code for nemo.collections.asr.modules.conformer_encoder
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import OrderedDict
from typing import List, Optional
import torch
import torch.distributed
import torch.nn as nn
from nemo.collections.asr.parts.submodules.conformer_modules import ConformerLayer
from nemo.collections.asr.parts.submodules.multi_head_attention import PositionalEncoding, RelPositionalEncoding
from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling, StackingSubsampling
from nemo.core.classes.common import typecheck
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.mixins import adapter_mixins
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types import AcousticEncodedRepresentation, LengthsType, NeuralType, SpectrogramType
__all__ = ['ConformerEncoder']
[docs]class ConformerEncoder(NeuralModule, Exportable):
"""
The encoder for ASR model of Conformer.
Based on this paper:
'Conformer: Convolution-augmented Transformer for Speech Recognition' by Anmol Gulati et al.
https://arxiv.org/abs/2005.08100
Args:
feat_in (int): the size of feature channels
n_layers (int): number of layers of ConformerBlock
d_model (int): the hidden size of the model
feat_out (int): the size of the output features
Defaults to -1 (means feat_out is d_model)
subsampling (str): the method of subsampling, choices=['vggnet', 'striding']
Defaults to striding.
subsampling_factor (int): the subsampling factor which should be power of 2
Defaults to 4.
subsampling_conv_channels (int): the size of the convolutions in the subsampling module
Defaults to -1 which would set it to d_model.
ff_expansion_factor (int): the expansion factor in feed forward layers
Defaults to 4.
self_attention_model (str): type of the attention layer and positional encoding
'rel_pos': relative positional embedding and Transformer-XL
'abs_pos': absolute positional embedding and Transformer
default is rel_pos.
pos_emb_max_len (int): the maximum length of positional embeddings
Defaulst to 5000
n_heads (int): number of heads in multi-headed attention layers
Defaults to 4.
xscaling (bool): enables scaling the inputs to the multi-headed attention layers by sqrt(d_model)
Defaults to True.
untie_biases (bool): whether to not share (untie) the bias weights between layers of Transformer-XL
Defaults to True.
conv_kernel_size (int): the size of the convolutions in the convolutional modules
Defaults to 31.
conv_norm_type (str): the type of the normalization in the convolutional modules
Defaults to 'batch_norm'.
dropout (float): the dropout rate used in all layers except the attention layers
Defaults to 0.1.
dropout_emb (float): the dropout rate used for the positional embeddings
Defaults to 0.1.
dropout_att (float): the dropout rate used for the attention layer
Defaults to 0.0.
"""
[docs] def input_example(self, max_batch=1, max_dim=256):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
dev = next(self.parameters()).device
input_example = torch.randn(max_batch, self._feat_in, max_dim).to(dev)
input_example_length = torch.randint(1, max_dim, (max_batch,)).to(dev)
return tuple([input_example, input_example_length])
@property
def input_types(self):
"""Returns definitions of module input ports.
"""
return OrderedDict(
{
"audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
}
)
@property
def output_types(self):
"""Returns definitions of module output ports.
"""
return OrderedDict(
{
"outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"encoded_lengths": NeuralType(tuple('B'), LengthsType()),
}
)
def __init__(
self,
feat_in,
n_layers,
d_model,
feat_out=-1,
subsampling='striding',
subsampling_factor=4,
subsampling_conv_channels=-1,
ff_expansion_factor=4,
self_attention_model='rel_pos',
n_heads=4,
att_context_size=None,
xscaling=True,
untie_biases=True,
pos_emb_max_len=5000,
conv_kernel_size=31,
conv_norm_type='batch_norm',
dropout=0.1,
dropout_emb=0.1,
dropout_att=0.0,
):
super().__init__()
d_ff = d_model * ff_expansion_factor
self.d_model = d_model
self._feat_in = feat_in
self.scale = math.sqrt(self.d_model)
if att_context_size:
self.att_context_size = att_context_size
else:
self.att_context_size = [-1, -1]
if xscaling:
self.xscale = math.sqrt(d_model)
else:
self.xscale = None
if subsampling_conv_channels == -1:
subsampling_conv_channels = d_model
if subsampling and subsampling_factor > 1:
if subsampling == 'stacking':
self.pre_encode = StackingSubsampling(
subsampling_factor=subsampling_factor, feat_in=feat_in, feat_out=d_model
)
else:
self.pre_encode = ConvSubsampling(
subsampling=subsampling,
subsampling_factor=subsampling_factor,
feat_in=feat_in,
feat_out=d_model,
conv_channels=subsampling_conv_channels,
activation=nn.ReLU(),
)
else:
self.pre_encode = nn.Linear(feat_in, d_model)
self._feat_out = d_model
if not untie_biases and self_attention_model == "rel_pos":
d_head = d_model // n_heads
pos_bias_u = nn.Parameter(torch.Tensor(n_heads, d_head))
pos_bias_v = nn.Parameter(torch.Tensor(n_heads, d_head))
nn.init.zeros_(pos_bias_u)
nn.init.zeros_(pos_bias_v)
else:
pos_bias_u = None
pos_bias_v = None
self.pos_emb_max_len = pos_emb_max_len
if self_attention_model == "rel_pos":
self.pos_enc = RelPositionalEncoding(
d_model=d_model,
dropout_rate=dropout,
max_len=pos_emb_max_len,
xscale=self.xscale,
dropout_rate_emb=dropout_emb,
)
elif self_attention_model == "abs_pos":
pos_bias_u = None
pos_bias_v = None
self.pos_enc = PositionalEncoding(
d_model=d_model, dropout_rate=dropout, max_len=pos_emb_max_len, xscale=self.xscale
)
else:
raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!")
self.layers = nn.ModuleList()
for i in range(n_layers):
layer = ConformerLayer(
d_model=d_model,
d_ff=d_ff,
self_attention_model=self_attention_model,
n_heads=n_heads,
conv_kernel_size=conv_kernel_size,
conv_norm_type=conv_norm_type,
dropout=dropout,
dropout_att=dropout_att,
pos_bias_u=pos_bias_u,
pos_bias_v=pos_bias_v,
)
self.layers.append(layer)
if feat_out > 0 and feat_out != self._feat_out:
self.out_proj = nn.Linear(self._feat_out, feat_out)
self._feat_out = feat_out
else:
self.out_proj = None
self._feat_out = d_model
self.set_max_audio_length(self.pos_emb_max_len)
self.use_pad_mask = True
[docs] def set_max_audio_length(self, max_audio_length):
""" Sets maximum input length.
Pre-calculates internal seq_range mask.
"""
self.max_audio_length = max_audio_length
device = next(self.parameters()).device
seq_range = torch.arange(0, self.max_audio_length, device=device)
if hasattr(self, 'seq_range'):
self.seq_range = seq_range
else:
self.register_buffer('seq_range', seq_range, persistent=False)
self.pos_enc.extend_pe(max_audio_length, device)
[docs] @typecheck()
def forward(self, audio_signal, length=None):
self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device)
return self.forward_for_export(audio_signal=audio_signal, length=length)
[docs] @typecheck()
def forward_for_export(self, audio_signal, length):
max_audio_length: int = audio_signal.size(-1)
if max_audio_length > self.max_audio_length:
self.set_max_audio_length(max_audio_length)
if length is None:
length = audio_signal.new_full(
audio_signal.size(0), max_audio_length, dtype=torch.int32, device=self.seq_range.device
)
audio_signal = torch.transpose(audio_signal, 1, 2)
if isinstance(self.pre_encode, nn.Linear):
audio_signal = self.pre_encode(audio_signal)
else:
audio_signal, length = self.pre_encode(audio_signal, length)
audio_signal, pos_emb = self.pos_enc(audio_signal)
# adjust size
max_audio_length = audio_signal.size(1)
# Create the self-attention and padding masks
pad_mask = self.make_pad_mask(max_audio_length, length)
att_mask = pad_mask.unsqueeze(1).repeat([1, max_audio_length, 1])
att_mask = torch.logical_and(att_mask, att_mask.transpose(1, 2))
if self.att_context_size[0] >= 0:
att_mask = att_mask.triu(diagonal=-self.att_context_size[0])
if self.att_context_size[1] >= 0:
att_mask = att_mask.tril(diagonal=self.att_context_size[1])
att_mask = ~att_mask
if self.use_pad_mask:
pad_mask = ~pad_mask
else:
pad_mask = None
for lth, layer in enumerate(self.layers):
audio_signal = layer(x=audio_signal, att_mask=att_mask, pos_emb=pos_emb, pad_mask=pad_mask)
if self.out_proj is not None:
audio_signal = self.out_proj(audio_signal)
audio_signal = torch.transpose(audio_signal, 1, 2)
return audio_signal, length
[docs] def update_max_seq_length(self, seq_length: int, device):
# Find global max audio length across all nodes
if torch.distributed.is_initialized():
global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device)
# Update across all ranks in the distributed system
torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX)
seq_length = global_max_len.int().item()
if seq_length > self.max_audio_length:
self.set_max_audio_length(seq_length)
[docs] def make_pad_mask(self, max_audio_length, seq_lens):
"""Make masking for padding."""
mask = self.seq_range[:max_audio_length].expand(seq_lens.size(0), -1) < seq_lens.unsqueeze(-1)
return mask
[docs] def enable_pad_mask(self, on=True):
# On inference, user may chose to disable pad mask
mask = self.use_pad_mask
self.use_pad_mask = on
return mask
class ConformerEncoderAdapter(ConformerEncoder, adapter_mixins.AdapterModuleMixin):
# Higher level forwarding
def add_adapter(self, name: str, cfg: dict):
for conformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin
conformer_layer.add_adapter(name, cfg)
def is_adapter_available(self) -> bool:
return any([conformer_layer.is_adapter_available() for conformer_layer in self.layers])
def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True):
for conformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin
conformer_layer.set_enabled_adapters(name=name, enabled=enabled)
def get_enabled_adapters(self) -> List[str]:
names = set([])
for conformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin
names.update(conformer_layer.get_enabled_adapters())
names = sorted(list(names))
return names
"""
Register any additional information
"""
if adapter_mixins.get_registered_adapter(ConformerEncoder) is None:
adapter_mixins.register_adapter(base_class=ConformerEncoder, adapter_class=ConformerEncoderAdapter)