Source code for nemo.collections.asr.parts.rnnt_greedy_decoding

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import List, Optional, Union

import torch

from nemo.collections.asr.modules import rnnt_abstract
from nemo.collections.asr.parts import rnnt_utils
from nemo.collections.common.parts.rnn import label_collate
from nemo.core.classes import Typing, typecheck
from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType


def pack_hypotheses(
    hypotheses: List[List[int]], timesteps: List[List[int]], logitlen: torch.Tensor
) -> List[rnnt_utils.Hypothesis]:
    logitlen_cpu = logitlen.to("cpu")
    return [
        rnnt_utils.Hypothesis(
            y_sequence=torch.tensor(sent, dtype=torch.long), score=-1.0, timestep=timestep, length=length
        )
        for sent, timestep, length in zip(hypotheses, timesteps, logitlen_cpu)
    ]


class _GreedyRNNTInfer(Typing):
    """A greedy transducer decoder.

    Provides a common abstraction for sample level and batch level greedy decoding.

    Args:
        decoder_model: rnnt_utils.AbstractRNNTDecoder implementation.
        joint_model: rnnt_utils.AbstractRNNTJoint implementation.
        blank_index: int index of the blank token. Can be 0 or len(vocabulary).
        max_symbols_per_step: Optional int. The maximum number of symbols that can be added
            to a sequence in a single time step; if set to None then there is
            no limit.
    """

    @property
    def input_types(self):
        """Returns definitions of module input ports.
        """
        return {
            "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
            "encoded_lengths": NeuralType(tuple('B'), LengthsType()),
        }

    @property
    def output_types(self):
        """Returns definitions of module output ports.
        """
        return {"predictions": [NeuralType(elements_type=HypothesisType())]}

    def __init__(
        self,
        decoder_model: rnnt_abstract.AbstractRNNTDecoder,
        joint_model: rnnt_abstract.AbstractRNNTJoint,
        blank_index: int,
        max_symbols_per_step: Optional[int] = None,
    ):
        super().__init__()
        self.decoder = decoder_model
        self.joint = joint_model

        self._blank_index = blank_index
        self._SOS = blank_index  # Start of single index
        self.max_symbols = max_symbols_per_step

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    @torch.no_grad()
    def _pred_step(
        self,
        label: Union[torch.Tensor, int],
        hidden: Optional[torch.Tensor],
        add_sos: bool = False,
        batch_size: Optional[int] = None,
    ) -> (torch.Tensor, torch.Tensor):
        """
        Common prediction step based on the AbstractRNNTDecoder implementation.

        Args:
            label: (int/torch.Tensor): Label or "Start-of-Signal" token.
            hidden: (Optional torch.Tensor): RNN State vector
            add_sos (bool): Whether to add a zero vector at the begging as "start of sentence" token.
            batch_size: Batch size of the output tensor.

        Returns:
            g: (B, U, H) if add_sos is false, else (B, U + 1, H)
            hid: (h, c) where h is the final sequence hidden state and c is
                the final cell state:
                    h (tensor), shape (L, B, H)
                    c (tensor), shape (L, B, H)
        """
        if isinstance(label, torch.Tensor):
            # label: [batch, 1]
            if label.dtype != torch.long:
                label = label.long()

        else:
            # Label is an integer
            if label == self._SOS:
                return self.decoder.predict(None, hidden, add_sos=add_sos, batch_size=batch_size)

            label = label_collate([[label]])

        # output: [B, 1, K]
        return self.decoder.predict(label, hidden, add_sos=add_sos, batch_size=batch_size)

    def _joint_step(self, enc, pred, log_normalize: Optional[bool] = None):
        """
        Common joint step based on AbstractRNNTJoint implementation.

        Args:
            enc: Output of the Encoder model. A torch.Tensor of shape [B, 1, H1]
            pred: Output of the Decoder model. A torch.Tensor of shape [B, 1, H2]
            log_normalize: Whether to log normalize or not. None will log normalize only for CPU.

        Returns:
             logits of shape (B, T=1, U=1, V + 1)
        """
        with torch.no_grad():
            logits = self.joint.joint(enc, pred)

            if log_normalize is None:
                if not logits.is_cuda:  # Use log softmax only if on CPU
                    logits = logits.log_softmax(dim=len(logits.shape) - 1)
            else:
                if log_normalize:
                    logits = logits.log_softmax(dim=len(logits.shape) - 1)

        return logits


[docs]class GreedyRNNTInfer(_GreedyRNNTInfer): """A greedy transducer decoder. Sequence level greedy decoding, performed auto-repressively. Args: decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. joint_model: rnnt_utils.AbstractRNNTJoint implementation. blank_index: int index of the blank token. Can be 0 or len(vocabulary). max_symbols_per_step: Optional int. The maximum number of symbols that can be added to a sequence in a single time step; if set to None then there is no limit. """ def __init__( self, decoder_model: rnnt_abstract.AbstractRNNTDecoder, joint_model: rnnt_abstract.AbstractRNNTJoint, blank_index: int, max_symbols_per_step: Optional[int] = None, ): super().__init__( decoder_model=decoder_model, joint_model=joint_model, blank_index=blank_index, max_symbols_per_step=max_symbols_per_step, )
[docs] @typecheck() def forward(self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. Args: encoder_output: A tensor of size (batch, features, timesteps). encoded_lengths: list of int representing the length of each sequence output sequence. Returns: packed list containing batch number of sentences (Hypotheses). """ # Preserve decoder and joint training state decoder_training_state = self.decoder.training joint_training_state = self.joint.training with torch.no_grad(): # Apply optional preprocessing encoder_output = encoder_output.transpose(1, 2) # (B, T, D) self.decoder.eval() self.joint.eval() hypotheses = [] timesteps = [] # Process each sequence independently with self.decoder.as_frozen(), self.joint.as_frozen(): for batch_idx in range(encoder_output.size(0)): inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] logitlen = encoded_lengths[batch_idx] sentence, timestep = self._greedy_decode(inseq, logitlen) hypotheses.append(sentence) timesteps.append(timestep) # Pack results into Hypotheses packed_result = pack_hypotheses(hypotheses, timesteps, encoded_lengths) self.decoder.train(decoder_training_state) self.joint.train(joint_training_state) return (packed_result,)
@torch.no_grad() def _greedy_decode(self, x: torch.Tensor, out_len: torch.Tensor): # x: [T, 1, D] # out_len: [seq_len] # Initialize blank state and empty label set hidden = None label = [] timesteps = [] # For timestep t in X_t for time_idx in range(out_len): # Extract encoder embedding at timestep t # f = x[time_idx, :, :].unsqueeze(0) # [1, 1, D] f = x.narrow(dim=0, start=time_idx, length=1) # Setup exit flags and counter not_blank = True symbols_added = 0 # While blank is not predicted, or we dont run out of max symbols per timestep while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): # In the first timestep, we initialize the network with RNNT Blank # In later timesteps, we provide previous predicted label as input. last_label = self._SOS if label == [] else label[-1] # Perform prediction network and joint network steps. g, hidden_prime = self._pred_step(last_label, hidden) logp = self._joint_step(f, g, log_normalize=None)[0, 0, 0, :] del g # torch.max(0) op doesnt exist for FP 16. if logp.dtype != torch.float32: logp = logp.float() # get index k, of max prob v, k = logp.max(0) k = k.item() # K is the label at timestep t_s in inner loop, s >= 0. del logp # If blank token is predicted, exit inner loop, move onto next timestep t if k == self._blank_index: not_blank = False else: # Append token to label set, update RNN state. label.append(k) timesteps.append(time_idx) hidden = hidden_prime # Increment token counter. symbols_added += 1 return label, timesteps
[docs]class GreedyBatchedRNNTInfer(_GreedyRNNTInfer): """A batch level greedy transducer decoder. Batch level greedy decoding, performed auto-repressively. Args: decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. joint_model: rnnt_utils.AbstractRNNTJoint implementation. blank_index: int index of the blank token. Can be 0 or len(vocabulary). max_symbols_per_step: Optional int. The maximum number of symbols that can be added to a sequence in a single time step; if set to None then there is no limit. """ def __init__( self, decoder_model: rnnt_abstract.AbstractRNNTDecoder, joint_model: rnnt_abstract.AbstractRNNTJoint, blank_index: int, max_symbols_per_step: Optional[int] = None, ): super().__init__( decoder_model=decoder_model, joint_model=joint_model, blank_index=blank_index, max_symbols_per_step=max_symbols_per_step, ) # Depending on availability of `blank_as_pad` support # switch between more efficient batch decoding technique if self.decoder.blank_as_pad: self._greedy_decode = self._greedy_decode_blank_as_pad else: self._greedy_decode = self._greedy_decode_masked
[docs] @typecheck() def forward(self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. Args: encoder_output: A tensor of size (batch, features, timesteps). encoded_lengths: list of int representing the length of each sequence output sequence. Returns: packed list containing batch number of sentences (Hypotheses). """ # Preserve decoder and joint training state decoder_training_state = self.decoder.training joint_training_state = self.joint.training with torch.no_grad(): # Apply optional preprocessing encoder_output = encoder_output.transpose(1, 2) # (B, T, D) logitlen = encoded_lengths self.decoder.eval() self.joint.eval() with self.decoder.as_frozen(), self.joint.as_frozen(): inseq = encoder_output # [B, T, D] hypotheses, timesteps = self._greedy_decode(inseq, logitlen, device=inseq.device) # Pack the hypotheses results packed_result = pack_hypotheses(hypotheses, timesteps, logitlen) del hypotheses, timesteps self.decoder.train(decoder_training_state) self.joint.train(joint_training_state) return (packed_result,)
def _greedy_decode_blank_as_pad(self, x: torch.Tensor, out_len: torch.Tensor, device: torch.device): with torch.no_grad(): # x: [B, T, D] # out_len: [B] # device: torch.device # Initialize state hidden = None batchsize = x.shape[0] # Output string buffer label = [[] for _ in range(batchsize)] timesteps = [[] for _ in range(batchsize)] # Last Label buffer + Last Label without blank buffer # batch level equivalent of the last_label last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device) # Mask buffers blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) # Get max sequence length max_out_len = out_len.max() for time_idx in range(max_out_len): f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] # Prepare t timestamp batch variables not_blank = True symbols_added = 0 # Reset blank mask blank_mask.mul_(False) # Update blank mask with time mask # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch) # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len blank_mask = time_idx >= out_len # Start inner loop while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): # Batch prediction and joint network steps # If very first prediction step, submit SOS tag (blank) to pred_step. # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state if time_idx == 0 and symbols_added == 0: g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize) else: # Perform batch step prediction of decoder, getting new states and scores ("g") g, hidden_prime = self._pred_step(last_label, hidden, batch_size=batchsize) # Batched joint step - Output = [B, V + 1] logp = self._joint_step(f, g, log_normalize=None)[:, 0, 0, :] if logp.dtype != torch.float32: logp = logp.float() # Get index k, of max prob for batch v, k = logp.max(1) del v, g, logp # Update blank mask with current predicted blanks # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U) k_is_blank = k == self._blank_index blank_mask.bitwise_or_(k_is_blank) del k_is_blank # If all samples predict / have predicted prior blanks, exit loop early # This is equivalent to if single sample predicted k if blank_mask.all(): not_blank = False else: # Collect batch indices where blanks occurred now/past blank_indices = [] if hidden is not None: blank_indices = (blank_mask == 1).nonzero(as_tuple=False) # Recover prior state for all samples which predicted blank now/past if hidden is not None: # LSTM has 2 states for state_id in range(len(hidden)): hidden_prime[state_id][:, blank_indices, :] = hidden[state_id][:, blank_indices, :] # Recover prior predicted label for all samples which predicted blank now/past k[blank_indices] = last_label[blank_indices, 0] # Update new label and hidden state for next iteration last_label = k.clone().view(-1, 1) hidden = hidden_prime # Update predicted labels, accounting for time mask # If blank was predicted even once, now or in the past, # Force the current predicted label to also be blank # This ensures that blanks propogate across all timesteps # once they have occured (normally stopping condition of sample level loop). for kidx, ki in enumerate(k): if blank_mask[kidx] == 0: label[kidx].append(ki) timesteps[kidx].append(time_idx) symbols_added += 1 return label, timesteps @torch.no_grad() def _greedy_decode_masked(self, x: torch.Tensor, out_len: torch.Tensor, device: torch.device): # x: [B, T, D] # out_len: [B] # device: torch.device # Initialize state hidden = None batchsize = x.shape[0] # Output string buffer label = [[] for _ in range(batchsize)] timesteps = [[] for _ in range(batchsize)] # Last Label buffer + Last Label without blank buffer # batch level equivalent of the last_label last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device) last_label_without_blank = last_label.clone() # Mask buffers blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) # Get max sequence length max_out_len = out_len.max() for time_idx in range(max_out_len): f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] # Prepare t timestamp batch variables not_blank = True symbols_added = 0 # Reset blank mask blank_mask.mul_(False) # Update blank mask with time mask # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch) # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len blank_mask = time_idx >= out_len # Start inner loop while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): # Batch prediction and joint network steps # If very first prediction step, submit SOS tag (blank) to pred_step. # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state if time_idx == 0 and symbols_added == 0: g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize) else: # Set a dummy label for the blank value # This value will be overwritten by "blank" again the last label update below # This is done as vocabulary of prediction network does not contain "blank" token of RNNT last_label_without_blank_mask = last_label == self._blank_index last_label_without_blank[last_label_without_blank_mask] = 0 # temp change of label last_label_without_blank[~last_label_without_blank_mask] = last_label[ ~last_label_without_blank_mask ] # Perform batch step prediction of decoder, getting new states and scores ("g") g, hidden_prime = self._pred_step(last_label_without_blank, hidden, batch_size=batchsize) # Batched joint step - Output = [B, V + 1] logp = self._joint_step(f, g, log_normalize=None)[:, 0, 0, :] if logp.dtype != torch.float32: logp = logp.float() # Get index k, of max prob for batch v, k = logp.max(1) del v, g, logp # Update blank mask with current predicted blanks # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U) k_is_blank = k == self._blank_index blank_mask.bitwise_or_(k_is_blank) # If all samples predict / have predicted prior blanks, exit loop early # This is equivalent to if single sample predicted k if blank_mask.all(): not_blank = False else: # Collect batch indices where blanks occurred now/past blank_indices = [] if hidden is not None: blank_indices = (blank_mask == 1).nonzero(as_tuple=False) # Recover prior state for all samples which predicted blank now/past if hidden is not None: # LSTM has 2 states for state_id in range(len(hidden)): hidden_prime[state_id][:, blank_indices, :] = hidden[state_id][:, blank_indices, :] # Recover prior predicted label for all samples which predicted blank now/past k[blank_indices] = last_label[blank_indices, 0] # Update new label and hidden state for next iteration last_label = k.view(-1, 1) hidden = hidden_prime # Update predicted labels, accounting for time mask # If blank was predicted even once, now or in the past, # Force the current predicted label to also be blank # This ensures that blanks propogate across all timesteps # once they have occured (normally stopping condition of sample level loop). for kidx, ki in enumerate(k): if blank_mask[kidx] == 0: label[kidx].append(ki) timesteps[kidx].append(time_idx) symbols_added += 1 return label, timesteps
@dataclass class GreedyRNNTInferConfig: max_symbols_per_step: Optional[int] = None @dataclass class GreedyBatchedRNNTInferConfig: max_symbols_per_step: Optional[int] = None