Source code for nemo.collections.nlp.models.language_modeling.bert_lm_model

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Dict, Optional

import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader

from nemo.collections.common.losses import AggregatorLoss, CrossEntropyLoss, SmoothedCrossEntropyLoss
from nemo.collections.common.metrics import Perplexity
from nemo.collections.nlp.data.language_modeling.lm_bert_dataset import (
    BertPretrainingDataset,
    BertPretrainingPreprocessedDataloader,
)
from nemo.collections.nlp.modules.common import BertPretrainingTokenClassifier, SequenceClassifier
from nemo.collections.nlp.modules.common.lm_utils import get_lm_model
from nemo.collections.nlp.modules.common.tokenizer_utils import get_tokenizer
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.classes.modelPT import ModelPT
from nemo.core.neural_types import NeuralType
from nemo.utils import logging

__all__ = ["BERTLMModel"]


[docs]class BERTLMModel(ModelPT): """ BERT language model pretraining. """ @property def input_types(self) -> Optional[Dict[str, NeuralType]]: return self.bert_model.input_types @property def output_types(self) -> Optional[Dict[str, NeuralType]]: output_types_dict = {"mlm_log_probs": self.mlm_classifier.output_types["log_probs"]} if not self.only_mlm_loss: output_types_dict["nsp_logits"] = self.nsp_classifier.output_types["logits"] return output_types_dict def __init__(self, cfg: DictConfig, trainer: Trainer = None): if cfg.tokenizer is not None: self._setup_tokenizer(cfg.tokenizer) else: self.tokenizer = None super().__init__(cfg=cfg, trainer=trainer) self.bert_model = get_lm_model( pretrained_model_name=cfg.language_model.pretrained_model_name, config_file=cfg.language_model.config_file, config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None, checkpoint_file=cfg.language_model.lm_checkpoint, vocab_file=cfg.tokenizer.get('vocab_file') if cfg.tokenizer is not None else None, ) self.hidden_size = self.bert_model.config.hidden_size self.vocab_size = self.bert_model.config.vocab_size self.only_mlm_loss = cfg.only_mlm_loss self.mlm_classifier = BertPretrainingTokenClassifier( hidden_size=self.hidden_size, num_classes=self.vocab_size, num_layers=cfg.num_tok_classification_layers, activation="gelu", log_softmax=True, use_transformer_init=True, ) self.mlm_loss = SmoothedCrossEntropyLoss() if not self.only_mlm_loss: self.nsp_classifier = SequenceClassifier( hidden_size=self.hidden_size, num_classes=2, num_layers=cfg.num_seq_classification_layers, log_softmax=False, activation="tanh", use_transformer_init=True, ) self.nsp_loss = CrossEntropyLoss() self.agg_loss = AggregatorLoss(num_inputs=2) # # tie weights of MLM softmax layer and embedding layer of the encoder if ( self.mlm_classifier.mlp.last_linear_layer.weight.shape != self.bert_model.embeddings.word_embeddings.weight.shape ): raise ValueError("Final classification layer does not match embedding layer.") self.mlm_classifier.mlp.last_linear_layer.weight = self.bert_model.embeddings.word_embeddings.weight # create extra bias # setup to track metrics self.validation_perplexity = Perplexity(compute_on_step=False) self.setup_optimization(cfg.optim)
[docs] @typecheck() def forward(self, input_ids, token_type_ids, attention_mask): """ No special modification required for Lightning, define it as you normally would in the `nn.Module` in vanilla PyTorch. """ hidden_states = self.bert_model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, ) mlm_log_probs = self.mlm_classifier(hidden_states=hidden_states) if self.only_mlm_loss: return (mlm_log_probs,) nsp_logits = self.nsp_classifier(hidden_states=hidden_states) return mlm_log_probs, nsp_logits
def _compute_losses(self, mlm_log_probs, nsp_logits, output_ids, output_mask, labels): mlm_loss = self.mlm_loss(log_probs=mlm_log_probs, labels=output_ids, output_mask=output_mask) if self.only_mlm_loss: loss, nsp_loss = mlm_loss, None else: nsp_loss = self.nsp_loss(logits=nsp_logits, labels=labels) loss = self.agg_loss(loss_1=mlm_loss, loss_2=nsp_loss) return mlm_loss, nsp_loss, loss def _parse_forward_outputs(self, forward_outputs): if self.only_mlm_loss: return forward_outputs[0], None else: return forward_outputs
[docs] def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. """ # forward pass input_ids, input_type_ids, input_mask, output_ids, output_mask, labels = batch forward_outputs = self.forward(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask) mlm_log_probs, nsp_logits = self._parse_forward_outputs(forward_outputs) _, _, loss = self._compute_losses(mlm_log_probs, nsp_logits, output_ids, output_mask, labels) tensorboard_logs = {"train_loss": loss} return {"loss": loss, "log": tensorboard_logs}
[docs] def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ input_ids, input_type_ids, input_mask, output_ids, output_mask, labels = batch forward_outputs = self.forward(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask) mlm_log_probs, nsp_logits = self._parse_forward_outputs(forward_outputs) _, _, loss = self._compute_losses(mlm_log_probs, nsp_logits, output_ids, output_mask, labels) self.validation_perplexity(logits=mlm_log_probs) tensorboard_logs = {'val_loss': loss} return {'val_loss': loss, 'log': tensorboard_logs}
[docs] def validation_epoch_end(self, outputs): """Called at the end of validation to aggregate outputs. Args: outputs (list): The individual outputs of each validation step. Returns: dict: Validation loss and tensorboard logs. """ if outputs: avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() perplexity = self.validation_perplexity.compute() tensorboard_logs = {'val_loss': avg_loss, 'perplexity': perplexity} logging.info(f"evaluation perplexity {perplexity.cpu().item()}") return {'val_loss': avg_loss, 'log': tensorboard_logs}
[docs] def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = ( self._setup_preprocessed_dataloader(train_data_config) if self.tokenizer is None else self._setup_dataloader(train_data_config) )
[docs] def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = ( self._setup_preprocessed_dataloader(val_data_config) if self.tokenizer is None else self._setup_dataloader(val_data_config) )
[docs] def setup_test_data(self, test_data_config: Optional[DictConfig]): pass
def _setup_preprocessed_dataloader(self, cfg: Optional[DictConfig]): dataset = cfg.data_file max_predictions_per_seq = cfg.max_predictions_per_seq batch_size = cfg.batch_size if os.path.isdir(dataset): files = [os.path.join(dataset, f) for f in os.listdir(dataset) if os.path.isfile(os.path.join(dataset, f))] else: files = [dataset] files.sort() dl = BertPretrainingPreprocessedDataloader( data_files=files, max_predictions_per_seq=max_predictions_per_seq, batch_size=batch_size, ) return dl def _setup_tokenizer(self, cfg: DictConfig): tokenizer = get_tokenizer( tokenizer_name=cfg.tokenizer_name, tokenizer_model=cfg.tokenizer_model, special_tokens=OmegaConf.to_container(cfg.special_tokens) if cfg.special_tokens else None, vocab_file=cfg.vocab_file, ) self.tokenizer = tokenizer def _setup_dataloader(self, cfg: DictConfig): dataset = BertPretrainingDataset( tokenizer=self.tokenizer, data_file=cfg.data_file, max_seq_length=cfg.max_seq_length, mask_prob=cfg.mask_prob, short_seq_prob=cfg.short_seq_prob, ) dl = torch.utils.data.DataLoader( dataset=dataset, batch_size=cfg.batch_size, collate_fn=dataset.collate_fn, drop_last=cfg.get("drop_last", False), shuffle=cfg.shuffle, num_workers=cfg.get("num_workers", 0), ) return dl
[docs] @classmethod def list_available_models(cls) -> Optional[PretrainedModelInfo]: """ This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. Returns: List of available pre-trained models. """ result = [] result.append( PretrainedModelInfo( pretrained_model_name="bertbaseuncased", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/bertbaseuncased/versions/1.0.0rc1/files/bertbaseuncased.nemo", description="The model was trained EN Wikipedia and BookCorpus on a sequence length of 512.", ) ) result.append( PretrainedModelInfo( pretrained_model_name="bertlargeuncased", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/bertlargeuncased/versions/1.0.0rc1/files/bertlargeuncased.nemo", description="The model was trained EN Wikipedia and BookCorpus on a sequence length of 512.", ) ) return result