Source code for nemo.collections.asr.parts.submodules.ctc_beam_decoding

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Union

import torch

from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.core.classes import Typing, typecheck
from nemo.core.neural_types import HypothesisType, LengthsType, LogprobsType, NeuralType
from nemo.utils import logging

DEFAULT_TOKEN_OFFSET = 100


def pack_hypotheses(
    hypotheses: List[rnnt_utils.NBestHypotheses], logitlen: torch.Tensor,
) -> List[rnnt_utils.NBestHypotheses]:

    if logitlen is not None:
        if hasattr(logitlen, 'cpu'):
            logitlen_cpu = logitlen.to('cpu')
        else:
            logitlen_cpu = logitlen

    for idx, hyp in enumerate(hypotheses):  # type: rnnt_utils.NBestHypotheses
        for candidate_idx, cand in enumerate(hyp.n_best_hypotheses):
            cand.y_sequence = torch.tensor(cand.y_sequence, dtype=torch.long)

            if logitlen is not None:
                cand.length = logitlen_cpu[idx]

            if cand.dec_state is not None:
                cand.dec_state = _states_to_device(cand.dec_state)

    return hypotheses


def _states_to_device(dec_state, device='cpu'):
    if torch.is_tensor(dec_state):
        dec_state = dec_state.to(device)

    elif isinstance(dec_state, (list, tuple)):
        dec_state = tuple(_states_to_device(dec_i, device) for dec_i in dec_state)

    return dec_state


class AbstractBeamCTCInfer(Typing):
    """A beam CTC decoder.

    Provides a common abstraction for sample level beam decoding.

    Args:
        blank_id: int, index of the blank token. Can be 0 or len(vocabulary).
        beam_size: int, size of the beam used in the underlying beam search engine.

    """

    @property
    def input_types(self):
        """Returns definitions of module input ports.
        """
        return {
            "decoder_output": NeuralType(('B', 'T', 'D'), LogprobsType()),
            "decoder_lengths": NeuralType(tuple('B'), LengthsType()),
        }

    @property
    def output_types(self):
        """Returns definitions of module output ports.
        """
        return {"predictions": [NeuralType(elements_type=HypothesisType())]}

    def __init__(self, blank_id: int, beam_size: int):
        self.blank_id = blank_id

        if beam_size < 1:
            raise ValueError("Beam search size cannot be less than 1!")

        self.beam_size = beam_size

        # Variables set by corresponding setter methods
        self.vocab = None
        self.decoding_type = None
        self.tokenizer = None

        # Utility maps for vocabulary
        self.vocab_index_map = None
        self.index_vocab_map = None

        # Internal variable, used to prevent double reduction of consecutive tokens (ctc collapse)
        self.override_fold_consecutive_value = None

    def set_vocabulary(self, vocab: List[str]):
        """
        Set the vocabulary of the decoding framework.

        Args:
            vocab: List of str. Each token corresponds to its location in the vocabulary emitted by the model.
                Note that this vocabulary must NOT contain the "BLANK" token.
        """
        self.vocab = vocab
        self.vocab_index_map = {v: i for i, v in enumerate(vocab)}
        self.index_vocab_map = {i: v for i, v in enumerate(vocab)}

    def set_decoding_type(self, decoding_type: str):
        """
        Sets the decoding type of the framework. Can support either char or subword models.

        Args:
            decoding_type: Str corresponding to decoding type. Only supports "char" and "subword".
        """
        decoding_type = decoding_type.lower()
        supported_types = ['char', 'subword']

        if decoding_type not in supported_types:
            raise ValueError(
                f"Unsupported decoding type. Supported types = {supported_types}.\n" f"Given = {decoding_type}"
            )

        self.decoding_type = decoding_type

    def set_tokenizer(self, tokenizer: TokenizerSpec):
        """
        Set the tokenizer of the decoding framework.

        Args:
            tokenizer: NeMo tokenizer object, which inherits from TokenizerSpec.
        """
        self.tokenizer = tokenizer

    @typecheck()
    def forward(
        self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor,
    ) -> Tuple[List[Union[rnnt_utils.Hypothesis, rnnt_utils.NBestHypotheses]]]:
        """Returns a list of hypotheses given an input batch of the encoder hidden embedding.
        Output token is generated auto-repressively.

        Args:
            decoder_output: A tensor of size (batch, timesteps, features) or (batch, timesteps) (each timestep is a label).
            decoder_lengths: list of int representing the length of each sequence
                output sequence.

        Returns:
            packed list containing batch number of sentences (Hypotheses).
        """
        raise NotImplementedError()

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)


[docs] class BeamCTCInfer(AbstractBeamCTCInfer): """A greedy CTC decoder. Provides a common abstraction for sample level and batch level greedy decoding. Args: blank_index: int index of the blank token. Can be 0 or len(vocabulary). preserve_alignments: Bool flag which preserves the history of logprobs generated during decoding (sample / batched). When set to true, the Hypothesis will contain the non-null value for `logprobs` in it. Here, `logprobs` is a torch.Tensors. compute_timestamps: A bool flag, which determines whether to compute the character/subword, or word based timestamp mapping the output log-probabilities to discrite intervals of timestamps. The timestamps will be available in the returned Hypothesis.timestep as a dictionary. """ def __init__( self, blank_id: int, beam_size: int, search_type: str = "default", return_best_hypothesis: bool = True, preserve_alignments: bool = False, compute_timestamps: bool = False, beam_alpha: float = 1.0, beam_beta: float = 0.0, kenlm_path: str = None, flashlight_cfg: Optional['FlashlightConfig'] = None, pyctcdecode_cfg: Optional['PyCTCDecodeConfig'] = None, ): super().__init__(blank_id=blank_id, beam_size=beam_size) self.search_type = search_type self.return_best_hypothesis = return_best_hypothesis self.preserve_alignments = preserve_alignments self.compute_timestamps = compute_timestamps if self.compute_timestamps: raise ValueError(f"Currently this flag is not supported for beam search algorithms.") self.vocab = None # This must be set by specific method by user before calling forward() ! if search_type == "default" or search_type == "nemo": self.search_algorithm = self.default_beam_search elif search_type == "pyctcdecode": self.search_algorithm = self._pyctcdecode_beam_search elif search_type == "flashlight": self.search_algorithm = self.flashlight_beam_search else: raise NotImplementedError( f"The search type ({search_type}) supplied is not supported!\n" f"Please use one of : (default, nemo, pyctcdecode)" ) # Log the beam search algorithm logging.info(f"Beam search algorithm: {search_type}") self.beam_alpha = beam_alpha self.beam_beta = beam_beta # Default beam search args self.kenlm_path = kenlm_path # PyCTCDecode params if pyctcdecode_cfg is None: pyctcdecode_cfg = PyCTCDecodeConfig() self.pyctcdecode_cfg = pyctcdecode_cfg # type: PyCTCDecodeConfig if flashlight_cfg is None: flashlight_cfg = FlashlightConfig() self.flashlight_cfg = flashlight_cfg # Default beam search scorer functions self.default_beam_scorer = None self.pyctcdecode_beam_scorer = None self.flashlight_beam_scorer = None self.token_offset = 0
[docs] @typecheck() def forward( self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor, ) -> Tuple[List[Union[rnnt_utils.Hypothesis, rnnt_utils.NBestHypotheses]]]: """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. Args: decoder_output: A tensor of size (batch, timesteps, features). decoder_lengths: list of int representing the length of each sequence output sequence. Returns: packed list containing batch number of sentences (Hypotheses). """ if self.vocab is None: raise RuntimeError("Please set the vocabulary with `set_vocabulary()` before calling this function.") if self.decoding_type is None: raise ValueError("Please set the decoding type with `set_decoding_type()` before calling this function.") with torch.no_grad(), torch.inference_mode(): # Process each sequence independently prediction_tensor = decoder_output if prediction_tensor.ndim != 3: raise ValueError( f"`decoder_output` must be a tensor of shape [B, T, V] (log probs, float). " f"Provided shape = {prediction_tensor.shape}" ) # determine type of input - logprobs or labels out_len = decoder_lengths if decoder_lengths is not None else None hypotheses = self.search_algorithm(prediction_tensor, out_len) # Pack results into Hypotheses packed_result = pack_hypotheses(hypotheses, decoder_lengths) # Pack the result if self.return_best_hypothesis and isinstance(packed_result[0], rnnt_utils.NBestHypotheses): packed_result = [res.n_best_hypotheses[0] for res in packed_result] # type: Hypothesis return (packed_result,)
@torch.no_grad() def default_beam_search( self, x: torch.Tensor, out_len: torch.Tensor ) -> List[Union[rnnt_utils.Hypothesis, rnnt_utils.NBestHypotheses]]: """ Open Seq2Seq Beam Search Algorithm (DeepSpeed) Args: x: Tensor of shape [B, T, V+1], where B is the batch size, T is the maximum sequence length, and V is the vocabulary size. The tensor contains log-probabilities. out_len: Tensor of shape [B], contains lengths of each sequence in the batch. Returns: A list of NBestHypotheses objects, one for each sequence in the batch. """ if self.compute_timestamps: raise ValueError( f"Beam Search with strategy `{self.search_type}` does not support time stamp calculation!" ) if self.default_beam_scorer is None: # Check for filepath if self.kenlm_path is None or not os.path.exists(self.kenlm_path): raise FileNotFoundError( f"KenLM binary file not found at : {self.kenlm_path}. " f"Please set a valid path in the decoding config." ) # perform token offset for subword models if self.decoding_type == 'subword': vocab = [chr(idx + self.token_offset) for idx in range(len(self.vocab))] else: # char models vocab = self.vocab # Must import at runtime to avoid circular dependency due to module level import. from nemo.collections.asr.modules.beam_search_decoder import BeamSearchDecoderWithLM self.default_beam_scorer = BeamSearchDecoderWithLM( vocab=vocab, lm_path=self.kenlm_path, beam_width=self.beam_size, alpha=self.beam_alpha, beta=self.beam_beta, num_cpus=max(1, os.cpu_count()), input_tensor=False, ) x = x.to('cpu') with typecheck.disable_checks(): data = [x[sample_id, : out_len[sample_id], :].softmax(dim=-1) for sample_id in range(len(x))] beams_batch = self.default_beam_scorer.forward(log_probs=data, log_probs_length=None) # For each sample in the batch nbest_hypotheses = [] for beams_idx, beams in enumerate(beams_batch): # For each beam candidate / hypothesis in each sample hypotheses = [] for candidate_idx, candidate in enumerate(beams): hypothesis = rnnt_utils.Hypothesis( score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None ) # For subword encoding, NeMo will double encode the subword (multiple tokens) into a # singular unicode id. In doing so, we preserve the semantic of the unicode token, and # compress the size of the final KenLM ARPA / Binary file. # In order to do double encoding, we shift the subword by some token offset. # This step is ignored for character based models. if self.decoding_type == 'subword': pred_token_ids = [ord(c) - self.token_offset for c in candidate[1]] else: # Char models pred_token_ids = [self.vocab_index_map[c] for c in candidate[1]] # We preserve the token ids and the score for this hypothesis hypothesis.y_sequence = pred_token_ids hypothesis.score = candidate[0] # If alignment must be preserved, we preserve a view of the output logprobs. # Note this view is shared amongst all beams within the sample, be sure to clone it if you # require specific processing for each sample in the beam. # This is done to preserve memory. if self.preserve_alignments: hypothesis.alignments = x[beams_idx][: out_len[beams_idx]] hypotheses.append(hypothesis) # Wrap the result in NBestHypothesis. hypotheses = rnnt_utils.NBestHypotheses(hypotheses) nbest_hypotheses.append(hypotheses) return nbest_hypotheses @torch.no_grad() def _pyctcdecode_beam_search( self, x: torch.Tensor, out_len: torch.Tensor ) -> List[Union[rnnt_utils.Hypothesis, rnnt_utils.NBestHypotheses]]: """ PyCTCDecode Beam Search Algorithm. Should support Char and Subword models. Args: x: Tensor of shape [B, T, V+1], where B is the batch size, T is the maximum sequence length, and V is the vocabulary size. The tensor contains log-probabilities. out_len: Tensor of shape [B], contains lengths of each sequence in the batch. Returns: A list of NBestHypotheses objects, one for each sequence in the batch. """ if self.compute_timestamps: raise ValueError( f"Beam Search with strategy `{self.search_type}` does not support time stamp calculation!" ) try: import pyctcdecode except (ImportError, ModuleNotFoundError): raise ImportError( f"Could not load `pyctcdecode` library. Please install it from pip using :\n" f"pip install --upgrade pyctcdecode" ) if self.pyctcdecode_beam_scorer is None: self.pyctcdecode_beam_scorer = pyctcdecode.build_ctcdecoder( labels=self.vocab, kenlm_model_path=self.kenlm_path, alpha=self.beam_alpha, beta=self.beam_beta ) # type: pyctcdecode.BeamSearchDecoderCTC x = x.to('cpu').numpy() with typecheck.disable_checks(): beams_batch = [] for sample_id in range(len(x)): logprobs = x[sample_id, : out_len[sample_id], :] result = self.pyctcdecode_beam_scorer.decode_beams( logprobs, beam_width=self.beam_size, beam_prune_logp=self.pyctcdecode_cfg.beam_prune_logp, token_min_logp=self.pyctcdecode_cfg.token_min_logp, prune_history=self.pyctcdecode_cfg.prune_history, hotwords=self.pyctcdecode_cfg.hotwords, hotword_weight=self.pyctcdecode_cfg.hotword_weight, lm_start_state=None, ) # Output format: text, last_lm_state, text_frames, logit_score, lm_score beams_batch.append(result) nbest_hypotheses = [] for beams_idx, beams in enumerate(beams_batch): hypotheses = [] for candidate_idx, candidate in enumerate(beams): # Candidate = (text, last_lm_state, text_frames, logit_score, lm_score) hypothesis = rnnt_utils.Hypothesis( score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None ) # TODO: Requires token ids to be returned rather than text. if self.decoding_type == 'subword': if self.tokenizer is None: raise ValueError("Tokenizer must be provided for subword decoding. Use set_tokenizer().") pred_token_ids = self.tokenizer.text_to_ids(candidate[0]) else: if self.vocab is None: raise ValueError("Vocab must be provided for character decoding. Use set_vocab().") chars = list(candidate[0]) pred_token_ids = [self.vocab_index_map[c] for c in chars] hypothesis.y_sequence = pred_token_ids hypothesis.text = candidate[0] # text hypothesis.score = candidate[4] # score # Inject word level timestamps hypothesis.timestep = candidate[2] # text_frames if self.preserve_alignments: hypothesis.alignments = torch.from_numpy(x[beams_idx][: out_len[beams_idx]]) hypotheses.append(hypothesis) hypotheses = rnnt_utils.NBestHypotheses(hypotheses) nbest_hypotheses.append(hypotheses) return nbest_hypotheses @torch.no_grad() def flashlight_beam_search( self, x: torch.Tensor, out_len: torch.Tensor ) -> List[Union[rnnt_utils.Hypothesis, rnnt_utils.NBestHypotheses]]: """ Flashlight Beam Search Algorithm. Should support Char and Subword models. Args: x: Tensor of shape [B, T, V+1], where B is the batch size, T is the maximum sequence length, and V is the vocabulary size. The tensor contains log-probabilities. out_len: Tensor of shape [B], contains lengths of each sequence in the batch. Returns: A list of NBestHypotheses objects, one for each sequence in the batch. """ if self.compute_timestamps: raise ValueError( f"Beam Search with strategy `{self.search_type}` does not support time stamp calculation!" ) if self.flashlight_beam_scorer is None: # Check for filepath if self.kenlm_path is None or not os.path.exists(self.kenlm_path): raise FileNotFoundError( f"KenLM binary file not found at : {self.kenlm_path}. " f"Please set a valid path in the decoding config." ) # perform token offset for subword models # if self.decoding_type == 'subword': # vocab = [chr(idx + self.token_offset) for idx in range(len(self.vocab))] # else: # # char models # vocab = self.vocab # Must import at runtime to avoid circular dependency due to module level import. from nemo.collections.asr.modules.flashlight_decoder import FlashLightKenLMBeamSearchDecoder self.flashlight_beam_scorer = FlashLightKenLMBeamSearchDecoder( lm_path=self.kenlm_path, vocabulary=self.vocab, tokenizer=self.tokenizer, lexicon_path=self.flashlight_cfg.lexicon_path, boost_path=self.flashlight_cfg.boost_path, beam_size=self.beam_size, beam_size_token=self.flashlight_cfg.beam_size_token, beam_threshold=self.flashlight_cfg.beam_threshold, lm_weight=self.beam_alpha, word_score=self.beam_beta, unk_weight=self.flashlight_cfg.unk_weight, sil_weight=self.flashlight_cfg.sil_weight, ) x = x.to('cpu') with typecheck.disable_checks(): beams_batch = self.flashlight_beam_scorer.forward(log_probs=x) # For each sample in the batch nbest_hypotheses = [] for beams_idx, beams in enumerate(beams_batch): # For each beam candidate / hypothesis in each sample hypotheses = [] for candidate_idx, candidate in enumerate(beams): hypothesis = rnnt_utils.Hypothesis( score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None ) # We preserve the token ids and the score for this hypothesis hypothesis.y_sequence = candidate['tokens'].tolist() hypothesis.score = candidate['score'] # If alignment must be preserved, we preserve a view of the output logprobs. # Note this view is shared amongst all beams within the sample, be sure to clone it if you # require specific processing for each sample in the beam. # This is done to preserve memory. if self.preserve_alignments: hypothesis.alignments = x[beams_idx][: out_len[beams_idx]] hypotheses.append(hypothesis) # Wrap the result in NBestHypothesis. hypotheses = rnnt_utils.NBestHypotheses(hypotheses) nbest_hypotheses.append(hypotheses) return nbest_hypotheses
[docs] def set_decoding_type(self, decoding_type: str): super().set_decoding_type(decoding_type) # Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need # TOKEN_OFFSET for BPE-based models if self.decoding_type == 'subword': self.token_offset = DEFAULT_TOKEN_OFFSET
@dataclass class PyCTCDecodeConfig: # These arguments cannot be imported from pyctcdecode (optional dependency) # Therefore we copy the values explicitly # Taken from pyctcdecode.constant beam_prune_logp: float = -10.0 token_min_logp: float = -5.0 prune_history: bool = False hotwords: Optional[List[str]] = None hotword_weight: float = 10.0 @dataclass class FlashlightConfig: lexicon_path: Optional[str] = None boost_path: Optional[str] = None beam_size_token: int = 16 beam_threshold: float = 20.0 unk_weight: float = -math.inf sil_weight: float = 0.0 @dataclass class BeamCTCInferConfig: beam_size: int search_type: str = 'default' preserve_alignments: bool = False compute_timestamps: bool = False return_best_hypothesis: bool = True beam_alpha: float = 1.0 beam_beta: float = 0.0 kenlm_path: Optional[str] = None flashlight_cfg: Optional[FlashlightConfig] = field(default_factory=lambda: FlashlightConfig()) pyctcdecode_cfg: Optional[PyCTCDecodeConfig] = field(default_factory=lambda: PyCTCDecodeConfig())