Source code for nemo.collections.asr.parts.submodules.ctc_greedy_decoding

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Union

import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.parts.context_biasing import BoostingTreeModelConfig, GPUBoostingTreeModel
from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.core.classes import Typing, typecheck
from nemo.core.neural_types import HypothesisType, LengthsType, LogprobsType, NeuralType
from nemo.core.utils.cuda_python_utils import (
    NeMoCUDAPythonException,
    check_cuda_python_cuda_graphs_conditional_nodes_supported,
    cu_call,
    run_nvrtc,
    with_conditional_node,
)
from nemo.core.utils.optional_libs import CUDA_PYTHON_AVAILABLE, cuda_python_required
from nemo.utils import logging, logging_mode
from nemo.utils.enum import PrettyStrEnum

if CUDA_PYTHON_AVAILABLE:
    from cuda.bindings import runtime as cudart

NEG_INF = float("-inf")


class CTCDecoderCudaGraphsState:
    """
    State for greedy CTC with NGPU-LM decoding. Used only with CUDA graphs.
    In initialization phase it is possible to assign values (tensors) to the state.
    For algorithm code the storage should be reused (prefer copy data instead of assigning tensors).
    """

    max_time: int  # maximum length of internal storage for time dimension
    batch_size: int  # (maximum) length of internal storage for batch dimension
    device: torch.device  # device to store preallocated tensors
    float_dtype: torch.dtype

    frame_idx: torch.Tensor
    active_mask: torch.Tensor

    decoder_outputs: torch.Tensor  # decoder output (probs)
    decoder_lengths: torch.Tensor  # decoder output lengths

    labels: torch.Tensor  # storage for current labels
    last_labels: torch.Tensor  # storage for previous labels
    scores: torch.Tensor  # storage for current scores

    batch_indices: torch.Tensor  # indices of elements in batch (constant, range [0, batch_size-1])

    batch_lm_states: Optional[torch.Tensor] = None
    lm_scores: Optional[torch.Tensor] = None
    batch_lm_states_candidates: Optional[torch.Tensor] = None

    prediction_labels: torch.Tensor
    prediction_logprobs: torch.Tensor

    full_graph = None

    def __init__(
        self,
        batch_size: int,
        max_time: int,
        vocab_dim: int,
        device: torch.device,
        float_dtype: torch.dtype,
    ):
        """

        Args:
            batch_size: batch size for encoder output storage
            max_time: maximum time for encoder output storage
            vocab_dim: number of vocabulary tokens (including blank)
            device: device to store tensors
            float_dtype: default float dtype for tensors (should match projected encoder output)
        """
        self.device = device
        self.float_dtype = float_dtype
        self.batch_size = batch_size
        self.max_time = max_time

        self.frame_idx = torch.tensor(
            0, dtype=torch.long, device=device
        )  # current frame index for each utterance (used to check if the decoding is finished)
        self.active_mask = torch.tensor(True, dtype=torch.bool, device=device)

        self.decoder_outputs = torch.zeros(
            (self.batch_size, self.max_time, vocab_dim),
            dtype=float_dtype,
            device=self.device,
        )
        self.decoder_lengths = torch.zeros((self.batch_size,), dtype=torch.long, device=self.device)

        self.labels = torch.zeros([self.batch_size], dtype=torch.long, device=self.device)
        self.last_labels = torch.zeros([self.batch_size], dtype=torch.long, device=self.device)
        self.scores = torch.zeros([self.batch_size], dtype=float_dtype, device=self.device)

        # indices of elements in batch (constant)
        self.batch_indices = torch.arange(self.batch_size, dtype=torch.long, device=self.device)

        # LM states
        self.batch_lm_states = torch.zeros([batch_size], dtype=torch.long, device=device)

        self.predictions_labels = torch.zeros([batch_size, max_time], device=device, dtype=torch.long)
        self.predictions_logprobs = torch.zeros([batch_size, max_time], device=device, dtype=float_dtype)

    def need_reinit(self, logits: torch.Tensor) -> bool:
        """Check if need to reinit state: larger batch_size/max_time, or new device"""
        return (
            self.batch_size < logits.shape[0]
            or self.max_time < logits.shape[1]
            or self.device.index != logits.device.index
        )


def pack_hypotheses(
    hypotheses: List[rnnt_utils.Hypothesis],
    logitlen: torch.Tensor,
) -> List[rnnt_utils.Hypothesis]:

    if logitlen is not None:
        if hasattr(logitlen, 'cpu'):
            logitlen_cpu = logitlen.to('cpu')
        else:
            logitlen_cpu = logitlen

    for idx, hyp in enumerate(hypotheses):  # type: rnnt_utils.Hypothesis
        hyp.y_sequence = torch.tensor(hyp.y_sequence, dtype=torch.long)

        if logitlen is not None:
            hyp.length = logitlen_cpu[idx]

        if hyp.dec_state is not None:
            hyp.dec_state = _states_to_device(hyp.dec_state)

    return hypotheses


def _states_to_device(dec_state, device='cpu'):
    if torch.is_tensor(dec_state):
        dec_state = dec_state.to(device)

    elif isinstance(dec_state, (list, tuple)):
        dec_state = tuple(_states_to_device(dec_i, device) for dec_i in dec_state)

    return dec_state


_DECODER_LENGTHS_NONE_WARNING = "Passing in decoder_lengths=None for CTC decoding is likely to be an error, since it is unlikely that each element of your batch has exactly the same length. decoder_lengths will default to decoder_output.shape[0]."


[docs] class GreedyCTCInfer(Typing, ConfidenceMethodMixin): """A greedy CTC decoder. Provides a common abstraction for sample level and batch level greedy decoding. Args: blank_index: int index of the blank token. Can be 0 or len(vocabulary). preserve_alignments: Bool flag which preserves the history of logprobs generated during decoding (sample / batched). When set to true, the Hypothesis will contain the non-null value for `logprobs` in it. Here, `logprobs` is a torch.Tensors. compute_timestamps: A bool flag, which determines whether to compute the character/subword, or word based timestamp mapping the output log-probabilities to discrite intervals of timestamps. The timestamps will be available in the returned Hypothesis.timestep as a dictionary. preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding. When set to true, the Hypothesis will contain the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of floats. confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. name: The method name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - 'entropy' for using a normalized entropy of a log-likelihood vector. entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. Supported values: - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). Note that for this entropy, the alpha should comply the following inequality: (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) entropy_norm: A mapping of the entropy value to the interval [0,1]. Supported values: - 'lin' for using the linear mapping. - 'exp' for using exponential mapping with linear shift. """ @property def input_types(self): """Returns definitions of module input ports.""" # Input can be of dimension - # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] return { "decoder_output": NeuralType(None, LogprobsType()), "decoder_lengths": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( self, blank_id: int, preserve_alignments: bool = False, compute_timestamps: bool = False, preserve_frame_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, ): super().__init__() self.blank_id = blank_id self.preserve_alignments = preserve_alignments # we need timestamps to extract non-blank per-frame confidence self.compute_timestamps = compute_timestamps | preserve_frame_confidence self.preserve_frame_confidence = preserve_frame_confidence # set confidence calculation method self._init_confidence_method(confidence_method_cfg)
[docs] @typecheck() def forward( self, decoder_output: torch.Tensor, decoder_lengths: Optional[torch.Tensor], ): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. Args: decoder_output: A tensor of size (batch, timesteps, features) or (batch, timesteps) (each timestep is a label). decoder_lengths: list of int representing the length of each sequence output sequence. Returns: packed list containing batch number of sentences (Hypotheses). """ logging.warning( "CTC decoding strategy 'greedy' is slower than 'greedy_batch', which implements the same exact interface. Consider changing your strategy to 'greedy_batch' for a free performance improvement.", mode=logging_mode.ONCE, ) if decoder_lengths is None: logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE) with torch.inference_mode(): hypotheses = [] # Process each sequence independently if decoder_output.is_cuda: # This two-liner is around twenty times faster than: # `prediction_cpu_tensor = decoder_output.cpu()` # cpu() does not use pinned memory, meaning that a slow pageable # copy must be done instead. prediction_cpu_tensor = torch.empty( decoder_output.shape, dtype=decoder_output.dtype, device=torch.device("cpu"), pin_memory=True ) prediction_cpu_tensor.copy_(decoder_output, non_blocking=True) else: prediction_cpu_tensor = decoder_output if decoder_lengths is not None and isinstance(decoder_lengths, torch.Tensor): # Before this change, self._greedy_decode_labels would copy # each scalar from GPU to CPU one at a time, in the line: # prediction = prediction[:out_len] # Doing one GPU to CPU copy ahead of time amortizes that overhead. decoder_lengths = decoder_lengths.cpu() if prediction_cpu_tensor.ndim < 2 or prediction_cpu_tensor.ndim > 3: raise ValueError( f"`decoder_output` must be a tensor of shape [B, T] (labels, int) or " f"[B, T, V] (log probs, float). Provided shape = {prediction_cpu_tensor.shape}" ) # determine type of input - logprobs or labels if prediction_cpu_tensor.ndim == 2: # labels greedy_decode = self._greedy_decode_labels else: greedy_decode = self._greedy_decode_logprobs for ind in range(prediction_cpu_tensor.shape[0]): out_len = decoder_lengths[ind] if decoder_lengths is not None else None hypothesis = greedy_decode(prediction_cpu_tensor[ind], out_len) hypotheses.append(hypothesis) # Pack results into Hypotheses packed_result = pack_hypotheses(hypotheses, decoder_lengths) return (packed_result,)
@torch.no_grad() def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: Optional[torch.Tensor]): # x: [T, D] # out_len: [seq_len] # Initialize blank state and empty label set in Hypothesis hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestamp=[], last_token=None) prediction = x.cpu() if out_len is not None: prediction = prediction[:out_len] prediction_logprobs, prediction_labels = prediction.max(dim=-1) non_blank_ids = prediction_labels != self.blank_id hypothesis.y_sequence = prediction_labels.tolist() hypothesis.score = (prediction_logprobs[non_blank_ids]).sum() if self.preserve_alignments: # Preserve the logprobs, as well as labels after argmax hypothesis.alignments = (prediction.clone(), prediction_labels.clone()) if self.compute_timestamps: hypothesis.timestamp = torch.nonzero(non_blank_ids, as_tuple=False)[:, 0].tolist() if self.preserve_frame_confidence: hypothesis.frame_confidence = self._get_confidence(prediction) return hypothesis @torch.no_grad() def _greedy_decode_labels(self, x: torch.Tensor, out_len: Optional[torch.Tensor]): # x: [T] # out_len: [seq_len] # Initialize blank state and empty label set in Hypothesis hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestamp=[], last_token=None) prediction_labels = x.cpu() if out_len is not None: prediction_labels = prediction_labels[:out_len] non_blank_ids = prediction_labels != self.blank_id hypothesis.y_sequence = prediction_labels.tolist() hypothesis.score = -1.0 if self.preserve_alignments: raise ValueError("Requested for alignments, but predictions provided were labels, not log probabilities.") if self.compute_timestamps: hypothesis.timestamp = torch.nonzero(non_blank_ids, as_tuple=False)[:, 0].tolist() if self.preserve_frame_confidence: raise ValueError( "Requested for per-frame confidence, but predictions provided were labels, not log probabilities." ) return hypothesis def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs)
class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin, WithOptionalCudaGraphs): """A vectorized greedy CTC decoder. This is basically always faster than GreedyCTCInfer, and supports the same interface. See issue #8891 on github for what is wrong with GreedyCTCInfer. GreedyCTCInfer loops over each element in the batch, running kernels at batch size one. CPU overheads end up dominating. This implementation does appropriate masking to appropriately do the same operation in a batched manner. Args: blank_index: int index of the blank token. Can be 0 or len(vocabulary). preserve_alignments: Bool flag which preserves the history of logprobs generated during decoding (sample / batched). When set to true, the Hypothesis will contain the non-null value for `logprobs` in it. Here, `logprobs` is a torch.Tensors. compute_timestamps: A bool flag, which determines whether to compute the character/subword, or word based timestamp mapping the output log-probabilities to discrite intervals of timestamps. The timestamps will be available in the returned Hypothesis.timestep as a dictionary. preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding. When set to true, the Hypothesis will contain the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of floats. confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. name: The method name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - 'entropy' for using a normalized entropy of a log-likelihood vector. entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. Supported values: - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). Note that for this entropy, the alpha should comply the following inequality: (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) entropy_norm: A mapping of the entropy value to the interval [0,1]. Supported values: - 'lin' for using the linear mapping. - 'exp' for using exponential mapping with linear shift. """ fusion_models: Optional[List[NGramGPULanguageModel]] fusion_models_alpha: Optional[List[float]] class CudaGraphsMode(PrettyStrEnum): FULL_GRAPH = "full_graph" # Cuda graphs with conditional nodes, fastest implementation NO_WHILE_LOOPS = "no_while_loops" # Decoding with PyTorch while loops + partial Cuda graphs NO_GRAPHS = "no_graphs" # d @property def input_types(self): """Returns definitions of module input ports.""" # Input can be of dimension - # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] return { "decoder_output": NeuralType(None, LogprobsType()), "decoder_lengths": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( self, blank_id: int, preserve_alignments: bool = False, compute_timestamps: bool = False, preserve_frame_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, ngram_lm_model: Optional[str | Path] = None, ngram_lm_alpha: float = 0.0, boosting_tree: Optional[BoostingTreeModelConfig] = None, boosting_tree_alpha: float = 0.0, allow_cuda_graphs: bool = True, tokenizer: Optional[TokenizerSpec] = None, ): super().__init__() self.blank_id = blank_id self.preserve_alignments = preserve_alignments # we need timestamps to extract non-blank per-frame confidence self.compute_timestamps = compute_timestamps | preserve_frame_confidence self.preserve_frame_confidence = preserve_frame_confidence # set confidence calculation method self._init_confidence_method(confidence_method_cfg) # load fusion models from paths (ngram_lm_model and boosting_tree_model) self.fusion_models, self.fusion_models_alpha = [], [] if ngram_lm_model is not None: self.fusion_models.append( NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self.blank_id) ) self.fusion_models_alpha.append(ngram_lm_alpha) if boosting_tree and not BoostingTreeModelConfig.is_empty(boosting_tree): self.fusion_models.append(GPUBoostingTreeModel.from_config(boosting_tree, tokenizer=tokenizer)) self.fusion_models_alpha.append(boosting_tree_alpha) if not self.fusion_models: self.fusion_models = None self.fusion_models_alpha = None self.allow_cuda_graphs = False self.cuda_graphs_mode = None else: self.allow_cuda_graphs = allow_cuda_graphs self.cuda_graphs_mode = None self.maybe_enable_cuda_graphs() self.state: CTCDecoderCudaGraphsState | None = None self.cuda_graphs_allow_fallback = True @typecheck() def forward( self, decoder_output: torch.Tensor, decoder_lengths: Optional[torch.Tensor], ): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. Args: decoder_output: A tensor of size (batch, timesteps, features) or (batch, timesteps) (each timestep is a label). decoder_lengths: list of int representing the length of each sequence output sequence. Returns: packed list containing batch number of sentences (Hypotheses). """ input_decoder_lengths = decoder_lengths if decoder_lengths is None: logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE) decoder_lengths = torch.tensor( [decoder_output.shape[1]], dtype=torch.long, device=decoder_output.device ).expand(decoder_output.shape[0]) # GreedyCTCInfer::forward(), by accident, works with # decoder_lengths on either CPU or GPU when decoder_output is # on GPU. For the sake of backwards compatibility, we also # allow decoder_lengths to be on the CPU device. In this case, # we simply copy the decoder_lengths from CPU to GPU. If both # tensors are already on the same device, this is a no-op. decoder_lengths = decoder_lengths.to(decoder_output.device) if decoder_output.ndim == 2: if self.fusion_models is not None: raise NotImplementedError hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths) else: hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths) packed_result = pack_hypotheses(hypotheses, input_decoder_lengths) return (packed_result,) @torch.no_grad() def _greedy_decode_logprobs_batched(self, x: torch.Tensor, out_len: torch.Tensor): # x: [B, T, D] # out_len: [B] batch_size = x.shape[0] max_time = x.shape[1] predictions = x if self.fusion_models is None: # In CTC greedy decoding, each output maximum likelihood token # is calculated independent of the other tokens. predictions_logprobs, predictions_labels = predictions.max(dim=-1) else: for fusion_model in self.fusion_models: fusion_model.to(x.device) # decoding with NGPU-LM and Boosting Tree if self.cuda_graphs_mode is not None and x.device.type == "cuda": predictions_labels, predictions_logprobs = ( self._greedy_decode_logprobs_batched_fusion_models_cuda_graphs(logits=x, out_len=out_len) ) else: predictions_labels, predictions_logprobs = self._greedy_decode_logprobs_batched_fusion_models_torch( logits=x, out_len=out_len ) # Since predictions_logprobs is a padded matrix in the time # dimension, we consider invalid timesteps to be "blank". time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time) non_blank_ids_mask = torch.logical_and(predictions_labels != self.blank_id, time_steps < out_len.unsqueeze(1)) # Sum the non-blank labels to compute the score of the # transcription. This follows from Eq. (3) of "Connectionist # Temporal Classification: Labelling Unsegmented Sequence Data # with Recurrent Neural Networks". scores = torch.where(non_blank_ids_mask, predictions_logprobs, 0.0).sum(axis=1) scores = scores.cpu() predictions_labels = predictions_labels.cpu() out_len = out_len.cpu() if self.preserve_alignments or self.preserve_frame_confidence: predictions = predictions.cpu() hypotheses = [] # This mimics the for loop in GreedyCTCInfer::forward. for i in range(batch_size): hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestamp=[], last_token=None) hypothesis.score = scores[i] prediction_labels_no_padding = predictions_labels[i, : out_len[i]].tolist() assert predictions_labels.dtype == torch.int64 hypothesis.y_sequence = prediction_labels_no_padding if self.preserve_alignments: hypothesis.alignments = ( predictions[i, : out_len[i], :].clone(), predictions_labels[i, : out_len[i]].clone(), ) if self.compute_timestamps: # TOOD: Could do this in a vectorized manner... Would # prefer to have nonzero_static, though, for sanity. # Or do a prefix sum on out_len hypothesis.timestamp = torch.nonzero(non_blank_ids_mask[i], as_tuple=False)[:, 0].cpu().tolist() if self.preserve_frame_confidence: hypothesis.frame_confidence = self._get_confidence(predictions[i, : out_len[i], :]) hypotheses.append(hypothesis) return hypotheses @torch.no_grad() def _greedy_decode_labels_batched(self, x: torch.Tensor, out_len: torch.Tensor): """ This does greedy decoding in the case where you have already found the most likely token at each timestep. """ # x: [B, T] # out_len: [B] batch_size = x.shape[0] max_time = x.shape[1] predictions_labels = x time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time) non_blank_ids_mask = torch.logical_and(predictions_labels != self.blank_id, time_steps < out_len.unsqueeze(1)) predictions_labels = predictions_labels.cpu() out_len = out_len.cpu() hypotheses = [] for i in range(batch_size): hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestamp=[], last_token=None) hypothesis.y_sequence = predictions_labels[i, : out_len[i]].tolist() hypothesis.score = -1.0 if self.preserve_alignments: raise ValueError( "Requested for alignments, but predictions provided were labels, not log probabilities." ) if self.compute_timestamps: # TOOD: Could do this in a vectorized manner... Would # prefer to have nonzero_static, though, for sanity. # Or do a prefix sum on out_len hypothesis.timestamp = torch.nonzero(non_blank_ids_mask[i], as_tuple=False)[:, 0].cpu().tolist() if self.preserve_frame_confidence: raise ValueError( "Requested for per-frame confidence, but predictions provided were labels, not log probabilities." ) hypotheses.append(hypothesis) return hypotheses @torch.no_grad() def _greedy_decode_logprobs_batched_fusion_models_torch(self, logits: torch.Tensor, out_len: torch.Tensor): batch_size = logits.shape[0] max_time = logits.shape[1] device = logits.device float_dtype = logits.dtype batch_indices = torch.arange(batch_size, device=device, dtype=torch.long) # Step 1: Initialization batch_fusion_states_list = [] for fusion_model in self.fusion_models: batch_fusion_states_list.append(fusion_model.get_init_states(batch_size=batch_size, bos=True)) last_labels = torch.full([batch_size], fill_value=self.blank_id, device=device, dtype=torch.long) # resulting labels and logprobs storage predictions_labels = torch.zeros([batch_size, max_time], device=device, dtype=torch.long) predictions_logprobs = torch.zeros([batch_size, max_time], device=device, dtype=float_dtype) for i in range(max_time): # Step 2: Get most likely labels for current frame log_probs, labels = logits[:, i].max(dim=-1) log_probs_w_fusion = logits[:, i].clone() # Step 3: Get fusion scores fusion_states_candidates_list = [] for fusion_idx, fusion_model in enumerate(self.fusion_models): fusion_scores, batch_fusion_states_candidates = fusion_model.advance( states=batch_fusion_states_list[fusion_idx] ) fusion_scores = fusion_scores.to(dtype=float_dtype) log_probs_w_fusion[:, :-1] += self.fusion_models_alpha[fusion_idx] * fusion_scores fusion_states_candidates_list.append(batch_fusion_states_candidates) # Step 4: Get most likely labels with fusion scores. Labels that are blank or repeated are ignored. # Note: no need to mask blank labels log_probs_w_fusion[:, -1] = NEG_INF, as argmax is without blanks # Note: for efficiency, use scatter instead of log_probs_w_fusion[batch_indices, last_labels] = NEG_INF log_probs_w_fusion.scatter_(dim=1, index=last_labels.unsqueeze(-1), value=NEG_INF) log_probs_w_fusion, labels_w_fusion = log_probs_w_fusion[:, :-1].max(dim=-1) # Step 5: Update labels if they initially weren't blank or repeated blank_or_repeated = (labels == self.blank_id) | (labels == last_labels) torch.where(blank_or_repeated, labels, labels_w_fusion, out=labels) torch.where(blank_or_repeated, log_probs, log_probs_w_fusion, out=log_probs_w_fusion) # Step 6: Update fusion states and scores for non-blank and non-repeated labels for fusion_idx, fusion_model in enumerate(self.fusion_models): torch.where( blank_or_repeated, batch_fusion_states_list[fusion_idx], fusion_states_candidates_list[fusion_idx][batch_indices, labels * ~blank_or_repeated], out=batch_fusion_states_list[fusion_idx], ) predictions_labels[:, i] = labels predictions_logprobs[:, i] = log_probs_w_fusion last_labels = labels return predictions_labels, predictions_logprobs @torch.no_grad() def _before_loop(self): """ Initializes the state. """ # Step 1: Initialization for fusion models self.state.fusion_states_list = [] self.state.fusion_states_candidates_list = [] for fusion_model in self.fusion_models: self.state.fusion_states_list.append( fusion_model.get_init_states(batch_size=self.state.batch_size, bos=True) ) self.state.fusion_states_candidates_list.append( torch.zeros( [self.state.batch_size, fusion_model.vocab_size], dtype=torch.long, device=self.state.device ) ) self.state.last_labels.fill_(self.blank_id) self.state.frame_idx.fill_(0) self.state.active_mask.copy_((self.state.decoder_lengths > 0).any()) # resulting labels and logprobs storage self.state.predictions_labels.zero_() self.state.predictions_logprobs.zero_() @torch.no_grad() def _inner_loop(self): """ Performs a decoding step. """ # Step 2: Get most likely labels for current frame logits = self.state.decoder_outputs[:, self.state.frame_idx.unsqueeze(0)].squeeze(1) log_probs, labels = logits.max(dim=-1) log_probs_w_fusion = logits.clone() # Step 3: Get fusion scores for fusion_idx, fusion_model in enumerate(self.fusion_models): fusion_scores, fusion_states_candidates = fusion_model.advance( states=self.state.fusion_states_list[fusion_idx] ) fusion_scores = fusion_scores.to(dtype=self.state.float_dtype) log_probs_w_fusion[:, :-1] += self.fusion_models_alpha[fusion_idx] * fusion_scores self.state.fusion_states_candidates_list[fusion_idx].copy_(fusion_states_candidates) # Step 4: Get most likely labels with fusion scores. Labels that are blank or repeated are ignored. # Note: no need to mask blank labels log_probs_w_fusion[:, -1] = NEG_INF, as argmax is without blanks # Note: for efficiency, use scatter instead of log_probs_w_fusion[batch_indices, last_labels] = NEG_INF log_probs_w_fusion.scatter_(dim=1, index=self.state.last_labels.unsqueeze(-1), value=NEG_INF) log_probs_w_fusion, labels_w_fusion = log_probs_w_fusion[:, :-1].max(dim=-1) # Step 5: Update labels if they initially weren't blank or repeated blank_or_repeated = (labels == self.blank_id) | (labels == self.state.last_labels) torch.where(blank_or_repeated, labels, labels_w_fusion, out=labels) torch.where(blank_or_repeated, log_probs, log_probs_w_fusion, out=log_probs_w_fusion) self.state.predictions_labels[:, self.state.frame_idx.unsqueeze(0)] = labels.unsqueeze(-1) self.state.predictions_logprobs[:, self.state.frame_idx.unsqueeze(0)] = log_probs_w_fusion.unsqueeze(-1) # Step 6: Update fusion states and scores for non-blank and non-repeated labels for fusion_idx, fusion_model in enumerate(self.fusion_models): torch.where( blank_or_repeated, self.state.fusion_states_list[fusion_idx], self.state.fusion_states_candidates_list[fusion_idx][ self.state.batch_indices, labels * ~blank_or_repeated ], out=self.state.fusion_states_list[fusion_idx], ) self.state.last_labels.copy_(labels) self.state.frame_idx += 1 self.state.active_mask.copy_((self.state.decoder_lengths > self.state.frame_idx).any()) @classmethod def _create_while_loop_kernel(cls): """ Creates a kernel that evaluates whether to enter the outer loop body (not all hypotheses are decoded). Condition: while(active_mask_any). """ kernel_string = r"""\ typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle; extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value); extern "C" __global__ void ctc_loop_conditional(cudaGraphConditionalHandle handle, const bool *decoding_active) { cudaGraphSetConditional(handle, *decoding_active); } """ return run_nvrtc(kernel_string, b"ctc_loop_conditional", b"while_conditional_ctc.cu") def _graph_reinitialize(self, logits, logits_len): batch_size, max_time, vocab_dim = logits.shape self.state = CTCDecoderCudaGraphsState( batch_size=batch_size, max_time=max(max_time, 375), vocab_dim=vocab_dim, device=logits.device, float_dtype=logits.dtype, ) if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: try: self._full_graph_compile() except NeMoCUDAPythonException as e: if not self.cuda_graphs_allow_fallback: raise RuntimeError("Full CUDA graph decoding failed. Mode is forced, raising exception") from e logging.warning( f"Full CUDA graph compilation failed: {e}. " "Falling back to native PyTorch CUDA graphs. Decoding will be slower." ) self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS self._partial_graphs_compile() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: self._partial_graphs_compile() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: # no graphs needed pass else: raise NotImplementedError(f"Unknown graph mode: {self.cuda_graphs_mode}") @cuda_python_required def _full_graph_compile(self): """Compiling full graph""" stream_for_graph = torch.cuda.Stream(self.state.device) stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device)) self.state.full_graph = torch.cuda.CUDAGraph() with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph(self.state.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), ): self._before_loop() # NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements capture_status, _, graph, *_ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) ) assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive # capture: while decoding_active: (loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) loop_kernel = self._create_while_loop_kernel() decoding_active_ptr = np.array([self.state.active_mask.data_ptr()], dtype=np.uint64) loop_args = np.array( [loop_conditional_handle.getPtr(), decoding_active_ptr.ctypes.data], dtype=np.uint64, ) # loop while there are active utterances with with_conditional_node( loop_kernel, loop_args, loop_conditional_handle, device=self.state.device, ): self._inner_loop() def _partial_graphs_compile(self): """Compiling partial graphs""" stream_for_graph = torch.cuda.Stream(self.state.device) stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device)) self.state.before_loop_graph = torch.cuda.CUDAGraph() self.state.inner_loop_graph = torch.cuda.CUDAGraph() with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( self.state.before_loop_graph, stream=stream_for_graph, capture_error_mode="thread_local", ), ): self._before_loop() with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( self.state.inner_loop_graph, stream=stream_for_graph, capture_error_mode="thread_local", ), ): self._inner_loop() def _greedy_decode_logprobs_batched_fusion_models_cuda_graphs(self, logits: torch.Tensor, out_len: torch.Tensor): current_batch_size = logits.shape[0] current_max_time = logits.shape[1] if torch.is_autocast_enabled(): logits = logits.to(torch.get_autocast_gpu_dtype()) # init or reinit graph if self.state is None or self.state.need_reinit(logits): self._graph_reinitialize(logits=logits, logits_len=out_len) # copy decoder outputs and lenghts self.state.decoder_outputs[:current_batch_size, :current_max_time, ...].copy_(logits) self.state.decoder_lengths[: logits.shape[0]].copy_(out_len) # set length to zero for elements outside the current batch self.state.decoder_lengths[current_batch_size:].fill_(0) if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: self.state.full_graph.replay() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: self.state.before_loop_graph.replay() for _ in range(current_max_time): self.state.inner_loop_graph.replay() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: # this mode is only for testing purposes # manual loop instead of using graphs self._before_loop() for _ in range(current_max_time): self._inner_loop() else: raise NotImplementedError(f"Unknown graph mode: {self.cuda_graphs_mode}") return ( self.state.predictions_labels[:current_batch_size, :current_max_time].clone(), self.state.predictions_logprobs[:current_batch_size, :current_max_time].clone(), ) def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]): """ Method to set graphs mode. Use only for testing purposes. For debugging the algorithm use "no_graphs" mode, since it is impossible to debug CUDA graphs directly. """ self.cuda_graphs_mode = self.CudaGraphsMode(mode) if mode is not None else None self.cuda_graphs_allow_fallback = False self.state = None def maybe_enable_cuda_graphs(self): """Enable CUDA graphs if conditions met""" if self.cuda_graphs_mode is not None: # CUDA graphs are already enabled return False if not self.allow_cuda_graphs: self.cuda_graphs_mode = None else: # cuda graphs are allowed # check while loops try: check_cuda_python_cuda_graphs_conditional_nodes_supported() self.cuda_graphs_mode = self.CudaGraphsMode.FULL_GRAPH except (ImportError, ModuleNotFoundError, EnvironmentError) as e: logging.warning( "No conditional node support for Cuda.\n" "Cuda graphs with while loops are disabled, decoding speed will be slower\n" f"Reason: {e}" ) self.cuda_graphs_mode = self.CudaGraphsMode.NO_GRAPHS self.reset_cuda_graphs_state() return self.cuda_graphs_mode is not None def disable_cuda_graphs(self) -> bool: """Disable CUDA graphs, can be used to disable graphs temporary, e.g., in training process""" if self.cuda_graphs_mode is None: # nothing to disable return False self.cuda_graphs_mode = None self.reset_cuda_graphs_state() return True def reset_cuda_graphs_state(self): """Reset state to release memory (for CUDA graphs implementations)""" self.state = None def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) @dataclass class GreedyCTCInferConfig: preserve_alignments: bool = False compute_timestamps: bool = False preserve_frame_confidence: bool = False confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) ngram_lm_model: Optional[str] = None ngram_lm_alpha: float = 0.0 boosting_tree: BoostingTreeModelConfig = field(default_factory=BoostingTreeModelConfig) boosting_tree_alpha: Optional[float] = 0.0 allow_cuda_graphs: bool = True def __post_init__(self): # OmegaConf.structured ensures that post_init check is always executed self.confidence_method_cfg = OmegaConf.structured( self.confidence_method_cfg if isinstance(self.confidence_method_cfg, ConfidenceMethodConfig) else ConfidenceMethodConfig(**self.confidence_method_cfg) )