How To Train, Evaluate, and Fine-Tune an n-gram Language Model
Contents
How To Train, Evaluate, and Fine-Tune an n-gram Language Model#
Language modeling returns a probability distribution over a sequence of words. Besides assigning a probability to a sequence of words, the language model also assigns a probability for the likelihood of a given word (or a sequence of words) that follows a sequence of words.
The sentence: all of a sudden I notice three guys standing on the sidewalk would be scored higher than the sentence: on guys all I of notice sidewalk three a sudden standing the by the language model.
A language model trained on a large corpus, that is, a large dataset, can significantly improve the accuracy of an ASR system as suggested by recent research.
n-gram Language Model#
There are primarily two types of language models:
n-gram language models: These models use the frequency of n-grams to learn the probability distribution over words. Two benefits of n-gram language models are simplicity and scalability – with a larger
n, a model can store more context with a well-understood space–time tradeoff, enabling small experiments to scale up efficiently.Neural language models: These models use different kinds of neural networks to model the probability distribution over words, and have surpassed the n-gram language models in the ability to model language, but are generally slower to evaluate.
In this tutorial, we will show how to train, evaluate, and optionally fine-tune an n-gram language model leveraging NeMo.
Prerequisites#
Ensure you meet the following prerequisites.
You have access and are logged into NVIDIA NGC. For step-by-step instructions, refer to the NGC Getting Started Guide.
You have installed Kaggle API. For step-by-step instructions, refer to this install and authenticate Kaggle API.
Training and Fine-tuning LM with KenLM and NeMo#
Installing and Setting up NeMo#
Cloning and installing NeMo from source.
## Install NeMo
BRANCH = 'main'
!git clone https://github.com/NVIDIA/NeMo.git
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]
Installing and Setting up KenLM#
Installing KenLM from source.
!apt install libeigen3-dev
!git clone https://github.com/kpu/kenlm.git
!cd kenlm && mkdir build && cd build && cmake .. && make -j
!pip3 install git+https://github.com/kpu/kenlm.git
!pip3 install git+https://github.com/flashlight/text.git
Installing and Setting up NGC CLI#
To install and setup the NGC CLI, follow the instructions from here.
Preparing the Dataset#
LibriSpeech LM Normalized Dataset#
For this tutorial, we use the normalized version of the LibriSpeech LM dataset to train our n-gram language model. The normalized version of the LibriSpeech LM dataset is available here.
The training data is publicly available here and can be downloaded directly.
Downloading the Dataset#
# Set the path to a folder where you want your data and results to be saved.
DATA_DOWNLOAD_DIR="content/datasets"
RESULTS_DIR="content/results"
MODELS_DIR="content/models"
!mkdir -p $DATA_DOWNLOAD_DIR $RESULTS_DIR $MODELS_DIR
# Note: Ensure that wget and unzip utilities are available. If not, install them.
!wget 'https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz' -P $DATA_DOWNLOAD_DIR
# Extract the data
!gzip -dk $DATA_DOWNLOAD_DIR/librispeech-lm-norm.txt.gz
For the sake of reducing the time this tutorial takes, we reduced the number of lines of the training dataset. Feel free to modify the number of used lines.
# Use a random 100,000 lines for training
!shuf -n 100000 $DATA_DOWNLOAD_DIR/librispeech-lm-norm.txt > $DATA_DOWNLOAD_DIR/reduced_training.txt
Download the Evaluation Dataset#
#Note: This data can be used only with NVIDIA’s products or services for evaluation and benchmarking purposes.
!source ~/.bash_profile && ngc registry resource download-version --dest $DATA_DOWNLOAD_DIR nvidia/riva/healthcare_eval_set:1.0
Generating the Base Language Model#
KENLM_BASE="kenlm/build/bin/"
$KENLM_BASE/lmplz
Required Arguments:
-o: Order of the language model to estimate.
Optional Arguments:
-S: Memory to use. This is a number followed by single-character suffix:%for percentage of physical memory (on platforms where this is measured),bfor bytes,Kfor kilobytes,Mfor megabytes, and so on forGandT. If no suffix is given, kilobytes are assumed for compatibility with GNU sort. The sort program is not used; the command line is simply designed to be compatible.-T: Temporary file location.--discount_fallback: Kneser-Ney smoothing discounts are estimated from counts of counts, including the number of singletons.
!$KENLM_BASE/lmplz -o 4 < $DATA_DOWNLOAD_DIR/reduced_training.txt > $RESULTS_DIR/base_lm.arpa
$KENLM_BASE/build_binary
Arguments:
-q: Quantization flag for the probabilities. For example-q 8store 8 bits of probability-b: Quantization flag for back-off weights. For example-b 7store 7 bits of back-off-a: Maximum number of bits to be removed and stored implicitly using a table of offsets to reduce memory footprint.
!$KENLM_BASE/build_binary trie -q 8 -b 7 -a 256 $RESULTS_DIR/base_lm.arpa $RESULTS_DIR/base_lm.bin
Load the ASR Model#
import torch
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.metrics.wer import CTCDecodingConfig
Download the English Conformer Model from NGC#
!source ~/.bash_profile && ngc registry model download-version "nvidia/riva/speechtotext_en_us_conformer:trainable_v3.1" --dest $MODELS_DIR
Create a Lexicon File for Flashlight Decoder#
!python NeMo/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py --arpa $RESULTS_DIR/base_lm.arpa --model $MODELS_DIR/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo --lower --dst $RESULTS_DIR
Updating the Decoder Type, Language Model, and Lexicon#
device = torch.device('cuda')
asr_model = ASRModel.restore_from(f"{MODELS_DIR}/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo").to(device)
decoding_cfg = CTCDecodingConfig()
decoding_cfg.strategy = "flashlight"
decoding_cfg.beam.search_type = "flashlight"
decoding_cfg.beam.kenlm_path = f'{RESULTS_DIR}/base_lm.bin'
decoding_cfg.beam.flashlight_cfg.lexicon_path=f'{RESULTS_DIR}/base_lm.lexicon'
decoding_cfg.beam.beam_size = 32
decoding_cfg.beam.beam_alpha = 0.2
decoding_cfg.beam.beam_beta = 0.2
decoding_cfg.beam.flashlight_cfg.beam_size_token = 32
decoding_cfg.beam.flashlight_cfg.beam_threshold = 25.0
asr_model.change_decoding_strategy(decoding_cfg)
Evaluate#
import json
import os
def transcribe_json(asr_model, json_path, output_json):
dataset_root = os.path.split(json_path)[0]
with open(json_path) as fin, open(output_json, 'w') as fout:
manifest = []
audios = []
for line in fin:
dt = json.loads(line.strip())
manifest.append(dt)
audios.append(dt['audio_filepath'].replace("/data", dataset_root))
transcripts = asr_model.transcribe(paths2audio_files=audios)
for i in range(len(transcripts)):
dt = {
'audio_filepath': manifest[i]['audio_filepath'],
'text': transcripts[i]
}
fout.write(json.dumps(dt)+"\n")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json")
Calculate the Word Error Rate#
!pip install jiwer
from jiwer import wer
import json
def calculate_wer(ground_truth_manifest, asr_transcript):
data ={}
ground_truths = []
predictions = []
with open(ground_truth_manifest) as file:
for line in file:
dt = json.loads(line)
data[dt['audio_filepath']] = dt['text']
with open(asr_transcript) as file:
for line in file:
dt = json.loads(line)
if dt['audio_filepath'] in data:
ground_truths.append(data[dt['audio_filepath']])
predictions.append(dt['text'])
return round(100*wer(ground_truths, predictions), 2)
print( "WER of base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json"))
print("WER of base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json"))
Fine-Tuning and Interpolation#
The fine-tuning process will continue training using a previously trained model by training a second model on new domain data and interpolating it with the original model. Fine-tuning requires that the original model have intermediate enabled during training. A fine-tuned model cannot be used for fine-tuning again.
Downloading and Procesing Domain Data (healthcare) for LM Fine-Tuning#
For the purpose of fine-tuning on a healthcare domain we can make use of the Kaggle dataset PubMed 200k RCT: a Dataset for Sequential Sentence Classification in Medical Abstracts.
This dataset is available here.
Follow the instructions to install and authenticate Kaggle API.
Note: Each user is responsible for checking the content of datasets and the applicable licenses and determining if they are suitable for the intended use.
!kaggle datasets download -d anshulmehtakaggl/200000-abstracts-for-seq-sentence-classification
!unzip -d $DATA_DOWNLOAD_DIR 200000-abstracts-for-seq-sentence-classification.zip
# Perform basic text cleaning and generate domain data
import string,re
def clean_text(text):
text = re.sub(r"[^a-z' ]+", "", text.lower().strip())
text = ' '.join(text.split())
if len(text.split())> 5:
return text.strip()
# Using dev file since we want a small amount of finetuning data. For better text Normalization use NeMo [https://github.com/NVIDIA/NeMo/tree/main/nemo_text_processing]
with open(f'{DATA_DOWNLOAD_DIR}/20k_abstracts_numbers_with_@/dev.txt') as file, open(f'{DATA_DOWNLOAD_DIR}/domain_data_all.txt', 'w') as outfile:
for line in file:
if line.startswith("###") or not line.strip():
continue
_, text = line.strip().split('\t')
text = clean_text(text)
if text:
outfile.write(text+'\n')
# Picking top 10000 lines from dataset
!head -10000 $DATA_DOWNLOAD_DIR/domain_data_all.txt > $DATA_DOWNLOAD_DIR/domain_data.txt
The fine-tuning process will continue training using a previously trained model by training a second model on new domain data and interpolating it with the original model. Fine-tuning requires the original model to have intermediate enabled during training. A fine-tuned model cannot be used for fine-tuning again.
For fine-tuning a n-gram language model with KenLM, perform the following steps:
Generating Intermediate ARPA#
# Base LM
!mkdir base_intermediate
!$KENLM_BASE/lmplz -o 4 --intermediate base_intermediate/inter < $DATA_DOWNLOAD_DIR/reduced_training.txt
# Healthcare LM
!mkdir healthcare_intermediate
!$KENLM_BASE/lmplz -o 4 --intermediate healthcare_intermediate/inter < $DATA_DOWNLOAD_DIR/domain_data.txt
Interpolate#
Weights for interpolation can be passed with
-w 0.6 0.4
Here, 60% weightage is assigned to the base LM and 40% to the domain.
!$KENLM_BASE/interpolate -w 0.6 0.4 -m base_intermediate/inter healthcare_intermediate/inter > $RESULTS_DIR/interpolated_lm_60-40.arpa
!$KENLM_BASE/build_binary trie -q 8 -b 7 -a 256 $RESULTS_DIR/interpolated_lm_60-40.arpa $RESULTS_DIR/interpolated_lm_60-40.bin
!python NeMo/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py --arpa $RESULTS_DIR/interpolated_lm_60-40.arpa --model $MODELS_DIR/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo --lower --dst $RESULTS_DIR
Evaluate#
decoding_cfg.beam.kenlm_path = f'{RESULTS_DIR}/interpolated_lm_60-40.bin'
decoding_cfg.beam.flashlight_cfg.lexicon_path=f'{RESULTS_DIR}/interpolated_lm_60-40.lexicon'
asr_model.change_decoding_strategy(decoding_cfg)
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_interpolated_lm_60-40.json")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_interpolated_lm_60-40.json")
Calculate WER#
print( "WER of base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json"))
print( "WER of Domain model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_interpolated_lm_60-40.json"))
print("WER of base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json"))
print("WER of Domain model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_interpolated_lm_60-40.json"))
Pruning#
LM generated by simply passing the text corpus to kenLM contains some n-grams which are less frequent (in corpus) and thus have very low probabilities. Such n-grams can be removed by pruning.
Pruning requires some thresholds which can be passed with the --prune parameter followed by space separated thresholds that specify the count thresholds for each order while generating ARPA:
--prune 0 1 7 9
All the n-gram with frequncy less than or equal to specified threshold will get eliminated.
Here, 2-grams with freq. <= 1, 3-gram with freq.<=7 & 4-gram with freq.<=9 will get eliminated.
There’s a tradeoff between degree of pruning and accuracy. High pruning parameters will reduce the size of language model but at the cost of model accuracy!
*Note:#
Pruning of 1-gram is not supported, threshold for 1-gram should always be 0
!kenlm/build/bin/lmplz -o 4 --prune 0 1 7 9 < $DATA_DOWNLOAD_DIR/reduced_training.txt > $RESULTS_DIR/pruned_lm.arpa
!$KENLM_BASE/build_binary trie -q 8 -b 7 -a 256 $RESULTS_DIR/pruned_lm.arpa $RESULTS_DIR/pruned_lm.bin
!python NeMo/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py --arpa $RESULTS_DIR/pruned_lm.arpa --model $MODELS_DIR/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo --lower --dst $RESULTS_DIR
# Lets check the size of original LM and Pruned LM
!echo "Size of unpruned ARPA: $(du -h $RESULTS_DIR/base_lm.arpa | cut -f 1)"
!echo "Size of pruned ARPA: $(du -h $RESULTS_DIR/pruned_lm.arpa | cut -f 1)"
Evaluate#
decoding_cfg.beam.kenlm_path = f'{RESULTS_DIR}/pruned_lm.bin'
decoding_cfg.beam.flashlight_cfg.lexicon_path=f'{RESULTS_DIR}/base_lm.lexicon'
asr_model.change_decoding_strategy(decoding_cfg)
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_pruned_lm.json")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_pruned_lm.json")
Calculate WER#
print( "WER of base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json"))
print( "WER of Pruned base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_pruned_lm.json"))
print("WER of base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json"))
print("WER of Pruned base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_pruned_lm.json"))