NeMo ASR collection API#

Model Classes#

class nemo.collections.asr.models.EncDecCTCModel(*args: Any, **kwargs: Any)[source]#

Bases: nemo.collections.asr.models.asr_model.ASRModel, nemo.collections.asr.models.asr_model.ExportableEncDecModel, nemo.collections.asr.parts.mixins.mixins.ASRModuleMixin

Base class for encoder decoder CTC-based models.

change_vocabulary(new_vocabulary: List[str])[source]#

Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you’d need model to learn capitalization, punctuation and/or special characters.

If new_vocabulary == self.decoder.vocabulary then nothing will be changed.

Parameters

new_vocabulary – list with new vocabulary. Must contain at least 2 elements. Typically, this is target alphabet.

Returns: None

forward(input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None)[source]#

Forward pass of the model.

Parameters
  • input_signal – Tensor that represents a batch of raw audio signals, of shape [B, T]. T here represents timesteps, with 1 second of audio represented as self.sample_rate number of floating point values.

  • input_signal_length – Vector of length B, that contains the individual lengths of the audio sequences.

  • processed_signal – Tensor that represents a batch of processed audio signals, of shape (B, D, T) that has undergone processing via some DALI preprocessor.

  • processed_signal_length – Vector of length B, that contains the individual lengths of the processed audio sequences.

Returns

A tuple of 3 elements - 1) The log probabilities tensor of shape [B, T, D]. 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B]. 3) The greedy token predictions of the model of shape [B, T] (via argmax)

property input_types: Optional[Dict[str, nemo.core.neural_types.neural_type.NeuralType]]#

Define these to enable input neural type checks

classmethod list_available_models() Optional[nemo.core.classes.common.PretrainedModelInfo][source]#

This method returns a list of pre-trained model which can be instantiated directly from NVIDIA’s NGC cloud.

Returns

List of available pre-trained models.

property output_types: Optional[Dict[str, nemo.core.neural_types.neural_type.NeuralType]]#

Define these to enable output neural type checks

predict_step(batch, batch_idx, dataloader_idx=0)[source]#
setup_test_data(test_data_config: Optional[Union[omegaconf.DictConfig, Dict]])[source]#

Sets up the test data loader via a Dict-like object.

Parameters

test_data_config – A config that contains the information regarding construction of an ASR Training dataset.

Supported Datasets:
setup_training_data(train_data_config: Optional[Union[omegaconf.DictConfig, Dict]])[source]#

Sets up the training data loader via a Dict-like object.

Parameters

train_data_config – A config that contains the information regarding construction of an ASR Training dataset.

Supported Datasets:
setup_validation_data(val_data_config: Optional[Union[omegaconf.DictConfig, Dict]])[source]#

Sets up the validation data loader via a Dict-like object.

Parameters

val_data_config – A config that contains the information regarding construction of an ASR Training dataset.

Supported Datasets:
test_dataloader()[source]#
test_step(batch, batch_idx, dataloader_idx=0)[source]#
training_step(batch, batch_nb)[source]#
transcribe(paths2audio_files: List[str], batch_size: int = 4, logprobs: bool = False, return_hypotheses: bool = False, num_workers: int = 0) List[str]#

Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.

Parameters
  • paths2audio_files – (a list) of paths to audio files. Recommended length per file is between 5 and 25 seconds. But it is possible to pass a few hours long file if enough GPU memory is available.

  • batch_size – (int) batch size to use during inference. Bigger will result in better throughput performance but would use more memory.

  • logprobs – (bool) pass True to get log probabilities instead of transcripts.

  • return_hypotheses – (bool) Either return hypotheses or text With hypotheses can do some postprocessing like getting timestamp or rescoring

  • num_workers – (int) number of workers for DataLoader

Returns

A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files

validation_step(batch, batch_idx, dataloader_idx=0)[source]#
class nemo.collections.asr.models.EncDecCTCModelBPE(*args: Any, **kwargs: Any)[source]#

Bases: nemo.collections.asr.models.ctc_models.EncDecCTCModel, nemo.collections.asr.parts.mixins.mixins.ASRBPEMixin

Encoder decoder CTC-based models with Byte Pair Encoding.

change_vocabulary(new_tokenizer_dir: Union[str, omegaconf.DictConfig], new_tokenizer_type: str)[source]#

Changes vocabulary of the tokenizer used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you’d need model to learn capitalization, punctuation and/or special characters.

Parameters
  • new_tokenizer_dir – Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is agg)

  • new_tokenizer_type – Either agg, bpe or wpe. bpe is used for SentencePiece tokenizers, whereas wpe is used for BertTokenizer.

  • new_tokenizer_cfg – A config for the new tokenizer. if provided, pre-empts the dir and type

Returns: None

classmethod list_available_models() Optional[nemo.core.classes.common.PretrainedModelInfo][source]#

This method returns a list of pre-trained model which can be instantiated directly from NVIDIA’s NGC cloud.

Returns

List of available pre-trained models.

class nemo.collections.asr.models.EncDecRNNTModel(*args: Any, **kwargs: Any)[source]#

Bases: nemo.collections.asr.models.asr_model.ASRModel, nemo.collections.asr.parts.mixins.mixins.ASRModuleMixin, nemo.core.classes.exportable.Exportable

Base class for encoder decoder RNNT-based models.

change_decoding_strategy(decoding_cfg: omegaconf.DictConfig)[source]#

Changes decoding strategy used during RNNT decoding process.

Parameters

decoding_cfg – A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here.

change_vocabulary(new_vocabulary: List[str], decoding_cfg: Optional[omegaconf.DictConfig] = None)[source]#

Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a pre-trained model. This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you’d need model to learn capitalization, punctuation and/or special characters.

Parameters
  • new_vocabulary – list with new vocabulary. Must contain at least 2 elements. Typically, this is target alphabet.

  • decoding_cfg – A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here.

Returns: None

property decoder_joint#
extract_rnnt_loss_cfg(cfg: Optional[omegaconf.DictConfig])[source]#

Helper method to extract the rnnt loss name, and potentially its kwargs to be passed.

Parameters

cfg

Should contain loss_name as a string which is resolved to a RNNT loss name. If the default should be used, then default can be used. Optionally, one can pass additional kwargs to the loss function. The subdict should have a keyname as follows : {loss_name}_kwargs.

Note that whichever loss_name is selected, that corresponding kwargs will be selected. For the “default” case, the “{resolved_default}_kwargs” will be used.

Examples

loss_name: "default"
warprnnt_numba_kwargs:
    kwargs2: some_other_val
Returns

A tuple, the resolved loss name as well as its kwargs (if found).

forward(input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None)[source]#

Forward pass of the model. Note that for RNNT Models, the forward pass of the model is a 3 step process, and this method only performs the first step - forward of the acoustic model.

Please refer to the training_step in order to see the full forward step for training - which performs the forward of the acoustic model, the prediction network and then the joint network. Finally, it computes the loss and possibly compute the detokenized text via the decoding step.

Please refer to the validation_step in order to see the full forward step for inference - which performs the forward of the acoustic model, the prediction network and then the joint network. Finally, it computes the decoded tokens via the decoding step and possibly compute the batch metrics.

Parameters
  • input_signal – Tensor that represents a batch of raw audio signals, of shape [B, T]. T here represents timesteps, with 1 second of audio represented as self.sample_rate number of floating point values.

  • input_signal_length – Vector of length B, that contains the individual lengths of the audio sequences.

  • processed_signal – Tensor that represents a batch of processed audio signals, of shape (B, D, T) that has undergone processing via some DALI preprocessor.

  • processed_signal_length – Vector of length B, that contains the individual lengths of the processed audio sequences.

Returns

A tuple of 2 elements - 1) The log probabilities tensor of shape [B, T, D]. 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B].

property input_types: Optional[Dict[str, nemo.core.neural_types.neural_type.NeuralType]]#

Define these to enable input neural type checks

classmethod list_available_models() Optional[nemo.core.classes.common.PretrainedModelInfo][source]#

This method returns a list of pre-trained model which can be instantiated directly from NVIDIA’s NGC cloud.

Returns

List of available pre-trained models.

list_export_subnets()[source]#

Returns default set of subnet names exported for this model First goes the one receiving input (input_example)

multi_test_epoch_end(outputs, dataloader_idx: int = 0)[source]#

Adds support for multiple test datasets. Should be overriden by subclass, so as to obtain appropriate logs for each of the dataloaders.

Parameters
  • outputs – Same as that provided by LightningModule.validation_epoch_end() for a single dataloader.

  • dataloader_idx – int representing the index of the dataloader.

Returns

A dictionary of values, optionally containing a sub-dict log, such that the values in the log will be pre-pended by the dataloader prefix.

multi_validation_epoch_end(outputs, dataloader_idx: int = 0)[source]#

Adds support for multiple validation datasets. Should be overriden by subclass, so as to obtain appropriate logs for each of the dataloaders.

Parameters
  • outputs – Same as that provided by LightningModule.validation_epoch_end() for a single dataloader.

  • dataloader_idx – int representing the index of the dataloader.

Returns

A dictionary of values, optionally containing a sub-dict log, such that the values in the log will be pre-pended by the dataloader prefix.

on_after_backward()[source]#

zero-out the gradients which any of them is NAN or INF

property output_types: Optional[Dict[str, nemo.core.neural_types.neural_type.NeuralType]]#

Define these to enable output neural type checks

predict_step(batch, batch_idx, dataloader_idx=0)[source]#
setup_optim_normalization()[source]#

Helper method to setup normalization of certain parts of the model prior to the optimization step.

Supported pre-optimization normalizations are as follows:

# Variation Noise injection
model:
    variational_noise:
        std: 0.0
        start_step: 0

# Joint - Length normalization
model:
    normalize_joint_txu: false

# Encoder Network - gradient normalization
model:
    normalize_encoder_norm: false

# Decoder / Prediction Network - gradient normalization
model:
    normalize_decoder_norm: false

# Joint - gradient normalization
model:
    normalize_joint_norm: false
setup_test_data(test_data_config: Optional[Union[omegaconf.DictConfig, Dict]])[source]#

Sets up the test data loader via a Dict-like object.

Parameters

test_data_config – A config that contains the information regarding construction of an ASR Training dataset.

Supported Datasets:
setup_training_data(train_data_config: Optional[Union[omegaconf.DictConfig, Dict]])[source]#

Sets up the training data loader via a Dict-like object.

Parameters

train_data_config – A config that contains the information regarding construction of an ASR Training dataset.

Supported Datasets:
setup_validation_data(val_data_config: Optional[Union[omegaconf.DictConfig, Dict]])[source]#

Sets up the validation data loader via a Dict-like object.

Parameters

val_data_config – A config that contains the information regarding construction of an ASR Training dataset.

Supported Datasets:
test_step(batch, batch_idx, dataloader_idx=0)[source]#
training_step(batch, batch_nb)[source]#
transcribe(paths2audio_files: List[str], batch_size: int = 4, return_hypotheses: bool = False, partial_hypothesis: Optional[List[Hypothesis]] = None, num_workers: int = 0)#

Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.

Parameters
  • paths2audio_files – (a list) of paths to audio files. Recommended length per file is between 5 and 25 seconds. But it is possible to pass a few hours long file if enough GPU memory is available.

  • batch_size – (int) batch size to use during inference. Bigger will result in better throughput performance but would use more memory.

  • return_hypotheses – (bool) Either return hypotheses or text

With hypotheses can do some postprocessing like getting timestamp or rescoring

num_workers: (int) number of workers for DataLoader

Returns

A list of transcriptions in the same order as paths2audio_files. Will also return

validation_step(batch, batch_idx, dataloader_idx=0)[source]#
class nemo.collections.asr.models.EncDecRNNTBPEModel(*args: Any, **kwargs: Any)[source]#

Bases: nemo.collections.asr.models.rnnt_models.EncDecRNNTModel, nemo.collections.asr.parts.mixins.mixins.ASRBPEMixin

Base class for encoder decoder RNNT-based models with subword tokenization.

change_decoding_strategy(decoding_cfg: omegaconf.DictConfig)[source]#

Changes decoding strategy used during RNNT decoding process.

Parameters

decoding_cfg – A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here.

change_vocabulary(new_tokenizer_dir: Union[str, omegaconf.DictConfig], new_tokenizer_type: str, decoding_cfg: Optional[omegaconf.DictConfig] = None)[source]#

Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning on from pre-trained model. This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you’d need model to learn capitalization, punctuation and/or special characters.

Parameters
  • new_tokenizer_dir – Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is agg)

  • new_tokenizer_type – Type of tokenizer. Can be either agg, bpe or wpe.

  • decoding_cfg – A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here.

Returns: None

classmethod list_available_models() List[nemo.core.classes.common.PretrainedModelInfo][source]#

This method returns a list of pre-trained model which can be instantiated directly from NVIDIA’s NGC cloud.

Returns

List of available pre-trained models.

class nemo.collections.asr.models.EncDecClassificationModel(*args: Any, **kwargs: Any)[source]#

Bases: nemo.collections.asr.models.classification_models._EncDecBaseModel

Encoder decoder Classification models.

change_labels(new_labels: List[str])[source]#

Changes labels used by the decoder model. Use this method when fine-tuning on from pre-trained model. This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would use it if you want to use pretrained encoder when fine-tuning on a data in another dataset.

If new_labels == self.decoder.vocabulary then nothing will be changed.

Parameters

new_labels – list with new labels. Must contain at least 2 elements. Typically, this is set of labels for the dataset.

Returns: None

forward(input_signal, input_signal_length)[source]#
classmethod list_available_models() Optional[List[nemo.core.classes.common.PretrainedModelInfo]][source]#

This method returns a list of pre-trained model which can be instantiated directly from NVIDIA’s NGC cloud.

Returns

List of available pre-trained models.

multi_test_epoch_end(outputs, dataloader_idx: int = 0)[source]#

Adds support for multiple test datasets. Should be overriden by subclass, so as to obtain appropriate logs for each of the dataloaders.

Parameters
  • outputs – Same as that provided by LightningModule.validation_epoch_end() for a single dataloader.

  • dataloader_idx – int representing the index of the dataloader.

Returns

A dictionary of values, optionally containing a sub-dict log, such that the values in the log will be pre-pended by the dataloader prefix.

multi_validation_epoch_end(outputs, dataloader_idx: int = 0)[source]#

Adds support for multiple validation datasets. Should be overriden by subclass, so as to obtain appropriate logs for each of the dataloaders.

Parameters
  • outputs – Same as that provided by LightningModule.validation_epoch_end() for a single dataloader.

  • dataloader_idx – int representing the index of the dataloader.

Returns

A dictionary of values, optionally containing a sub-dict log, such that the values in the log will be pre-pended by the dataloader prefix.

property output_types: Optional[Dict[str, nemo.core.neural_types.neural_type.NeuralType]]#

Define these to enable output neural type checks

test_step(batch, batch_idx, dataloader_idx=0)[source]#
training_step(batch, batch_nb)[source]#
validation_step(batch, batch_idx, dataloader_idx=0)[source]#
class nemo.collections.asr.models.EncDecSpeakerLabelModel(*args: Any, **kwargs: Any)[source]#

Bases: nemo.core.classes.modelPT.ModelPT, nemo.collections.asr.models.asr_model.ExportableEncDecModel

Encoder decoder class for speaker label models. Model class creates training, validation methods for setting up data performing model forward pass. Expects config dict for

  • preprocessor

  • Jasper/Quartznet Encoder

  • Speaker Decoder

static extract_labels(data_layer_config)[source]#
forward(input_signal, input_signal_length)[source]#
forward_for_export(processed_signal, processed_signal_len)[source]#
static get_batch_embeddings(speaker_model, manifest_filepath, batch_size=32, sample_rate=16000, device='cuda')#
get_embedding(path2audio_file)#

Returns the speaker embeddings for a provided audio file.

Parameters

path2audio_file – path to audio wav file

Returns

speaker embeddings

Return type

embs

property input_types: Optional[Dict[str, nemo.core.neural_types.neural_type.NeuralType]]#

Define these to enable input neural type checks

classmethod list_available_models() List[nemo.core.classes.common.PretrainedModelInfo][source]#

This method returns a list of pre-trained model which can be instantiated directly from NVIDIA’s NGC cloud. :returns: List of available pre-trained models.

multi_test_epoch_end(outputs, dataloader_idx: int = 0)[source]#

Adds support for multiple test datasets. Should be overriden by subclass, so as to obtain appropriate logs for each of the dataloaders.

Parameters
  • outputs – Same as that provided by LightningModule.validation_epoch_end() for a single dataloader.

  • dataloader_idx – int representing the index of the dataloader.

Returns

A dictionary of values, optionally containing a sub-dict log, such that the values in the log will be pre-pended by the dataloader prefix.

multi_validation_epoch_end(outputs, dataloader_idx: int = 0)[source]#

Adds support for multiple validation datasets. Should be overriden by subclass, so as to obtain appropriate logs for each of the dataloaders.

Parameters
  • outputs – Same as that provided by LightningModule.validation_epoch_end() for a single dataloader.

  • dataloader_idx – int representing the index of the dataloader.

Returns

A dictionary of values, optionally containing a sub-dict log, such that the values in the log will be pre-pended by the dataloader prefix.

property output_types: Optional[Dict[str, nemo.core.neural_types.neural_type.NeuralType]]#

Define these to enable output neural type checks

setup_finetune_model(model_config: omegaconf.DictConfig)[source]#

setup_finetune_model method sets up training data, validation data and test data with new provided config, this checks for the previous labels set up during training from scratch, if None, it sets up labels for provided finetune data from manifest files

Parameters
  • model_config – cfg which has train_ds, optional validation_ds, optional test_ds,

  • data. (mandatory encoder and decoder model params. Make sure you set num_classes correctly for finetune) –

Returns

None

setup_test_data(test_data_layer_params: Optional[Union[omegaconf.DictConfig, Dict]])[source]#

(Optionally) Setups data loader to be used in test

Parameters

test_data_layer_config – test data layer parameters.

Returns:

setup_training_data(train_data_layer_config: Optional[Union[omegaconf.DictConfig, Dict]])[source]#

Setups data loader to be used in training

Parameters

train_data_layer_config – training data layer parameters.

Returns:

setup_validation_data(val_data_layer_config: Optional[Union[omegaconf.DictConfig, Dict]])[source]#

Setups data loader to be used in validation :param val_data_layer_config: validation data layer parameters.

Returns:

test_dataloader()[source]#
test_step(batch, batch_idx, dataloader_idx: int = 0)[source]#
training_step(batch, batch_idx)[source]#
validation_step(batch, batch_idx, dataloader_idx: int = 0)[source]#
verify_speakers(path2audio_file1, path2audio_file2, threshold=0.7)#

Verify if two audio files are from the same speaker or not.

Parameters
  • path2audio_file1 – path to audio wav file of speaker 1

  • path2audio_file2 – path to audio wav file of speaker 2

  • threshold – cosine similarity score used as a threshold to distinguish two embeddings (default = 0.7)

Returns

True if both audio files are from same speaker, False otherwise

Modules#

class nemo.collections.asr.modules.ConvASREncoder(*args: Any, **kwargs: Any)[source]#

Bases: nemo.core.classes.module.NeuralModule, nemo.core.classes.exportable.Exportable

Convolutional encoder for ASR models. With this class you can implement JasperNet and QuartzNet models.

Based on these papers:

https://arxiv.org/pdf/1904.03288.pdf https://arxiv.org/pdf/1910.10261.pdf

forward(audio_signal, length)[source]#
input_example(max_batch=1, max_dim=8192)[source]#

Generates input examples for tracing etc. :returns: A tuple of input examples.

property input_types#

Returns definitions of module input ports.

property output_types#

Returns definitions of module output ports.

update_max_sequence_length(seq_length: int, device)[source]#
class nemo.collections.asr.modules.ConvASRDecoder(*args: Any, **kwargs: Any)[source]#

Bases: nemo.core.classes.module.NeuralModule, nemo.core.classes.exportable.Exportable, nemo.core.classes.mixins.adapter_mixins.AdapterModuleMixin

Simple ASR Decoder for use with CTC-based models such as JasperNet and QuartzNet

Based on these papers:

https://arxiv.org/pdf/1904.03288.pdf https://arxiv.org/pdf/1910.10261.pdf https://arxiv.org/pdf/2005.04290.pdf

add_adapter(name: str, cfg: omegaconf.DictConfig)[source]#

Add an Adapter module to this module.

Parameters
  • name – A globally unique name for the adapter. Will be used to access, enable and disable adapters.

  • cfg – A DictConfig or Dataclass that contains at the bare minimum __target__ to instantiate a new Adapter module.

forward(encoder_output)[source]#
input_example(max_batch=1, max_dim=256)[source]#

Generates input examples for tracing etc. :returns: A tuple of input examples.

property input_types#

Define these to enable input neural type checks

property num_classes_with_blank#
property output_types#

Define these to enable output neural type checks

property vocabulary#
class nemo.collections.asr.modules.ConvASRDecoderClassification(*args: Any, **kwargs: Any)[source]#

Bases: nemo.core.classes.module.NeuralModule, nemo.core.classes.exportable.Exportable

Simple ASR Decoder for use with classification models such as JasperNet and QuartzNet

Based on these papers:

https://arxiv.org/pdf/2005.04290.pdf

forward(encoder_output)[source]#
input_example(max_batch=1, max_dim=256)[source]#

Generates input examples for tracing etc. :returns: A tuple of input examples.

property input_types#

Define these to enable input neural type checks

property num_classes#
property output_types#

Define these to enable output neural type checks

class nemo.collections.asr.modules.SpeakerDecoder(*args: Any, **kwargs: Any)[source]#

Bases: nemo.core.classes.module.NeuralModule, nemo.core.classes.exportable.Exportable

Speaker Decoder creates the final neural layers that maps from the outputs of Jasper Encoder to the embedding layer followed by speaker based softmax loss. :param feat_in: Number of channels being input to this module :type feat_in: int :param num_classes: Number of unique speakers in dataset :type num_classes: int :param emb_sizes: shapes of intermediate embedding layers (we consider speaker embbeddings from 1st of this layers)

Defaults to [1024,1024]

Parameters
  • pool_mode (str) – Pooling stratergy type. options are ‘xvector’,’tap’, ‘attention’ Defaults to ‘xvector (mean and variance)’ tap (temporal average pooling: just mean) attention (attention based pooling)

  • init_mode (str) – Describes how neural network parameters are initialized. Options are [‘xavier_uniform’, ‘xavier_normal’, ‘kaiming_uniform’,’kaiming_normal’]. Defaults to “xavier_uniform”.

affine_layer(inp_shape, out_shape, learn_mean=True, affine_type='conv')[source]#
forward(encoder_output, length=None)[source]#
input_example(max_batch=1, max_dim=256)[source]#

Generates input examples for tracing etc. :returns: A tuple of input examples.

property input_types#

Define these to enable input neural type checks

property output_types#

Define these to enable output neural type checks

class nemo.collections.asr.modules.ConformerEncoder(*args: Any, **kwargs: Any)[source]#

Bases: nemo.core.classes.module.NeuralModule, nemo.core.classes.exportable.Exportable

The encoder for ASR model of Conformer. Based on this paper: ‘Conformer: Convolution-augmented Transformer for Speech Recognition’ by Anmol Gulati et al. https://arxiv.org/abs/2005.08100

Parameters
  • feat_in (int) – the size of feature channels

  • n_layers (int) – number of layers of ConformerBlock

  • d_model (int) – the hidden size of the model

  • feat_out (int) – the size of the output features Defaults to -1 (means feat_out is d_model)

  • subsampling (str) – the method of subsampling, choices=[‘vggnet’, ‘striding’] Defaults to striding.

  • subsampling_factor (int) – the subsampling factor which should be power of 2 Defaults to 4.

  • subsampling_conv_channels (int) – the size of the convolutions in the subsampling module Defaults to -1 which would set it to d_model.

  • ff_expansion_factor (int) – the expansion factor in feed forward layers Defaults to 4.

  • self_attention_model (str) – type of the attention layer and positional encoding ‘rel_pos’: relative positional embedding and Transformer-XL ‘abs_pos’: absolute positional embedding and Transformer default is rel_pos.

  • pos_emb_max_len (int) – the maximum length of positional embeddings Defaulst to 5000

  • n_heads (int) – number of heads in multi-headed attention layers Defaults to 4.

  • xscaling (bool) – enables scaling the inputs to the multi-headed attention layers by sqrt(d_model) Defaults to True.

  • untie_biases (bool) – whether to not share (untie) the bias weights between layers of Transformer-XL Defaults to True.

  • conv_kernel_size (int) – the size of the convolutions in the convolutional modules Defaults to 31.

  • conv_norm_type (str) – the type of the normalization in the convolutional modules Defaults to ‘batch_norm’.

  • dropout (float) – the dropout rate used in all layers except the attention layers Defaults to 0.1.

  • dropout_emb (float) – the dropout rate used for the positional embeddings Defaults to 0.1.

  • dropout_att (float) – the dropout rate used for the attention layer Defaults to 0.0.

enable_pad_mask(on=True)[source]#
forward(audio_signal, length=None)[source]#
forward_for_export(audio_signal, length)[source]#
input_example(max_batch=1, max_dim=256)[source]#

Generates input examples for tracing etc. :returns: A tuple of input examples.

property input_types#

Returns definitions of module input ports.

make_pad_mask(max_audio_length, seq_lens)[source]#

Make masking for padding.

property output_types#

Returns definitions of module output ports.

set_max_audio_length(max_audio_length)[source]#

Sets maximum input length. Pre-calculates internal seq_range mask.

update_max_seq_length(seq_length: int, device)[source]#

Parts#

class nemo.collections.asr.parts.submodules.jasper.JasperBlock(*args: Any, **kwargs: Any)[source]#

Bases: torch.nn.Module, nemo.core.classes.mixins.adapter_mixins.AdapterModuleMixin, nemo.core.classes.mixins.access_mixins.AccessMixin

Constructs a single “Jasper” block. With modified parameters, also constructs other blocks for models such as QuartzNet and Citrinet.

  • For Jasper : separable flag should be False

  • For QuartzNet : separable flag should be True

  • For Citrinet : separable flag and se flag should be True

Note that above are general distinctions, each model has intricate differences that expand over multiple such blocks.

For further information about the differences between models which use JasperBlock, please review the configs for ASR models found in the ASR examples directory.

Parameters
  • inplanes – Number of input channels.

  • planes – Number of output channels.

  • repeat – Number of repeated sub-blocks (R) for this block.

  • kernel_size – Convolution kernel size across all repeated sub-blocks.

  • kernel_size_factor – Floating point scale value that is multiplied with kernel size, then rounded down to nearest odd integer to compose the kernel size. Defaults to 1.0.

  • stride – Stride of the convolutional layers.

  • dilation – Integer which defined dilation factor of kernel. Note that when dilation > 1, stride must be equal to 1.

  • padding – String representing type of padding. Currently only supports “same” padding, which symmetrically pads the input tensor with zeros.

  • dropout – Floating point value, determins percentage of output that is zeroed out.

  • activation – String representing activation functions. Valid activation functions are : {“hardtanh”: nn.Hardtanh, “relu”: nn.ReLU, “selu”: nn.SELU, “swish”: Swish}. Defaults to “relu”.

  • residual – Bool that determined whether a residual branch should be added or not. All residual branches are constructed using a pointwise convolution kernel, that may or may not perform strided convolution depending on the parameter residual_mode.

  • groups – Number of groups for Grouped Convolutions. Defaults to 1.

  • separable – Bool flag that describes whether Time-Channel depthwise separable convolution should be constructed, or ordinary convolution should be constructed.

  • heads – Number of “heads” for the masked convolution. Defaults to -1, which disables it.

  • normalization – String that represents type of normalization performed. Can be one of “batch”, “group”, “instance” or “layer” to compute BatchNorm1D, GroupNorm1D, InstanceNorm or LayerNorm (which are special cases of GroupNorm1D).

  • norm_groups – Number of groups used for GroupNorm (if normalization == “group”).

  • residual_mode – String argument which describes whether the residual branch should be simply added (“add”) or should first stride, then add (“stride_add”). Required when performing stride on parallel branch as well as utilizing residual add.

  • residual_panes – Number of residual panes, used for Jasper-DR models. Please refer to the paper.

  • conv_mask – Bool flag which determines whether to utilize masked convolutions or not. In general, it should be set to True.

  • se – Bool flag that determines whether Squeeze-and-Excitation layer should be used.

  • se_reduction_ratio – Integer value, which determines to what extend the hidden dimension of the SE intermediate step should be reduced. Larger values reduce number of parameters, but also limit the effectiveness of SE layers.

  • se_context_window – Integer value determining the number of timesteps that should be utilized in order to compute the averaged context window. Defaults to -1, which means it uses global context - such that all timesteps are averaged. If any positive integer is used, it will utilize limited context window of that size.

  • se_interpolation_mode – String used for interpolation mode of timestep dimension for SE blocks. Used only if context window is > 1. The modes available for resizing are: nearest, linear (3D-only), bilinear, area.

  • stride_last – Bool flag that determines whether all repeated blocks should stride at once, (stride of S^R when this flag is False) or just the last repeated block should stride (stride of S when this flag is True).

  • future_context

    Int value that determins how many “right” / “future” context frames will be utilized when calculating the output of the conv kernel. All calculations are done for odd kernel sizes only.

    By default, this is -1, which is recomputed as the symmetric padding case.

    When future_context >= 0, will compute the asymmetric padding as follows : (left context, right context) = [K - 1 - future_context, future_context]

    Determining an exact formula to limit future context is dependent on global layout of the model. As such, we provide both “local” and “global” guidelines below.

    Local context limit (should always be enforced) - future context should be <= half the kernel size for any given layer - future context > kernel size defaults to symmetric kernel - future context of layer = number of future frames * width of each frame (dependent on stride)

    Global context limit (should be carefully considered) - future context should be layed out in an ever reducing pattern. Initial layers should restrict future context less than later layers, since shallow depth (and reduced stride) means each frame uses less amounts of future context. - Beyond a certain point, future context should remain static for a given stride level. This is the upper bound of the amount of future context that can be provided to the model on a global scale. - future context is calculated (roughly) as - (2 ^ stride) * (K // 2) number of future frames. This resultant value should be bound to some global maximum number of future seconds of audio (in ms).

    Note: In the special case where K < future_context, it is assumed that the kernel is too small to limit its future context, so symmetric padding is used instead.

    Note: There is no explicit limitation on the amount of future context used, as long as K > future_context constraint is maintained. This might lead to cases where future_context is more than half the actual kernel size K! In such cases, the conv layer is utilizing more of the future context than its current and past context to compute the output. While this is possible to do, it is not recommended and the layer will raise a warning to notify the user of such cases. It is advised to simply use symmetric padding for such cases.

    Example: Say we have a model that performs 8x stride and receives spectrogram frames with stride of 0.01s. Say we wish to upper bound future context to 80 ms.

    Layer ID, Kernel Size, Stride, Future Context, Global Context 0, K=5, S=1, FC=8, GC= 2 * (2^0) = 2 * 0.01 ms (special case, K < FC so use symmetric pad) 1, K=7, S=1, FC=3, GC= 3 * (2^0) = 3 * 0.01 ms (note that symmetric pad here uses 3 FC frames!) 2, K=11, S=2, FC=4, GC= 4 * (2^1) = 8 * 0.01 ms (note that symmetric pad here uses 5 FC frames!) 3, K=15, S=1, FC=4, GC= 4 * (2^1) = 8 * 0.01 ms (note that symmetric pad here uses 7 FC frames!) 4, K=21, S=2, FC=2, GC= 2 * (2^2) = 8 * 0.01 ms (note that symmetric pad here uses 10 FC frames!) 5, K=25, S=2, FC=1, GC= 1 * (2^3) = 8 * 0.01 ms (note that symmetric pad here uses 14 FC frames!) 6, K=29, S=1, FC=1, GC= 1 * (2^3) = 8 * 0.01 ms …

  • quantize – Bool flag whether to quantize the Convolutional blocks.

forward(input_: Tuple[List[torch.Tensor], Optional[torch.Tensor]])[source]#

Forward pass of the module.

Parameters

input – The input is a tuple of two values - the preprocessed audio signal as well as the lengths of the audio signal. The audio signal is padded to the shape [B, D, T] and the lengths are a torch vector of length B.

Returns

The output of the block after processing the input through repeat number of sub-blocks, as well as the lengths of the encoded audio after padding/striding.

Mixins#

class nemo.collections.asr.parts.mixins.mixins.ASRBPEMixin[source]#

Bases: abc.ABC

ASR BPE Mixin class that sets up a Tokenizer via a config

This mixin class adds the method _setup_tokenizer(…), which can be used by ASR models which depend on subword tokenization.

The setup_tokenizer method adds the following parameters to the class -
  • tokenizer_cfg: The resolved config supplied to the tokenizer (with dir and type arguments).

  • tokenizer_dir: The directory path to the tokenizer vocabulary + additional metadata.

  • tokenizer_type: The type of the tokenizer. Currently supports bpe and wpe, as well as agg.

  • vocab_path: Resolved path to the vocabulary text file.

In addition to these variables, the method will also instantiate and preserve a tokenizer (subclass of TokenizerSpec) if successful, and assign it to self.tokenizer.

The mixin also supports aggregate tokenizers, which consist of ordinary, monolingual tokenizers. If a conversion between a monolongual and an aggregate tokenizer (or vice versa) is detected, all registered artifacts will be cleaned up.

AGGREGATE_TOKENIZERS_DICT_PREFIX = 'langs'#
class nemo.collections.asr.parts.mixins.mixins.ASRModuleMixin[source]#

Bases: nemo.collections.asr.parts.mixins.asr_adapter_mixins.ASRAdapterModelMixin

ASRModuleMixin is a mixin class added to ASR models in order to add methods that are specific to a particular instantiation of a module inside of an ASRModel.

Each method should first check that the module is present within the subclass, and support additional functionality if the corresponding module is present.

change_conv_asr_se_context_window(context_window: int, update_config: bool = True)[source]#

Update the context window of the SqueezeExcitation module if the provided model contains an encoder which is an instance of ConvASREncoder.

Parameters
  • context_window

    An integer representing the number of input timeframes that will be used to compute the context. Each timeframe corresponds to a single window stride of the STFT features.

    Say the window_stride = 0.01s, then a context window of 128 represents 128 * 0.01 s of context to compute the Squeeze step.

  • update_config – Whether to update the config or not with the new context window.

Datasets#

Character Encoding Datasets#

class nemo.collections.asr.data.audio_to_text.AudioToCharDataset(*args: Any, **kwargs: Any)[source]#

Bases: nemo.collections.asr.data.audio_to_text._AudioTextDataset

Dataset that loads tensors via a json file containing paths to audio files, transcripts, and durations (in seconds). Each new line is a different sample. Example below: {“audio_filepath”: “/path/to/audio.wav”, “text_filepath”: “/path/to/audio.txt”, “duration”: 23.147} … {“audio_filepath”: “/path/to/audio.wav”, “text”: “the transcription”, “offset”: 301.75, “duration”: 0.82, “utt”: “utterance_id”, “ctm_utt”: “en_4156”, “side”: “A”} :param manifest_filepath: Path to manifest json as described above. Can

be comma-separated paths.

Parameters
  • labels – String containing all the possible characters to map to

  • sample_rate (int) – Sample rate to resample loaded audio to

  • int_values (bool) – If true, load samples as 32-bit integers. Defauts to False.

  • augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor) – An AudioAugmentor object used to augment loaded audio

  • max_duration – If audio exceeds this length, do not include in dataset

  • min_duration – If audio is less than this length, do not include in dataset

  • max_utts – Limit number of utterances

  • blank_index – blank character index, default = -1

  • unk_index – unk_character index, default = -1

  • normalize – whether to normalize transcript text (default): True

  • bos_id – Id of beginning of sequence symbol to append if not None

  • eos_id – Id of end of sequence symbol to append if not None

  • return_sample_id (bool) – whether to return the sample_id as a part of each sample

property output_types: Optional[Dict[str, nemo.core.neural_types.neural_type.NeuralType]]#

Returns definitions of module output ports.

class nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset(*args: Any, **kwargs: Any)[source]#

Bases: nemo.collections.asr.data.audio_to_text._TarredAudioToTextDataset

A similar Dataset to the AudioToCharDataset, but which loads tarred audio files.

Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToCharDataset), as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should contain the information for one audio file, including at least the transcript and name of the audio file within the tarball.

Valid formats for the audio_tar_filepaths argument include: (1) a single string that can be brace-expanded, e.g. ‘path/to/audio.tar’ or ‘path/to/audio_{1..100}.tar.gz’, or (2) a list of file paths that will not be brace-expanded, e.g. [‘audio_1.tar’, ‘audio_2.tar’, …].

See the WebDataset documentation for more information about accepted data and input formats.

If using multiple workers the number of shards should be divisible by world_size to ensure an even split among workers. If it is not divisible, logging will give a warning but training will proceed. In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering is applied. We currently do not check for this, but your program may hang if the shards are uneven!

Notice that a few arguments are different from the AudioToCharDataset; for example, shuffle (bool) has been replaced by shuffle_n (int).

Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.

Parameters
  • audio_tar_filepaths – Either a list of audio tarball filepaths, or a string (can be brace-expandable).

  • manifest_filepath (str) – Path to the manifest.

  • labels (list) – List of characters that can be output by the ASR model. For Jasper, this is the 28 character set {a-z ‘}. The CTC blank symbol is automatically added later for models using ctc.

  • sample_rate (int) – Sample rate to resample loaded audio to

  • int_values (bool) – If true, load samples as 32-bit integers. Defauts to False.

  • augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor) – An AudioAugmentor object used to augment loaded audio

  • shuffle_n (int) – How many samples to look ahead and load to be shuffled. See WebDataset documentation for more details. Defaults to 0.

  • min_duration (float) – Dataset parameter. All training files which have a duration less than min_duration are dropped. Note: Duration is read from the manifest JSON. Defaults to 0.1.

  • max_duration (float) – Dataset parameter. All training files which have a duration more than max_duration are dropped. Note: Duration is read from the manifest JSON. Defaults to None.

  • max_utts (int) – Limit number of utterances. 0 means no maximum.

  • blank_index (int) – Blank character index, defaults to -1.

  • unk_index (int) – Unknown character index, defaults to -1.

  • normalize (bool) – Dataset parameter. Whether to use automatic text cleaning. It is highly recommended to manually clean text for best results. Defaults to True.

  • trim (bool) – Whether to use trim silence from beginning and end of audio signal using librosa.effects.trim(). Defaults to False.

  • bos_id (id) – Dataset parameter. Beginning of string symbol id used for seq2seq models. Defaults to None.

  • eos_id (id) – Dataset parameter. End of string symbol id used for seq2seq models. Defaults to None.

  • pad_id (id) – Token used to pad when collating samples in batches. If this is None, pads using 0s. Defaults to None.

  • shard_strategy (str) –

    Tarred dataset shard distribution strategy chosen as a str value during ddp. - scatter: The default shard strategy applied by WebDataset, where each node gets

    a unique set of shards, which are permanently pre-allocated and never changed at runtime.

    • replicate: Optional shard strategy, where each node gets all of the set of shards available in the tarred dataset, which are permanently pre-allocated and never changed at runtime. The benefit of replication is that it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of shuffle_n.

      Note: Replicated strategy allows every node to sample the entire set of available tarfiles, and therefore more than one node may sample the same tarfile, and even sample the same data points! As such, there is no assured guarantee that all samples in the dataset will be sampled at least once during 1 epoch.

  • global_rank (int) – Worker rank, used for partitioning shards. Defaults to 0.

  • world_size (int) – Total number of processes, used for partitioning shards. Defaults to 0.

  • return_sample_id (bool) – whether to return the sample_id as a part of each sample

Subword Encoding Datasets#

class nemo.collections.asr.data.audio_to_text.AudioToBPEDataset(*args: Any, **kwargs: Any)[source]#

Bases: nemo.collections.asr.data.audio_to_text._AudioTextDataset

Dataset that loads tensors via a json file containing paths to audio files, transcripts, and durations (in seconds). Each new line is a different sample. Example below: {“audio_filepath”: “/path/to/audio.wav”, “text_filepath”: “/path/to/audio.txt”, “duration”: 23.147} … {“audio_filepath”: “/path/to/audio.wav”, “text”: “the transcription”, “offset”: 301.75, “duration”: 0.82, “utt”: “utterance_id”, “ctm_utt”: “en_4156”, “side”: “A”}

In practice, the dataset and manifest used for character encoding and byte pair encoding are exactly the same. The only difference lies in how the dataset tokenizes the text in the manifest.

Parameters
  • manifest_filepath – Path to manifest json as described above. Can be comma-separated paths.

  • tokenizer – A subclass of the Tokenizer wrapper found in the common collection, nemo.collections.common.tokenizers.TokenizerSpec. ASR Models support a subset of all available tokenizers.

  • sample_rate (int) – Sample rate to resample loaded audio to

  • int_values (bool) – If true, load samples as 32-bit integers. Defauts to False.

  • augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor) – An AudioAugmentor object used to augment loaded audio

  • max_duration – If audio exceeds this length, do not include in dataset

  • min_duration – If audio is less than this length, do not include in dataset

  • max_utts – Limit number of utterances

  • trim – Whether to trim silence segments

  • use_start_end_token – Boolean which dictates whether to add [BOS] and [EOS] tokens to beginning and ending of speech respectively.

  • return_sample_id (bool) – whether to return the sample_id as a part of each sample

property output_types: Optional[Dict[str, nemo.core.neural_types.neural_type.NeuralType]]#

Returns definitions of module output ports.

class nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset(*args: Any, **kwargs: Any)[source]#

Bases: nemo.collections.asr.data.audio_to_text._TarredAudioToTextDataset

A similar Dataset to the AudioToBPEDataset, but which loads tarred audio files.

Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToBPEDataset), as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should contain the information for one audio file, including at least the transcript and name of the audio file within the tarball.

Valid formats for the audio_tar_filepaths argument include: (1) a single string that can be brace-expanded, e.g. ‘path/to/audio.tar’ or ‘path/to/audio_{1..100}.tar.gz’, or (2) a list of file paths that will not be brace-expanded, e.g. [‘audio_1.tar’, ‘audio_2.tar’, …].

See the WebDataset documentation for more information about accepted data and input formats.

If using multiple workers the number of shards should be divisible by world_size to ensure an even split among workers. If it is not divisible, logging will give a warning but training will proceed. In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering is applied. We currently do not check for this, but your program may hang if the shards are uneven!

Notice that a few arguments are different from the AudioToBPEDataset; for example, shuffle (bool) has been replaced by shuffle_n (int).

Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.

Parameters
  • audio_tar_filepaths – Either a list of audio tarball filepaths, or a string (can be brace-expandable).

  • manifest_filepath (str) – Path to the manifest.

  • tokenizer (TokenizerSpec) – Either a Word Piece Encoding tokenizer (BERT), or a Sentence Piece Encoding tokenizer (BPE). The CTC blank symbol is automatically added later for models using ctc.

  • sample_rate (int) – Sample rate to resample loaded audio to

  • int_values (bool) – If true, load samples as 32-bit integers. Defauts to False.

  • augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor) – An AudioAugmentor object used to augment loaded audio

  • shuffle_n (int) – How many samples to look ahead and load to be shuffled. See WebDataset documentation for more details. Defaults to 0.

  • min_duration (float) – Dataset parameter. All training files which have a duration less than min_duration are dropped. Note: Duration is read from the manifest JSON. Defaults to 0.1.

  • max_duration (float) – Dataset parameter. All training files which have a duration more than max_duration are dropped. Note: Duration is read from the manifest JSON. Defaults to None.

  • max_utts (int) – Limit number of utterances. 0 means no maximum.

  • trim (bool) – Whether to use trim silence from beginning and end of audio signal using librosa.effects.trim(). Defaults to False.

  • use_start_end_token – Boolean which dictates whether to add [BOS] and [EOS] tokens to beginning and ending of speech respectively.

  • pad_id (id) – Token used to pad when collating samples in batches. If this is None, pads using 0s. Defaults to None.

  • shard_strategy (str) –

    Tarred dataset shard distribution strategy chosen as a str value during ddp. - scatter: The default shard strategy applied by WebDataset, where each node gets

    a unique set of shards, which are permanently pre-allocated and never changed at runtime.

    • replicate: Optional shard strategy, where each node gets all of the set of shards available in the tarred dataset, which are permanently pre-allocated and never changed at runtime. The benefit of replication is that it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of shuffle_n.

      Note: Replicated strategy allows every node to sample the entire set of available tarfiles, and therefore more than one node may sample the same tarfile, and even sample the same data points! As such, there is no assured guarantee that all samples in the dataset will be sampled at least once during 1 epoch.

  • global_rank (int) – Worker rank, used for partitioning shards. Defaults to 0.

  • world_size (int) – Total number of processes, used for partitioning shards. Defaults to 0.

  • return_sample_id (bool) – whether to return the sample_id as a part of each sample

Audio Preprocessors#

class nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor(*args: Any, **kwargs: Any)[source]#

Bases: nemo.collections.asr.modules.audio_preprocessing.AudioPreprocessor

Featurizer module that converts wavs to mel spectrograms. We don’t use torchaudio’s implementation here because the original implementation is not the same, so for the sake of backwards-compatibility this will use the old FilterbankFeatures for now. :param sample_rate: Sample rate of the input audio data.

Defaults to 16000

Parameters
  • window_size (float) – Size of window for fft in seconds Defaults to 0.02

  • window_stride (float) – Stride of window for fft in seconds Defaults to 0.01

  • n_window_size (int) – Size of window for fft in samples Defaults to None. Use one of window_size or n_window_size.

  • n_window_stride (int) – Stride of window for fft in samples Defaults to None. Use one of window_stride or n_window_stride.

  • window (str) – Windowing function for fft. can be one of [‘hann’, ‘hamming’, ‘blackman’, ‘bartlett’] Defaults to “hann”

  • normalize (str) – Can be one of [‘per_feature’, ‘all_features’]; all other options disable feature normalization. ‘all_features’ normalizes the entire spectrogram to be mean 0 with std 1. ‘pre_features’ normalizes per channel / freq instead. Defaults to “per_feature”

  • n_fft (int) – Length of FT window. If None, it uses the smallest power of 2 that is larger than n_window_size. Defaults to None

  • preemph (float) – Amount of pre emphasis to add to audio. Can be disabled by passing None. Defaults to 0.97

  • features (int) – Number of mel spectrogram freq bins to output. Defaults to 64

  • lowfreq (int) – Lower bound on mel basis in Hz. Defaults to 0

  • highfreq (int) – Lower bound on mel basis in Hz. Defaults to None

  • log (bool) – Log features. Defaults to True

  • log_zero_guard_type (str) – Need to avoid taking the log of zero. There are two options: “add” or “clamp”. Defaults to “add”.

  • log_zero_guard_value (float, or str) – Add or clamp requires the number to add with or clamp to. log_zero_guard_value can either be a float or “tiny” or “eps”. torch.finfo is used if “tiny” or “eps” is passed. Defaults to 2**-24.

  • dither (float) – Amount of white-noise dithering. Defaults to 1e-5

  • pad_to (int) – Ensures that the output size of the time dimension is a multiple of pad_to. Defaults to 16

  • frame_splicing (int) – Defaults to 1

  • exact_pad (bool) – If True, sets stft center to False and adds padding, such that num_frames = audio_length // hop_length. Defaults to False.

  • pad_value (float) – The value that shorter mels are padded with. Defaults to 0

  • mag_power (float) – The power that the linear spectrogram is raised to prior to multiplication with mel basis. Defaults to 2 for a power spec

  • rng – Random number generator

  • nb_augmentation_prob (float) – Probability with which narrowband augmentation would be applied to samples in the batch. Defaults to 0.0

  • nb_max_freq (int) – Frequency above which all frequencies will be masked for narrowband augmentation. Defaults to 4000

  • stft_exact_pad – Deprecated argument, kept for compatibility with older checkpoints.

  • stft_conv – Deprecated argument, kept for compatibility with older checkpoints.

property filter_banks#
get_features(input_signal, length)[source]#
property input_types#

Returns definitions of module input ports.

property output_types#

Returns definitions of module output ports. processed_signal:

0: AxisType(BatchTag) 1: AxisType(MelSpectrogramSignalTag) 2: AxisType(ProcessedTimeTag)

processed_length:

0: AxisType(BatchTag)

classmethod restore_from(restore_path: str)[source]#

Restores model instance (weights and configuration) from a .nemo file

Parameters
  • restore_path – path to .nemo file from which model should be instantiated

  • override_config_path – path to a yaml config that will override the internal config file or an OmegaConf / DictConfig object representing the model config.

  • map_location – Optional torch.device() to map the instantiated model to a device. By default (None), it will select a GPU if available, falling back to CPU otherwise.

  • strict – Passed to load_state_dict. By default True

  • return_config – If set to true, will return just the underlying config of the restored model as an OmegaConf DictConfig object without instantiating the model.

  • trainer – An optional Trainer object, passed to the model constructor.

  • save_restore_connector – An optional SaveRestoreConnector object that defines the implementation of the restore_from() method.

save_to(save_path: str)[source]#

Standardized method to save a tarfile containing the checkpoint, config, and any additional artifacts. Implemented via nemo.core.connectors.save_restore_connector.SaveRestoreConnector.save_to().

Parameters

save_path – str, path to where the file should be saved.

class nemo.collections.asr.modules.AudioToMFCCPreprocessor(*args: Any, **kwargs: Any)[source]#

Bases: nemo.collections.asr.modules.audio_preprocessing.AudioPreprocessor

Preprocessor that converts wavs to MFCCs. Uses torchaudio.transforms.MFCC. :param sample_rate: The sample rate of the audio.

Defaults to 16000.

Parameters
  • window_size – Size of window for fft in seconds. Used to calculate the win_length arg for mel spectrogram. Defaults to 0.02

  • window_stride – Stride of window for fft in seconds. Used to caculate the hop_length arg for mel spect. Defaults to 0.01

  • n_window_size – Size of window for fft in samples Defaults to None. Use one of window_size or n_window_size.

  • n_window_stride – Stride of window for fft in samples Defaults to None. Use one of window_stride or n_window_stride.

  • window – Windowing function for fft. can be one of [‘hann’, ‘hamming’, ‘blackman’, ‘bartlett’, ‘none’, ‘null’]. Defaults to ‘hann’

  • n_fft – Length of FT window. If None, it uses the smallest power of 2 that is larger than n_window_size. Defaults to None

  • lowfreq (int) – Lower bound on mel basis in Hz. Defaults to 0

  • highfreq (int) – Lower bound on mel basis in Hz. Defaults to None

  • n_mels – Number of mel filterbanks. Defaults to 64

  • n_mfcc – Number of coefficients to retain Defaults to 64

  • dct_type – Type of discrete cosine transform to use

  • norm – Type of norm to use

  • log – Whether to use log-mel spectrograms instead of db-scaled. Defaults to True.

get_features(input_signal, length)[source]#
property input_types#

Returns definitions of module input ports.

property output_types#

Returns definitions of module output ports.

classmethod restore_from(restore_path: str)[source]#

Restores model instance (weights and configuration) from a .nemo file

Parameters
  • restore_path – path to .nemo file from which model should be instantiated

  • override_config_path – path to a yaml config that will override the internal config file or an OmegaConf / DictConfig object representing the model config.

  • map_location – Optional torch.device() to map the instantiated model to a device. By default (None), it will select a GPU if available, falling back to CPU otherwise.

  • strict – Passed to load_state_dict. By default True

  • return_config – If set to true, will return just the underlying config of the restored model as an OmegaConf DictConfig object without instantiating the model.

  • trainer – An optional Trainer object, passed to the model constructor.

  • save_restore_connector – An optional SaveRestoreConnector object that defines the implementation of the restore_from() method.

save_to(save_path: str)[source]#

Standardized method to save a tarfile containing the checkpoint, config, and any additional artifacts. Implemented via nemo.core.connectors.save_restore_connector.SaveRestoreConnector.save_to().

Parameters

save_path – str, path to where the file should be saved.

Audio Augmentors#

class nemo.collections.asr.modules.SpectrogramAugmentation(*args: Any, **kwargs: Any)[source]#

Bases: nemo.core.classes.module.NeuralModule

Performs time and freq cuts in one of two ways. SpecAugment zeroes out vertical and horizontal sections as described in SpecAugment (https://arxiv.org/abs/1904.08779). Arguments for use with SpecAugment are freq_masks, time_masks, freq_width, and time_width. SpecCutout zeroes out rectangulars as described in Cutout (https://arxiv.org/abs/1708.04552). Arguments for use with Cutout are rect_masks, rect_freq, and rect_time. :param freq_masks: how many frequency segments should be cut.

Defaults to 0.

Parameters
  • time_masks (int) – how many time segments should be cut Defaults to 0.

  • freq_width (int) – maximum number of frequencies to be cut in one segment. Defaults to 10.

  • time_width (int) – maximum number of time steps to be cut in one segment Defaults to 10.

  • rect_masks (int) – how many rectangular masks should be cut Defaults to 0.

  • rect_freq (int) – maximum size of cut rectangles along the frequency dimension Defaults to 5.

  • rect_time (int) – maximum size of cut rectangles along the time dimension Defaults to 25.

forward(input_spec, length)[source]#
property input_types#

Returns definitions of module input types

property output_types#

Returns definitions of module output types

class nemo.collections.asr.modules.CropOrPadSpectrogramAugmentation(*args: Any, **kwargs: Any)[source]#

Bases: nemo.core.classes.module.NeuralModule

Pad or Crop the incoming Spectrogram to a certain shape. :param audio_length: the final number of timesteps that is required.

The signal will be either padded or cropped temporally to this size.

forward#
property input_types#

Returns definitions of module output ports.

property output_types#

Returns definitions of module output ports.

classmethod restore_from(restore_path: str)[source]#

Restores model instance (weights and configuration) from a .nemo file

Parameters
  • restore_path – path to .nemo file from which model should be instantiated

  • override_config_path – path to a yaml config that will override the internal config file or an OmegaConf / DictConfig object representing the model config.

  • map_location – Optional torch.device() to map the instantiated model to a device. By default (None), it will select a GPU if available, falling back to CPU otherwise.

  • strict – Passed to load_state_dict. By default True

  • return_config – If set to true, will return just the underlying config of the restored model as an OmegaConf DictConfig object without instantiating the model.

  • trainer – An optional Trainer object, passed to the model constructor.

  • save_restore_connector – An optional SaveRestoreConnector object that defines the implementation of the restore_from() method.

save_to(save_path: str)[source]#

Standardized method to save a tarfile containing the checkpoint, config, and any additional artifacts. Implemented via nemo.core.connectors.save_restore_connector.SaveRestoreConnector.save_to().

Parameters

save_path – str, path to where the file should be saved.

class nemo.collections.asr.parts.preprocessing.perturb.SpeedPerturbation(sr, resample_type, min_speed_rate=0.9, max_speed_rate=1.1, num_rates=5, rng=None)[source]#

Bases: nemo.collections.asr.parts.preprocessing.perturb.Perturbation

Performs Speed Augmentation by re-sampling the data to a different sampling rate, which does not preserve pitch.

Note: This is a very slow operation for online augmentation. If space allows, it is preferable to pre-compute and save the files to augment the dataset.

Parameters
  • sr – Original sampling rate.

  • resample_type – Type of resampling operation that will be performed. For better speed using resampy’s fast resampling method, use resample_type=’kaiser_fast’. For high-quality resampling, set resample_type=’kaiser_best’. To use scipy.signal.resample, set resample_type=’fft’ or resample_type=’scipy’

  • min_speed_rate – Minimum sampling rate modifier.

  • max_speed_rate – Maximum sampling rate modifier.

  • num_rates – Number of discrete rates to allow. Can be a positive or negative integer. If a positive integer greater than 0 is provided, the range of speed rates will be discretized into num_rates values. If a negative integer or 0 is provided, the full range of speed rates will be sampled uniformly. Note: If a positive integer is provided and the resultant discretized range of rates contains the value ‘1.0’, then those samples with rate=1.0, will not be augmented at all and simply skipped. This is to unnecessary augmentation and increase computation time. Effective augmentation chance in such a case is = prob * (num_rates - 1 / num_rates) * 100`% chance where `prob is the global probability of a sample being augmented.

  • rng – Random seed number.

max_augmentation_length(length)[source]#
perturb(data)[source]#
class nemo.collections.asr.parts.preprocessing.perturb.TimeStretchPerturbation(min_speed_rate=0.9, max_speed_rate=1.1, num_rates=5, n_fft=512, rng=None)[source]#

Bases: nemo.collections.asr.parts.preprocessing.perturb.Perturbation

Time-stretch an audio series by a fixed rate while preserving pitch, based on [1, 2].

Note: This is a simplified implementation, intended primarily for reference and pedagogical purposes. It makes no attempt to handle transients, and is likely to produce audible artifacts.

Reference [1] [Ellis, D. P. W. “A phase vocoder in Matlab.” Columbia University, 2002.] (http://www.ee.columbia.edu/~dpwe/resources/matlab/pvoc/) [2] [librosa.effects.time_stretch] (https://librosa.github.io/librosa/generated/librosa.effects.time_stretch.html)

Parameters
  • min_speed_rate – Minimum sampling rate modifier.

  • max_speed_rate – Maximum sampling rate modifier.

  • num_rates – Number of discrete rates to allow. Can be a positive or negative integer. If a positive integer greater than 0 is provided, the range of speed rates will be discretized into num_rates values. If a negative integer or 0 is provided, the full range of speed rates will be sampled uniformly. Note: If a positive integer is provided and the resultant discretized range of rates contains the value ‘1.0’, then those samples with rate=1.0, will not be augmented at all and simply skipped. This is to avoid unnecessary augmentation and increase computation time. Effective augmentation chance in such a case is = prob * (num_rates - 1 / num_rates) * 100`% chance where `prob is the global probability of a sample being augmented.

  • n_fft – Number of fft filters to be computed.

  • rng – Random seed number.

max_augmentation_length(length)[source]#
perturb(data)[source]#
class nemo.collections.asr.parts.preprocessing.perturb.GainPerturbation(min_gain_dbfs=- 10, max_gain_dbfs=10, rng=None)[source]#

Bases: nemo.collections.asr.parts.preprocessing.perturb.Perturbation

Applies random gain to the audio.

Parameters
  • min_gain_dbfs (float) – Min gain level in dB

  • max_gain_dbfs (float) – Max gain level in dB

  • rng – Random number generator

perturb(data)[source]#
class nemo.collections.asr.parts.preprocessing.perturb.ImpulsePerturbation(manifest_path=None, rng=None, audio_tar_filepaths=None, shuffle_n=128, shift_impulse=False)[source]#

Bases: nemo.collections.asr.parts.preprocessing.perturb.Perturbation

Convolves audio with a Room Impulse Response.

Parameters
  • manifest_path (list) – Manifest file for RIRs

  • audio_tar_filepaths (list) – Tar files, if RIR audio files are tarred

  • shuffle_n (int) – Shuffle parameter for shuffling buffered files from the tar files

  • shift_impulse (bool) – Shift impulse response to adjust for delay at the beginning

perturb(data)[source]#
class nemo.collections.asr.parts.preprocessing.perturb.ShiftPerturbation(min_shift_ms=- 5.0, max_shift_ms=5.0, rng=None)[source]#

Bases: nemo.collections.asr.parts.preprocessing.perturb.Perturbation

Perturbs audio by shifting the audio in time by a random amount between min_shift_ms and max_shift_ms. The final length of the audio is kept unaltered by padding the audio with zeros.

Parameters
  • min_shift_ms (float) – Minimum time in milliseconds by which audio will be shifted

  • max_shift_ms (float) – Maximum time in milliseconds by which audio will be shifted

  • rng – Random number generator

perturb(data)[source]#
class nemo.collections.asr.parts.preprocessing.perturb.NoisePerturbation(manifest_path=None, min_snr_db=10, max_snr_db=50, max_gain_db=300.0, rng=None, audio_tar_filepaths=None, shuffle_n=100, orig_sr=16000)[source]#

Bases: nemo.collections.asr.parts.preprocessing.perturb.Perturbation

Perturbation that adds noise to input audio.

Parameters
  • manifest_path (str) – Manifest file with paths to noise files

  • min_snr_db (float) – Minimum SNR of audio after noise is added

  • max_snr_db (float) – Maximum SNR of audio after noise is added

  • max_gain_db (float) – Maximum gain that can be applied on the noise sample

  • audio_tar_filepaths (list) – Tar files, if noise audio files are tarred

  • shuffle_n (int) – Shuffle parameter for shuffling buffered files from the tar files

  • orig_sr (int) – Original sampling rate of the noise files

  • rng – Random number generator

get_one_noise_sample(target_sr)[source]#
property orig_sr#
perturb(data)[source]#
perturb_with_foreground_noise(data, noise, data_rms=None, max_noise_dur=2, max_additions=1)[source]#
perturb_with_input_noise(data, noise, data_rms=None)[source]#
class nemo.collections.asr.parts.preprocessing.perturb.WhiteNoisePerturbation(min_level=- 90, max_level=- 46, rng=None)[source]#

Bases: nemo.collections.asr.parts.preprocessing.perturb.Perturbation

Perturbation that adds white noise to an audio file in the training dataset.

Parameters
  • min_level (int) – Minimum level in dB at which white noise should be added

  • max_level (int) – Maximum level in dB at which white noise should be added

  • rng – Random number generator

perturb(data)[source]#
class nemo.collections.asr.parts.preprocessing.perturb.RirAndNoisePerturbation(rir_manifest_path=None, rir_prob=0.5, noise_manifest_paths=None, min_snr_db=0, max_snr_db=50, rir_tar_filepaths=None, rir_shuffle_n=100, noise_tar_filepaths=None, apply_noise_rir=False, orig_sample_rate=None, max_additions=5, max_duration=2.0, bg_noise_manifest_paths=None, bg_min_snr_db=10, bg_max_snr_db=50, bg_noise_tar_filepaths=None, bg_orig_sample_rate=None)[source]#

Bases: nemo.collections.asr.parts.preprocessing.perturb.Perturbation

RIR augmentation with additive foreground and background noise. In this implementation audio data is augmented by first convolving the audio with a Room Impulse Response and then adding foreground noise and background noise at various SNRs. RIR, foreground and background noises should either be supplied with a manifest file or as tarred audio files (faster).

Different sets of noise audio files based on the original sampling rate of the noise. This is useful while training a mixed sample rate model. For example, when training a mixed model with 8 kHz and 16 kHz audio with a target sampling rate of 16 kHz, one would want to augment 8 kHz data with 8 kHz noise rather than 16 kHz noise.

Parameters
  • rir_manifest_path – Manifest file for RIRs

  • rir_tar_filepaths – Tar files, if RIR audio files are tarred

  • rir_prob – Probability of applying a RIR

  • noise_manifest_paths – Foreground noise manifest path

  • min_snr_db – Min SNR for foreground noise

  • max_snr_db – Max SNR for background noise,

  • noise_tar_filepaths – Tar files, if noise files are tarred

  • apply_noise_rir – Whether to convolve foreground noise with a a random RIR

  • orig_sample_rate – Original sampling rate of foreground noise audio

  • max_additions – Max number of times foreground noise is added to an utterance,

  • max_duration – Max duration of foreground noise

  • bg_noise_manifest_paths – Background noise manifest path

  • bg_min_snr_db – Min SNR for background noise

  • bg_max_snr_db – Max SNR for background noise

  • bg_noise_tar_filepaths – Tar files, if noise files are tarred

  • bg_orig_sample_rate – Original sampling rate of background noise audio

perturb(data)[source]#
class nemo.collections.asr.parts.preprocessing.perturb.TranscodePerturbation(codecs=None, rng=None)[source]#

Bases: nemo.collections.asr.parts.preprocessing.perturb.Perturbation

Audio codec augmentation. This implementation uses sox to transcode audio with low rate audio codecs, so users need to make sure that the installed sox version supports the codecs used here (G711 and amr-nb).

Parameters

rng – Random number generator

perturb(data)[source]#

Miscellaneous Classes#

RNNT Decoding#

class nemo.collections.asr.metrics.rnnt_wer.RNNTDecoding(decoding_cfg, decoder, joint, vocabulary)[source]#

Bases: nemo.collections.asr.metrics.rnnt_wer.AbstractRNNTDecoding

Used for performing RNN-T auto-regressive decoding of the Decoder+Joint network given the encoder state.

Parameters
  • decoding_cfg

    A dict-like object which contains the following key-value pairs. strategy: str value which represents the type of decoding that can occur.

    Possible values are : - greedy, greedy_batch (for greedy decoding). - beam, tsd, alsd (for beam search decoding).

    compute_hypothesis_token_set: A bool flag, which determines whether to compute a list of decoded

    tokens as well as the decoded string. Default is False in order to avoid double decoding unless required.

    preserve_alignments: Bool flag which preserves the history of logprobs generated during

    decoding (sample / batched). When set to true, the Hypothesis will contain the non-null value for logprobs in it. Here, logprobs is a List of torch.Tensors.

    In order to obtain this hypothesis, please utilize rnnt_decoder_predictions_tensor function with the return_hypotheses flag set to True.

    The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. U is the number of target tokens for the current timestep Ti.

    The config may further contain the following sub-dictionaries: “greedy”:

    max_symbols: int, describing the maximum number of target tokens to decode per

    timestep during greedy decoding. Setting to larger values allows longer sentences to be decoded, at the cost of increased execution time.

    ”beam”:
    beam_size: int, defining the beam size for beam search. Must be >= 1.

    If beam_size == 1, will perform cached greedy search. This might be slightly different results compared to the greedy search above.

    score_norm: optional bool, whether to normalize the returned beam score in the hypotheses.

    Set to True by default.

    return_best_hypothesis: optional bool, whether to return just the best hypothesis or all of the

    hypotheses after beam search has concluded. This flag is set by default.

    tsd_max_sym_exp: optional int, determines number of symmetric expansions of the target symbols

    per timestep of the acoustic model. Larger values will allow longer sentences to be decoded, at increased cost to execution time.

    alsd_max_target_len: optional int or float, determines the potential maximum target sequence length.

    If an integer is provided, it can decode sequences of that particular maximum length. If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len), where seq_len is the length of the acoustic model output (T).

    NOTE:

    If a float is provided, it can be greater than 1! By default, a float of 2.0 is used so that a target sequence can be at most twice as long as the acoustic model output length T.

    maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient,

    and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0.

    maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1

    in order to reduce expensive beam search cost later. int >= 0.

    maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size.

    Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, and affects the speed of inference since large values will perform large beam search in the next step.

    maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions.

    The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob is the “most” likely token to be predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for expansion apart from the “most likely” candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally tuned on a validation set.

    softmax_temperature: Scales the logits of the joint prior to computing log_softmax.

  • decoder – The Decoder/Prediction network module.

  • joint – The Joint network module.

  • vocabulary – The vocabulary (excluding the RNNT blank token) which will be used for decoding.

decode_ids_to_tokens(tokens: List[int]) List[str][source]#

Implemented by subclass in order to decode a token id list into a token list. A token list is the string representation of each token id.

Parameters

tokens – List of int representing the token ids.

Returns

A list of decoded tokens.

decode_tokens_to_str(tokens: List[int]) str[source]#

Implemented by subclass in order to decoder a token list into a string.

Parameters

tokens – List of int representing the token ids.

Returns

A decoded string.

class nemo.collections.asr.metrics.rnnt_wer_bpe.RNNTBPEDecoding(decoding_cfg, decoder, joint, tokenizer: nemo.collections.common.tokenizers.tokenizer_spec.TokenizerSpec)[source]#

Bases: nemo.collections.asr.metrics.rnnt_wer.AbstractRNNTDecoding

Used for performing RNN-T auto-regressive decoding of the Decoder+Joint network given the encoder state.

Parameters
  • decoding_cfg

    A dict-like object which contains the following key-value pairs. strategy: str value which represents the type of decoding that can occur.

    Possible values are : - greedy, greedy_batch (for greedy decoding). - beam, tsd, alsd (for beam search decoding).

    compute_hypothesis_token_set: A bool flag, which determines whether to compute a list of decoded

    tokens as well as the decoded string. Default is False in order to avoid double decoding unless required.

    preserve_alignments: Bool flag which preserves the history of logprobs generated during

    decoding (sample / batched). When set to true, the Hypothesis will contain the non-null value for logprobs in it. Here, logprobs is a List of torch.Tensors.

    In order to obtain this hypothesis, please utilize rnnt_decoder_predictions_tensor function with the return_hypotheses flag set to True.

    The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. U is the number of target tokens for the current timestep Ti.

    The config may further contain the following sub-dictionaries: “greedy”:

    max_symbols: int, describing the maximum number of target tokens to decode per

    timestep during greedy decoding. Setting to larger values allows longer sentences to be decoded, at the cost of increased execution time.

    ”beam”:
    beam_size: int, defining the beam size for beam search. Must be >= 1.

    If beam_size == 1, will perform cached greedy search. This might be slightly different results compared to the greedy search above.

    score_norm: optional bool, whether to normalize the returned beam score in the hypotheses.

    Set to True by default.

    return_best_hypothesis: optional bool, whether to return just the best hypothesis or all of the

    hypotheses after beam search has concluded.

    tsd_max_sym_exp: optional int, determines number of symmetric expansions of the target symbols

    per timestep of the acoustic model. Larger values will allow longer sentences to be decoded, at increased cost to execution time.

    alsd_max_target_len: optional int or float, determines the potential maximum target sequence length.

    If an integer is provided, it can decode sequences of that particular maximum length. If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len), where seq_len is the length of the acoustic model output (T).

    NOTE:

    If a float is provided, it can be greater than 1! By default, a float of 2.0 is used so that a target sequence can be at most twice as long as the acoustic model output length T.

    maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient,

    and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0.

    maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1

    in order to reduce expensive beam search cost later. int >= 0.

    maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size.

    Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, and affects the speed of inference since large values will perform large beam search in the next step.

    maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions.

    The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob is the “most” likely token to be predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for expansion apart from the “most likely” candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally tuned on a validation set.

    softmax_temperature: Scales the logits of the joint prior to computing log_softmax.

  • decoder – The Decoder/Prediction network module.

  • joint – The Joint network module.

  • tokenizer – The tokenizer which will be used for decoding.

decode_ids_to_tokens(tokens: List[int]) List[str][source]#

Implemented by subclass in order to decode a token id list into a token list. A token list is the string representation of each token id.

Parameters

tokens – List of int representing the token ids.

Returns

A list of decoded tokens.

decode_tokens_to_str(tokens: List[int]) str[source]#

Implemented by subclass in order to decoder a token list into a string.

Parameters

tokens – List of int representing the token ids.

Returns

A decoded string.

class nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyRNNTInfer(decoder_model: nemo.collections.asr.modules.rnnt_abstract.AbstractRNNTDecoder, joint_model: nemo.collections.asr.modules.rnnt_abstract.AbstractRNNTJoint, blank_index: int, max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False)[source]#

Bases: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding._GreedyRNNTInfer

A greedy transducer decoder.

Sequence level greedy decoding, performed auto-repressively.

Parameters
  • decoder_model – rnnt_utils.AbstractRNNTDecoder implementation.

  • joint_model – rnnt_utils.AbstractRNNTJoint implementation.

  • blank_index – int index of the blank token. Can be 0 or len(vocabulary).

  • max_symbols_per_step – Optional int. The maximum number of symbols that can be added to a sequence in a single time step; if set to None then there is no limit.

  • preserve_alignments

    Bool flag which preserves the history of alignments generated during greedy decoding (sample / batched). When set to true, the Hypothesis will contain the non-null value for alignments in it. Here, alignments is a List of List of ints.

    The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. U is the number of target tokens for the current timestep Ti.

forward(encoder_output: torch.Tensor, encoded_lengths: torch.Tensor, partial_hypotheses: Optional[List[nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis]] = None)[source]#

Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively.

Parameters
  • encoder_output – A tensor of size (batch, features, timesteps).

  • encoded_lengths – list of int representing the length of each sequence output sequence.

Returns

packed list containing batch number of sentences (Hypotheses).

class nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyBatchedRNNTInfer(decoder_model: nemo.collections.asr.modules.rnnt_abstract.AbstractRNNTDecoder, joint_model: nemo.collections.asr.modules.rnnt_abstract.AbstractRNNTJoint, blank_index: int, max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False)[source]#

Bases: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding._GreedyRNNTInfer

A batch level greedy transducer decoder.

Batch level greedy decoding, performed auto-repressively.

Parameters
  • decoder_model – rnnt_utils.AbstractRNNTDecoder implementation.

  • joint_model – rnnt_utils.AbstractRNNTJoint implementation.

  • blank_index – int index of the blank token. Can be 0 or len(vocabulary).

  • max_symbols_per_step – Optional int. The maximum number of symbols that can be added to a sequence in a single time step; if set to None then there is no limit.

  • preserve_alignments

    Bool flag which preserves the history of alignments generated during greedy decoding (sample / batched). When set to true, the Hypothesis will contain the non-null value for alignments in it. Here, alignments is a List of List of ints.

    The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. U is the number of target tokens for the current timestep Ti.

forward(encoder_output: torch.Tensor, encoded_lengths: torch.Tensor, partial_hypotheses: Optional[List[nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis]] = None)[source]#

Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively.

Parameters
  • encoder_output – A tensor of size (batch, features, timesteps).

  • encoded_lengths – list of int representing the length of each sequence output sequence.

Returns

packed list containing batch number of sentences (Hypotheses).

class nemo.collections.asr.parts.submodules.rnnt_beam_decoding.BeamRNNTInfer(decoder_model: nemo.collections.asr.modules.rnnt_abstract.AbstractRNNTDecoder, joint_model: nemo.collections.asr.modules.rnnt_abstract.AbstractRNNTJoint, beam_size: int, search_type: str = 'default', score_norm: bool = True, return_best_hypothesis: bool = True, tsd_max_sym_exp_per_step: Optional[int] = 50, alsd_max_target_len: Union[int, float] = 1.0, nsc_max_timesteps_expansion: int = 1, nsc_prefix_alpha: int = 1, maes_num_steps: int = 2, maes_prefix_alpha: int = 1, maes_expansion_gamma: float = 2.3, maes_expansion_beta: int = 2, language_model: Optional[Dict[str, Any]] = None, softmax_temperature: float = 1.0, preserve_alignments: bool = False)[source]#

Bases: nemo.core.classes.common.Typing

Beam Search implementation ported from ESPNet implementation - https://github.com/espnet/espnet/blob/master/espnet/nets/beam_search_transducer.py

Sequence level beam decoding or batched-beam decoding, performed auto-repressively depending on the search type chosen.

Parameters
  • decoder_model – rnnt_utils.AbstractRNNTDecoder implementation.

  • joint_model – rnnt_utils.AbstractRNNTJoint implementation.

  • beam_size

    number of beams for beam search. Must be a positive integer >= 1. If beam size is 1, defaults to stateful greedy search. This greedy search might result in slightly different results than the greedy results obtained by GreedyRNNTInfer due to implementation differences.

    For accurate greedy results, please use GreedyRNNTInfer or GreedyBatchedRNNTInfer.

  • search_type (# The following arguments are specific to the chosen) –

    str representing the type of beam search to perform. Must be one of [‘beam’, ‘tsd’, ‘alsd’]. ‘nsc’ is currently not supported.

    Algoritm used: beam - basic beam search strategy. Larger beams generally result in better decoding,

    however the time required for the search also grows steadily.

    tsd - time synchronous decoding. Please refer to the paper:

    [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) for details on the algorithm implemented.

    Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions. For longer sequences, T is greater, and can therefore take a long time for beams to obtain good results. This also requires greater memory to execute.

    alsd - alignment-length synchronous decoding. Please refer to the paper:

    [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) for details on the algorithm implemented.

    Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with growth factor of T + U_max, where U_max is the maximum target length expected during execution.

    Generally, T + U_max < T * max_symmetric_expansions. However, ALSD beams are non-unique, therefore it is required to use larger beam sizes to achieve the same (or close to the same) decoding accuracy as TSD.

    For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD.

    maes = modified adaptive expansion searcn. Please refer to the paper:

    [Accelerating RNN Transducer Inference via Adaptive Expansion Search](https://ieeexplore.ieee.org/document/9250505)

    Modified Adaptive Synchronous Decoding (mAES) execution time is adaptive w.r.t the number of expansions (for tokens) required per timestep. The number of expansions can usually be constrained to 1 or 2, and in most cases 2 is sufficient.

    This beam search technique can possibly obtain superior WER while sacrificing some evaluation time.

  • score_norm – bool, whether to normalize the scores of the log probabilities.

  • return_best_hypothesis – bool, decides whether to return a single hypothesis (the best out of N), or return all N hypothesis (sorted with best score first). The container class changes based this flag - When set to True (default), returns a single Hypothesis. When set to False, returns a NBestHypotheses container, which contains a list of Hypothesis.

  • search_type

  • tsd_max_sym_exp_per_step – Used for search_type=tsd. The maximum symmetric expansions allowed per timestep during beam search. Larger values should be used to attempt decoding of longer sequences, but this in turn increases execution time and memory usage.

  • alsd_max_target_len – Used for search_type=alsd. The maximum expected target sequence length during beam search. Larger values allow decoding of longer sequences at the expense of execution time and memory.

  • stabilized. (# The following two flags are placeholders and unused until nsc implementation is) –

  • nsc_max_timesteps_expansion – Unused int.

  • nsc_prefix_alpha – Unused int.

  • flags (# mAES) –

  • maes_num_steps – Number of adaptive steps to take. From the paper, 2 steps is generally sufficient. int > 1.

  • maes_prefix_alpha – Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1 in order to reduce expensive beam search cost later. int >= 0.

  • maes_expansion_beta – Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, and affects the speed of inference since large values will perform large beam search in the next step.

  • maes_expansion_gamma – Float pruning threshold used in the prune-by-value step when computing the expansions. The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob is the “most” likely token to be predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for expansion apart from the “most likely” candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally tuned on a validation set.

  • softmax_temperature – Scales the logits of the joint prior to computing log_softmax.

  • preserve_alignments

    Bool flag which preserves the history of alignments generated during beam decoding (sample). When set to true, the Hypothesis will contain the non-null value for alignments in it. Here, alignments is a List of List of ints.

    The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. U is the number of target tokens for the current timestep Ti.

    NOTE: preserve_alignments is an invalid argument for any search_type other than basic beam search.

align_length_sync_decoding(h: torch.Tensor, encoded_lengths: torch.Tensor, partial_hypotheses: Optional[nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis] = None) List[nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis][source]#

Alignment-length synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040

Parameters

h – Encoded speech features (1, T_max, D_enc)

Returns

N-best decoding results

Return type

nbest_hyps

Beam search implementation.

Parameters

x – Encoded speech features (1, T_max, D_enc)

Returns

N-best decoding results

Return type

nbest_hyps

Greedy search implementation for transducer. Generic case when beam size = 1. Results might differ slightly due to implementation details as compared to GreedyRNNTInfer and GreedyBatchRNNTInfer.

Parameters

h – Encoded speech features (1, T_max, D_enc)

Returns

1-best decoding results

Return type

hyp

property input_types#

Returns definitions of module input ports.

Based on/modified from https://ieeexplore.ieee.org/document/9250505

Parameters

h – Encoded speech features (1, T_max, D_enc)

Returns

N-best decoding results

Return type

nbest_hyps

property output_types#

Returns definitions of module output ports.

Prefix search for NSC and mAES strategies. Based on https://arxiv.org/pdf/1211.3711.pdf

recombine_hypotheses(hypotheses: List[nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis]) List[nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis][source]#

Recombine hypotheses with equivalent output sequence.

Parameters

hypotheses (list) – list of hypotheses

Returns

list of recombined hypotheses

Return type

final (list)

sort_nbest(hyps: List[nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis]) List[nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis][source]#

Sort hypotheses by score or score given sequence length.

Parameters

hyps – list of hypotheses

Returns

sorted list of hypotheses

Return type

hyps

time_sync_decoding(h: torch.Tensor, encoded_lengths: torch.Tensor, partial_hypotheses: Optional[nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis] = None) List[nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis][source]#

Time synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040

Parameters

h – Encoded speech features (1, T_max, D_enc)

Returns

N-best decoding results

Return type

nbest_hyps

Hypotheses#

class nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis(score: float, y_sequence: typing.Union[typing.List[int], torch.Tensor], text: typing.Optional[str] = None, dec_out: typing.Optional[typing.List[torch.Tensor]] = None, dec_state: typing.Optional[typing.Union[typing.List[typing.List[torch.Tensor]], typing.List[torch.Tensor]]] = None, timestep: typing.Union[typing.List[int], torch.Tensor] = <factory>, alignments: typing.Optional[typing.Union[typing.List[int], typing.List[typing.List[int]]]] = None, length: typing.Union[int, torch.Tensor] = 0, y: typing.Optional[typing.List[torch.tensor]] = None, lm_state: typing.Optional[typing.Union[typing.Dict[str, typing.Any], typing.List[typing.Any]]] = None, lm_scores: typing.Optional[torch.Tensor] = None, tokens: typing.Optional[typing.Union[typing.List[int], torch.Tensor]] = None, last_token: typing.Optional[torch.Tensor] = None)[source]#

Bases: object

Hypothesis class for beam search algorithms.

score: A float score obtained from an AbstractRNNTDecoder module’s score_hypothesis method.

y_sequence: Either a sequence of integer ids pointing to some vocabulary, or a packed torch.Tensor

behaving in the same manner. dtype must be torch.Long in the latter case.

dec_state: A list (or list of list) of LSTM-RNN decoder states. Can be None.

text: (Optional) A decoded string after processing via CTC / RNN-T decoding (removing the CTC/RNNT

blank tokens, and optionally merging word-pieces). Should be used as decoded string for Word Error Rate calculation.

timestep: (Optional) A list of integer indices representing at which index in the decoding

process did the token appear. Should be of same length as the number of non-blank tokens.

alignments: (Optional) Represents the CTC / RNNT token alignments as integer tokens along an axis of

time T (for CTC) or Time x Target (TxU). For CTC, represented as a single list of integer indices. For RNNT, represented as a dangling list of list of integer indices. Outer list represents Time dimension (T), inner list represents Target dimension (U). The set of valid indices includes the CTC / RNNT blank token in order to represent alignments.

length: Represents the length of the sequence (the original length without padding), otherwise

defaults to 0.

y: (Unused) A list of torch.Tensors representing the list of hypotheses.

lm_state: (Unused) A dictionary state cache used by an external Language Model.

lm_scores: (Unused) Score of the external Language Model.

tokens: (Optional) A list of decoded tokens (can be characters or word-pieces.

last_token (Optional): A token or batch of tokens which was predicted in the last step.

class nemo.collections.asr.parts.utils.rnnt_utils.NBestHypotheses(n_best_hypotheses: Optional[List[nemo.collections.asr.parts.utils.rnnt_utils.Hypothesis]])[source]#

Bases: object

List of N best hypotheses