Important
NeMo 2.0 is an experimental feature and currently released in the dev container only: nvcr.io/nvidia/nemo:dev. Please refer to the Migration Guide for information on getting started.
Punctuation and Capitalization Lexical Audio Model
Sometimes punctuation and capitalization cannot be restored based only on text. In this case we can use audio to improve model’s accuracy.
Like in these examples:
Oh yeah? or Oh yeah.
We need to go? or We need to go.
Yeah, they make you work. Yeah, over there you walk a lot? or Yeah, they make you work. Yeah, over there you walk a lot.
You can find more details on text only punctuation and capitalization in the Punctuation And Capitalization page. In this document, we focus on model changes needed to use acoustic features.
Quick Start Guide
from nemo.collections.nlp.models import PunctuationCapitalizationLexicalAudioModel
# to get the list of pre-trained models
PunctuationCapitalizationLexicalAudioModel.list_available_models()
# Download and load the pre-trained model
model = PunctuationCapitalizationLexicalAudioModel.from_pretrained("<PATH to .nemo file>")
# try the model on a few examples
model.add_punctuation_capitalization(['how are you', 'great how about you'], audio_queries=['/path/to/1.wav', '/path/to/2.wav'], target_sr=16000)
Model Description
In addition to Punctuation And Capitalization model we add audio encoder (e.g. Conformer’s encoder) and attention based fusion of lexical and audio features. This model architecture is based on Multimodal Semi-supervised Learning Framework for Punctuation Prediction in Conversational Speech [NLP-PUNCT-LEX1].
Note
An example script on how to train and evaluate the model can be found at: NeMo/examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py.
The default configuration file for the model can be found at: NeMo/examples/nlp/token_classification/conf/punctuation_capitalization_lexical_audio_config.yaml.
The script for inference can be found at: NeMo/examples/nlp/token_classification/punctuate_capitalize_infer.py.
Raw Data Format
In addition to Punctuation And Capitalization Raw Data Format this model also requires audio data.
You have to provide audio_train.txt
and audio_dev.txt
(and optionally audio_test.txt
) which contain one valid path to audio per row.
Example of the audio_train.txt
/audio_dev.txt
file:
/path/to/1.wav
/path/to/2.wav
....
In this case source_data_dir
structure should look similar to the following:
.
|--sourced_data_dir
|-- dev.txt
|-- train.txt
|-- audio_train.txt
|-- audio_dev.txt
Tarred dataset
It is recommended to use tarred dataset for training with large amount of data (>500 hours) due to large amount of RAM consumed by loading whole audio data into memory and CPU usage.
For creating of tarred dataset with audio you will need data in NeMo format:
python examples/nlp/token_classification/data/create_punctuation_capitalization_tarred_dataset.py \
--text <PATH/TO/LOWERCASED/TEXT/WITHOUT/PUNCTUATION> \
--labels <PATH/TO/LABELS/IN/NEMO/FORMAT> \
--output_dir <PATH/TO/DIRECTORY/WITH/OUTPUT/TARRED/DATASET> \
--num_batches_per_tarfile 100 \
--use_audio \
--audio_file <PATH/TO/AUDIO/PATHS/FILE> \
--sample_rate 16000
Note
You can change sample rate to any positive integer. It will be used in constructor of AudioSegment
. It is recomended to set sample_rate
to the same value as data which was used during training of ASR model.
Training Punctuation and Capitalization Model
The audio encoder is initialized with pretrained ASR model. You can use any of list_available_models()
of EncDecCTCModel
or your own checkpoints, either one should be provided in model.audio_encoder.pretrained_model
.
You can freeze audio encoder during training and add additional ConformerLayer
on top of encoder to reduce compute with model.audio_encoder.freeze
. You can also add Adapters to reduce compute with model.audio_encoder.adapter
. Parameters of fusion module are stored in model.audio_encoder.fusion
.
An example of a model configuration file for training the model can be found at:
NeMo/examples/nlp/token_classification/conf/punctuation_capitalization_lexical_audio_config.yaml.
Configs
Note
This page contains only parameters specific to lexical and audio model. Others parameters can be found in the Punctuation And Capitalization page.
Model config
A configuration of
PunctuationCapitalizationLexicalAudioModel
model.
Parameter |
Data type |
Default value |
Description |
---|---|---|---|
audio_encoder |
A configuration for audio encoder. |
Data config
Parent config |
Keys in parent config |
---|---|
|
|
|
Parameter |
Data type |
Default value |
Description |
---|---|---|---|
use_audio |
bool |
|
If set to |
audio_file |
string |
|
A path to file with audio paths. |
sample_rate |
int |
|
Target sample rate of audios. Can be used for up sampling or down sampling of audio. |
use_bucketing |
bool |
|
If set to True will sort samples based on their audio length and assamble batches more efficently (less padding in batch). If set to False dataset will return |
preload_audios |
bool |
|
If set to True batches will include waveforms, if set to False will store audio_filepaths instead and load audios during |
Audio Encoder config
Parameter |
Data type |
Default value |
Description |
---|---|---|---|
pretrained_model |
string |
|
Pretrained model name or path to |
freeze |
Configuration for freezing audio encoder’s weights. |
||
adapter |
Configuration for adapter. |
||
fusion |
Configuration for fusion. |
Parameter |
Data type |
Default value |
Description |
---|---|---|---|
is_enabled |
bool |
|
If set to |
d_model |
int |
|
Input dimension of |
d_ff |
int |
|
Hidden dimension of |
num_layers |
int |
|
Number of additional |
Parameter |
Data type |
Default value |
Description |
---|---|---|---|
enable |
bool |
|
If set to |
config |
|
|
For more details see nemo.collections.common.parts.LinearAdapterConfig class. |
Parameter |
Data type |
Default value |
Description |
---|---|---|---|
num_layers |
int |
|
Number of layers to use in fusion. |
num_attention_heads |
int |
|
Number of attention heads to use in fusion. |
inner_size |
int |
|
Fusion inner size. |
Model training
For more information, refer to the Model NLP section.
To train the model from scratch, run:
python examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py \
model.train_ds.ds_item=<PATH/TO/TRAIN/DATA_DIR> \
model.train_ds.text_file=<NAME_OF_TRAIN_INPUT_TEXT_FILE> \
model.train_ds.labels_file=<NAME_OF_TRAIN_LABELS_FILE> \
model.validation_ds.ds_item=<PATH/TO/DEV/DATA_DIR> \
model.validation_ds.text_file=<NAME_OF_DEV_TEXT_FILE> \
model.validation_ds.labels_file=<NAME_OF_DEV_LABELS_FILE> \
trainer.devices=[0,1] \
trainer.accelerator='gpu' \
optim.name=adam \
optim.lr=0.0001 \
model.train_ds.audio_file=<NAME_OF_TRAIN_AUDIO_FILE> \
model.validation_ds.audio_file=<NAME_OF_DEV_AUDIO_FILE>
The above command will start model training on GPUs 0 and 1 with Adam optimizer and learning rate of 0.0001; and the
trained model is stored in the nemo_experiments/Punctuation_and_Capitalization
folder.
To train from the pre-trained model, run:
python examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py \
model.train_ds.ds_item=<PATH/TO/TRAIN/DATA_DIR> \
model.train_ds.text_file=<NAME_OF_TRAIN_INPUT_TEXT_FILE> \
model.train_ds.labels_file=<NAME_OF_TRAIN_LABELS_FILE> \
model.validation_ds.ds_item=<PATH/TO/DEV/DATA/DIR> \
model.validation_ds.text_file=<NAME_OF_DEV_TEXT_FILE> \
model.validation_ds.labels_file=<NAME_OF_DEV_LABELS_FILE> \
model.train_ds.audio_file=<NAME_OF_TRAIN_AUDIO_FILE> \
model.validation_ds.audio_file=<NAME_OF_DEV_AUDIO_FILE> \
pretrained_model=<PATH/TO/SAVE/.nemo>
Note
All parameters defined in the configuration file can be changed with command arguments. For example, the sample
config file mentioned above has train_ds.tokens_in_batch
set to 2048
. However, if you see that
the GPU utilization can be optimized further by using a larger batch size, you may override to the desired value
by adding the field train_ds.tokens_in_batch=4096
over the command-line. You can repeat this with
any of the parameters defined in the sample configuration file.
Inference
Inference is performed by a script examples/nlp/token_classification/punctuate_capitalize_infer.py
python punctuate_capitalize_infer.py \
--input_manifest <PATH/TO/INPUT/MANIFEST> \
--output_manifest <PATH/TO/OUTPUT/MANIFEST> \
--pretrained_name <PATH to .nemo file> \
--max_seq_length 64 \
--margin 16 \
--step 8 \
--use_audio
Long audios are split just like in text only case, audio sequences treated the same as text seqences except max_seq_length
for audio equals max_seq_length*4000
.
Model Evaluation
Model evaluation is performed by the same script examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py as training.
Use :ref`config<run-config-lab>` parameter do_training=false
to disable training and parameter do_testing=true
to enable testing. If both parameters do_training
and do_testing
are true
, then model is trained and then
tested.
To start evaluation of the pre-trained model, run:
python punctuation_capitalization_lexical_audio_train_evaluate.py \
+model.do_training=false \
+model.to_testing=true \
model.test_ds.ds_item=<PATH/TO/TEST/DATA/DIR> \
pretrained_model=<PATH to .nemo file> \
model.test_ds.text_file=<NAME_OF_TEST_INPUT_TEXT_FILE> \
model.test_ds.labels_file=<NAME_OF_TEST_LABELS_FILE> \
model.test_ds.audio_file=<NAME_OF_TEST_AUDIO_FILE>
Required Arguments
pretrained_model
: pretrained Punctuation and Capitalization Lexical Audio model fromlist_available_models()
or path to a.nemo
file. For example:your_model.nemo
.model.test_ds.ds_item
: path to the directory that containsmodel.test_ds.text_file
,model.test_ds.labels_file
andmodel.test_ds.audio_file
References
- NLP-PUNCT-LEX1
Monica Sunkara, Srikanth Ronanki, Dhanush Bekal, Sravan Bodapati, and Katrin Kirchhoff. Multimodal Semi-Supervised Learning Framework for Punctuation Prediction in Conversational Speech. In Proc. Interspeech 2020, 4911–4915. 2020. doi:10.21437/Interspeech.2020-3074.