# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import ABC, abstractmethod
from typing import List
from omegaconf import DictConfig, OmegaConf, open_dict
from nemo.collections.asr.parts.mixins.asr_adapter_mixins import ASRAdapterModelMixin
from nemo.collections.asr.parts.utils import asr_module_utils
from nemo.collections.common import tokenizers
from nemo.utils import logging
[docs]class ASRBPEMixin(ABC):
""" ASR BPE Mixin class that sets up a Tokenizer via a config
This mixin class adds the method `_setup_tokenizer(...)`, which can be used by ASR models
which depend on subword tokenization.
The setup_tokenizer method adds the following parameters to the class -
- tokenizer_cfg: The resolved config supplied to the tokenizer (with `dir` and `type` arguments).
- tokenizer_dir: The directory path to the tokenizer vocabulary + additional metadata.
- tokenizer_type: The type of the tokenizer. Currently supports `bpe` and `wpe`, as well as `agg`.
- vocab_path: Resolved path to the vocabulary text file.
In addition to these variables, the method will also instantiate and preserve a tokenizer
(subclass of TokenizerSpec) if successful, and assign it to self.tokenizer.
The mixin also supports aggregate tokenizers, which consist of ordinary, monolingual tokenizers.
If a conversion between a monolongual and an aggregate tokenizer (or vice versa) is detected,
all registered artifacts will be cleaned up.
"""
# this will be used in configs and nemo artifacts
AGGREGATE_TOKENIZERS_DICT_PREFIX = 'langs'
def _setup_tokenizer(self, tokenizer_cfg: DictConfig):
tokenizer_type = tokenizer_cfg.get('type')
if tokenizer_type is None:
raise ValueError("`tokenizer.type` cannot be None")
elif tokenizer_type.lower() == 'agg':
self._setup_aggregate_tokenizer(tokenizer_cfg)
else:
self._setup_monolingual_tokenizer(tokenizer_cfg)
def _setup_monolingual_tokenizer(self, tokenizer_cfg: DictConfig):
# Prevent tokenizer parallelism (unless user has explicitly set it)
if 'TOKENIZERS_PARALLELISM' not in os.environ:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
self.tokenizer_cfg = OmegaConf.to_container(tokenizer_cfg, resolve=True) # type: dict
self.tokenizer_dir = self.tokenizer_cfg.pop('dir') # Remove tokenizer directory
self.tokenizer_type = self.tokenizer_cfg.pop('type').lower() # Remove tokenizer_type
self.hf_tokenizer_kwargs = self.tokenizer_cfg.pop("hf_kwargs", {}) # Remove HF tokenizer kwargs
# just in case the previous tokenizer was an aggregate
self._cleanup_aggregate_config_and_artifacts_if_needed()
# Preserve config
if hasattr(self, 'cfg') and 'tokenizer' in self.cfg:
self.cfg.tokenizer.dir = self.tokenizer_dir
self.cfg.tokenizer.type = self.tokenizer_type
if 'hf_kwargs' in tokenizer_cfg:
with open_dict(self.cfg.tokenizer):
self.cfg.tokenizer.hf_kwargs = tokenizer_cfg.get('hf_kwargs')
if self.tokenizer_type not in ['bpe', 'wpe']:
raise ValueError(
"`tokenizer.type` must be either `bpe` for SentencePiece tokenizer or "
"`wpe` for BERT based tokenizer"
)
if self.tokenizer_type == 'bpe':
# This is a BPE Tokenizer
if 'model_path' in self.tokenizer_cfg:
model_path = self.tokenizer_cfg.get('model_path')
else:
model_path = os.path.join(self.tokenizer_dir, 'tokenizer.model')
model_path = self.register_artifact('tokenizer.model_path', model_path)
self.model_path = model_path
if 'special_tokens' in self.tokenizer_cfg:
special_tokens = self.tokenizer_cfg['special_tokens']
if special_tokens is not None:
raise ValueError("`special_tokens` are no longer supported for SentencePiece based tokenizers.")
# Update special tokens
self.tokenizer = tokenizers.SentencePieceTokenizer(model_path=model_path)
if 'vocab_path' in self.tokenizer_cfg:
vocab_path = self.tokenizer_cfg.get('vocab_path')
else:
vocab_path = os.path.join(self.tokenizer_dir, 'vocab.txt')
vocab_path = self.register_artifact('tokenizer.vocab_path', vocab_path)
self.vocab_path = vocab_path
try:
if 'spe_tokenizer_vocab' in self.tokenizer_cfg:
spe_vocab_path = self.tokenizer_cfg.get('spe_tokenizer_vocab')
else:
spe_vocab_path = os.path.join(self.tokenizer_dir, 'tokenizer.vocab')
spe_vocab_path = self.register_artifact('tokenizer.spe_tokenizer_vocab', spe_vocab_path)
self.spe_vocab_path = spe_vocab_path
except FileNotFoundError:
# fallback case for older checkpoints that did not preserve the tokenizer.vocab
self.spe_vocab_path = None
vocabulary = {}
for i in range(self.tokenizer.vocab_size):
piece = self.tokenizer.ids_to_tokens([i])
piece = piece[0]
vocabulary[piece] = i + 1
# wrapper method to get vocabulary conveniently
def get_vocab():
return vocabulary
# attach utility values to the tokenizer wrapper
self.tokenizer.tokenizer.vocab_size = len(vocabulary)
self.tokenizer.tokenizer.get_vocab = get_vocab
self.tokenizer.tokenizer.all_special_tokens = self.tokenizer.special_token_to_id
else:
# This is a WPE Tokenizer
# If path from previous registration exists, remove it
if 'vocab_path' in self.tokenizer_cfg:
vocab_path = self.tokenizer_cfg.get('vocab_path')
else:
vocab_path = os.path.join(self.tokenizer_dir, 'vocab.txt')
vocab_path = self.register_artifact('tokenizer.vocab_path', vocab_path)
self.vocab_path = vocab_path
# If path from previous registration exists, remove it
if 'vocab_path' in self.tokenizer_cfg:
self.tokenizer_cfg.pop('vocab_path')
self.tokenizer = tokenizers.AutoTokenizer(
pretrained_model_name='bert-base-cased',
vocab_file=self.vocab_path,
mask_token=self.hf_tokenizer_kwargs.get('mask_token', None),
bos_token=self.hf_tokenizer_kwargs.get('bos_token', None),
eos_token=self.hf_tokenizer_kwargs.get('eos_token', None),
pad_token=self.hf_tokenizer_kwargs.get('pad_token', None),
sep_token=self.hf_tokenizer_kwargs.get('sep_token', None),
cls_token=self.hf_tokenizer_kwargs.get('cls_token', None),
unk_token=self.hf_tokenizer_kwargs.get('unk_token', None),
use_fast=self.hf_tokenizer_kwargs.get('use_fast', False),
)
logging.info(
"Tokenizer {} initialized with {} tokens".format(
self.tokenizer.__class__.__name__, self.tokenizer.vocab_size
)
)
def _setup_aggregate_tokenizer(self, tokenizer_cfg: DictConfig):
# Prevent tokenizer parallelism (unless user has explicitly set it)
if 'TOKENIZERS_PARALLELISM' not in os.environ:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
self.tokenizer_cfg = OmegaConf.to_container(tokenizer_cfg, resolve=True) # type: dict
# the aggregate tokenizer does not have one tokenizer_dir but multiple ones
self.tokenizer_dir = None
self.tokenizer_cfg.pop('dir', None) # Remove tokenizer directory, if any
# Remove tokenizer_type -- obviously if we are here, the type is 'agg'
self.tokenizer_type = self.tokenizer_cfg.pop('type').lower()
# the aggregate tokenizer should not have these
self.hf_tokenizer_kwargs = {}
self.tokenizer_cfg.pop("hf_kwargs", {}) # Remove HF tokenizer kwargs, if any
logging.info('_setup_tokenizer: detected an aggregate tokenizer')
# need to de-register any monolingual config items if they exist
self._cleanup_monolingual_and_aggregate_config_and_artifacts_if_needed()
# overwrite tokenizer type
if hasattr(self, 'cfg') and 'tokenizer' in self.cfg:
self.cfg.tokenizer.type = self.tokenizer_type
tokenizers_dict = {}
# init each of the monolingual tokenizers found in the config and assemble into AggregateTokenizer
for lang, tokenizer_config in self.tokenizer_cfg[self.AGGREGATE_TOKENIZERS_DICT_PREFIX].items():
(tokenizer, model_path, vocab_path, spe_vocab_path,) = self._make_tokenizer(tokenizer_config, lang)
tokenizers_dict[lang] = tokenizer
if hasattr(self, 'cfg'):
with open_dict(self.cfg.tokenizer):
self.cfg.tokenizer[self.AGGREGATE_TOKENIZERS_DICT_PREFIX][lang]['dir'] = self.tokenizer_cfg[
self.AGGREGATE_TOKENIZERS_DICT_PREFIX
][lang]['dir']
self.cfg.tokenizer[self.AGGREGATE_TOKENIZERS_DICT_PREFIX][lang]['type'] = self.tokenizer_cfg[
self.AGGREGATE_TOKENIZERS_DICT_PREFIX
][lang]['type']
self.tokenizer = tokenizers.AggregateTokenizer(tokenizers_dict)
def _make_tokenizer(self, tokenizer_cfg: DictConfig, lang=None):
tokenizer_type = tokenizer_cfg.get('type').lower()
tokenizer_dir = tokenizer_cfg.get('dir')
if tokenizer_type not in ['bpe', 'wpe']:
raise ValueError(
'`tokenizer.type` must be either `bpe` for SentencePiece tokenizer or' '`wpe` for BERT based tokenizer'
)
# defaults
model_path = None
vocab_path = None
spe_vocab_path = None
if tokenizer_type == 'bpe':
# This is a BPE Tokenizer
if 'model_path' in tokenizer_cfg:
model_path = tokenizer_cfg.get('model_path')
else:
model_path = os.path.join(tokenizer_dir, 'tokenizer.model')
model_path = self.register_artifact(
'tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.' + lang + '.model_path', model_path
)
if 'special_tokens' in tokenizer_cfg:
special_tokens = tokenizer_cfg['special_tokens']
if special_tokens is not None:
raise ValueError('`special_tokens` are no longer supported for SentencePiece based tokenizers.')
# Update special tokens
tokenizer = tokenizers.SentencePieceTokenizer(model_path=model_path)
if 'vocab_path' in tokenizer_cfg:
vocab_path = tokenizer_cfg.get('vocab_path')
else:
vocab_path = os.path.join(tokenizer_dir, 'vocab.txt')
vocab_path = self.register_artifact(
'tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.' + lang + '.vocab_path', vocab_path
)
try:
if 'spe_tokenizer_vocab' in tokenizer_cfg:
spe_vocab_path = tokenizer_cfg.get('spe_tokenizer_vocab')
else:
spe_vocab_path = os.path.join(tokenizer_dir, 'tokenizer.vocab')
spe_vocab_path = self.register_artifact(
'tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.' + lang + '.spe_tokenizer_vocab',
spe_vocab_path,
)
except FileNotFoundError:
# fallback case for older checkpoints that did not preserve the tokenizer.vocab
spe_vocab_path = None
vocabulary = {}
for i in range(tokenizer.vocab_size):
piece = tokenizer.ids_to_tokens([i])
piece = piece[0]
vocabulary[piece] = i + 1
# wrapper method to get vocabulary conveniently
def get_vocab():
return vocabulary
# attach utility values to the tokenizer wrapper
tokenizer.tokenizer.vocab_size = len(vocabulary)
tokenizer.tokenizer.get_vocab = get_vocab
tokenizer.tokenizer.all_special_tokens = tokenizer.special_token_to_id
else:
# This is a WPE Tokenizer
# If path from previous registration exists, remove it
if 'vocab_path' in tokenizer_cfg:
vocab_path = tokenizer_cfg.get('vocab_path')
else:
vocab_path = os.path.join(tokenizer_dir, 'vocab.txt')
vocab_path = self.register_artifact(
'tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.' + lang + '.vocab_path', vocab_path
)
# If path from previous registration exists, remove it
if 'vocab_path' in tokenizer_cfg:
tokenizer_cfg.pop('vocab_path')
hf_tokenizer_kwargs = tokenizer_cfg.get('hf_kwargs', {})
tokenizer = tokenizers.AutoTokenizer(
pretrained_model_name='bert-base-cased',
vocab_file=vocab_path,
mask_token=hf_tokenizer_kwargs.get('mask_token', None),
bos_token=hf_tokenizer_kwargs.get('bos_token', None),
eos_token=hf_tokenizer_kwargs.get('eos_token', None),
pad_token=hf_tokenizer_kwargs.get('pad_token', None),
sep_token=hf_tokenizer_kwargs.get('sep_token', None),
cls_token=hf_tokenizer_kwargs.get('cls_token', None),
unk_token=hf_tokenizer_kwargs.get('unk_token', None),
use_fast=hf_tokenizer_kwargs.get('use_fast', False),
)
logging.info(
'Tokenizer {} initialized with {} tokens'.format(tokenizer.__class__.__name__, tokenizer.vocab_size)
)
return tokenizer, model_path, vocab_path, spe_vocab_path
def _cleanup_monolingual_and_aggregate_config_and_artifacts_if_needed(self):
"""
Clean ups any monolingual and some aggregate config items and artifacts.
We need to do this when we switch from a monolingual tokenizer to an aggregate one
or go between aggregate tokenizers which could have a different number of languages
"""
if hasattr(self, 'cfg'):
with open_dict(self.cfg.tokenizer):
self.cfg.tokenizer.pop('dir', None)
self.cfg.tokenizer.pop('model_path', None)
self.cfg.tokenizer.pop('vocab_path', None)
self.cfg.tokenizer.pop('spe_tokenizer_vocab', None)
self.cfg.tokenizer.pop('hf_kwargs', None)
# need to de-register any monolingual artifacts if they exist
if hasattr(self, 'artifacts'):
self.artifacts.pop('tokenizer.model_path', None)
self.artifacts.pop('tokenizer.vocab_path', None)
self.artifacts.pop('tokenizer.spe_tokenizer_vocab', None)
# just in case we are replacing one aggregate tokenizer with another one, we better
# clean up the old aggregate artifacts as well
for akey in list(self.artifacts.keys()):
if akey.startswith('tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.'):
self.artifacts.pop(akey)
def _cleanup_aggregate_config_and_artifacts_if_needed(self):
"""
Clean ups any aggregate config items and artifacts.
We need to do this when we switch from an aggregate tokenizer to a monolingual one
"""
if hasattr(self, 'cfg'):
with open_dict(self.cfg.tokenizer):
self.cfg.tokenizer.pop(self.AGGREGATE_TOKENIZERS_DICT_PREFIX, None)
# clean up the old aggregate artifacts as well
if hasattr(self, 'artifacts'):
for akey in list(self.artifacts.keys()):
if akey.startswith('tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.'):
self.artifacts.pop(akey)
[docs]class ASRModuleMixin(ASRAdapterModelMixin):
"""
ASRModuleMixin is a mixin class added to ASR models in order to add methods that are specific
to a particular instantiation of a module inside of an ASRModel.
Each method should first check that the module is present within the subclass, and support additional
functionality if the corresponding module is present.
"""
[docs] def change_conv_asr_se_context_window(self, context_window: int, update_config: bool = True):
"""
Update the context window of the SqueezeExcitation module if the provided model contains an
`encoder` which is an instance of `ConvASREncoder`.
Args:
context_window: An integer representing the number of input timeframes that will be used
to compute the context. Each timeframe corresponds to a single window stride of the
STFT features.
Say the window_stride = 0.01s, then a context window of 128 represents 128 * 0.01 s
of context to compute the Squeeze step.
update_config: Whether to update the config or not with the new context window.
"""
asr_module_utils.change_conv_asr_se_context_window(
self, context_window=context_window, update_config=update_config
)
[docs]class DiarizationMixin(ABC):
[docs] @abstractmethod
def diarize(self, paths2audio_files: List[str], batch_size: int = 1) -> List[str]:
"""
Takes paths to audio files and returns speaker labels
Args:
paths2audio_files: paths to audio fragment to be transcribed
Returns:
Speaker labels
"""
pass