Punctuation and Capitalization Lexical Audio Model

Sometimes punctuation and capitalization cannot be restored based only on text. In this case we can use audio to improve model’s accuracy.

Like in these examples:

Oh yeah? or Oh yeah.

We need to go? or We need to go.

Yeah, they make you work. Yeah, over there you walk a lot? or Yeah, they make you work. Yeah, over there you walk a lot.

You can find more details on text only punctuation and capitalization in the Punctuation And Capitalization page. In this document, we focus on model changes needed to use acoustic features.

Quick Start Guide

from nemo.collections.nlp.models import PunctuationCapitalizationLexicalAudioModel

# to get the list of pre-trained models
PunctuationCapitalizationLexicalAudioModel.list_available_models()

# Download and load the pre-trained model
model = PunctuationCapitalizationLexicalAudioModel.from_pretrained("<PATH to .nemo file>")

# try the model on a few examples
model.add_punctuation_capitalization(['how are you', 'great how about you'], audio_queries=['/path/to/1.wav', '/path/to/2.wav'], target_sr=16000)

Model Description

In addition to Punctuation And Capitalization model we add audio encoder (e.g. Conformer’s encoder) and attention based fusion of lexical and audio features. This model architecture is based on Multimodal Semi-supervised Learning Framework for Punctuation Prediction in Conversational Speech [NLP-PUNCT-LEX1].

Note

An example script on how to train and evaluate the model can be found at: NeMo/examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py.

The default configuration file for the model can be found at: NeMo/examples/nlp/token_classification/conf/punctuation_capitalization_lexical_audio_config.yaml.

The script for inference can be found at: NeMo/examples/nlp/token_classification/punctuate_capitalize_infer.py.

Raw Data Format

In addition to Punctuation And Capitalization Raw Data Format this model also requires audio data. You have to provide audio_train.txt and audio_dev.txt (and optionally audio_test.txt) which contain one valid path to audio per row.

Example of the audio_train.txt/audio_dev.txt file:

/path/to/1.wav
/path/to/2.wav
....

In this case source_data_dir structure should look similar to the following:

.
|--sourced_data_dir
  |-- dev.txt
  |-- train.txt
  |-- audio_train.txt
  |-- audio_dev.txt

Tarred dataset

It is recommended to use tarred dataset for training with large amount of data (>500 hours) due to large amount of RAM consumed by loading whole audio data into memory and CPU usage.

For creating of tarred dataset with audio you will need data in NeMo format:

python examples/nlp/token_classification/data/create_punctuation_capitalization_tarred_dataset.py \
    --text <PATH/TO/LOWERCASED/TEXT/WITHOUT/PUNCTUATION> \
    --labels <PATH/TO/LABELS/IN/NEMO/FORMAT> \
    --output_dir <PATH/TO/DIRECTORY/WITH/OUTPUT/TARRED/DATASET> \
    --num_batches_per_tarfile 100 \
    --use_audio \
    --audio_file <PATH/TO/AUDIO/PATHS/FILE> \
    --sample_rate 16000

Note

You can change sample rate to any positive integer. It will be used in constructor of AudioSegment. It is recomended to set sample_rate to the same value as data which was used during training of ASR model.

Training Punctuation and Capitalization Model

The audio encoder is initialized with pretrained ASR model. You can use any of list_available_models() of EncDecCTCModel or your own checkpoints, either one should be provided in model.audio_encoder.pretrained_model. You can freeze audio encoder during training and add additional ConformerLayer on top of encoder to reduce compute with model.audio_encoder.freeze. You can also add Adapters to reduce compute with model.audio_encoder.adapter. Parameters of fusion module are stored in model.audio_encoder.fusion. An example of a model configuration file for training the model can be found at: NeMo/examples/nlp/token_classification/conf/punctuation_capitalization_lexical_audio_config.yaml.

Configs

Note

This page contains only parameters specific to lexical and audio model. Others parameters can be found in the Punctuation And Capitalization page.

Model config

A configuration of PunctuationCapitalizationLexicalAudioModel model.

Model config

Parameter

Data type

Default value

Description

audio_encoder

audio encoder config

audio encoder config

A configuration for audio encoder.

Data config

Location of data configs in parent configs

Parent config

Keys in parent config

Run config

model.train_ds, model.validation_ds, model.test_ds

Model config

train_ds, validation_ds, test_ds

Parameters for regular (BertPunctuationCapitalizationDataset) dataset

Parameter

Data type

Default value

Description

use_audio

bool

false

If set to true dataset will return audio as well as text.

audio_file

string

null

A path to file with audio paths.

sample_rate

int

null

Target sample rate of audios. Can be used for up sampling or down sampling of audio.

use_bucketing

bool

true

If set to True will sort samples based on their audio length and assamble batches more efficently (less padding in batch). If set to False dataset will return batch_size batches instead of number_of_tokens tokens.

preload_audios

bool

true

If set to True batches will include waveforms, if set to False will store audio_filepaths instead and load audios during collate_fn call.

Audio Encoder config

Audio Encoder Config

Parameter

Data type

Default value

Description

pretrained_model

string

stt_en_conformer_ctc_medium

Pretrained model name or path to .nemo` file to take audio encoder from.

freeze

freeze config

freeze config

Configuration for freezing audio encoder’s weights.

adapter

adapter config

adapter config

Configuration for adapter.

fusion

fusion config

fusion config

Configuration for fusion.

Freeze Config

Parameter

Data type

Default value

Description

is_enabled

bool

false

If set to true encoder’s weights will not be updated during training and aditional ConformerLayer layers will be added.

d_model

int

256

Input dimension of MultiheadAttentionMechanism and PositionwiseFeedForward of additional ConformerLayer layers.

d_ff

int

1024

Hidden dimension of PositionwiseFeedForward of additional ConformerLayer layers.

num_layers

int

4

Number of additional ConformerLayer layers.

Adapter Config

Parameter

Data type

Default value

Description

enable

bool

false

If set to true will enable adapters for audio encoder.

config

LinearAdapterConfig

null

For more details see nemo.collections.common.parts.LinearAdapterConfig class.

Fusion Config

Parameter

Data type

Default value

Description

num_layers

int

4

Number of layers to use in fusion.

num_attention_heads

int

4

Number of attention heads to use in fusion.

inner_size

int

2048

Fusion inner size.

Model training

For more information, refer to the Model NLP section.

To train the model from scratch, run:

python examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py \
       model.train_ds.ds_item=<PATH/TO/TRAIN/DATA_DIR> \
       model.train_ds.text_file=<NAME_OF_TRAIN_INPUT_TEXT_FILE> \
       model.train_ds.labels_file=<NAME_OF_TRAIN_LABELS_FILE> \
       model.validation_ds.ds_item=<PATH/TO/DEV/DATA_DIR> \
       model.validation_ds.text_file=<NAME_OF_DEV_TEXT_FILE> \
       model.validation_ds.labels_file=<NAME_OF_DEV_LABELS_FILE> \
       trainer.devices=[0,1] \
       trainer.accelerator='gpu' \
       optim.name=adam \
       optim.lr=0.0001 \
       model.train_ds.audio_file=<NAME_OF_TRAIN_AUDIO_FILE> \
       model.validation_ds.audio_file=<NAME_OF_DEV_AUDIO_FILE>

The above command will start model training on GPUs 0 and 1 with Adam optimizer and learning rate of 0.0001; and the trained model is stored in the nemo_experiments/Punctuation_and_Capitalization folder.

To train from the pre-trained model, run:

python examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py \
       model.train_ds.ds_item=<PATH/TO/TRAIN/DATA_DIR> \
       model.train_ds.text_file=<NAME_OF_TRAIN_INPUT_TEXT_FILE> \
       model.train_ds.labels_file=<NAME_OF_TRAIN_LABELS_FILE> \
       model.validation_ds.ds_item=<PATH/TO/DEV/DATA/DIR> \
       model.validation_ds.text_file=<NAME_OF_DEV_TEXT_FILE> \
       model.validation_ds.labels_file=<NAME_OF_DEV_LABELS_FILE> \
       model.train_ds.audio_file=<NAME_OF_TRAIN_AUDIO_FILE> \
       model.validation_ds.audio_file=<NAME_OF_DEV_AUDIO_FILE> \
       pretrained_model=<PATH/TO/SAVE/.nemo>

Note

All parameters defined in the configuration file can be changed with command arguments. For example, the sample config file mentioned above has train_ds.tokens_in_batch set to 2048. However, if you see that the GPU utilization can be optimized further by using a larger batch size, you may override to the desired value by adding the field train_ds.tokens_in_batch=4096 over the command-line. You can repeat this with any of the parameters defined in the sample configuration file.

Inference

Inference is performed by a script examples/nlp/token_classification/punctuate_capitalize_infer.py

python punctuate_capitalize_infer.py \
    --input_manifest <PATH/TO/INPUT/MANIFEST> \
    --output_manifest <PATH/TO/OUTPUT/MANIFEST> \
    --pretrained_name <PATH to .nemo file> \
    --max_seq_length 64 \
    --margin 16 \
    --step 8 \
    --use_audio

Long audios are split just like in text only case, audio sequences treated the same as text seqences except max_seq_length for audio equals max_seq_length*4000.

Model Evaluation

Model evaluation is performed by the same script examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py as training.

Use :ref`config<run-config-lab>` parameter do_training=false to disable training and parameter do_testing=true to enable testing. If both parameters do_training and do_testing are true, then model is trained and then tested.

To start evaluation of the pre-trained model, run:

python punctuation_capitalization_lexical_audio_train_evaluate.py \
       +model.do_training=false \
       +model.to_testing=true \
       model.test_ds.ds_item=<PATH/TO/TEST/DATA/DIR>  \
       pretrained_model=<PATH to .nemo file> \
       model.test_ds.text_file=<NAME_OF_TEST_INPUT_TEXT_FILE> \
       model.test_ds.labels_file=<NAME_OF_TEST_LABELS_FILE> \
       model.test_ds.audio_file=<NAME_OF_TEST_AUDIO_FILE>

Required Arguments

  • pretrained_model: pretrained Punctuation and Capitalization Lexical Audio model from list_available_models() or path to a .nemo file. For example: your_model.nemo.

  • model.test_ds.ds_item: path to the directory that contains model.test_ds.text_file, model.test_ds.labels_file and model.test_ds.audio_file

References

NLP-PUNCT-LEX1

Monica Sunkara, Srikanth Ronanki, Dhanush Bekal, Sravan Bodapati, and Katrin Kirchhoff. Multimodal Semi-Supervised Learning Framework for Punctuation Prediction in Conversational Speech. In Proc. Interspeech 2020, 4911–4915. 2020. doi:10.21437/Interspeech.2020-3074.