Text to Speech Finetuning using NeMo#

NeMo Toolkit is a python based AI toolkit for training and customizing purpose-built pre-trained AI models with your own data.

Transfer learning extracts learned features from an existing neural network to a new one. Transfer learning is often used when creating a large training dataset is not feasible.

Developers, researchers and software partners building intelligent AI apps and services, can bring their own data to fine-tune pre-trained models instead of going through the hassle of training from scratch.

Let’s see this in action with a use case for Speech Synthesis!

Text to Speech#

Text to Speech (TTS) is often the last step in building a Conversational AI model. A TTS model converts text into audible speech. The main objective is to synthesize reasonable and natural speech for given text. Since there are no universal standard to measure quality of synthesized speech, you will need to listen to some inferred speech to tell whether a TTS model is well trained.

In this tutorial we will look at two models: FastPitch for spectrogram generation and HiFiGAN as vocoder.


Let’s Dig in: TTS using NeMo#

This notebook assumes that you are already familiar with TTS Training using NeMo, as described in the text-to-speech-training notebook, and that you have a pretrained TTS model.

After installing NeMo, the next step is to setup the paths to save data and results. NeMo can be used with docker containers or virtual environments.

Replace the variables FIXME with the required paths enclosed in “” as a string.

IMPORTANT NOTE: Here, we map directories in which we save the data, specs, results and cache. You should configure it for your specific case so these directories are correctly visible to the docker container. Make sure this tutorial is in the NeMo folder.

Installation of packages and importing of files#

We will first install all necessary packages.

! pip install numba>=0.53
! pip install librosa
! pip install soundfile
! pip install tqdm

Install the following packages only if you want to export your models to .riva format else you can skip it. We will now install the packages NeMo and nemo2riva. nemo2riva is available on ngc. Make sure you install NGC CLI first before running the following commands.

!pip install nvidia-pyindex
!pip install nemo_toolkit['all']
!ngc registry resource download-version "nvidia/riva/riva_quickstart:2.8.1"
!pip install "riva_quickstart_v2.8.1/nemo2riva-2.8.1-py3-none-any.whl"
!pip install protobuf==3.20.0
# Installing pynini separately
!wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/nemo_text_processing/install_pynini.sh \
bash install_pynini.sh

We will now import all the relevant files from NeMo

! wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/dataset_processing/tts/ljspeech/get_data.py
    
! wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/dataset_processing/tts/extract_sup_data.py
! mkdir -p ljspeech && cd ljspeech \
&& wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/dataset_processing/tts/ljspeech/ds_conf/ds_for_fastpitch_align.yaml \
&& cd ..
    
# additional files
!wget https://raw.githubusercontent.com/nvidia/NeMo/main/examples/tts/fastpitch_finetune.py

!mkdir -p conf \
&& cd conf \
&& wget https://raw.githubusercontent.com/nvidia/NeMo/main/examples/tts/conf/fastpitch_align_v1.05.yaml \
&& cd ..

!mkdir -p tts_dataset_files && cd tts_dataset_files \
&& wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/tts_dataset_files/cmudict-0.7b_nv22.10 \
&& wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/tts_dataset_files/heteronyms-052722 \
&& wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/nemo_text_processing/text_normalization/en/data/whitelist/tts.tsv \
&& cd ..
            
! wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/dataset_processing/tts/generate_mels.py
    
! wget https://raw.githubusercontent.com/nvidia/NeMo/main/examples/tts/hifigan_finetune.py

Set Relevant Paths#

# NOTE: The following paths are set from the perspective of the NeMo Docker.

import os
from pathlib import Path

# The data is saved here
DATA_DIR = FIXME
RESULTS_DIR = FIXME

! mkdir -p {DATA_DIR}
! mkdir -p {RESULTS_DIR}

os.environ["DATA_DIR"] = DATA_DIR
os.environ["RESULTS_DIR"] = RESULTS_DIR

Data#

For the rest of this notebook, it is assumed that you have:

  • Pretrained FastPitch and HiFiGAN models that were trained on LJSpeech sampled at 22kHz

In the case that you are not using a TTS model trained on LJSpeech at the correct sampling rate. Please ensure that you have the original data, including wav files and a .json manifest file. If you have a TTS model but not at 22kHz, please ensure that you set the correct sampling rate, and fft parameters.

For the rest of this notebook, we will be using a subset of audio samples from the Hi-Fi TTS dataset adding up to about one minute. This dataset is for demo purposes only. For a good quality model, we recommend at least 30 minutes of audio. If you want to record your own dataset, you can follow the Guidelines to Record a TTS Dataset at Home. Sample scripts to download and preprocess datasets supported by NeMo can be found here.

Let’s first download and pre-process the original LJSpeech dataset and set variables that point to this as the original data’s .json file.

Pre-Processing#

This step downloads audio to text file lists from NVIDIA for LJSpeech and generates the manifest files. If you use your own dataset, you have to generate three files: ljs_audio_text_train_manifest.json, ljs_audio_text_val_manifest.json, ljs_audio_text_test_manifest.json yourself. Those files correspond to your train / val / test split. For each text file, the number of rows should be equal to number of samples in this split and each row for a single speaker dataset should be like:

{"audio_filepath": "path_to_audio_file", "text": "text_of_the_audio", "duration": duration_of_the_audio}

In case of multi-speaker dataset

{"audio_filepath": "path_to_audio_file", "text": "text_of_the_audio", "duration": duration_of_the_audio, "speaker": speaker_id}

An example row is:

{"audio_filepath": "actressinhighlife_01_bowen_0001.flac", "text": "the pleasant season did my heart employ", "duration": 2.4}

We will now download the audio and the manifest files then convert them to the above format, also normalize the text. These steps for LJSpeech can be found in NeMo scripts/dataset_processing/tts/ljspeech/get_data.py. Be patient, this step is expected to take some time.

! python get_data.py --data-root {DATA_DIR}
import os

original_data_json = os.path.join(os.environ["DATA_DIR"], "LJSpeech-1.1/train_manifest.json")
os.environ["original_data_json"] = original_data_json

Let’s now download the Hi-Fi TTS audio samples, and place the data in the DATA_DIR. Create a manifest file named manifest.json and copy the contents of both dev.json and train.json into it.

import os

# Name of the untarred Hi-Fi TTS audio samples directory.
finetune_data_name = FIX_ME
# Absolute path of finetuning dataset from the perspective of NeMo container
finetune_data_path = os.path.join(os.environ["DATA_DIR"], finetune_data_name)

os.environ["finetune_data_name"] = finetune_data_name

Now that you have downloaded the data, let’s make sure that the audio clips and sample at the same sampling frequency as the clips used to train the pretrained model. For the course of this notebook, NVIDIA recommends using a model trained on the LJSpeech dataset. The sampling rate for this model is 22.05kHz.

import soundfile
import librosa
import json
import os

def resample_audio(input_file_path, output_path, target_sampling_rate=22050):
    """Resample a single audio file.
    
    Args:
        input_file_path (str): Path to the input audio file.
        output_path (str): Path to the output audio file.
        target_sampling_rate (int): Sampling rate for output audio file.
        
    Returns:
        No explicit returns
    """
    if not input_file_path.endswith(".wav"):
        raise NotImplementedError("Loading only implemented for wav files.")
    if not os.path.exists(input_file_path):
        raise FileNotFoundError(f"Cannot file input file at {input_file_path}")
    audio, sampling_rate = librosa.load(
      input_file_path,
      sr=target_sampling_rate
    )
    # Filterring out empty audio files.
    if librosa.get_duration(y=audio, sr=sampling_rate) == 0:
        print(f"0 duration audio file encountered at {input_file_path}")
        return None
    filename = os.path.basename(input_file_path)
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    soundfile.write(
        os.path.join(output_path, filename),
        audio,
        samplerate=target_sampling_rate,
        format="wav"
    )
    return filename
from tqdm.notebook import tqdm

relative_path = f"{finetune_data_name}/clips_resampled"
resampled_manifest_file = os.path.join(
    os.environ["DATA_DIR"],
    f"{finetune_data_name}/manifest_resampled.json"
)
input_manifest_file = os.path.join(
    os.environ["DATA_DIR"],
    f"{finetune_data_name}/manifest.json"
)
sampling_rate = 22050
output_path = os.path.join(os.environ["DATA_DIR"], relative_path)

# Resampling the audio clip.
with open(input_manifest_file, "r") as finetune_file:
    with open(resampled_manifest_file, "w") as resampled_file:
        for line in tqdm(finetune_file.readlines()):
            data = json.loads(line)
            filename = resample_audio(
                os.path.join(
                    os.environ["DATA_DIR"],
                    finetune_data_name,
                    data["audio_filepath"]
                ),
                output_path,
                target_sampling_rate=sampling_rate
            )
            if not filename:
                print("Skipping clip {} from training dataset")
                continue
            data["audio_filepath"] = os.path.join(
                os.environ["DATA_DIR"],
                relative_path, filename
            )
            resampled_file.write(f"{json.dumps(data)}\n")

assert resampled_file.closed, "Output file wasn't closed properly"
assert finetune_file.closed, "Input file wasn't closed properly"
# Splitting the dataset to train and val set.
! cat $finetune_data_path/manifest_resampled.json | tail -n 2 > $finetune_data_path/manifest_val.json
! cat $finetune_data_path/manifest_resampled.json | head -n -2 > $finetune_data_path/manifest_train.json
from pathlib import Path

finetune_data_json = os.path.join(os.environ["DATA_DIR"], f'{finetune_data_name}/manifest_train.json')
os.environ["finetune_data_json"] = finetune_data_json
os.environ["finetune_val_data_json"] = os.path.join(os.environ["DATA_DIR"], f'{finetune_data_name}/manifest_val.json')

The first step is to create a json that contains data from both the original data and the finetuning data. Since, the original data is much larger than the finetuning data, we merge the finetuning data with a sample of the original data. We can do this using the following:

import random
import json

def json_reader(filename):
    with open(filename) as f:
        for line in f:
            yield json.loads(line)
            
            
def json_writer(file, json_objects):
    with open(file, "w") as f:
        for jsonobj in json_objects:
            jsonstr = json.dumps(jsonobj)
            f.write(jsonstr + "\n")
            
            
def dataset_merge(original_manifest, finetune_manifest, num_records_original=50):
    original_ds = list(json_reader(original_manifest))
    finetune_ds = list(json_reader(finetune_manifest))
    original_ds = random.sample(original_ds, num_records_original)
    merged_ds = original_ds + finetune_ds
    random.shuffle(merged_ds)
    return merged_ds
merged_ds = dataset_merge(os.environ["original_data_json"], os.environ["finetune_data_json"])

os.environ["merged_data_json"] = f"{DATA_DIR}/{finetune_data_name}/merged_train.json"
json_writer(os.environ["merged_data_json"], merged_ds)

Getting Pitch Statistics#

Training Fastpitch requires you to set 2 values for pitch extraction:

  • avg: The average used to normalize the pitch

  • std: The std deviation used to normalize the pitch

We can compute pitch for the training data using scripts/dataset_processing/tts/extract_sup_data.py and extract pitch statistics using the NeMo script scripts/dataset_processing/tts/compute_speaker_stats.py, We have already downloaded the files earlier in the tutorial. Let’s use it to get pitch_mean and pitch_std.

First we will extract the pitch supplementary data using extract_sup_data.py file. This file works with a yaml config file ds_for_fastpitch_align, which we downloaded above. To make this work for your dataset simply change the manifest_path to your manifest path. The argument sup_data_path determines where the supplementary data is stored.

sup_data_path = f'{finetune_data_path}/sup_data_path'
pitch_stats_path = f'{finetune_data_path}/pitch_stats.json'

# The script extract_sup_data.py writes the pitch mean and pitch std in the commandline. We will parse it to get the pitch mean and std
cmd_str_list = !python extract_sup_data.py --config-path "ljspeech" manifest_filepath={os.environ["merged_data_json"]} sup_data_path={sup_data_path}
# Select only the line that contains PITCH_MEAN
cmd_str = [c for c in cmd_str_list if "PITCH_MEAN" in c][0]
# Extract pitch mean and std from the commandline
pitch_mean_str = cmd_str.split(',')[0]
pitch_mean = float(pitch_mean_str.split('=')[1])
pitch_std_str = cmd_str.split(',')[1]
pitch_std = float(pitch_std_str.split('=')[1])
pitch_mean, pitch_std

Setting the pitch_fmean and pitch_fmax based on the results from the cell above.

os.environ["pitch_mean"] = str(pitch_mean)
os.environ["pitch_std"] = str(pitch_std)

print(f"pitch mean: {pitch_mean}")
print(f"pitch std: {pitch_std}")

Finetuning#

We are now ready to finetune our TTS pipeline. In order to do so, you need to finetune FastPitch. For best results, you need to finetune HiFiGAN as well.

Here we are using pretrained checkpoints from NGC, FastPitch and HiFiGAN

Finetuning FastPitch#

We will need some additional files from NeMo to run finetuning on FastPitch, we have downloaded them earlier in the tutorial. In NeMo you can find the fastpitch_finetuning.py script and the config in examples section.

!(python fastpitch_finetune.py --config-name=fastpitch_align_v1.05.yaml \
  train_dataset={os.environ["merged_data_json"]} \
  validation_datasets={os.environ["finetune_val_data_json"]} \
  sup_data_path={sup_data_path} \
  phoneme_dict_path=tts_dataset_files/cmudict-0.7b_nv22.10 \
  heteronyms_path=tts_dataset_files/heteronyms-052722 \
  whitelist_path=tts_dataset_files/tts.tsv \
  exp_manager.exp_dir={os.environ["RESULTS_DIR"]} \
  +init_from_pretrained_model="tts_en_fastpitch" \
  +trainer.max_steps=1000 \
  ~trainer.max_epochs \
  trainer.check_val_every_n_epoch=10 \
  model.train_ds.dataloader_params.batch_size=24 \
  model.validation_ds.dataloader_params.batch_size=24 \
  model.n_speakers=1 \
  model.pitch_mean={os.environ["pitch_mean"]} model.pitch_std={os.environ["pitch_std"]} \
  model.optim.lr=2e-4 \
  ~model.optim.sched \
  model.optim.name=adam \
  trainer.devices=1 \
  trainer.strategy=null \
  +model.text_tokenizer.add_blank_at=true \
)

Let’s take a closer look at the training command:

  • --config-name=fastpitch_align_v1.05.yaml

    • We first tell the script what config file to use.

  • train_dataset=./9017_manifest_train_dur_5_mins_local.json  validation_datasets=./9017_manifest_dev_ns_all_local.json  sup_data_path=./fastpitch_sup_data

    • We tell the script what manifest files to train and eval on, as well as where supplementary data is located (or will be calculated and saved during training if not provided).

  • phoneme_dict_path=tts_dataset_files/cmudict-0.7b_nv22.10  heteronyms_path=tts_dataset_files/heteronyms-052722 whitelist_path=tts_dataset_files/tts.tsv 

    • We tell the script where phoneme_dict_path, heteronyms-052722 and whitelist_path are located. These are the additional files we downloaded earlier, and are used in preprocessing the data.

  • exp_manager.exp_dir=./ljspeech_to_9017_no_mixing_5_mins

    • Where we want to save our log files, tensorboard file, checkpoints, and more.

  • +init_from_nemo_model=./tts_en_fastpitch_align.nemo

    • We tell the script what checkpoint to finetune from.

  • +trainer.max_steps=1000 ~trainer.max_epochs trainer.check_val_every_n_epoch=25

    • For this experiment, we tell the script to train for 1000 training steps/iterations rather than specifying a number of epochs to run. Since the config file specifies max_epochs instead, we need to remove that using ~trainer.max_epochs.

  • model.train_ds.dataloader_params.batch_size=24 model.validation_ds.dataloader_params.batch_size=24

    • Set batch sizes for the training and validation data loaders.

  • model.n_speakers=1

    • The number of speakers in the data. There is only 1 for now, but we will revisit this parameter later in the notebook.

  • model.pitch_mean=152.3 model.pitch_std=64.0 model.pitch_fmin=30 model.pitch_fmax=512

    • For the new speaker, we need to define new pitch hyperparameters for better audio quality.

    • These parameters work for speaker 9017 from the Hi-Fi TTS dataset.

    • If you are using a custom dataset, running the script python <NeMo_base>/scripts/dataset_processing/tts/extract_sup_data.py manifest_filepath=<your_manifest_path> will precalculate supplementary data and print these pitch stats.

    • fmin and fmax are hyperparameters to librosa’s pyin function. We recommend tweaking these only if the speaker is in a noisy environment, such that background noise isn’t predicted to be speech.

  • model.optim.lr=2e-4 ~model.optim.sched model.optim.name=adam

    • For fine-tuning, we lower the learning rate.

    • We use a fixed learning rate of 2e-4.

    • We switch from the lamb optimizer to the adam optimizer.

  • trainer.devices=1 trainer.strategy=null

    • For this notebook, we default to 1 gpu which means that we do not need ddp.

    • If you have the compute resources, feel free to scale this up to the number of free gpus you have available.

    • Please remove the trainer.strategy=null section if you intend on multi-gpu training.

Finetuning HiFiGAN#

In order to get the best audio from HiFiGAN, we need to finetune it:

  • on the new speaker

  • using mel spectrograms from our finetuned FastPitch Model

Let’s first generate mels from our FastPitch model, and save it to a new .json manifest for use with HiFiGAN. We can generate the mels using generate_mels.py file from NeMo.

fastpitch_checkpoint = FIXME
mel_dir = f"{finetune_data_path}/mels"
! mkdir -p mel_dir

!(python generate_mels.py \
  --fastpitch-model-ckpt {fastpitch_checkpoint} \
  --input-json-manifests {os.environ["merged_data_json"]} \
  --output-json-manifest-root {mel_dir} \
 )

Finetuning HiFiGAN#

Now let’s finetune hifigan. Finetuning HiFiGAN can be done in NeMo using the script examples/tts/hifigan_finetune.py and configs present in examples/tts/conf/hifigan.

Create a small validation dataset for HiFiGAN finetuning

hifigan_full_ds = f"{finetune_data_path}/mels/merged_full_mel.json"
hifigan_train_ds = f"{finetune_data_path}/mels/merged_train_mel.json"
hifigan_val_ds = f"{finetune_data_path}/mels/merged_val_mel.json"

! cat {hifigan_train_ds} > {hifigan_full_ds}
! cat {hifigan_full_ds} | tail -n 2 > {hifigan_val_ds}
! cat {hifigan_full_ds} | head -n -2 > {hifigan_train_ds}

Run the following command for HiFiGAN finetuning

!(python examples/tts/hifigan_finetune.py \
--config-name=hifigan.yaml \
model.train_ds.dataloader_params.batch_size=32 \
model.max_steps=1000 \
model.optim.lr=0.00001 \
~model.optim.sched \
train_dataset={hifigan_train_ds} \
validation_datasets={hifigan_val_ds} \
exp_manager.exp_dir={os.environ["RESULTS_DIR"]} \
+init_from_pretrained_model=tts_hifigan \
trainer.check_val_every_n_epoch=10 \
model/train_ds=train_ds_finetune \
model/validation_ds=val_ds_finetune)

TTS Inference#

As aforementioned, since there are no universal standard to measure quality of synthesized speech, you will need to listen to some inferred speech to tell whether a TTS model is well trained. Therefore, we do not provide evaluate functionality in NeMo Toolkit for TTS but only provide infer functionality.

Generate spectrogram and audio#

The first step for inference is generating spectrogram. That’s a numpy array (saved as .npy file) for a sentence which can be converted to voice by a vocoder. We use FastPitch we just trained to generate spectrogram

Please update the hifigan_checkpoint variable with the path to the HiFiGAN checkpoint you want to use.

Let’s load the two models, FastPitch and HiFiGAN, for inference

from nemo.collections.tts.models import FastPitchModel, HifiGanModel

hifigan_checkpoint = FIXME
vocoder = HifiGanModel.load_from_checkpoint(hifigan_checkpoint)
vocoder = vocoder.eval().cuda()
spec_model = FastPitchModel.load_from_checkpoint(fastpitch_checkpoint)
spec_model.eval().cuda()

Let’s create a helper method to do inference given a string input. In case of multi-speaker inference the same method can be used by passing the speaker ID as a parameter.

import torch

def infer(spec_gen_model, vocoder_model, str_input, speaker=None):
    """
    Synthesizes spectrogram and audio from a text string given a spectrogram synthesis and vocoder model.
    
    Args:
        spec_gen_model: Spectrogram generator model (FastPitch in our case)
        vocoder_model: Vocoder model (HiFiGAN in our case)
        str_input: Text input for the synthesis
        speaker: Speaker ID
    
    Returns:
        spectrogram and waveform of the synthesized audio.
    """
    with torch.no_grad():
        parsed = spec_gen_model.parse(str_input)
        if speaker is not None:
            speaker = torch.tensor([speaker]).long().to(device=spec_gen_model.device)
        spectrogram = spec_gen_model.generate_spectrogram(tokens=parsed, speaker=speaker)
        audio = vocoder_model.convert_spectrogram_to_audio(spec=spectrogram)
        
    if spectrogram is not None:
        if isinstance(spectrogram, torch.Tensor):
            spectrogram = spectrogram.to('cpu').numpy()
        if len(spectrogram.shape) == 3:
            spectrogram = spectrogram[0]
    if isinstance(audio, torch.Tensor):
        audio = audio.to('cpu').numpy()
    return spectrogram, audio
import IPython.display as ipd
from matplotlib.pyplot import imshow
from matplotlib import pyplot as plt

# Path to test manifest file (.json)
test_records_path = FIXME
test_records = list(json_reader(test_records_path))
new_speaker_id = FIXME

for test_record in test_records:
    print("Real validation audio")
    ipd.display(ipd.Audio(test_record['audio_filepath'], rate=22050))
    duration_mins = test_record['duration']
    if 'speaker' in test_record:
        speaker_id = test_record['speaker']
    else:
        speaker_id = new_speaker_id
    print(f"SYNTHESIZED | Duration: {duration_mins} mins | Text: {test_record['text']}")
    spec, audio = infer(spec_model, vocoder, test_record['text'], speaker=speaker_id)
    ipd.display(ipd.Audio(audio, rate=22050))
    %matplotlib inline
    imshow(spec, origin="lower", aspect="auto")
    plt.show()

Debug#

The data provided is only meant to be a sample to understand how finetuning works in NeMo. In order to generate better speech quality, we recommend recording at least 30 mins of audio, and increasing the number of finetuning steps from the current trainer.max_steps=1000 to trainer.max_steps=5000 for both models.

TTS model export#

You can also export your model in a format that can deployed using Nvidia Riva, a highly performant application framework for multi-modal conversational AI services using GPUs!

Export to RIVA#

Executing the snippets in the cells below, allows you to generate a .riva model file for the spectrogram generator and vocoder models that were trained the preceding cells. These models are required to generate a complete Text-To-Speech pipeline.

Convert to riva.#

Convert the downloaded model to .riva format, we will use encryption key=nemotoriva. Change this while generating .riva models for production.

hifigan_nemo_file_path = FIXME
hifigan_riva_file_path = hifigan_nemo_file_path[:-5]+".riva"
fastpitch_nemo_file_path = FIXME
fastpitch_riva_file_path = fastpitch_nemo_file_path[:-5]+".riva"

!nemo2riva --out {fastpitch_riva_file_path} --key=nemotoriva {fastpitch_nemo_file_path}
!nemo2riva --out {hifigan_riva_file_path} --key=nemotoriva {hifigan_nemo_file_path}

What’s Next ?#

You could use NeMo to build custom models for your own applications, and deploy them to Nvidia Riva! To try deploying these models to RIVA, use the tts-deploy.ipynb as a quick sample.