Checkpoints

There are two main ways to load pretrained checkpoints in NeMo:

  • Using the restore_from() method to load a local checkpoint file (.nemo), or

  • Using the from_pretrained() method to download and set up a checkpoint from NGC.

Refer to the following sections for instructions and examples for each.

Note that these instructions are for loading fully trained checkpoints for evaluation or fine-tuning. For resuming an unfinished training experiment, use the Experiment Manager to do so by setting the resume_if_exists flag to True.

NeMo automatically saves checkpoints of a model that is trained in a .nemo format. Alternatively, to manually save the model at any point, issue model.save_to(<checkpoint_path>.nemo).

If there is a local .nemo checkpoint that you’d like to load, use the restore_from() method:

Copy
Copied!
            

import nemo.collections.asr as nemo_asr model = nemo_asr.models.<MODEL_BASE_CLASS>.restore_from(restore_path="<path/to/checkpoint/file.nemo>")

Where the model base class is the ASR model class of the original checkpoint, or the general ASRModel class.

Hybrid ASR-TTS model is a transparent wrapper for the ASR model, text-to-mel-spectrogram generator, and optional enhancer. The model is saved as a solid .nemo checkpoint containing all these parts. Due to transparency, the ASR model can be extracted after training/finetuning separately by using the asr_model attribute (NeMo submodel) hybrid_model.asr_model.save_to(<asr_checkpoint_path>.nemo) or by using a wrapper made for convenience purpose hybrid_model.save_asr_model_to(<asr_checkpoint_path>.nemo)

The ASR collection has checkpoints of several models trained on various datasets for a variety of tasks. These checkpoints are obtainable via NGC NeMo Automatic Speech Recognition collection. The model cards on NGC contain more information about each of the checkpoints available.

The tables below list the ASR models available from NGC. The models can be accessed via the from_pretrained() method inside the ASR Model class. In general, you can load any of these models with code in the following format:

Copy
Copied!
            

import nemo.collections.asr as nemo_asr model = nemo_asr.models.ASRModel.from_pretrained(model_name="<MODEL_NAME>")

Where the model name is the value under “Model Name” entry in the tables below.

For example, to load the base English QuartzNet model for speech recognition, run:

Copy
Copied!
            

model = nemo_asr.models.ASRModel.from_pretrained(model_name="QuartzNet15x5Base-En")

You can also call from_pretrained() from the specific model class (such as EncDecCTCModel for QuartzNet) if you need to access a specific model functionality.

If you would like to programmatically list the models available for a particular base class, you can use the list_available_models() method.

Copy
Copied!
            

nemo_asr.models.<MODEL_BASE_CLASS>.list_available_models()

Transcribing/Inference

To perform inference and transcribe a sample of speech after loading the model, use the transcribe() method:

Copy
Copied!
            

model.transcribe(audio=[list of audio files], batch_size=BATCH_SIZE)

audio can be a string path to a file, a list of string paths to multiple files, a numpy or PyTorch tensor that is an audio file loaded via soundfile or some other library or even a list of such tensors. This expanded support for inputs to transcription should help users to easily integrate NeMo into their pipelines.


You can do inference on a numpy array that represents an audio signal as follows. Note that it is your responsibility to process the audio to be monochannel and 16KHz sample rate before passing it to the model.

Copy
Copied!
            

import torch import soundfile as sf from nemo.collections.asr.models import ASRModel model = ASRModel.from_pretrained(<Model Name>) model.eval() # Load audio files audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") audio, sr = sf.read(audio_file, dtype='float32') audio_file_2 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an104-mrcb-b.wav") audio_2, sr = sf.read(audio_file_2, dtype='float32') # Mix one numpy array audio segment with torch audio tensor audio_2 = torch.from_numpy(audio_2) # Numpy array + torch tensor mixed tensor input (for batched inference) outputs = model.transcribe([audio, audio_2], batch_size=2)


In order to obtain alignments from CTC or RNNT models (previously called logprobs), you can use the following code:

Copy
Copied!
            

hyps = model.transcribe(audio=[list of audio files], batch_size=BATCH_SIZE, return_hypotheses=True) logprobs = hyps[0].alignments # or hyps[0][0].alignments for RNNT


Often times, we want to transcribe a large number of files at once (maybe from a manifest for example). In this case, using transcribe() directly may be incorrect because it will delay the return of the result until every single sample in the input is processed. One work around is to call transcribe() multiple times, each time using a small subset of the data. This workflow is now supported via a transcribe_generator().

Copy
Copied!
            

import nemo.collections.asr as nemo_asr model = nemo_asr.models.ASRModel.from_pretrained(<Model Name>) config = model.get_transcribe_config() config.batch_size = 32 generator = model.transcribe_generator(audio, config) for processed_outputs in generator: # process a batch of 32 results (or less if last batch does not contain 32 elements) ....

For more information, see nemo.collections.asr.modules. For more information on the general Transcription API, please take a look at TranscriptionMixin. The audio files should be 16KHz mono-channel wav files.


Inference with Multi-task Models

Multi-task models that use structured prompts require additionl task tokens as input, in which case it is recommended to use manifest as input. Below is an example of using the nvidia/canary-1b model:

Copy
Copied!
            

from nemo.collections.asr.models import EncDecMultiTaskModel # load model canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b') # update dcode params decode_cfg = canary_model.cfg.decoding decode_cfg.beam.beam_size = 1 canary_model.change_decoding_strategy(decode_cfg) # run transcribe predicted_text = canary_model.transcribe( "<path to input manifest file>", batch_size=16, # batch size to run the inference with )

Here the manifest file should be a json file where each line has the following format:

Copy
Copied!
            

{ "audio_filepath": "/path/to/audio.wav", # path to the audio file "duration": None, # duration of the audio in seconds, set to `None` to use full audio "taskname": "asr", # use "ast" for speech-to-text translation "source_lang": "en", # language of the audio input, set `source_lang`==`target_lang` for ASR "target_lang": "en", # language of the text output "pnc": "yes", # whether to have PnC output, choices=['yes', 'no'] "answer": "na", # set to non-dummy strings to calculate WER/BLEU scores }

Note that using manifest allows to specify the task configuration for each audio individually. If we want to use the same task configuration for all the audio files, it can be specified in transcribe method directly.

Copy
Copied!
            

canary_model.transcribe( audio=[list of audio files], batch_size=4, # batch size to run the inference with task="asr", # use "ast" for speech-to-text translation source_lang="en", # language of the audio input, set `source_lang`==`target_lang` for ASR target_lang="en", # language of the text output pnc=True, # whether to have PnC output, choices=[True, False] )

Inference on long audio

In some cases the audio is too long for standard inference, especially if you’re using a model such as Conformer, where the time and memory costs of the attention layers scale quadratically with the duration.

There are two main ways of performing inference on long audio files in NeMo:

The first way is to use buffered inference, where the audio is divided into chunks to run on, and the output is merged afterwards. The relevant scripts for this are contained in this folder.

The second way, specifically for models with the Conformer/Fast Conformer encoder, is to use local attention, which changes the costs to be linear. You can train Fast Conformer models with Longformer-style (https://arxiv.org/abs/2004.05150) local+global attention using one of the following configs: CTC config at <NeMo_git_root>/examples/asr/conf/fastconformer/fast-conformer-long_ctc_bpe.yaml and transducer config at <NeMo_git_root>/examples/asr/conf/fastconformer/fast-conformer-long_transducer_bpe.yaml. You can also convert any model trained with full context attention to local, though this may result in lower WER in some cases. You can switch to local attention when running the transcribe or evaluation scripts in the following way:

Copy
Copied!
            

python speech_to_text_eval.py \ (...other parameters...) \ ++model_change.conformer.self_attention_model="rel_pos_local_attn" \ ++model_change.conformer.att_context_size=[128, 128]

Alternatively, you can change the attention model after loading a checkpoint:

Copy
Copied!
            

asr_model = ASRModel.from_pretrained('stt_en_conformer_ctc_large') asr_model.change_attention_model( self_attention_model="rel_pos_local_attn", att_context_size=[128, 128] )

Sometimes, the downsampling module at the earliest stage of the model can take more memory than the actual forward pass since it directly operates on the audio sequence which may not be able to fit in memory for very long audio files. In order to reduce the memory consumption of the subsampling module, you can ask the model to perform auto-chunking of the input sequence and process it piece by piece, taking more time but avoiding an OutOfMemoryError.

Copy
Copied!
            

asr_model = ASRModel.from_pretrained('stt_en_fastconformer_ctc_large') # Speedup conv subsampling factor to speed up the subsampling module. asr_model.change_subsampling_conv_chunking_factor(1) # 1 = auto select

Note

Only certain models which use depthwise separable convolutions in the downsampling layer support this operation. Please try it out on your model and see if it is supported.

Inference on Apple M-Series GPU

To perform inference on Apple Mac M-Series GPU (mps PyTorch device), use PyTorch 2.0 or higher (see the mac-installation <https://github.com/NVIDIA/NeMo/blob/stable/README.rst#mac-computers-with-apple-silicon> section). Environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 should be set, since not all operations in PyTorch are currently implemented on mps device.

If allow_mps=true flag is passed to speech_to_text_eval.py, the mps device will be selected automatically.

Copy
Copied!
            

PYTORCH_ENABLE_MPS_FALLBACK=1 python speech_to_text_eval.py \ (...other parameters...) \ allow_mps=true

Fine-tuning on Different Datasets

There are multiple ASR tutorials provided in the Tutorials section. Most of these tutorials explain how to instantiate a pre-trained model, prepare the model for fine-tuning on some dataset (in the same language) as a demonstration.

When preparing your own inference scripts, please follow the execution flow diagram order for correct inference, found at the examples directory for ASR collection.

Below is a list of all the ASR models that are available in NeMo for specific languages, as well as auxiliary language models for certain languages.

Language Models for ASR

Model Name

Model Base Class

Model Card

asrlm_en_transformer_large_ls TransformerLMModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:asrlm_en_transformer_large_ls

English

Model Name

Model Base Class

Model Card

QuartzNet15x5Base-En EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemospeechmodels
stt_en_jasper10x5dr EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_jasper10x5dr
stt_en_citrinet_256 EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_256
stt_en_citrinet_512 EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_512
stt_en_citrinet_1024 EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_1024
stt_en_citrinet_256_gamma_0_25 EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_256_gamma_0_25
stt_en_citrinet_512_gamma_0_25 EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_512_gamma_0_25
stt_en_citrinet_1024_gamma_0_25 EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_1024_gamma_0_25
stt_en_contextnet_256_mls EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_contextnet_256_mls
stt_en_contextnet_512_mls EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_contextnet_512_mls
stt_en_contextnet_1024_mls EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_contextnet_1024_mls
stt_en_contextnet_256 EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_contextnet_256
stt_en_contextnet_512 EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_contextnet_512
stt_en_contextnet_1024 EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_contextnet_1024
stt_en_conformer_ctc_small EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_small
stt_en_conformer_ctc_medium EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_medium
stt_en_conformer_ctc_large EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_large
stt_en_conformer_ctc_xlarge EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_xlarge
stt_en_conformer_ctc_small_ls EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_small_ls
stt_en_conformer_ctc_medium_ls EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_medium_ls
stt_en_conformer_ctc_large_ls EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_ctc_large_ls
stt_en_conformer_transducer_large_ls EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_large_ls
stt_en_conformer_transducer_small EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_small
stt_en_conformer_transducer_medium EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_medium
stt_en_conformer_transducer_large EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_large
stt_en_conformer_transducer_xlarge EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_xlarge
stt_en_conformer_transducer_xxlarge EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_xxlarge
stt_en_fastconformer_ctc_large_ls EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_ctc_large_ls
stt_en_fastconformer_transducer_large_ls EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_transducer_large_ls
stt_en_fastconformer_transducer_large EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_transducer_large
stt_en_fastconformer_ctc_large EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_ctc_large
stt_en_fastconformer_hybrid_large_pc EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_pc
stt_en_fastconformer_transducer_xlarge EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_transducer_xlarge
stt_en_fastconformer_ctc_xlarge EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_ctc_xlarge
stt_en_fastconformer_transducer_xxlarge EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_transducer_xxlarge
stt_en_fastconformer_ctc_xxlarge EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_ctc_xxlarge
stt_en_fastconformer_hybrid_large_streaming_80ms EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_80ms
stt_en_fastconformer_hybrid_large_streaming_480ms EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_480ms
stt_en_fastconformer_hybrid_large_streaming_1040ms EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_1040ms
stt_en_fastconformer_hybrid_large_streaming_multi EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_multi

Mandarin

Model

Model Base Class

Model Card

stt_zh_citrinet_512 EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_zh_citrinet_512
stt_zh_citrinet_1024_gamma_0_25 EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_zh_citrinet_1024_gamma_0_25
stt_zh_conformer_transducer_large EncDecRNNTModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_zh_conformer_transducer_large

German

Model

Model Base Class

Model Card

stt_de_quartznet15x5 EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_de_quartznet15x5
stt_de_citrinet_1024 EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_de_citrinet_1024
stt_de_contextnet_1024 EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_de_contextnet_1024
stt_de_conformer_ctc_large EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_de_conformer_ctc_large
stt_de_conformer_transducer_large EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_de_conformer_transducer_large
stt_de_fastconformer_hybrid_large_pc EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_de_fastconformer_hybrid_large_pc

French

Model

Model Base Class

Model Card

stt_fr_quartznet15x5 EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_fr_quartznet15x5
stt_fr_citrinet_1024_gamma_0_25 EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_fr_citrinet_1024_gamma_0_25
stt_fr_no_hyphen_citrinet_1024_gamma_0_25 EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_fr_citrinet_1024_gamma_0_25
stt_fr_contextnet_1024 EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_fr_contextnet_1024
stt_fr_conformer_ctc_large EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_fr_conformer_ctc_large
stt_fr_no_hyphen_conformer_ctc_large EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_fr_conformer_ctc_large
stt_fr_conformer_transducer_large EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_fr_conformer_transducer_large
stt_fr_fastconformer_hybrid_large_pc EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_fr_fastconformer_hybrid_large_pc

Polish

Model

Model Base Class

Model Card

stt_pl_quartznet15x5 EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_pl_quartznet15x5
stt_pl_fastconformer_hybrid_large_pc EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_pl_fastconformer_hybrid_large_pc

Italian

Model

Model Base Class

Model Card

stt_it_quartznet15x5 EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_it_quartznet15x5
stt_it_fastconformer_hybrid_large_pc EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_it_fastconformer_hybrid_large_pc

Russian

Model

Model Base Class

Model Card

stt_ru_quartznet15x5 EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ru_quartznet15x5
stt_ru_fastconformer_hybrid_large_pc EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ru_fastconformer_hybrid_large_pc

Spanish

Model

Model Base Class

Model Card

stt_es_quartznet15x5 EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_quartznet15x5
stt_es_citrinet_512 EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_citrinet_512
stt_es_citrinet_1024_gamma_0_25 EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_citrinet_1024_gamma_0_25
stt_es_conformer_ctc_large EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_conformer_ctc_large
stt_es_conformer_transducer_large EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_conformer_transducer_large
stt_es_contextnet_1024 EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_contextnet_1024
stt_es_fastconformer_hybrid_large_pc EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_fastconformer_hybrid_large_pc

Catalan

Model

Model Base Class

Model Card

stt_ca_quartznet15x5 EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ca_quartznet15x5
stt_ca_conformer_ctc_large EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ca_conformer_ctc_large
stt_ca_conformer_transducer_large EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ca_conformer_transducer_large

Hindi

Model Name

Model Base Class

Model Card

stt_hi_conformer_ctc_medium EncDecCTCModelBPE https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_hi_conformer_ctc_medium

Marathi

Model Name

Model Base Class

Model Card

stt_mr_conformer_ctc_medium EncDecCTCModelBPE https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_mr_conformer_ctc_medium

Kinyarwanda

Model

Model Base Class

Model Card

stt_rw_conformer_ctc_large EncDecCTCModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_rw_conformer_ctc_large
stt_rw_conformer_transducer_large EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_rw_conformer_transducer_large

Belarusian

Model

Model Base Class

Model Card

stt_by_fastconformer_hybrid_large_pc EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_by_fastconformer_hybrid_large_pc

Ukrainian

Model

Model Base Class

Model Card

stt_ua_fastconformer_hybrid_large_pc EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ua_fastconformer_hybrid_large_pc

Multilingual

Model

Model Base Class

Model Card

stt_enes_conformer_ctc_large EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_enes_conformer_ctc_large
stt_enes_conformer_transducer_large EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_enes_conformer_transducer_large
stt_multilingual_fastconformer_hybrid_large_pc EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_multilingual_fastconformer_hybrid_large_pc
stt_multilingual_fastconformer_hybrid_large_pc_blend_eu EncDecHybridRNNTCTCBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_multilingual_fastconformer_hybrid_large_pc_blend_eu

Code-Switching

Model

Model Base Class

Model Card

stt_enes_conformer_ctc_large_codesw EncDecCTCModelBPE https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_enes_conformer_ctc_large_codesw
stt_enes_conformer_transducer_large_codesw EncDecRNNTBPEModel https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_enes_conformer_transducer_large_codesw
Previous ASR Language Modeling and Customization
Next Scores
© Copyright 2023-2024, NVIDIA. Last updated on Apr 12, 2024.