# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict, List, Optional
import onnx
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from nemo.collections.common.losses import AggregatorLoss, CrossEntropyLoss
from nemo.collections.nlp.data.token_classification.punctuation_capitalization_dataset import (
BertPunctuationCapitalizationDataset,
BertPunctuationCapitalizationInferDataset,
)
from nemo.collections.nlp.metrics.classification_report import ClassificationReport
from nemo.collections.nlp.models.nlp_model import NLPModel
from nemo.collections.nlp.modules.common import TokenClassifier
from nemo.collections.nlp.modules.common.lm_utils import get_lm_model
from nemo.collections.nlp.parts.utils_funcs import tensor2list
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.classes.exportable import Exportable, ExportFormat
from nemo.core.neural_types import LogitsType, NeuralType
from nemo.utils import logging
__all__ = ['PunctuationCapitalizationModel']
[docs]class PunctuationCapitalizationModel(NLPModel, Exportable):
@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
return self.bert_model.input_types
@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {
"punct_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
"capit_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
}
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
"""
Initializes BERT Punctuation and Capitalization model.
"""
self.setup_tokenizer(cfg.tokenizer)
super().__init__(cfg=cfg, trainer=trainer)
self.bert_model = get_lm_model(
pretrained_model_name=cfg.language_model.pretrained_model_name,
config_file=cfg.language_model.config_file,
config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None,
checkpoint_file=cfg.language_model.lm_checkpoint,
vocab_file=cfg.tokenizer.vocab_file,
)
self.punct_classifier = TokenClassifier(
hidden_size=self.bert_model.config.hidden_size,
num_classes=len(self._cfg.punct_label_ids),
activation=cfg.punct_head.activation,
log_softmax=False,
dropout=cfg.punct_head.fc_dropout,
num_layers=cfg.punct_head.punct_num_fc_layers,
use_transformer_init=cfg.punct_head.use_transformer_init,
)
self.capit_classifier = TokenClassifier(
hidden_size=self.bert_model.config.hidden_size,
num_classes=len(self._cfg.capit_label_ids),
activation=cfg.capit_head.activation,
log_softmax=False,
dropout=cfg.capit_head.fc_dropout,
num_layers=cfg.capit_head.capit_num_fc_layers,
use_transformer_init=cfg.capit_head.use_transformer_init,
)
self.loss = CrossEntropyLoss(logits_ndim=3)
self.agg_loss = AggregatorLoss(num_inputs=2)
# setup to track metrics
self.punct_class_report = ClassificationReport(
num_classes=len(self._cfg.punct_label_ids),
label_ids=self._cfg.punct_label_ids,
mode='macro',
dist_sync_on_step=True,
)
self.capit_class_report = ClassificationReport(
num_classes=len(self._cfg.capit_label_ids),
label_ids=self._cfg.capit_label_ids,
mode='macro',
dist_sync_on_step=True,
)
[docs] @typecheck()
def forward(self, input_ids, attention_mask, token_type_ids=None):
"""
No special modification required for Lightning, define it as you normally would
in the `nn.Module` in vanilla PyTorch.
"""
hidden_states = self.bert_model(
input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
)
punct_logits = self.punct_classifier(hidden_states=hidden_states)
capit_logits = self.capit_classifier(hidden_states=hidden_states)
return punct_logits, capit_logits
def _make_step(self, batch):
input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, punct_labels, capit_labels = batch
punct_logits, capit_logits = self(
input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask
)
punct_loss = self.loss(logits=punct_logits, labels=punct_labels, loss_mask=loss_mask)
capit_loss = self.loss(logits=capit_logits, labels=capit_labels, loss_mask=loss_mask)
loss = self.agg_loss(loss_1=punct_loss, loss_2=capit_loss)
return loss, punct_logits, capit_logits
[docs] def training_step(self, batch, batch_idx):
"""
Lightning calls this inside the training loop with the data from the training dataloader
passed in as `batch`.
"""
loss, _, _ = self._make_step(batch)
lr = self._optimizer.param_groups[0]['lr']
self.log('lr', lr, prog_bar=True)
self.log('train_loss', loss)
return {'loss': loss, 'lr': lr}
[docs] def validation_step(self, batch, batch_idx, dataloader_idx=0):
"""
Lightning calls this inside the validation loop with the data from the validation dataloader
passed in as `batch`.
"""
_, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
val_loss, punct_logits, capit_logits = self._make_step(batch)
subtokens_mask = subtokens_mask > 0.5
punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
punct_labels = punct_labels[subtokens_mask]
self.punct_class_report.update(punct_preds, punct_labels)
capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
capit_labels = capit_labels[subtokens_mask]
self.capit_class_report.update(capit_preds, capit_labels)
return {
'val_loss': val_loss,
'punct_tp': self.punct_class_report.tp,
'punct_fn': self.punct_class_report.fn,
'punct_fp': self.punct_class_report.fp,
'capit_tp': self.capit_class_report.tp,
'capit_fn': self.capit_class_report.fn,
'capit_fp': self.capit_class_report.fp,
}
[docs] def test_step(self, batch, batch_idx, dataloader_idx=0):
"""
Lightning calls this inside the validation loop with the data from the validation dataloader
passed in as `batch`.
"""
_, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
test_loss, punct_logits, capit_logits = self._make_step(batch)
subtokens_mask = subtokens_mask > 0.5
punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
punct_labels = punct_labels[subtokens_mask]
self.punct_class_report.update(punct_preds, punct_labels)
capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
capit_labels = capit_labels[subtokens_mask]
self.capit_class_report.update(capit_preds, capit_labels)
return {
'test_loss': test_loss,
'punct_tp': self.punct_class_report.tp,
'punct_fn': self.punct_class_report.fn,
'punct_fp': self.punct_class_report.fp,
'capit_tp': self.capit_class_report.tp,
'capit_fn': self.capit_class_report.fn,
'capit_fp': self.capit_class_report.fp,
}
[docs] def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
"""
Called at the end of validation to aggregate outputs.
outputs: list of individual outputs of each validation step.
"""
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
# calculate metrics and log classification report for Punctuation task
punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
logging.info(f'Punctuation report: {punct_report}')
# calculate metrics and log classification report for Capitalization task
capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
logging.info(f'Capitalization report: {capit_report}')
self.log('val_loss', avg_loss, prog_bar=True)
self.log('punct_precision', punct_precision)
self.log('punct_f1', punct_f1)
self.log('punct_recall', punct_recall)
self.log('capit_precision', capit_precision)
self.log('capit_f1', capit_f1)
self.log('capit_recall', capit_recall)
[docs] def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
"""
Called at the end of test to aggregate outputs.
outputs: list of individual outputs of each validation step.
"""
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
# calculate metrics and log classification report for Punctuation task
punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
logging.info(f'Punctuation report: {punct_report}')
# calculate metrics and log classification report for Capitalization task
capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
logging.info(f'Capitalization report: {capit_report}')
self.log('test_loss', avg_loss, prog_bar=True)
self.log('punct_precision', punct_precision)
self.log('punct_f1', punct_f1)
self.log('punct_recall', punct_recall)
self.log('capit_precision', capit_precision)
self.log('capit_f1', capit_f1)
self.log('capit_recall', capit_recall)
[docs] def update_data_dir(self, data_dir: str) -> None:
"""
Update data directory
Args:
data_dir: path to data directory
"""
if os.path.exists(data_dir):
logging.info(f'Setting model.dataset.data_dir to {data_dir}.')
self._cfg.dataset.data_dir = data_dir
else:
raise ValueError(f'{data_dir} not found')
[docs] def setup_training_data(self, train_data_config: Optional[DictConfig] = None):
"""Setup training data"""
if train_data_config is None:
train_data_config = self._cfg.train_ds
# for older(pre - 1.0.0.b3) configs compatibility
if not hasattr(self._cfg, "class_labels") or self._cfg.class_labels is None:
OmegaConf.set_struct(self._cfg, False)
self._cfg.class_labels = {}
self._cfg.class_labels = OmegaConf.create(
{'punct_labels_file': 'punct_label_ids.csv', 'capit_labels_file': 'capit_label_ids.csv'}
)
self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
self.register_artifact(
self._cfg.class_labels.punct_labels_file, self._train_dl.dataset.punct_label_ids_file
)
self.register_artifact(
self._cfg.class_labels.capit_labels_file, self._train_dl.dataset.capit_label_ids_file
)
# save label maps to the config
self._cfg.punct_label_ids = OmegaConf.create(self._train_dl.dataset.punct_label_ids)
self._cfg.capit_label_ids = OmegaConf.create(self._train_dl.dataset.capit_label_ids)
[docs] def setup_validation_data(self, val_data_config: Optional[Dict] = None):
"""
Setup validaton data
val_data_config: validation data config
"""
if val_data_config is None:
val_data_config = self._cfg.validation_ds
self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config)
[docs] def setup_test_data(self, test_data_config: Optional[Dict] = None):
if test_data_config is None:
test_data_config = self._cfg.test_ds
self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)
def _setup_dataloader_from_config(self, cfg: DictConfig):
# use data_dir specified in the ds_item to run evaluation on multiple datasets
if 'ds_item' in cfg and cfg.ds_item is not None:
data_dir = cfg.ds_item
else:
data_dir = self._cfg.dataset.data_dir
text_file = os.path.join(data_dir, cfg.text_file)
label_file = os.path.join(data_dir, cfg.labels_file)
dataset = BertPunctuationCapitalizationDataset(
tokenizer=self.tokenizer,
text_file=text_file,
label_file=label_file,
pad_label=self._cfg.dataset.pad_label,
punct_label_ids=self._cfg.punct_label_ids,
capit_label_ids=self._cfg.capit_label_ids,
max_seq_length=self._cfg.dataset.max_seq_length,
ignore_extra_tokens=self._cfg.dataset.ignore_extra_tokens,
ignore_start_end=self._cfg.dataset.ignore_start_end,
use_cache=self._cfg.dataset.use_cache,
num_samples=cfg.num_samples,
punct_label_ids_file=self._cfg.class_labels.punct_labels_file
if 'class_labels' in self._cfg
else 'punct_label_ids.csv',
capit_label_ids_file=self._cfg.class_labels.capit_labels_file
if 'class_labels' in self._cfg
else 'capit_label_ids.csv',
)
return torch.utils.data.DataLoader(
dataset=dataset,
collate_fn=dataset.collate_fn,
batch_size=cfg.batch_size,
shuffle=cfg.shuffle,
num_workers=self._cfg.dataset.num_workers,
pin_memory=self._cfg.dataset.pin_memory,
drop_last=self._cfg.dataset.drop_last,
)
def _setup_infer_dataloader(
self, queries: List[str], batch_size: int, max_seq_length: int = None
) -> 'torch.utils.data.DataLoader':
"""
Setup function for a infer data loader.
Args:
queries: lower cased text without punctuation
batch_size: batch size to use during inference
max_seq_length: maximum sequence length after tokenization
Returns:
A pytorch DataLoader.
"""
if max_seq_length is None:
max_seq_length = self._cfg.dataset.max_seq_length
dataset = BertPunctuationCapitalizationInferDataset(
tokenizer=self.tokenizer, queries=queries, max_seq_length=max_seq_length
)
return torch.utils.data.DataLoader(
dataset=dataset,
collate_fn=dataset.collate_fn,
batch_size=batch_size,
shuffle=False,
num_workers=self._cfg.dataset.num_workers,
pin_memory=self._cfg.dataset.pin_memory,
drop_last=False,
)
[docs] def add_punctuation_capitalization(
self, queries: List[str], batch_size: int = None, max_seq_length: int = 512
) -> List[str]:
"""
Adds punctuation and capitalization to the queries. Use this method for debugging and prototyping.
Args:
queries: lower cased text without punctuation
batch_size: batch size to use during inference
max_seq_length: maximum sequence length after tokenization
Returns:
result: text with added capitalization and punctuation
"""
if queries is None or len(queries) == 0:
return []
if batch_size is None:
batch_size = len(queries)
logging.info(f'Using batch size {batch_size} for inference')
# We will store the output here
result = []
# Model's mode and device
mode = self.training
device = 'cuda' if torch.cuda.is_available() else 'cpu'
try:
# Switch model to evaluation mode
self.eval()
self = self.to(device)
infer_datalayer = self._setup_infer_dataloader(queries, batch_size, max_seq_length)
# store predictions for all queries in a single list
all_punct_preds = []
all_capit_preds = []
for batch in infer_datalayer:
input_ids, input_type_ids, input_mask, subtokens_mask = batch
punct_logits, capit_logits = self.forward(
input_ids=input_ids.to(device),
token_type_ids=input_type_ids.to(device),
attention_mask=input_mask.to(device),
)
subtokens_mask = subtokens_mask > 0.5
punct_preds = [
tensor2list(p_l[subtokens_mask[i]]) for i, p_l in enumerate(torch.argmax(punct_logits, axis=-1))
]
capit_preds = [
tensor2list(c_l[subtokens_mask[i]]) for i, c_l in enumerate(torch.argmax(capit_logits, axis=-1))
]
all_punct_preds.extend(punct_preds)
all_capit_preds.extend(capit_preds)
punct_ids_to_labels = {v: k for k, v in self._cfg.punct_label_ids.items()}
capit_ids_to_labels = {v: k for k, v in self._cfg.capit_label_ids.items()}
queries = [q.strip().split() for q in queries]
for i, query in enumerate(queries):
punct_preds = all_punct_preds[i]
capit_preds = all_capit_preds[i]
if len(query) != len(punct_preds):
logging.warning(
f'Max sequence length of query {query} is set to {max_seq_length}. Truncating the input.'
)
# removing the end of phrase punctuation of the truncated segment
punct_preds[-1] = 0
max_len = len(punct_preds)
query = query[:max_len]
query_with_punct_and_capit = ''
for j, word in enumerate(query):
punct_label = punct_ids_to_labels[punct_preds[j]]
capit_label = capit_ids_to_labels[capit_preds[j]]
if capit_label != self._cfg.dataset.pad_label:
word = word.capitalize()
query_with_punct_and_capit += word
if punct_label != self._cfg.dataset.pad_label:
query_with_punct_and_capit += punct_label
query_with_punct_and_capit += ' '
result.append(query_with_punct_and_capit.strip())
finally:
# set mode back to its original value
self.train(mode=mode)
return result
[docs] @classmethod
def list_available_models(cls) -> Optional[Dict[str, str]]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Returns:
List of available pre-trained models.
"""
result = []
result.append(
PretrainedModelInfo(
pretrained_model_name="punctuation_en_bert",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/punctuation_en_bert/versions/1.0.0rc1/files/punctuation_en_bert.nemo",
description="The model was trained with NeMo BERT base uncased checkpoint on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
)
)
result.append(
PretrainedModelInfo(
pretrained_model_name="punctuation_en_distilbert",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/punctuation_en_distilbert/versions/1.0.0rc1/files/punctuation_en_distilbert.nemo",
description="The model was trained with DiltilBERT base uncased checkpoint from HuggingFace on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
)
)
return result
@property
def input_module(self):
return self.bert_model
@property
def output_module(self):
return self
def _prepare_for_export(self):
self.bert_model._prepare_for_export()
self.punct_classifier._prepare_for_export()
self.capit_classifier._prepare_for_export()