# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional
import editdistance
import torch
from pytorch_lightning.metrics import Metric
from nemo.collections.asr.metrics.rnnt_wer import AbstractRNNTDecoding
from nemo.collections.asr.parts import rnnt_beam_decoding as beam_decode
from nemo.collections.asr.parts import rnnt_greedy_decoding as greedy_decode
from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.utils import logging
__all__ = ['RNNTBPEDecoding', 'RNNTBPEWER']
[docs]class RNNTBPEDecoding(AbstractRNNTDecoding):
"""
Used for performing RNN-T auto-regressive decoding of the Decoder+Joint network given the encoder state.
Args:
decoding_cfg: A dict-like object which contains the following key-value pairs.
strategy: str value which represents the type of decoding that can occur.
Possible values are :
- greedy, greedy_batch (for greedy decoding).
- beam, tsd, alsd (for beam search decoding).
compute_hypothesis_token_set: A bool flag, which determines whether to compute a list of decoded
tokens as well as the decoded string. Default is False in order to avoid double decoding
unless required.
The config may further contain the following sub-dictionaries:
"greedy":
max_symbols: int, describing the maximum number of target tokens to decode per
timestep during greedy decoding. Setting to larger values allows longer sentences
to be decoded, at the cost of increased execution time.
"beam":
beam_size: int, defining the beam size for beam search. Must be >= 1.
If beam_size == 1, will perform cached greedy search. This might be slightly different
results compared to the greedy search above.
score_norm: optional bool, whether to normalize the returned beam score in the hypotheses.
Set to True by default.
return_best_hypothesis: optional bool, whether to return just the best hypothesis or all of the
hypotheses after beam search has concluded.
tsd_max_sym_exp: optional int, determines number of symmetric expansions of the target symbols
per timestep of the acoustic model. Larger values will allow longer sentences to be decoded,
at increased cost to execution time.
alsd_max_target_len: optional int or float, determines the potential maximum target sequence length.
If an integer is provided, it can decode sequences of that particular maximum length.
If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len),
where seq_len is the length of the acoustic model output (T).
NOTE:
If a float is provided, it can be greater than 1!
By default, a float of 2.0 is used so that a target sequence can be at most twice
as long as the acoustic model output length T.
decoder: The Decoder/Prediction network module.
joint: The Joint network module.
tokenizer: The tokenizer which will be used for decoding.
"""
def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec):
blank_id = tokenizer.tokenizer.vocab_size
self.tokenizer = tokenizer
super(RNNTBPEDecoding, self).__init__(
decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, blank_id=blank_id
)
[docs] def decode_tokens_to_str(self, tokens: List[int]) -> str:
"""
Implemented by subclass in order to decoder a token list into a string.
Args:
tokens: List of int representing the token ids.
Returns:
A decoded string.
"""
hypothesis = self.tokenizer.ids_to_text(tokens)
return hypothesis
[docs] def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]:
"""
Implemented by subclass in order to decode a token id list into a token list.
A token list is the string representation of each token id.
Args:
tokens: List of int representing the token ids.
Returns:
A list of decoded tokens.
"""
token_list = self.tokenizer.ids_to_tokens(tokens)
return token_list
class RNNTBPEWER(Metric):
"""
This metric computes numerator and denominator for Overall Word Error Rate (WER) between prediction and reference texts.
When doing distributed training/evaluation the result of res=WER(predictions, targets, target_lengths) calls
will be all-reduced between all workers using SUM operations.
Here contains two numbers res=[wer_numerator, wer_denominator]. WER=wer_numerator/wer_denominator.
If used with PytorchLightning LightningModule, include wer_numerator and wer_denominators inside validation_step results.
Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER.
Example:
def validation_step(self, batch, batch_idx):
...
wer_num, wer_denom = self.__wer(predictions, transcript, transcript_len)
return {'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom}
def validation_epoch_end(self, outputs):
...
wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum()
wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum()
tensorboard_logs = {'validation_loss': val_loss_mean, 'validation_avg_wer': wer_num / wer_denom}
return {'val_loss': val_loss_mean, 'log': tensorboard_logs}
Args:
decoding: RNNTBPEDecoding object that will perform autoregressive decoding of the RNNT model.
batch_dim_index: Index of the batch dimension.
use_cer: Whether to use Character Error Rate isntead of Word Error Rate.
log_prediction: Whether to log a single decoded sample per call.
Returns:
res: a torch.Tensor object with two elements: [wer_numerator, wer_denominator]. To correctly compute average
text word error rate, compute wer=wer_numerator/wer_denominator
"""
def __init__(
self,
decoding: RNNTBPEDecoding,
batch_dim_index=0,
use_cer: bool = False,
log_prediction: bool = True,
dist_sync_on_step=False,
):
super(RNNTBPEWER, self).__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
self.decoding = decoding
self.batch_dim_index = batch_dim_index
self.use_cer = use_cer
self.log_prediction = log_prediction
self.blank_id = self.decoding.blank_id
self.tokenizer = self.decoding.tokenizer
self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False)
self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False)
def update(
self,
encoder_output: torch.Tensor,
encoded_lengths: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
words = 0.0
scores = 0.0
references = []
with torch.no_grad():
# prediction_cpu_tensor = tensors[0].long().cpu()
targets_cpu_tensor = targets.long().cpu()
tgt_lenths_cpu_tensor = target_lengths.long().cpu()
# iterate over batch
for ind in range(targets_cpu_tensor.shape[self.batch_dim_index]):
tgt_len = tgt_lenths_cpu_tensor[ind].item()
target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist()
reference = self.decoding.decode_tokens_to_str(target)
references.append(reference)
hypotheses, _ = self.decoding.rnnt_decoder_predictions_tensor(encoder_output, encoded_lengths)
if self.log_prediction:
logging.info(f"\n")
logging.info(f"reference :{references[0]}")
logging.info(f"predicted :{hypotheses[0]}")
for h, r in zip(hypotheses, references):
if self.use_cer:
h_list = list(h)
r_list = list(r)
else:
h_list = h.split()
r_list = r.split()
words += len(r_list)
# Compute Levenshtein's distance
scores += editdistance.eval(h_list, r_list)
del hypotheses
self.scores += torch.tensor(scores, device=self.scores.device, dtype=self.scores.dtype)
self.words += torch.tensor(words, device=self.words.device, dtype=self.words.dtype)
# return torch.tensor([scores, words]).to(predictions.device)
def compute(self):
wer = self.scores.float() / self.words
return wer, self.scores.detach(), self.words.detach()
@dataclass
class RNNTBPEDecodingConfig:
strategy: str = "greedy_batch"
compute_hypothesis_token_set: bool = False
# greedy decoding config
greedy: greedy_decode.GreedyRNNTInferConfig = greedy_decode.GreedyRNNTInferConfig()
# beam decoding config
beam: beam_decode.BeamRNNTInferConfig = beam_decode.BeamRNNTInferConfig(beam_size=4)