How To Train, Evaluate, and Fine-Tune an n-gram Language Model#

Language modeling returns a probability distribution over a sequence of words. Besides assigning a probability to a sequence of words, the language model also assigns a probability for the likelihood of a given word (or a sequence of words) that follows a sequence of words.

The sentence: all of a sudden I notice three guys standing on the sidewalk would be scored higher than the sentence: on guys all I of notice sidewalk three a sudden standing the by the language model.

A language model trained on a large corpus, that is, a large dataset, can significantly improve the accuracy of an ASR system as suggested by recent research.

n-gram Language Model#

There are primarily two types of language models:

  • n-gram language models: These models use the frequency of n-grams to learn the probability distribution over words. Two benefits of n-gram language models are simplicity and scalability – with a larger n, a model can store more context with a well-understood space–time tradeoff, enabling small experiments to scale up efficiently.

  • Neural language models: These models use different kinds of neural networks to model the probability distribution over words, and have surpassed the n-gram language models in the ability to model language, but are generally slower to evaluate.

In this tutorial, we will show how to train, evaluate, and optionally fine-tune an n-gram language model leveraging NeMo.

Prerequisites#

Ensure you meet the following prerequisites.

  1. You have access and are logged into NVIDIA NGC. For step-by-step instructions, refer to the NGC Getting Started Guide.

  2. You have installed Kaggle API. For step-by-step instructions, refer to this install and authenticate Kaggle API.


Training and Fine-tuning LM with KenLM and NeMo#

Installing and Setting up NeMo#

Cloning and installing NeMo from source.

## Install NeMo
BRANCH = 'main'
!git clone https://github.com/NVIDIA/NeMo.git
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]

Installing and Setting up KenLM#

Installing KenLM from source.

!apt install libeigen3-dev
!git clone https://github.com/kpu/kenlm.git
!cd kenlm && mkdir build && cd build && cmake .. && make -j

!pip3 install git+https://github.com/kpu/kenlm.git
!pip3 install git+https://github.com/flashlight/text.git

Installing and Setting up NGC CLI#

To install and setup the NGC CLI, follow the instructions from here.


Preparing the Dataset#

LibriSpeech LM Normalized Dataset#

For this tutorial, we use the normalized version of the LibriSpeech LM dataset to train our n-gram language model. The normalized version of the LibriSpeech LM dataset is available here.
The training data is publicly available here and can be downloaded directly.

Downloading the Dataset#

# Set the path to a folder where you want your data and results to be saved.
DATA_DOWNLOAD_DIR="content/datasets"
RESULTS_DIR="content/results"
MODELS_DIR="content/models"

!mkdir -p $DATA_DOWNLOAD_DIR $RESULTS_DIR $MODELS_DIR
# Note: Ensure that wget and unzip utilities are available. If not, install them.
!wget 'https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz' -P $DATA_DOWNLOAD_DIR

# Extract the data
!gzip -dk $DATA_DOWNLOAD_DIR/librispeech-lm-norm.txt.gz

For the sake of reducing the time this tutorial takes, we reduced the number of lines of the training dataset. Feel free to modify the number of used lines.

# Use a random 100,000 lines for training
!shuf -n 100000 $DATA_DOWNLOAD_DIR/librispeech-lm-norm.txt  > $DATA_DOWNLOAD_DIR/reduced_training.txt

Download the Evaluation Dataset#

#Note: This data can be used only with NVIDIA’s products or services for evaluation and benchmarking purposes.
!source ~/.bash_profile && ngc registry resource  download-version --dest $DATA_DOWNLOAD_DIR nvidia/riva/healthcare_eval_set:1.0

Generating the Base Language Model#

KENLM_BASE="kenlm/build/bin/"

$KENLM_BASE/lmplz
Required Arguments:

  • -o: Order of the language model to estimate.

Optional Arguments:

  • -S: Memory to use. This is a number followed by single-character suffix: % for percentage of physical memory (on platforms where this is measured), b for bytes, K for kilobytes, M for megabytes, and so on for G and T. If no suffix is given, kilobytes are assumed for compatibility with GNU sort. The sort program is not used; the command line is simply designed to be compatible.

  • -T: Temporary file location.

  • --discount_fallback: Kneser-Ney smoothing discounts are estimated from counts of counts, including the number of singletons.

!$KENLM_BASE/lmplz -o 4 < $DATA_DOWNLOAD_DIR/reduced_training.txt > $RESULTS_DIR/base_lm.arpa

$KENLM_BASE/build_binary
Arguments:

  • -q: Quantization flag for the probabilities. For example -q 8 store 8 bits of probability

  • -b: Quantization flag for back-off weights. For example -b 7 store 7 bits of back-off

  • -a: Maximum number of bits to be removed and stored implicitly using a table of offsets to reduce memory footprint.

!$KENLM_BASE/build_binary  trie -q 8 -b 7 -a 256 $RESULTS_DIR/base_lm.arpa $RESULTS_DIR/base_lm.bin

Load the ASR Model#

import torch
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.metrics.wer import CTCDecodingConfig

Download the English Conformer Model from NGC#

!source ~/.bash_profile && ngc registry model download-version "nvidia/riva/speechtotext_en_us_conformer:trainable_v3.1" --dest $MODELS_DIR

Create a Lexicon File for Flashlight Decoder#

!python NeMo/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py --arpa $RESULTS_DIR/base_lm.arpa --model $MODELS_DIR/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo --lower --dst $RESULTS_DIR

Updating the Decoder Type, Language Model, and Lexicon#

device = torch.device('cuda')
asr_model = ASRModel.restore_from(f"{MODELS_DIR}/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo").to(device)


decoding_cfg = CTCDecodingConfig()

decoding_cfg.strategy = "flashlight"
decoding_cfg.beam.search_type = "flashlight"
decoding_cfg.beam.kenlm_path = f'{RESULTS_DIR}/base_lm.bin'
decoding_cfg.beam.flashlight_cfg.lexicon_path=f'{RESULTS_DIR}/base_lm.lexicon'
decoding_cfg.beam.beam_size = 32
decoding_cfg.beam.beam_alpha = 0.2
decoding_cfg.beam.beam_beta = 0.2
decoding_cfg.beam.flashlight_cfg.beam_size_token = 32
decoding_cfg.beam.flashlight_cfg.beam_threshold = 25.0

asr_model.change_decoding_strategy(decoding_cfg)

Evaluate#

import json
import os

def transcribe_json(asr_model, json_path, output_json):
    dataset_root = os.path.split(json_path)[0]
    with open(json_path) as fin, open(output_json, 'w') as fout:
        manifest = []
        audios = []
        for line in fin:
            dt = json.loads(line.strip())
            manifest.append(dt)
            audios.append(dt['audio_filepath'].replace("/data", dataset_root))
        transcripts = asr_model.transcribe(paths2audio_files=audios)
        for i in range(len(transcripts)):
            dt = {
                'audio_filepath': manifest[i]['audio_filepath'],
                'text': transcripts[i]
            }
            fout.write(json.dumps(dt)+"\n")

transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json")

Calculate the Word Error Rate#

!pip install jiwer
from jiwer import wer
import json

def calculate_wer(ground_truth_manifest, asr_transcript):
    data ={}
    ground_truths = []
    predictions = []
    with open(ground_truth_manifest) as file:
        for line in file:
            dt = json.loads(line)
            data[dt['audio_filepath']] = dt['text']
    with open(asr_transcript) as file:
        for line in file:
            dt = json.loads(line)
            if dt['audio_filepath'] in data:
                ground_truths.append(data[dt['audio_filepath']])
                predictions.append(dt['text'])
    return round(100*wer(ground_truths, predictions), 2)

print( "WER of base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json"))
print("WER of base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json"))

Fine-Tuning and Interpolation#

The fine-tuning process will continue training using a previously trained model by training a second model on new domain data and interpolating it with the original model. Fine-tuning requires that the original model have intermediate enabled during training. A fine-tuned model cannot be used for fine-tuning again.

Downloading and Procesing Domain Data (healthcare) for LM Fine-Tuning#

For the purpose of fine-tuning on a healthcare domain we can make use of the Kaggle dataset PubMed 200k RCT: a Dataset for Sequential Sentence Classification in Medical Abstracts.
This dataset is available here.
Follow the instructions to install and authenticate Kaggle API.
Note: Each user is responsible for checking the content of datasets and the applicable licenses and determining if they are suitable for the intended use.

!kaggle datasets download -d anshulmehtakaggl/200000-abstracts-for-seq-sentence-classification
!unzip -d $DATA_DOWNLOAD_DIR 200000-abstracts-for-seq-sentence-classification.zip
# Perform basic text cleaning and generate domain data
import string,re
def clean_text(text):
    text = re.sub(r"[^a-z' ]+", "", text.lower().strip())
    text = ' '.join(text.split())
    if len(text.split())> 5:
        return text.strip()
    
# Using dev file since we want a small amount of finetuning data. For better text Normalization use NeMo [https://github.com/NVIDIA/NeMo/tree/main/nemo_text_processing]
with open(f'{DATA_DOWNLOAD_DIR}/20k_abstracts_numbers_with_@/dev.txt') as file, open(f'{DATA_DOWNLOAD_DIR}/domain_data_all.txt', 'w') as outfile:
    for line in file:
        if line.startswith("###") or not line.strip():
            continue
        _, text = line.strip().split('\t')
        text = clean_text(text)
        if text:
            outfile.write(text+'\n')
            
# Picking top 10000 lines from dataset
!head -10000 $DATA_DOWNLOAD_DIR/domain_data_all.txt > $DATA_DOWNLOAD_DIR/domain_data.txt

The fine-tuning process will continue training using a previously trained model by training a second model on new domain data and interpolating it with the original model. Fine-tuning requires the original model to have intermediate enabled during training. A fine-tuned model cannot be used for fine-tuning again.

For fine-tuning a n-gram language model with KenLM, perform the following steps:

  1. Generate intermediate ARPA files for Base LM and Domain LM.

  2. Interpolate Base LM and Domain LM with suitable weights.

Generating Intermediate ARPA#

# Base LM
!mkdir base_intermediate
!$KENLM_BASE/lmplz -o 4 --intermediate base_intermediate/inter < $DATA_DOWNLOAD_DIR/reduced_training.txt

# Healthcare LM
!mkdir healthcare_intermediate
!$KENLM_BASE/lmplz -o 4 --intermediate healthcare_intermediate/inter < $DATA_DOWNLOAD_DIR/domain_data.txt

Interpolate#

Weights for interpolation can be passed with

 -w 0.6 0.4

Here, 60% weightage is assigned to the base LM and 40% to the domain.

!$KENLM_BASE/interpolate -w 0.6 0.4 -m base_intermediate/inter healthcare_intermediate/inter > $RESULTS_DIR/interpolated_lm_60-40.arpa
!$KENLM_BASE/build_binary  trie -q 8 -b 7 -a 256 $RESULTS_DIR/interpolated_lm_60-40.arpa $RESULTS_DIR/interpolated_lm_60-40.bin
!python NeMo/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py --arpa $RESULTS_DIR/interpolated_lm_60-40.arpa --model $MODELS_DIR/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo --lower --dst $RESULTS_DIR

Evaluate#

decoding_cfg.beam.kenlm_path = f'{RESULTS_DIR}/interpolated_lm_60-40.bin'
decoding_cfg.beam.flashlight_cfg.lexicon_path=f'{RESULTS_DIR}/interpolated_lm_60-40.lexicon'

asr_model.change_decoding_strategy(decoding_cfg)
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_interpolated_lm_60-40.json")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_interpolated_lm_60-40.json")

Calculate WER#

print( "WER of base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json"))
print( "WER of Domain model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_interpolated_lm_60-40.json"))
print("WER of base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json"))
print("WER of Domain model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_interpolated_lm_60-40.json"))

Pruning#

LM generated by simply passing the text corpus to kenLM contains some n-grams which are less frequent (in corpus) and thus have very low probabilities. Such n-grams can be removed by pruning.
Pruning requires some thresholds which can be passed with the --prune parameter followed by space separated thresholds that specify the count thresholds for each order while generating ARPA:

 --prune 0 1 7 9

All the n-gram with frequncy less than or equal to specified threshold will get eliminated.
Here, 2-grams with freq. <= 1, 3-gram with freq.<=7 & 4-gram with freq.<=9 will get eliminated.
There’s a tradeoff between degree of pruning and accuracy. High pruning parameters will reduce the size of language model but at the cost of model accuracy!

*Note:#

Pruning of 1-gram is not supported, threshold for 1-gram should always be 0

!kenlm/build/bin/lmplz -o 4 --prune 0 1 7 9  < $DATA_DOWNLOAD_DIR/reduced_training.txt > $RESULTS_DIR/pruned_lm.arpa
!$KENLM_BASE/build_binary  trie -q 8 -b 7 -a 256 $RESULTS_DIR/pruned_lm.arpa $RESULTS_DIR/pruned_lm.bin
!python NeMo/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py --arpa $RESULTS_DIR/pruned_lm.arpa --model $MODELS_DIR/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo --lower --dst $RESULTS_DIR
# Lets check the size of original LM and Pruned LM
!echo "Size of unpruned ARPA: $(du -h $RESULTS_DIR/base_lm.arpa | cut -f 1)"
!echo "Size of pruned ARPA: $(du -h $RESULTS_DIR/pruned_lm.arpa | cut -f 1)"

Evaluate#

decoding_cfg.beam.kenlm_path = f'{RESULTS_DIR}/pruned_lm.bin'
decoding_cfg.beam.flashlight_cfg.lexicon_path=f'{RESULTS_DIR}/base_lm.lexicon'

asr_model.change_decoding_strategy(decoding_cfg)
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_pruned_lm.json")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_pruned_lm.json")

Calculate WER#

print( "WER of base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json"))
print( "WER of Pruned base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_pruned_lm.json"))
print("WER of base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json"))
print("WER of Pruned base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_pruned_lm.json"))