Source code for nemo.collections.asr.metrics.rnnt_wer

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Union

import editdistance
import torch
from pytorch_lightning.metrics import Metric

from nemo.collections.asr.parts import rnnt_beam_decoding as beam_decode
from nemo.collections.asr.parts import rnnt_greedy_decoding as greedy_decode
from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses
from nemo.utils import logging

__all__ = ['RNNTDecoding', 'RNNTWER']


class AbstractRNNTDecoding(ABC):
    """
    Used for performing RNN-T auto-regressive decoding of the Decoder+Joint network given the encoder state.

    Args:
        decoding_cfg: A dict-like object which contains the following key-value pairs.
            strategy: str value which represents the type of decoding that can occur.
                Possible values are :
                -   greedy, greedy_batch (for greedy decoding).
                -   beam, tsd, alsd (for beam search decoding).

            compute_hypothesis_token_set: A bool flag, which determines whether to compute a list of decoded
                tokens as well as the decoded string. Default is False in order to avoid double decoding
                unless required.

            The config may further contain the following sub-dictionaries:
            "greedy":
                max_symbols: int, describing the maximum number of target tokens to decode per
                    timestep during greedy decoding. Setting to larger values allows longer sentences
                    to be decoded, at the cost of increased execution time.

            "beam":
                beam_size: int, defining the beam size for beam search. Must be >= 1.
                    If beam_size == 1, will perform cached greedy search. This might be slightly different
                    results compared to the greedy search above.

                score_norm: optional bool, whether to normalize the returned beam score in the hypotheses.
                    Set to True by default.

                return_best_hypothesis: optional bool, whether to return just the best hypothesis or all of the
                    hypotheses after beam search has concluded. This flag is set by default.

                tsd_max_sym_exp: optional int, determines number of symmetric expansions of the target symbols
                    per timestep of the acoustic model. Larger values will allow longer sentences to be decoded,
                    at increased cost to execution time.

                alsd_max_target_len: optional int or float, determines the potential maximum target sequence length.
                    If an integer is provided, it can decode sequences of that particular maximum length.
                    If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len),
                    where seq_len is the length of the acoustic model output (T).

                    NOTE:
                        If a float is provided, it can be greater than 1!
                        By default, a float of 2.0 is used so that a target sequence can be at most twice
                        as long as the acoustic model output length T.

        decoder: The Decoder/Prediction network module.
        joint: The Joint network module.
        blank_id: The id of the RNNT blank token.
    """

    def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
        super(AbstractRNNTDecoding, self).__init__()
        self.cfg = decoding_cfg
        self.blank_id = blank_id
        self.compute_hypothesis_token_set = self.cfg.get("compute_hypothesis_token_set", False)

        possible_strategies = ['greedy', 'greedy_batch', 'beam', 'tsd', 'alsd']
        if self.cfg.strategy not in possible_strategies:
            raise ValueError(f"Decoding strategy must be one of {possible_strategies}")

        if self.cfg.strategy == 'greedy':
            self.decoding = greedy_decode.GreedyRNNTInfer(
                decoder_model=decoder,
                joint_model=joint,
                blank_index=self.blank_id,
                max_symbols_per_step=self.cfg.greedy.get('max_symbols', None),
            )

        elif self.cfg.strategy == 'greedy_batch':
            self.decoding = greedy_decode.GreedyBatchedRNNTInfer(
                decoder_model=decoder,
                joint_model=joint,
                blank_index=self.blank_id,
                max_symbols_per_step=self.cfg.greedy.get('max_symbols', None),
            )

        elif self.cfg.strategy == 'beam':

            self.decoding = beam_decode.BeamRNNTInfer(
                decoder_model=decoder,
                joint_model=joint,
                beam_size=self.cfg.beam.beam_size,
                return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True),
                search_type='default',
                score_norm=self.cfg.beam.get('score_norm', True),
            )

        elif self.cfg.strategy == 'tsd':

            self.decoding = beam_decode.BeamRNNTInfer(
                decoder_model=decoder,
                joint_model=joint,
                beam_size=self.cfg.beam.beam_size,
                return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True),
                search_type='tsd',
                score_norm=self.cfg.beam.get('score_norm', True),
                tsd_max_sym_exp_per_step=self.cfg.beam.get('tsd_max_sym_exp', 50),
            )

        elif self.cfg.strategy == 'alsd':

            self.decoding = beam_decode.BeamRNNTInfer(
                decoder_model=decoder,
                joint_model=joint,
                beam_size=self.cfg.beam.beam_size,
                return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True),
                search_type='alsd',
                score_norm=self.cfg.beam.get('score_norm', True),
                alsd_max_target_len=self.cfg.beam.get('alsd_max_target_len', 2),
            )

    def rnnt_decoder_predictions_tensor(
        self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor, return_hypotheses: bool = False
    ) -> (List[str], Optional[List[List[str]]], Optional[Union[Hypothesis, NBestHypotheses]]):
        """
        Decode an encoder output by autoregressive decoding of the Decoder+Joint networks.

        Args:
            encoder_output: torch.Tensor of shape [B, D, T].
            encoded_lengths: torch.Tensor containing lengths of the padded encoder outputs. Shape [B].
            return_hypotheses: bool. If set to True it will return list of Hypothesis or NBestHypotheses

        Returns:
            If `return_best_hypothesis` is set:
                A tuple (hypotheses, None):
                hypotheses - list of Hypothesis (best hypothesis per sample).
                    Look at rnnt_utils.Hypothesis for more information.

            If `return_best_hypothesis` is not set:
                A tuple(hypotheses, all_hypotheses)
                hypotheses - list of Hypothesis (best hypothesis per sample).
                    Look at rnnt_utils.Hypothesis for more information.
                all_hypotheses - list of NBestHypotheses. Each NBestHypotheses further contains a sorted
                    list of all the hypotheses of the model per sample.
                    Look at rnnt_utils.NBestHypotheses for more information.
        """
        # Compute hypotheses
        with torch.no_grad():
            hypotheses_list = self.decoding(
                encoder_output=encoder_output, encoded_lengths=encoded_lengths
            )  # type: [List[Hypothesis]]

            # extract the hypotheses
            hypotheses_list = hypotheses_list[0]  # type: List[Hypothesis]

        prediction_list = hypotheses_list

        if isinstance(prediction_list[0], NBestHypotheses):
            hypotheses = []
            all_hypotheses = []
            for nbest_hyp in prediction_list:  # type: NBestHypotheses
                n_hyps = nbest_hyp.n_best_hypotheses  # Extract all hypotheses for this sample
                decoded_hyps = self.decode_hypothesis(n_hyps)  # type: List[str]
                hypotheses.append(decoded_hyps[0])  # best hypothesis
                all_hypotheses.append(decoded_hyps)
            if return_hypotheses:
                return hypotheses, all_hypotheses
            best_hyp_text = [h.text for h in hypotheses]
            all_hyp_text = [h.text for hh in all_hypotheses for h in hh]
            return best_hyp_text, all_hyp_text
        else:
            hypotheses = self.decode_hypothesis(prediction_list)  # type: List[str]
            if return_hypotheses:
                return hypotheses, None
            best_hyp_text = [h.text for h in hypotheses]
            return best_hyp_text, None

    def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hypothesis, NBestHypotheses]]:
        """
        Decode a list of hypotheses into a list of strings.

        Args:
            hypotheses_list: List of Hypothesis.

        Returns:
            A list of strings.
        """
        for ind in range(len(hypotheses_list)):
            # Extract the integer encoded hypothesis
            prediction = hypotheses_list[ind].y_sequence

            if type(prediction) != list:
                prediction = prediction.tolist()

            # RNN-T sample level is already preprocessed by implicit CTC decoding
            # Simply remove any blank tokens
            prediction = [p for p in prediction if p != self.blank_id]

            # De-tokenize the integer tokens
            hypothesis = self.decode_tokens_to_str(prediction)
            hypotheses_list[ind].text = hypothesis

            if self.compute_hypothesis_token_set:
                hypotheses_list[ind].tokens = self.decode_ids_to_tokens(prediction)
        return hypotheses_list

    @abstractmethod
    def decode_tokens_to_str(self, tokens: List[int]) -> str:
        """
        Implemented by subclass in order to decoder a token id list into a string.

        Args:
            tokens: List of int representing the token ids.

        Returns:
            A decoded string.
        """
        raise NotImplementedError()

    @abstractmethod
    def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]:
        """
        Implemented by subclass in order to decode a token id list into a token list.
        A token list is the string representation of each token id.

        Args:
            tokens: List of int representing the token ids.

        Returns:
            A list of decoded tokens.
        """
        raise NotImplementedError()


[docs]class RNNTDecoding(AbstractRNNTDecoding): """ Used for performing RNN-T auto-regressive decoding of the Decoder+Joint network given the encoder state. Args: decoding_cfg: A dict-like object which contains the following key-value pairs. strategy: str value which represents the type of decoding that can occur. Possible values are : - greedy, greedy_batch (for greedy decoding). - beam, tsd, alsd (for beam search decoding). compute_hypothesis_token_set: A bool flag, which determines whether to compute a list of decoded tokens as well as the decoded string. Default is False in order to avoid double decoding unless required. The config may further contain the following sub-dictionaries: "greedy": max_symbols: int, describing the maximum number of target tokens to decode per timestep during greedy decoding. Setting to larger values allows longer sentences to be decoded, at the cost of increased execution time. "beam": beam_size: int, defining the beam size for beam search. Must be >= 1. If beam_size == 1, will perform cached greedy search. This might be slightly different results compared to the greedy search above. score_norm: optional bool, whether to normalize the returned beam score in the hypotheses. Set to True by default. return_best_hypothesis: optional bool, whether to return just the best hypothesis or all of the hypotheses after beam search has concluded. This flag is set by default. tsd_max_sym_exp: optional int, determines number of symmetric expansions of the target symbols per timestep of the acoustic model. Larger values will allow longer sentences to be decoded, at increased cost to execution time. alsd_max_target_len: optional int or float, determines the potential maximum target sequence length. If an integer is provided, it can decode sequences of that particular maximum length. If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len), where seq_len is the length of the acoustic model output (T). NOTE: If a float is provided, it can be greater than 1! By default, a float of 2.0 is used so that a target sequence can be at most twice as long as the acoustic model output length T. decoder: The Decoder/Prediction network module. joint: The Joint network module. vocabulary: The vocabulary (excluding the RNNT blank token) which will be used for decoding. """ def __init__( self, decoding_cfg, decoder, joint, vocabulary, ): blank_id = len(vocabulary) self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))]) super(RNNTDecoding, self).__init__(decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, blank_id=blank_id)
[docs] def decode_tokens_to_str(self, tokens: List[int]) -> str: """ Implemented by subclass in order to decoder a token list into a string. Args: tokens: List of int representing the token ids. Returns: A decoded string. """ hypothesis = ''.join([self.labels_map[c] for c in tokens if c != self.blank_id]) return hypothesis
[docs] def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]: """ Implemented by subclass in order to decode a token id list into a token list. A token list is the string representation of each token id. Args: tokens: List of int representing the token ids. Returns: A list of decoded tokens. """ token_list = [self.labels_map[c] for c in tokens if c != self.blank_id] return token_list
class RNNTWER(Metric): """ This metric computes numerator and denominator for Overall Word Error Rate (WER) between prediction and reference texts. When doing distributed training/evaluation the result of res=WER(predictions, targets, target_lengths) calls will be all-reduced between all workers using SUM operations. Here contains two numbers res=[wer_numerator, wer_denominator]. WER=wer_numerator/wer_denominator. If used with PytorchLightning LightningModule, include wer_numerator and wer_denominators inside validation_step results. Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER. Example: def validation_step(self, batch, batch_idx): ... wer_num, wer_denom = self.__wer(predictions, transcript, transcript_len) return {'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom} def validation_epoch_end(self, outputs): ... wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum() wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum() tensorboard_logs = {'validation_loss': val_loss_mean, 'validation_avg_wer': wer_num / wer_denom} return {'val_loss': val_loss_mean, 'log': tensorboard_logs} Args: decoding: RNNTDecoding object that will perform autoregressive decoding of the RNNT model. batch_dim_index: Index of the batch dimension. use_cer: Whether to use Character Error Rate isntead of Word Error Rate. log_prediction: Whether to log a single decoded sample per call. Returns: res: a torch.Tensor object with two elements: [wer_numerator, wer_denominator]. To correctly compute average text word error rate, compute wer=wer_numerator/wer_denominator """ def __init__( self, decoding: RNNTDecoding, batch_dim_index=0, use_cer=False, log_prediction=True, dist_sync_on_step=False ): super(RNNTWER, self).__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) self.decoding = decoding self.batch_dim_index = batch_dim_index self.use_cer = use_cer self.log_prediction = log_prediction self.blank_id = self.decoding.blank_id self.labels_map = self.decoding.labels_map self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) def update( self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, ) -> torch.Tensor: words = 0.0 scores = 0.0 references = [] with torch.no_grad(): # prediction_cpu_tensor = tensors[0].long().cpu() targets_cpu_tensor = targets.long().cpu() tgt_lenths_cpu_tensor = target_lengths.long().cpu() # iterate over batch for ind in range(targets_cpu_tensor.shape[self.batch_dim_index]): tgt_len = tgt_lenths_cpu_tensor[ind].item() target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist() reference = self.decoding.decode_tokens_to_str(target) references.append(reference) hypotheses, _ = self.decoding.rnnt_decoder_predictions_tensor(encoder_output, encoded_lengths) if self.log_prediction: logging.info(f"\n") logging.info(f"reference :{references[0]}") logging.info(f"predicted :{hypotheses[0]}") for h, r in zip(hypotheses, references): if self.use_cer: h_list = list(h) r_list = list(r) else: h_list = h.split() r_list = r.split() words += len(r_list) # Compute Levenshtein's distance scores += editdistance.eval(h_list, r_list) self.scores += torch.tensor(scores, device=self.scores.device, dtype=self.scores.dtype) self.words += torch.tensor(words, device=self.words.device, dtype=self.words.dtype) # return torch.tensor([scores, words]).to(predictions.device) def compute(self): wer = self.scores.float() / self.words return wer, self.scores.detach(), self.words.detach() @dataclass class RNNTDecodingConfig: strategy: str = "greedy_batch" compute_hypothesis_token_set: bool = False # greedy decoding config greedy: greedy_decode.GreedyRNNTInferConfig = greedy_decode.GreedyRNNTInferConfig() # beam decoding config beam: beam_decode.BeamRNNTInferConfig = beam_decode.BeamRNNTInferConfig(beam_size=4)