Source code for nemo.collections.nlp.modules.common.lm_utils

# Copyright 2018 The Google AI Language Team Authors and
# The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List, Optional

from nemo.collections.nlp.modules.common.bert_module import BertModule
from nemo.collections.nlp.modules.common.huggingface.huggingface_utils import (
    get_huggingface_lm_model,
    get_huggingface_pretrained_lm_models_list,
)
from nemo.collections.nlp.modules.common.megatron.megatron_utils import (
    get_megatron_lm_model,
    get_megatron_lm_models_list,
)
from nemo.utils import logging

__all__ = ['get_pretrained_lm_models_list', 'get_lm_model']


[docs]def get_pretrained_lm_models_list(include_external: bool = False) -> List[str]: """ Returns the list of supported pretrained model names Args: include_external if true includes all HuggingFace model names, not only those supported language models in NeMo. """ return get_megatron_lm_models_list() + get_huggingface_pretrained_lm_models_list(include_external=include_external)
[docs]def get_lm_model( pretrained_model_name: str, config_dict: Optional[dict] = None, config_file: Optional[str] = None, checkpoint_file: Optional[str] = None, vocab_file: Optional[str] = None, ) -> BertModule: """ Helper function to instantiate a language model encoder, either from scratch or a pretrained model. If only pretrained_model_name are passed, a pretrained model is returned. If a configuration is passed, whether as a file or dictionary, the model is initialized with random weights. Args: pretrained_model_name: pretrained model name, for example, bert-base-uncased or megatron-bert-cased. See get_pretrained_lm_models_list() for full list. config_dict: path to the model configuration dictionary config_file: path to the model configuration file checkpoint_file: path to the pretrained model checkpoint vocab_file: path to vocab_file to be used with Megatron-LM Returns: Pretrained BertModule """ # check valid model type if not pretrained_model_name or pretrained_model_name not in get_pretrained_lm_models_list(include_external=False): logging.warning( f'{pretrained_model_name} is not in get_pretrained_lm_models_list(include_external=False), ' f'will be using AutoModel from HuggingFace.' ) # warning when user passes both configuration dict and file if config_dict and config_file: logging.warning( f"Both config_dict and config_file were found, defaulting to use config_file: {config_file} will be used." ) if "megatron" in pretrained_model_name: model, checkpoint_file = get_megatron_lm_model( config_dict=config_dict, config_file=config_file, pretrained_model_name=pretrained_model_name, checkpoint_file=checkpoint_file, vocab_file=vocab_file, ) else: model = get_huggingface_lm_model( config_dict=config_dict, config_file=config_file, pretrained_model_name=pretrained_model_name, ) if checkpoint_file and os.path.exists(checkpoint_file): model.restore_weights(restore_path=checkpoint_file) return model