# Copyright 2018 The Google AI Language Team Authors and
# The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from megatron import get_args, initialize_megatron
from megatron.checkpointing import set_checkpoint_version
from megatron.model import get_language_model
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.mpu import get_model_parallel_group, model_parallel_is_initialized
from omegaconf import OmegaConf
from nemo.collections.nlp.modules.common.bert_module import BertModule
from nemo.core.classes import typecheck
from nemo.utils import logging
from nemo.utils.app_state import AppState
__all__ = ['MegatronBertEncoder']
def complete_lazy_init(self):
# finish megatron-lm initialization
if hasattr(self, "_lazy_init_fn") and self._lazy_init_fn is not None:
self._lazy_init_fn()
self._lazy_init_fn = None
[docs]class MegatronBertEncoder(BertModule):
"""
MegatronBERT wraps around the Megatron Language model
from https://github.com/NVIDIA/Megatron-LM
Args:
config_file (str): path to model configuration file.
vocab_file (str): path to vocabulary file.
tokenizer_type (str): tokenizer type, currently only 'BertWordPieceLowerCase' supported.
"""
def __init__(self, model_name, config, vocab_file, model_parallel_size=None, model_parallel_rank=None):
super().__init__()
self._model_parallel_size = model_parallel_size
self._model_parallel_rank = model_parallel_rank
self._restore_path = None
self._app_state = None
self._model_name = model_name
if not os.path.exists(vocab_file):
raise ValueError(f'Vocab file not found at {vocab_file}')
config["vocab_file"] = vocab_file
config['tokenizer_type'] = 'BertWordPieceLowerCase'
config['lazy_mpu_init'] = True
config['onnx_safe'] = True
# if 'model_parallel_size' in config:
if self._model_parallel_size is not None:
app_state = AppState()
self._app_state = app_state
# must be set for model parallel megatron-lm
os.environ["WORLD_SIZE"] = str(app_state.world_size)
os.environ["RANK"] = str(self._model_parallel_rank)
# used to set model_parallel_size in megatron-lm argparser
def _update_model_parallel_arg(parser):
parser.set_defaults(model_parallel_size=self._model_parallel_size)
return parser
extra_args_provider = _update_model_parallel_arg
else:
extra_args_provider = None
# Initialize part of Megatron global state that is needed for its constructor.
# We set 'lazy_mpu_init' flag on to make Megatron do only the initialization that does not depend
# on ddp be initialized yet (and we don't want Megatron to initialize DDP itself either)
# and to return a hook for us to call after PTL has torch.distributed initialized.
# (or if no PTL in case of inference - then we'll initialize torch.distributed)
# We call and clear this hook on first call to forward()
self._lazy_init_fn = initialize_megatron(
extra_args_provider=extra_args_provider, args_defaults=config, ignore_unknown_args=True
)
# read Megatron arguments back
args = get_args()
logging.info(f'Megatron-lm argparse args: {args}')
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func, num_tokentypes=2, add_pooler=False
)
self.config = OmegaConf.create(config)
# key used for checkpoints
self._hidden_size = self.language_model.hidden_size
[docs] def complete_lazy_init(self):
# finish megatron-lm initialization
if hasattr(self, "_lazy_init_fn") and self._lazy_init_fn is not None:
self._lazy_init_fn()
self._lazy_init_fn = None
@property
def hidden_size(self):
"""
Property returning hidden size.
Returns:
Hidden size.
"""
return self._hidden_size
[docs] @typecheck()
def forward(self, input_ids, attention_mask, token_type_ids):
self.complete_lazy_init()
extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
sequence_output = self.language_model(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=extended_attention_mask,
tokentype_ids=token_type_ids,
)
return sequence_output
def _load_checkpoint(self, filename):
"""Helper function for loading megatron checkpoints.
Args:
filename (str): Path to megatron checkpoint.
"""
state_dict = torch.load(filename, map_location='cpu')
if 'checkpoint_version' in state_dict:
if state_dict['checkpoint_version'] is not None:
set_checkpoint_version(state_dict['checkpoint_version'])
else:
logging.warning('Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.')
set_checkpoint_version(0)
# to load from Megatron pretrained checkpoint
if 'model' in state_dict:
self.language_model.load_state_dict(state_dict['model'][self._language_model_key])
else:
self.load_state_dict(state_dict)
logging.info(f"Checkpoint loaded from from {filename}")
[docs] def restore_weights(self, restore_path: str):
"""Restores module/model's weights.
For model parallel checkpoints the directory structure
should be restore_path/mp_rank_0X/model_optim_rng.pt
Args:
restore_path (str): restore_path should a file or a directory if using model parallel
"""
self._restore_path = restore_path
if os.path.isfile(restore_path):
self._load_checkpoint(restore_path)
elif os.path.isdir(restore_path):
# need model parallel groups to restore model parallel checkpoints
if model_parallel_is_initialized():
model_parallel_rank = torch.distributed.get_rank(group=get_model_parallel_group())
mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt'
self._load_checkpoint(mp_restore_path)
else:
logging.info(f'torch.distributed not initialized yet. Will not restore model parallel checkpoint')
else:
logging.error(f'restore_path: {restore_path} must be a file or directory.')