Source code for nemo.collections.asr.parts.mixins.mixins

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import tarfile
from abc import ABC, abstractmethod
from typing import List

import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from torch import Tensor

import nemo.collections.asr.models as asr_models
from nemo.collections.asr.parts.mixins.asr_adapter_mixins import ASRAdapterModelMixin
from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder
from nemo.collections.asr.parts.utils import asr_module_utils
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.asr.parts.utils.tokenizer_utils import (
    extract_capitalized_tokens_from_vocab,
    extract_punctuation_from_vocab,
)
from nemo.collections.common import tokenizers
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.utils import app_state, logging
from nemo.utils.file_utils import robust_copy


[docs] class ASRBPEMixin(ABC): """ASR BPE Mixin class that sets up a Tokenizer via a config This mixin class adds the method `_setup_tokenizer(...)`, which can be used by ASR models which depend on subword tokenization. The setup_tokenizer method adds the following parameters to the class - - tokenizer_cfg: The resolved config supplied to the tokenizer (with `dir` and `type` arguments). - tokenizer_dir: The directory path to the tokenizer vocabulary + additional metadata. - tokenizer_type: The type of the tokenizer. Currently supports `bpe` and `wpe`, as well as `agg`. - vocab_path: Resolved path to the vocabulary text file. In addition to these variables, the method will also instantiate and preserve a tokenizer (subclass of TokenizerSpec) if successful, and assign it to self.tokenizer. The mixin also supports aggregate tokenizers, which consist of ordinary, monolingual tokenizers. If a conversion between a monolongual and an aggregate tokenizer (or vice versa) is detected, all registered artifacts will be cleaned up. """ # this will be used in configs and nemo artifacts AGGREGATE_TOKENIZERS_DICT_PREFIX = 'langs' def _setup_tokenizer(self, tokenizer_cfg: DictConfig): tokenizer_type = tokenizer_cfg.get('type') if tokenizer_type is None: raise ValueError("`tokenizer.type` cannot be None") elif tokenizer_type.lower() == 'agg': self._setup_aggregate_tokenizer(tokenizer_cfg) else: self._setup_monolingual_tokenizer(tokenizer_cfg) self._derive_tokenizer_properties() def _setup_monolingual_tokenizer(self, tokenizer_cfg: DictConfig): # Prevent tokenizer parallelism (unless user has explicitly set it) if 'TOKENIZERS_PARALLELISM' not in os.environ: os.environ['TOKENIZERS_PARALLELISM'] = 'false' self.tokenizer_cfg = OmegaConf.to_container(tokenizer_cfg, resolve=True) # type: dict self.tokenizer_dir = self.tokenizer_cfg.pop('dir') # Remove tokenizer directory self.tokenizer_type = self.tokenizer_cfg.pop('type').lower() # Remove tokenizer_type self.hf_tokenizer_kwargs = self.tokenizer_cfg.pop("hf_kwargs", {}) # Remove HF tokenizer kwargs # just in case the previous tokenizer was an aggregate self._cleanup_aggregate_config_and_artifacts_if_needed() # Preserve config if hasattr(self, 'cfg') and 'tokenizer' in self.cfg: self.cfg.tokenizer.dir = self.tokenizer_dir self.cfg.tokenizer.type = self.tokenizer_type if 'hf_kwargs' in tokenizer_cfg: with open_dict(self.cfg.tokenizer): self.cfg.tokenizer.hf_kwargs = tokenizer_cfg.get('hf_kwargs') if self.tokenizer_type not in ['bpe', 'wpe']: raise ValueError( "`tokenizer.type` must be either `bpe` for SentencePiece tokenizer or " "`wpe` for BERT based tokenizer" ) if self.tokenizer_type == 'bpe': # This is a BPE Tokenizer if 'model_path' in self.tokenizer_cfg: model_path = self.tokenizer_cfg.get('model_path') else: model_path = os.path.join(self.tokenizer_dir, 'tokenizer.model') model_path = self.register_artifact('tokenizer.model_path', model_path) self.model_path = model_path if 'special_tokens' in self.tokenizer_cfg: special_tokens = self.tokenizer_cfg['special_tokens'] if special_tokens is not None: raise ValueError("`special_tokens` are no longer supported for SentencePiece based tokenizers.") if "custom_tokenizer" in self.tokenizer_cfg: self.tokenizer = self.from_config_dict( {"_target_": tokenizer_cfg["custom_tokenizer"]["_target_"], "model_path": model_path} ) else: self.tokenizer = tokenizers.SentencePieceTokenizer(model_path=model_path) if 'vocab_path' in self.tokenizer_cfg: vocab_path = self.tokenizer_cfg.get('vocab_path') else: vocab_path = os.path.join(self.tokenizer_dir, 'vocab.txt') vocab_path = self.register_artifact('tokenizer.vocab_path', vocab_path) self.vocab_path = vocab_path try: if 'spe_tokenizer_vocab' in self.tokenizer_cfg: spe_vocab_path = self.tokenizer_cfg.get('spe_tokenizer_vocab') else: spe_vocab_path = os.path.join(self.tokenizer_dir, 'tokenizer.vocab') spe_vocab_path = self.register_artifact('tokenizer.spe_tokenizer_vocab', spe_vocab_path) self.spe_vocab_path = spe_vocab_path except FileNotFoundError: # fallback case for older checkpoints that did not preserve the tokenizer.vocab self.spe_vocab_path = None vocabulary = {} for i in range(self.tokenizer.vocab_size): piece = self.tokenizer.ids_to_tokens([i]) piece = piece[0] vocabulary[piece] = i + 1 # wrapper method to get vocabulary conveniently def get_vocab(): return vocabulary # attach utility values to the tokenizer wrapper self.tokenizer.tokenizer.vocab_size = len(vocabulary) self.tokenizer.tokenizer.get_vocab = get_vocab self.tokenizer.tokenizer.all_special_tokens = self.tokenizer.special_token_to_id else: # This is a WPE Tokenizer # If path from previous registration exists, remove it if 'vocab_path' in self.tokenizer_cfg: vocab_path = self.tokenizer_cfg.get('vocab_path') else: vocab_path = os.path.join(self.tokenizer_dir, 'vocab.txt') vocab_path = self.register_artifact('tokenizer.vocab_path', vocab_path) self.vocab_path = vocab_path # If path from previous registration exists, remove it if 'vocab_path' in self.tokenizer_cfg: self.tokenizer_cfg.pop('vocab_path') self.tokenizer = tokenizers.AutoTokenizer( pretrained_model_name='bert-base-cased', vocab_file=self.vocab_path, mask_token=self.hf_tokenizer_kwargs.get('mask_token', None), bos_token=self.hf_tokenizer_kwargs.get('bos_token', None), eos_token=self.hf_tokenizer_kwargs.get('eos_token', None), pad_token=self.hf_tokenizer_kwargs.get('pad_token', None), sep_token=self.hf_tokenizer_kwargs.get('sep_token', None), cls_token=self.hf_tokenizer_kwargs.get('cls_token', None), unk_token=self.hf_tokenizer_kwargs.get('unk_token', None), use_fast=self.hf_tokenizer_kwargs.get('use_fast', False), ) logging.info( "Tokenizer {} initialized with {} tokens".format( self.tokenizer.__class__.__name__, self.tokenizer.vocab_size ) ) def _setup_aggregate_tokenizer(self, tokenizer_cfg: DictConfig): # Prevent tokenizer parallelism (unless user has explicitly set it) if 'TOKENIZERS_PARALLELISM' not in os.environ: os.environ['TOKENIZERS_PARALLELISM'] = 'false' self.tokenizer_cfg = OmegaConf.to_container(tokenizer_cfg, resolve=True) # type: dict # the aggregate tokenizer does not have one tokenizer_dir but multiple ones self.tokenizer_dir = None self.tokenizer_cfg.pop('dir', None) # Remove tokenizer directory, if any # Remove tokenizer_type -- obviously if we are here, the type is 'agg' self.tokenizer_type = self.tokenizer_cfg.pop('type').lower() # the aggregate tokenizer should not have these self.hf_tokenizer_kwargs = {} self.tokenizer_cfg.pop("hf_kwargs", {}) # Remove HF tokenizer kwargs, if any logging.info('_setup_tokenizer: detected an aggregate tokenizer') # need to de-register any monolingual config items if they exist self._cleanup_monolingual_and_aggregate_config_and_artifacts_if_needed() # overwrite tokenizer type if hasattr(self, 'cfg') and 'tokenizer' in self.cfg: self.cfg.tokenizer.type = self.tokenizer_type tokenizers_dict = {} # init each of the monolingual tokenizers found in the config and assemble into AggregateTokenizer for lang, tokenizer_config in self.tokenizer_cfg[self.AGGREGATE_TOKENIZERS_DICT_PREFIX].items(): ( tokenizer, model_path, vocab_path, spe_vocab_path, ) = self._make_tokenizer(tokenizer_config, lang) tokenizers_dict[lang] = tokenizer if hasattr(self, 'cfg'): with open_dict(self.cfg.tokenizer): self.cfg.tokenizer[self.AGGREGATE_TOKENIZERS_DICT_PREFIX][lang]['dir'] = self.tokenizer_cfg[ self.AGGREGATE_TOKENIZERS_DICT_PREFIX ][lang]['dir'] self.cfg.tokenizer[self.AGGREGATE_TOKENIZERS_DICT_PREFIX][lang]['type'] = self.tokenizer_cfg[ self.AGGREGATE_TOKENIZERS_DICT_PREFIX ][lang]['type'] if "custom_tokenizer" in tokenizer_cfg: # Class which implements this is usually a ModelPT, has access to Serializable mixin by extension self.tokenizer = self.from_config_dict( {"_target_": tokenizer_cfg["custom_tokenizer"]["_target_"], "tokenizers": tokenizers_dict} ) else: self.tokenizer = tokenizers.AggregateTokenizer(tokenizers_dict) def _make_tokenizer(self, tokenizer_cfg: DictConfig, lang=None): tokenizer_type = tokenizer_cfg.get('type').lower() tokenizer_dir = tokenizer_cfg.get('dir') if tokenizer_type not in ['bpe', 'wpe']: raise ValueError( '`tokenizer.type` must be either `bpe` for SentencePiece tokenizer or' '`wpe` for BERT based tokenizer' ) # defaults model_path = None vocab_path = None spe_vocab_path = None if tokenizer_type == 'bpe': # This is a BPE Tokenizer if 'model_path' in tokenizer_cfg: model_path = tokenizer_cfg.get('model_path') else: model_path = os.path.join(tokenizer_dir, 'tokenizer.model') model_path = self.register_artifact( 'tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.' + lang + '.model_path', model_path ) if 'special_tokens' in tokenizer_cfg: special_tokens = tokenizer_cfg['special_tokens'] if special_tokens is not None: raise ValueError('`special_tokens` are no longer supported for SentencePiece based tokenizers.') # Update special tokens tokenizer = tokenizers.SentencePieceTokenizer(model_path=model_path) if 'vocab_path' in tokenizer_cfg: vocab_path = tokenizer_cfg.get('vocab_path') else: vocab_path = os.path.join(tokenizer_dir, 'vocab.txt') vocab_path = self.register_artifact( 'tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.' + lang + '.vocab_path', vocab_path ) try: if 'spe_tokenizer_vocab' in tokenizer_cfg: spe_vocab_path = tokenizer_cfg.get('spe_tokenizer_vocab') else: spe_vocab_path = os.path.join(tokenizer_dir, 'tokenizer.vocab') spe_vocab_path = self.register_artifact( 'tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.' + lang + '.spe_tokenizer_vocab', spe_vocab_path, ) except FileNotFoundError: # fallback case for older checkpoints that did not preserve the tokenizer.vocab spe_vocab_path = None vocabulary = {} for i in range(tokenizer.vocab_size): piece = tokenizer.ids_to_tokens([i]) piece = piece[0] vocabulary[piece] = i + 1 # wrapper method to get vocabulary conveniently def get_vocab(): return vocabulary # attach utility values to the tokenizer wrapper tokenizer.tokenizer.vocab_size = len(vocabulary) tokenizer.tokenizer.get_vocab = get_vocab tokenizer.tokenizer.all_special_tokens = tokenizer.special_token_to_id else: # This is a WPE Tokenizer # If path from previous registration exists, remove it if 'vocab_path' in tokenizer_cfg: vocab_path = tokenizer_cfg.get('vocab_path') else: vocab_path = os.path.join(tokenizer_dir, 'vocab.txt') vocab_path = self.register_artifact( 'tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.' + lang + '.vocab_path', vocab_path ) # If path from previous registration exists, remove it if 'vocab_path' in tokenizer_cfg: tokenizer_cfg.pop('vocab_path') hf_tokenizer_kwargs = tokenizer_cfg.get('hf_kwargs', {}) tokenizer = tokenizers.AutoTokenizer( pretrained_model_name=hf_tokenizer_kwargs.get('pretrained_model_name', 'bert-base-cased'), vocab_file=vocab_path, mask_token=hf_tokenizer_kwargs.get('mask_token', None), bos_token=hf_tokenizer_kwargs.get('bos_token', None), eos_token=hf_tokenizer_kwargs.get('eos_token', None), pad_token=hf_tokenizer_kwargs.get('pad_token', None), sep_token=hf_tokenizer_kwargs.get('sep_token', None), cls_token=hf_tokenizer_kwargs.get('cls_token', None), unk_token=hf_tokenizer_kwargs.get('unk_token', None), use_fast=hf_tokenizer_kwargs.get('use_fast', False), ) logging.info( 'Tokenizer {} initialized with {} tokens'.format(tokenizer.__class__.__name__, tokenizer.vocab_size) ) return tokenizer, model_path, vocab_path, spe_vocab_path def _cleanup_monolingual_and_aggregate_config_and_artifacts_if_needed(self): """ Clean ups any monolingual and some aggregate config items and artifacts. We need to do this when we switch from a monolingual tokenizer to an aggregate one or go between aggregate tokenizers which could have a different number of languages """ if hasattr(self, 'cfg'): with open_dict(self.cfg.tokenizer): self.cfg.tokenizer.pop('dir', None) self.cfg.tokenizer.pop('model_path', None) self.cfg.tokenizer.pop('vocab_path', None) self.cfg.tokenizer.pop('spe_tokenizer_vocab', None) self.cfg.tokenizer.pop('hf_kwargs', None) # need to de-register any monolingual artifacts if they exist if hasattr(self, 'artifacts'): self.artifacts.pop('tokenizer.model_path', None) self.artifacts.pop('tokenizer.vocab_path', None) self.artifacts.pop('tokenizer.spe_tokenizer_vocab', None) # just in case we are replacing one aggregate tokenizer with another one, we better # clean up the old aggregate artifacts as well for akey in list(self.artifacts.keys()): if akey.startswith('tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.'): self.artifacts.pop(akey) def _cleanup_aggregate_config_and_artifacts_if_needed(self): """ Clean ups any aggregate config items and artifacts. We need to do this when we switch from an aggregate tokenizer to a monolingual one """ if hasattr(self, 'cfg'): with open_dict(self.cfg.tokenizer): self.cfg.tokenizer.pop(self.AGGREGATE_TOKENIZERS_DICT_PREFIX, None) # clean up the old aggregate artifacts as well if hasattr(self, 'artifacts'): for akey in list(self.artifacts.keys()): if akey.startswith('tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.'): self.artifacts.pop(akey)
[docs] def save_tokenizers(self, directory: str): """ Save the model tokenizer(s) to the specified directory. Args: directory: The directory to save the tokenizer(s) to. """ if not hasattr(self, 'cfg'): raise RuntimeError( "The model has not been initialized with a tokenizer yet. Please call the model's " "__init__ and _setup_tokenizer methods first." ) if self.tokenizer_type == 'agg': for lang in self.tokenizer.langs: subconfig = self.cfg.tokenizer.langs.get(lang) new_dir = os.path.join(directory, lang) self._extract_tokenizer_from_config(subconfig, new_dir) else: self._extract_tokenizer_from_config(self.cfg.tokenizer, directory)
def _extract_tokenizer_from_config(self, tokenizer_cfg: DictConfig, dir: str): """ Extracts the tokenizer from the config and write the objects to dir. The file may be from a local path (new model init) or from a .nemo file (restored model). If its from a newly initialized model, the file is copied to dir. If its from a restored model, the file is extracted from the .nemo file and copied to dir. Args: tokenizer_cfg: The tokenizer config to extract the tokenizer from. dir: The directory to write the tokenizer objects to. """ if not os.path.exists(dir): os.makedirs(dir, exist_ok=True) nemo_file_objects = [] for k, v in tokenizer_cfg.items(): # Check if the value is a filepath (new model init) or has `nemo:` in it (restored model) if isinstance(v, str) and os.path.exists(v): # local file from first instantiation loc = robust_copy(v, dir) logging.info(f"Saved {k} at {loc}") if isinstance(v, str) and v.startswith('nemo:'): nemo_object_name = v[5:] nemo_file_objects.append(nemo_object_name) if len(nemo_file_objects) > 0: logging.debug(f"Copying the following nemo file objects to {dir}: {nemo_file_objects}") if not hasattr(self, 'model_guid'): raise ValueError( "The model does not have a model_guid attribute. " "Please ensure that the model has been restored from a .nemo file." ) appstate = app_state.AppState() restore_path = appstate.get_model_metadata_from_guid(self.model_guid).restoration_path if restore_path is None: raise ValueError( "The model has not been restored from a .nemo file. Cannot extract the tokenizer " "as the nemo file cannot be located." ) # Read the nemo file without fully extracting all contents # we start with an assumption of uncompressed tar, # which should be true for versions 1.7.0 and above tar_header = "r:" try: tar_test = tarfile.open(restore_path, tar_header) tar_test.close() except tarfile.ReadError: # can be older checkpoint => try compressed tar tar_header = "r:gz" with tarfile.open(restore_path, tar_header) as tar: for nemo_object_name in nemo_file_objects: members = [x for x in tar.getmembers() if nemo_object_name in x.name] extracted_members = SaveRestoreConnector._safe_extract(tar, dir, members=members) for member in extracted_members: new_name = member.name.split("_")[1:] if len(new_name) > 1: new_name = "_".join(new_name) else: new_name = new_name[0] os.rename(os.path.join(dir, member.name), os.path.join(dir, new_name)) logging.info(f"Saved {nemo_object_name} at {os.path.join(dir, new_name)}") def _derive_tokenizer_properties(self): vocab = self.tokenizer.tokenizer.get_vocab() self.tokenizer.supports_capitalization = bool(extract_capitalized_tokens_from_vocab(vocab)) self.tokenizer.supported_punctuation = extract_punctuation_from_vocab(vocab)
[docs] class ASRModuleMixin(ASRAdapterModelMixin): """ ASRModuleMixin is a mixin class added to ASR models in order to add methods that are specific to a particular instantiation of a module inside of an ASRModel. Each method should first check that the module is present within the subclass, and support additional functionality if the corresponding module is present. """
[docs] def change_conv_asr_se_context_window(self, context_window: int, update_config: bool = True): """ Update the context window of the SqueezeExcitation module if the provided model contains an `encoder` which is an instance of `ConvASREncoder`. Args: context_window: An integer representing the number of input timeframes that will be used to compute the context. Each timeframe corresponds to a single window stride of the STFT features. Say the window_stride = 0.01s, then a context window of 128 represents 128 * 0.01 s of context to compute the Squeeze step. update_config: Whether to update the config or not with the new context window. """ asr_module_utils.change_conv_asr_se_context_window( self, context_window=context_window, update_config=update_config )
[docs] def change_attention_model( self, self_attention_model: str = None, att_context_size: List[int] = None, update_config: bool = True ): """ Update the self_attention_model if function is available in encoder. Args: self_attention_model (str): type of the attention layer and positional encoding 'rel_pos': relative positional embedding and Transformer-XL 'rel_pos_local_attn': relative positional embedding and Transformer-XL with local attention using overlapping windows. Attention context is determined by att_context_size parameter. 'abs_pos': absolute positional embedding and Transformer If None is provided, the self_attention_model isn't changed. Defauts to None. att_context_size (List[int]): List of 2 ints corresponding to left and right attention context sizes, or None to keep as it is. Defauts to None. update_config (bool): Whether to update the config or not with the new attention model. Defaults to True. """ if self_attention_model is None and att_context_size is None: return if not hasattr(self, 'encoder'): logging.info( "Could not change the self_attention_model in encoder " "since the model provided does not contain an `encoder` module in its config." ) return if not hasattr(self.encoder, "change_attention_model"): logging.info("Model encoder doesn't have a change_attention_model method ") return self.encoder.change_attention_model(self_attention_model, att_context_size, update_config, self.device) if update_config: with open_dict(self.cfg): self.cfg.encoder.self_attention_model = self_attention_model self.cfg.encoder.att_context_size = att_context_size
[docs] def change_subsampling_conv_chunking_factor( self, subsampling_conv_chunking_factor: int, update_config: bool = True ): """ Update the conv_chunking_factor (int) if function is available in encoder. Default is 1 (auto) Set it to -1 (disabled) or to a specific value (power of 2) if you OOM in the conv subsampling layers Args: conv_chunking_factor (int) """ if not hasattr(self, 'encoder'): logging.info( "Could not call the change_subsampling_conv_chunking_factor method in encoder " "since the model provided does not contain an `encoder` module in its config." ) return if not hasattr(self.encoder, "change_subsampling_conv_chunking_factor"): logging.info("Model encoder doesn't have a change_subsampling_conv_chunking_factor method ") return self.encoder.change_subsampling_conv_chunking_factor(subsampling_conv_chunking_factor) if update_config: with open_dict(self.cfg): self.cfg.encoder.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
def _apply_prompt_to_encoded(self, encoded: Tensor) -> Tensor: """Hook for prompt-conditioned subclasses to inject a language prompt into the encoder output. Default: no-op. See ``PromptStreamingMixin`` for the prompt-aware override.""" return encoded
[docs] def conformer_stream_step( self, processed_signal: Tensor, processed_signal_length: Tensor = None, cache_last_channel: Tensor = None, cache_last_time: Tensor = None, cache_last_channel_len: Tensor = None, keep_all_outputs: bool = True, previous_hypotheses: List[Hypothesis] = None, previous_pred_out: Tensor = None, drop_extra_pre_encoded: int = None, return_transcription: bool = True, return_log_probs: bool = False, bypass_pre_encode: bool = False, ): """ It simulates a forward step with caching for streaming purposes. It supports the ASR models where their encoder supports streaming like Conformer. Args: processed_signal: the input audio signals processed_signal_length: the length of the audios cache_last_channel: the cache tensor for last channel layers like MHA cache_last_channel_len: lengths for cache_last_channel cache_last_time: the cache tensor for last time layers like convolutions keep_all_outputs: if set to True, would not drop the extra outputs specified by encoder.streaming_cfg.valid_out_len previous_hypotheses: the hypotheses from the previous step for RNNT models previous_pred_out: the predicted outputs from the previous step for CTC models drop_extra_pre_encoded: number of steps to drop from the beginning of the outputs after the downsampling module. This can be used if extra paddings are added on the left side of the input. return_transcription: whether to decode and return the transcriptions. It can not get disabled for Transducer models. return_log_probs: whether to return the log probs, only valid for ctc model Returns: greedy_predictions: the greedy predictions from the decoder all_hyp_or_transcribed_texts: the decoder hypotheses for Transducer models and the transcriptions for CTC models cache_last_channel_next: the updated tensor cache for last channel layers to be used for next streaming step cache_last_time_next: the updated tensor cache for last time layers to be used for next streaming step cache_last_channel_next_len: the updated lengths for cache_last_channel best_hyp: the best hypotheses for the Transducer models log_probs: the logits tensor of current streaming chunk, only returned when return_log_probs=True encoded_len: the length of the output log_probs + history chunk log_probs, only returned when return_log_probs=True """ if not isinstance(self, asr_models.EncDecRNNTModel) and not isinstance(self, asr_models.EncDecCTCModel): raise NotImplementedError(f"stream_step does not support {type(self)}!") if not isinstance(self.encoder, StreamingEncoder): raise NotImplementedError("Encoder of this model does not support streaming!") if isinstance(self, asr_models.EncDecRNNTModel) and return_transcription is False: logging.info( "return_transcription can not be False for Transducer models as decoder returns the transcriptions too." ) if not isinstance(self, asr_models.EncDecCTCModel) and return_log_probs is True: logging.info("return_log_probs can only be True for CTC models.") ( encoded, encoded_len, cache_last_channel_next, cache_last_time_next, cache_last_channel_next_len, ) = self.encoder.cache_aware_stream_step( processed_signal=processed_signal, processed_signal_length=processed_signal_length, cache_last_channel=cache_last_channel, cache_last_time=cache_last_time, cache_last_channel_len=cache_last_channel_len, keep_all_outputs=keep_all_outputs, drop_extra_pre_encoded=drop_extra_pre_encoded, bypass_pre_encode=bypass_pre_encode, ) encoded = self._apply_prompt_to_encoded(encoded) if isinstance(self, asr_models.EncDecCTCModel) or ( isinstance(self, asr_models.EncDecHybridRNNTCTCModel) and self.cur_decoder == "ctc" ): if hasattr(self, "ctc_decoder"): decoding = self.ctc_decoding decoder = self.ctc_decoder else: decoding = self.decoding decoder = self.decoder log_probs = decoder(encoder_output=encoded) predictions_tensor = log_probs.argmax(dim=-1, keepdim=False) # Concatenate the previous predictions with the current one to have the full predictions. # We drop the extra predictions for each sample by using the lengths returned by the encoder (encoded_len) # Then create a list of the predictions for the batch. The predictions can have different lengths because of the paddings. greedy_predictions = [] if return_transcription: all_hyp_or_transcribed_texts = [] else: all_hyp_or_transcribed_texts = None for preds_idx, preds in enumerate(predictions_tensor): if encoded_len is None: preds_cur = predictions_tensor[preds_idx] else: preds_cur = predictions_tensor[preds_idx, : encoded_len[preds_idx]] if previous_pred_out is not None: greedy_predictions_concat = torch.cat((previous_pred_out[preds_idx], preds_cur), dim=-1) encoded_len[preds_idx] += len(previous_pred_out[preds_idx]) else: greedy_predictions_concat = preds_cur greedy_predictions.append(greedy_predictions_concat) # TODO: make decoding more efficient by avoiding the decoding process from the beginning if return_transcription: decoded_out = decoding.ctc_decoder_predictions_tensor( decoder_outputs=greedy_predictions_concat.unsqueeze(0), decoder_lengths=encoded_len[preds_idx : preds_idx + 1], return_hypotheses=False, ) all_hyp_or_transcribed_texts.append(decoded_out[0]) best_hyp = None else: best_hyp = self.decoding.rnnt_decoder_predictions_tensor( encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=True, partial_hypotheses=previous_hypotheses, ) greedy_predictions = [hyp.y_sequence for hyp in best_hyp] all_hyp_or_transcribed_texts = best_hyp result = [ greedy_predictions, all_hyp_or_transcribed_texts, cache_last_channel_next, cache_last_time_next, cache_last_channel_next_len, best_hyp, ] if return_log_probs: result.append(log_probs) result.append(encoded_len) return tuple(result)
[docs] @torch.no_grad() def transcribe_simulate_cache_aware_streaming( self, paths2audio_files: List[str], batch_size: int = 4, logprobs: bool = False, return_hypotheses: bool = False, online_normalization: bool = False, ): """ Args: paths2audio_files: (a list) of paths to audio files. batch_size: (int) batch size to use during inference. Bigger will result in better throughput performance but would use more memory. logprobs: (bool) pass True to get log probabilities instead of transcripts. return_hypotheses: (bool) Either return hypotheses or text With hypotheses can do some postprocessing like getting timestamp or rescoring online_normalization: (bool) Perform normalization on the run per chunk. Returns: A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files """ if paths2audio_files is None or len(paths2audio_files) == 0: return {} if return_hypotheses and logprobs: raise ValueError( "Either `return_hypotheses` or `logprobs` can be True at any given time." "Returned hypotheses will contain the logprobs." ) if not isinstance(self, asr_models.EncDecCTCModel): raise NotImplementedError(f"simulate streaming does not support {type(self)}!") if not isinstance(self.encoder, StreamingEncoder): raise NotImplementedError("Encoder of this model does not support streaming!") data_loader = self._setup_streaming_transcribe_dataloader(paths2audio_files, batch_size, online_normalization) total_log_probs = [] total_texts = [] for streaming_buffer in data_loader: streaming_buffer_iter = iter(streaming_buffer) batch_size = len(streaming_buffer.streams_length) cache_last_channel, cache_last_time, cache_last_channel_len = self.encoder.get_initial_cache_state( batch_size=batch_size ) previous_hypotheses = None pred_out_stream = None encoded_len = None transcribed_texts = None batch_log_probs = [] for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter): drop_extra_pre_encoded = self.encoder.streaming_cfg.drop_extra_pre_encoded if step_num != 0 else 0 with torch.inference_mode(): result = self.conformer_stream_step( processed_signal=chunk_audio, processed_signal_length=chunk_lengths, cache_last_channel=cache_last_channel, cache_last_time=cache_last_time, cache_last_channel_len=cache_last_channel_len, keep_all_outputs=streaming_buffer.is_buffer_empty(), previous_hypotheses=previous_hypotheses, previous_pred_out=pred_out_stream, drop_extra_pre_encoded=drop_extra_pre_encoded, return_transcription=True, return_log_probs=logprobs or return_hypotheses, ) if logprobs or return_hypotheses: ( pred_out_stream, transcribed_texts, cache_last_channel, cache_last_time, cache_last_channel_len, previous_hypotheses, cur_chunk_log_probs, encoded_len, ) = result batch_log_probs.append(cur_chunk_log_probs.cpu()) else: ( pred_out_stream, transcribed_texts, cache_last_channel, cache_last_time, cache_last_channel_len, previous_hypotheses, ) = result if logprobs or return_hypotheses: # concatenate chunk log probs on T dim batch_log_probs = torch.cat(batch_log_probs, axis=1) for log_probs, log_prob_len in zip(batch_log_probs, encoded_len): total_log_probs.append(log_probs[0:log_prob_len]) if transcribed_texts is None: total_texts += [''] * batch_size else: total_texts += transcribed_texts if logprobs: return total_log_probs if not return_hypotheses: return total_texts hyps = [] for log_probs, text in zip(total_log_probs, total_texts): hyps.append(Hypothesis(y_sequence=log_probs, text=text, score=0.0, dec_state=None)) return hyps
def _setup_streaming_transcribe_dataloader( self, paths2audio_files: List[str], batch_size: int, online_normalization=False ): """ Setup function for a temporary data loader which wraps the provided audio file. Args: paths2audio_files: (a list) of paths to audio files. batch_size: (int) batch size to use during inference. Bigger will result in better throughput performance but would use more memory. online_normalization: whether to do online normalization Returns: a new batch streaming buffer """ from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer streaming_buffer = CacheAwareStreamingAudioBuffer(model=self, online_normalization=online_normalization) for sample_idx, sample in enumerate(paths2audio_files): processed_signal, processed_signal_length, stream_id = streaming_buffer.append_audio_file( sample, stream_id=-1 ) logging.info(f'Added this sample to the buffer: {sample}') if (sample_idx + 1) % batch_size == 0 or sample_idx == len(paths2audio_files) - 1: logging.info(f"Starting to stream samples {sample_idx - len(streaming_buffer) + 1} to {sample_idx}...") yield streaming_buffer streaming_buffer.reset_buffer()
class VerificationMixin(ABC): @staticmethod def path2audio_files_to_manifest(paths2audio_files, manifest_filepath): """ Takes paths to audio files and manifest filepath and creates manifest file with the audios Args: paths2audio_files: paths to audio fragment to be verified manifest_filepath: path to manifest file to bre created """ with open(manifest_filepath, 'w', encoding='utf-8') as fp: for audio_file in paths2audio_files: audio_file = audio_file.strip() entry = {'audio_filepath': audio_file, 'offset': 0.0, 'duration': None, 'text': '-', 'label': 'infer'} fp.write(json.dumps(entry) + '\n')
[docs] class DiarizationMixin(VerificationMixin):
[docs] @abstractmethod def diarize(self, paths2audio_files: List[str], batch_size: int = 1) -> List[str]: """ Takes paths to audio files and returns speaker labels Args: paths2audio_files: paths to audio fragment to be transcribed Returns: Speaker labels """ pass
class PromptStreamingMixin: """Adds language-ID prompt conditioning to a cache-aware streaming ASR model. Overrides ``ASRModuleMixin._apply_prompt_to_encoded`` so that ``conformer_stream_step`` injects a one-hot language prompt into the encoder output. Subclasses must call ``super().initialize_prompt_feature()`` to populate ``self.concat``, ``self.num_prompts``, and ``self.prompt_kernel``; they may then attach their own decoding / WER objects. """ # Plain class-level defaults document the mixin contract. ``prompt_kernel`` # is intentionally NOT declared here — it's an ``nn.Module`` and a class-level # default would shadow ``nn.Module.__getattr__``'s lookup into ``_modules`` # after the real Sequential is registered in ``initialize_prompt_feature``. concat: bool = False num_prompts: int = None def initialize_prompt_feature(self): """Populate the attributes ``_apply_prompt_to_encoded`` depends on. Subclasses should call ``super().initialize_prompt_feature()`` first, then attach their decoding / WER / joint objects. The mixin sets ``self.concat``, ``self.num_prompts``, and ``self.prompt_kernel``. """ self.concat = True self.num_prompts = self.cfg.get('num_prompts', 128) proj_in_size = self.num_prompts + self._cfg.model_defaults.enc_hidden proj_out_size = self._cfg.model_defaults.enc_hidden self.prompt_kernel = torch.nn.Sequential( torch.nn.Linear(proj_in_size, proj_out_size * 2), torch.nn.ReLU(), torch.nn.Linear(proj_out_size * 2, proj_out_size), ) def set_inference_prompt(self, target_lang: str): """ Set the language prompt for streaming inference. Call this before ``conformer_stream_step`` to condition decoding on a specific language, following the same pattern as ``change_decoding_strategy``. Args: target_lang: A key from the model's ``prompt_dictionary`` (e.g. ``"en-US"``, ``"auto"``). """ prompt_dict = self.cfg.model_defaults.get('prompt_dictionary', {}) if target_lang not in prompt_dict: available = list(prompt_dict.keys()) raise ValueError( f"Unknown target language '{target_lang}'. " f"Available: {available[:20]}{'...' if len(available) > 20 else ''}" ) self._inference_prompt_index = prompt_dict[target_lang] logging.info(f"Inference prompt set to '{target_lang}' (index {self._inference_prompt_index})") def _apply_prompt_to_encoded(self, encoded: Tensor) -> Tensor: """ Inject the language-ID prompt into encoder output during streaming. ``encoded`` arrives as (B, D, T) from the encoder cache-aware step. Returns the same shape after prompt concatenation + projection. """ if not self.concat or not hasattr(self, '_inference_prompt_index'): return encoded encoded = encoded.transpose(1, 2) # (B, D, T) -> (B, T, D) batch_size, time_steps, _ = encoded.shape prompt = torch.zeros( batch_size, time_steps, self.num_prompts, dtype=encoded.dtype, device=encoded.device, ) idx = torch.full( (batch_size,), self._inference_prompt_index, dtype=torch.long, device=encoded.device, ) prompt.scatter_(2, idx.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) out_dtype = encoded.dtype encoded = self.prompt_kernel(torch.cat([encoded, prompt], dim=-1)).to(out_dtype) return encoded.transpose(1, 2) # (B, T, D) -> (B, D, T)