Source code for nemo.collections.asr.parts.submodules.rnnt_beam_decoding

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
from tqdm import tqdm

from nemo.collections.asr.modules import rnnt_abstract
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses, is_prefix, select_k_expansions
from nemo.core.classes import Typing, typecheck
from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType
from nemo.utils import logging


def pack_hypotheses(hypotheses: List[Hypothesis]) -> List[Hypothesis]:
    for idx, hyp in enumerate(hypotheses):  # type: rnnt_utils.Hypothesis
        hyp.y_sequence = torch.tensor(hyp.y_sequence, dtype=torch.long)

        if hyp.dec_state is not None:
            hyp.dec_state = _states_to_device(hyp.dec_state)

    return hypotheses


def _states_to_device(dec_state, device='cpu'):
    if torch.is_tensor(dec_state):
        dec_state = dec_state.to(device)

    elif isinstance(dec_state, (list, tuple)):
        dec_state = tuple(_states_to_device(dec_i, device) for dec_i in dec_state)

    return dec_state


[docs]class BeamRNNTInfer(Typing): """ Beam Search implementation ported from ESPNet implementation - https://github.com/espnet/espnet/blob/master/espnet/nets/beam_search_transducer.py Sequence level beam decoding or batched-beam decoding, performed auto-repressively depending on the search type chosen. Args: decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. joint_model: rnnt_utils.AbstractRNNTJoint implementation. beam_size: number of beams for beam search. Must be a positive integer >= 1. If beam size is 1, defaults to stateful greedy search. This greedy search might result in slightly different results than the greedy results obtained by GreedyRNNTInfer due to implementation differences. For accurate greedy results, please use GreedyRNNTInfer or GreedyBatchedRNNTInfer. search_type: str representing the type of beam search to perform. Must be one of ['beam', 'tsd', 'alsd']. 'nsc' is currently not supported. Algoritm used: `beam` - basic beam search strategy. Larger beams generally result in better decoding, however the time required for the search also grows steadily. `tsd` - time synchronous decoding. Please refer to the paper: [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) for details on the algorithm implemented. Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions. For longer sequences, T is greater, and can therefore take a long time for beams to obtain good results. This also requires greater memory to execute. `alsd` - alignment-length synchronous decoding. Please refer to the paper: [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) for details on the algorithm implemented. Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with growth factor of T + U_max, where U_max is the maximum target length expected during execution. Generally, T + U_max < T * max_symmetric_expansions. However, ALSD beams are non-unique, therefore it is required to use larger beam sizes to achieve the same (or close to the same) decoding accuracy as TSD. For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD. `maes` = modified adaptive expansion searcn. Please refer to the paper: [Accelerating RNN Transducer Inference via Adaptive Expansion Search](https://ieeexplore.ieee.org/document/9250505) Modified Adaptive Synchronous Decoding (mAES) execution time is adaptive w.r.t the number of expansions (for tokens) required per timestep. The number of expansions can usually be constrained to 1 or 2, and in most cases 2 is sufficient. This beam search technique can possibly obtain superior WER while sacrificing some evaluation time. score_norm: bool, whether to normalize the scores of the log probabilities. return_best_hypothesis: bool, decides whether to return a single hypothesis (the best out of N), or return all N hypothesis (sorted with best score first). The container class changes based this flag - When set to True (default), returns a single Hypothesis. When set to False, returns a NBestHypotheses container, which contains a list of Hypothesis. # The following arguments are specific to the chosen `search_type` tsd_max_sym_exp_per_step: Used for `search_type=tsd`. The maximum symmetric expansions allowed per timestep during beam search. Larger values should be used to attempt decoding of longer sequences, but this in turn increases execution time and memory usage. alsd_max_target_len: Used for `search_type=alsd`. The maximum expected target sequence length during beam search. Larger values allow decoding of longer sequences at the expense of execution time and memory. # The following two flags are placeholders and unused until `nsc` implementation is stabilized. nsc_max_timesteps_expansion: Unused int. nsc_prefix_alpha: Unused int. # mAES flags maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient. int > 1. maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1 in order to reduce expensive beam search cost later. int >= 0. maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, and affects the speed of inference since large values will perform large beam search in the next step. maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for expansion apart from the "most likely" candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally tuned on a validation set. softmax_temperature: Scales the logits of the joint prior to computing log_softmax. preserve_alignments: Bool flag which preserves the history of alignments generated during beam decoding (sample). When set to true, the Hypothesis will contain the non-null value for `alignments` in it. Here, `alignments` is a List of List of ints. The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. U is the number of target tokens for the current timestep Ti. NOTE: `preserve_alignments` is an invalid argument for any `search_type` other than basic beam search. """ @property def input_types(self): """Returns definitions of module input ports. """ return { "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "encoded_lengths": NeuralType(tuple('B'), LengthsType()), "partial_hypotheses": [NeuralType(elements_type=HypothesisType(), optional=True)], # must always be last } @property def output_types(self): """Returns definitions of module output ports. """ return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( self, decoder_model: rnnt_abstract.AbstractRNNTDecoder, joint_model: rnnt_abstract.AbstractRNNTJoint, beam_size: int, search_type: str = 'default', score_norm: bool = True, return_best_hypothesis: bool = True, tsd_max_sym_exp_per_step: Optional[int] = 50, alsd_max_target_len: Union[int, float] = 1.0, nsc_max_timesteps_expansion: int = 1, nsc_prefix_alpha: int = 1, maes_num_steps: int = 2, maes_prefix_alpha: int = 1, maes_expansion_gamma: float = 2.3, maes_expansion_beta: int = 2, language_model: Optional[Dict[str, Any]] = None, softmax_temperature: float = 1.0, preserve_alignments: bool = False, ): self.decoder = decoder_model self.joint = joint_model self.blank = decoder_model.blank_idx self.vocab_size = decoder_model.vocab_size self.search_type = search_type self.return_best_hypothesis = return_best_hypothesis if beam_size < 1: raise ValueError("Beam search size cannot be less than 1!") self.beam_size = beam_size self.score_norm = score_norm if self.beam_size == 1: logging.info("Beam size of 1 was used, switching to sample level `greedy_search`") self.search_algorithm = self.greedy_search elif search_type == "default": self.search_algorithm = self.default_beam_search elif search_type == "tsd": self.search_algorithm = self.time_sync_decoding elif search_type == "alsd": self.search_algorithm = self.align_length_sync_decoding elif search_type == "nsc": raise NotImplementedError("`nsc` (Constrained Beam Search) has not been implemented.") # self.search_algorithm = self.nsc_beam_search elif search_type == "maes": self.search_algorithm = self.modified_adaptive_expansion_search else: raise NotImplementedError( f"The search type ({search_type}) supplied is not supported!\n" f"Please use one of : (default, tsd, alsd, nsc)" ) if tsd_max_sym_exp_per_step is None: tsd_max_sym_exp_per_step = -1 if search_type in ['tsd', 'alsd', 'nsc'] and not self.decoder.blank_as_pad: raise ValueError( f"Search type was chosen as '{search_type}', however the decoder module provided " f"does not support the `blank` token as a pad value. {search_type} requires " f"the blank token as pad value support in order to perform batched beam search." f"Please chose one of the other beam search methods, or re-train your model " f"with this support." ) self.tsd_max_symmetric_expansion_per_step = tsd_max_sym_exp_per_step self.alsd_max_target_length = alsd_max_target_len self.nsc_max_timesteps_expansion = nsc_max_timesteps_expansion self.nsc_prefix_alpha = int(nsc_prefix_alpha) self.maes_prefix_alpha = int(maes_prefix_alpha) self.maes_num_steps = int(maes_num_steps) self.maes_expansion_gamma = float(maes_expansion_gamma) self.maes_expansion_beta = int(maes_expansion_beta) if self.maes_prefix_alpha < 0: raise ValueError("`maes_prefix_alpha` must be a positive integer.") if self.maes_num_steps < 2: raise ValueError("`maes_num_steps` must be greater than 1.") if softmax_temperature != 1.0 and language_model is not None: logging.warning( "Softmax temperature is not supported with LM decoding." "Setting softmax-temperature value to 1.0." ) self.softmax_temperature = 1.0 else: self.softmax_temperature = softmax_temperature self.language_model = language_model self.preserve_alignments = preserve_alignments @typecheck() def __call__( self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor, partial_hypotheses: Optional[List[Hypothesis]] = None, ) -> Union[Hypothesis, NBestHypotheses]: """Perform general beam search. Args: encoder_output: Encoded speech features (B, T_max, D_enc) encoded_lengths: Lengths of the encoder outputs Returns: Either a list containing a single Hypothesis (when `return_best_hypothesis=True`, otherwise a list containing a single NBestHypotheses, which itself contains a list of Hypothesis. This list is sorted such that the best hypothesis is the first element. """ # Preserve decoder and joint training state decoder_training_state = self.decoder.training joint_training_state = self.joint.training with torch.no_grad(): # Apply optional preprocessing encoder_output = encoder_output.transpose(1, 2) # (B, T, D) self.decoder.eval() self.joint.eval() hypotheses = [] with tqdm( range(encoder_output.size(0)), desc='Beam search progress:', total=encoder_output.size(0), unit='sample', ) as idx_gen: # Freeze the decoder and joint to prevent recording of gradients # during the beam loop. with self.decoder.as_frozen(), self.joint.as_frozen(): _p = next(self.joint.parameters()) dtype = _p.dtype # Decode every sample in the batch independently. for batch_idx in idx_gen: inseq = encoder_output[batch_idx : batch_idx + 1, : encoded_lengths[batch_idx], :] # [1, T, D] logitlen = encoded_lengths[batch_idx] if inseq.dtype != dtype: inseq = inseq.to(dtype=dtype) # Extract partial hypothesis if exists partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None # Execute the specific search strategy nbest_hyps = self.search_algorithm( inseq, logitlen, partial_hypotheses=partial_hypothesis ) # sorted list of hypothesis # Prepare the list of hypotheses nbest_hyps = pack_hypotheses(nbest_hyps) # Pack the result if self.return_best_hypothesis: best_hypothesis = nbest_hyps[0] # type: Hypothesis else: best_hypothesis = NBestHypotheses(nbest_hyps) # type: NBestHypotheses hypotheses.append(best_hypothesis) self.decoder.train(decoder_training_state) self.joint.train(joint_training_state) return (hypotheses,)
[docs] def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]: """Sort hypotheses by score or score given sequence length. Args: hyps: list of hypotheses Return: hyps: sorted list of hypotheses """ if self.score_norm: return sorted(hyps, key=lambda x: x.score / len(x.y_sequence), reverse=True) else: return sorted(hyps, key=lambda x: x.score, reverse=True)
[docs] def time_sync_decoding( self, h: torch.Tensor, encoded_lengths: torch.Tensor, partial_hypotheses: Optional[Hypothesis] = None ) -> List[Hypothesis]: """Time synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoded speech features (1, T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ if self.preserve_alignments: raise NotImplementedError("`preseve_alignments` is not implemented for Time-Synchronous Decoding.") if partial_hypotheses is not None: raise NotImplementedError("`partial_hypotheses` support is not supported") # Precompute some constants for blank position ids = list(range(self.vocab_size + 1)) ids.remove(self.blank) # Used when blank token is first vs last token if self.blank == 0: index_incr = 1 else: index_incr = 0 # prepare the batched beam states beam = min(self.beam_size, self.vocab_size) beam_state = self.decoder.initialize_state( torch.zeros(beam, device=h.device, dtype=h.dtype) ) # [L, B, H], [L, B, H] (for LSTMs) # Initialize first hypothesis for the beam (blank) B = [ Hypothesis( y_sequence=[self.blank], score=0.0, dec_state=self.decoder.batch_select_state(beam_state, 0), timestep=[-1], length=0, ) ] cache = {} for i in range(int(encoded_lengths)): hi = h[:, i : i + 1, :] # Update caches A = [] C = B h_enc = hi # For a limited number of symmetric expansions per timestep "i" for v in range(self.tsd_max_symmetric_expansion_per_step): D = [] # Decode a batch of beam states and scores beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score_hypothesis(C, cache, beam_state) # Extract the log probabilities and the predicted tokens beam_logp = torch.log_softmax( self.joint.joint(h_enc, beam_y) / self.softmax_temperature, dim=-1 ) # [B, 1, 1, V + 1] beam_logp = beam_logp[:, 0, 0, :] # [B, V + 1] beam_topk = beam_logp[:, ids].topk(beam, dim=-1) seq_A = [h.y_sequence for h in A] for j, hyp in enumerate(C): # create a new hypothesis in A if hyp.y_sequence not in seq_A: # If the sequence is not in seq_A, add it as the blank token # In this step, we dont add a token but simply update score A.append( Hypothesis( score=(hyp.score + float(beam_logp[j, self.blank])), y_sequence=hyp.y_sequence[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, timestep=hyp.timestep[:], length=encoded_lengths, ) ) else: # merge the existing blank hypothesis score with current score. dict_pos = seq_A.index(hyp.y_sequence) A[dict_pos].score = np.logaddexp( A[dict_pos].score, (hyp.score + float(beam_logp[j, self.blank])) ) if v < self.tsd_max_symmetric_expansion_per_step: for j, hyp in enumerate(C): # for each current hypothesis j # extract the top token score and top token id for the jth hypothesis for logp, k in zip(beam_topk[0][j], beam_topk[1][j] + index_incr): # create new hypothesis and store in D # Note: This loop does *not* include the blank token! new_hyp = Hypothesis( score=(hyp.score + float(logp)), y_sequence=(hyp.y_sequence + [int(k)]), dec_state=self.decoder.batch_select_state(beam_state, j), lm_state=hyp.lm_state, timestep=hyp.timestep[:] + [i], length=encoded_lengths, ) D.append(new_hyp) # Prune beam C = sorted(D, key=lambda x: x.score, reverse=True)[:beam] # Prune beam B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] return self.sort_nbest(B)
[docs] def align_length_sync_decoding( self, h: torch.Tensor, encoded_lengths: torch.Tensor, partial_hypotheses: Optional[Hypothesis] = None ) -> List[Hypothesis]: """Alignment-length synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoded speech features (1, T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ if self.preserve_alignments: raise NotImplementedError( "`preseve_alignments` is not implemented for Alignment-length Synchronous Decoding." ) if partial_hypotheses is not None: raise NotImplementedError("`partial_hypotheses` support is not supported") # Precompute some constants for blank position ids = list(range(self.vocab_size + 1)) ids.remove(self.blank) # Used when blank token is first vs last token if self.blank == 0: index_incr = 1 else: index_incr = 0 # prepare the batched beam states beam = min(self.beam_size, self.vocab_size) h = h[0] # [T, D] h_length = int(encoded_lengths) beam_state = self.decoder.initialize_state( torch.zeros(beam, device=h.device, dtype=h.dtype) ) # [L, B, H], [L, B, H] for LSTMS # compute u_max as either a specific static limit, # or a multiple of current `h_length` dynamically. if type(self.alsd_max_target_length) == float: u_max = int(self.alsd_max_target_length * h_length) else: u_max = int(self.alsd_max_target_length) # Initialize first hypothesis for the beam (blank) B = [ Hypothesis( y_sequence=[self.blank], score=0.0, dec_state=self.decoder.batch_select_state(beam_state, 0), timestep=[-1], length=0, ) ] final = [] cache = {} # ALSD runs for T + U_max steps for i in range(h_length + u_max): # Update caches A = [] B_ = [] h_states = [] # preserve the list of batch indices which are added into the list # and those which are removed from the list # This is necessary to perform state updates in the correct batch indices later batch_ids = list(range(len(B))) # initialize as a list of all batch ids batch_removal_ids = [] # update with sample ids which are removed for bid, hyp in enumerate(B): u = len(hyp.y_sequence) - 1 t = i - u if t > (h_length - 1): batch_removal_ids.append(bid) continue B_.append(hyp) h_states.append((t, h[t])) if B_: # Compute the subset of batch ids which were *not* removed from the list above sub_batch_ids = None if len(B_) != beam: sub_batch_ids = batch_ids for id in batch_removal_ids: # sub_batch_ids contains list of ids *that were not removed* sub_batch_ids.remove(id) # extract the states of the sub batch only. beam_state_ = [beam_state[state_id][:, sub_batch_ids, :] for state_id in range(len(beam_state))] else: # If entire batch was used (none were removed), simply take all the states beam_state_ = beam_state # Decode a batch/sub-batch of beam states and scores beam_y, beam_state_, beam_lm_tokens = self.decoder.batch_score_hypothesis(B_, cache, beam_state_) # If only a subset of batch ids were updated (some were removed) if sub_batch_ids is not None: # For each state in the RNN (2 for LSTM) for state_id in range(len(beam_state)): # Update the current batch states with the sub-batch states (in the correct indices) # These indices are specified by sub_batch_ids, the ids of samples which were updated. beam_state[state_id][:, sub_batch_ids, :] = beam_state_[state_id][...] else: # If entire batch was updated, simply update all the states beam_state = beam_state_ # h_states = list of [t, h[t]] # so h[1] here is a h[t] of shape [D] # Simply stack all of the h[t] within the sub_batch/batch (T <= beam) h_enc = torch.stack([h[1] for h in h_states]) # [T=beam, D] h_enc = h_enc.unsqueeze(1) # [B=beam, T=1, D]; batch over the beams # Extract the log probabilities and the predicted tokens beam_logp = torch.log_softmax( self.joint.joint(h_enc, beam_y) / self.softmax_temperature, dim=-1 ) # [B=beam, 1, 1, V + 1] beam_logp = beam_logp[:, 0, 0, :] # [B=beam, V + 1] beam_topk = beam_logp[:, ids].topk(beam, dim=-1) for j, hyp in enumerate(B_): # For all updated samples in the batch, add it as the blank token # In this step, we dont add a token but simply update score new_hyp = Hypothesis( score=(hyp.score + float(beam_logp[j, self.blank])), y_sequence=hyp.y_sequence[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, timestep=hyp.timestep[:], length=i, ) # Add blank prediction to A A.append(new_hyp) # If the prediction "timestep" t has reached the length of the input sequence # we can add it to the "finished" hypothesis list. if h_states[j][0] == (h_length - 1): final.append(new_hyp) # Here, we carefully select the indices of the states that we want to preserve # for the next token (non-blank) update. if sub_batch_ids is not None: h_states_idx = sub_batch_ids[j] else: h_states_idx = j # for each current hypothesis j # extract the top token score and top token id for the jth hypothesis for logp, k in zip(beam_topk[0][j], beam_topk[1][j] + index_incr): # create new hypothesis and store in A # Note: This loop does *not* include the blank token! new_hyp = Hypothesis( score=(hyp.score + float(logp)), y_sequence=(hyp.y_sequence[:] + [int(k)]), dec_state=self.decoder.batch_select_state(beam_state, h_states_idx), lm_state=hyp.lm_state, timestep=hyp.timestep[:] + [i], length=i, ) A.append(new_hyp) # Prune and recombine same hypothesis # This may cause next beam to be smaller than max beam size # Therefore larger beam sizes may be required for better decoding. B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] B = self.recombine_hypotheses(B) # If B_ is empty list, then we may be able to early exit elif len(batch_ids) == len(batch_removal_ids): break if final: return self.sort_nbest(final) else: return B
[docs] def recombine_hypotheses(self, hypotheses: List[Hypothesis]) -> List[Hypothesis]: """Recombine hypotheses with equivalent output sequence. Args: hypotheses (list): list of hypotheses Returns: final (list): list of recombined hypotheses """ final = [] for hyp in hypotheses: seq_final = [f.y_sequence for f in final if f.y_sequence] if hyp.y_sequence in seq_final: seq_pos = seq_final.index(hyp.y_sequence) final[seq_pos].score = np.logaddexp(final[seq_pos].score, hyp.score) else: final.append(hyp) return hypotheses
@dataclass class BeamRNNTInferConfig: beam_size: int search_type: str = 'default' score_norm: bool = True return_best_hypothesis: bool = True tsd_max_sym_exp_per_step: Optional[int] = 50 alsd_max_target_len: float = 1.0 nsc_max_timesteps_expansion: int = 1 nsc_prefix_alpha: int = 1 maes_num_steps: int = 2 maes_prefix_alpha: int = 1 maes_expansion_gamma: float = 2.3 maes_expansion_beta: int = 2 language_model: Optional[Dict[str, Any]] = None softmax_temperature: float = 1.0 preserve_alignments: bool = False