# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright 2015 and onwards Google, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tarfile
import tempfile
import uuid
from typing import Optional, Union
import torch
from omegaconf import DictConfig, OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
from nemo.utils import logging, model_utils
from nemo.utils.app_state import AppState
from nemo.utils.get_rank import is_global_rank_zero
from nemo.utils.model_utils import inject_model_parallel_rank
[docs]class SaveRestoreConnector:
def __init__(self) -> None:
self._model_config_yaml = "model_config.yaml"
self._model_weights_ckpt = "model_weights.ckpt"
self._model_extracted_dir = None
[docs] def save_to(self, model, save_path: str):
"""
Saves model instance (weights and configuration) into .nemo file.
You can use "restore_from" method to fully restore instance from .nemo file.
.nemo file is an archive (tar.gz) with the following:
model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor
model_wights.chpt - model checkpoint
Args:
model: ModelPT object to be saved.
save_path: Path to .nemo file where model instance should be saved
"""
if is_global_rank_zero():
with tempfile.TemporaryDirectory() as tmpdir:
config_yaml = os.path.join(tmpdir, self.model_config_yaml)
model_weights = os.path.join(tmpdir, self.model_weights_ckpt)
model.to_config_file(path2yaml_file=config_yaml)
if hasattr(model, 'artifacts') and model.artifacts is not None:
self._handle_artifacts(model, nemo_file_folder=tmpdir)
# We should not update self._cfg here - the model can still be in use
self._update_artifact_paths(model, path2yaml_file=config_yaml)
self._save_state_dict_to_disk(model.state_dict(), model_weights)
self._make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir)
else:
return
[docs] def load_config_and_state_dict(
self,
calling_cls,
restore_path: str,
override_config_path: Optional[Union[OmegaConf, str]] = None,
map_location: Optional[torch.device] = None,
strict: bool = True,
return_config: bool = False,
trainer: Trainer = None,
):
"""
Restores model instance (weights and configuration) into .nemo file
Args:
restore_path: path to .nemo file from which model should be instantiated
override_config_path: path to a yaml config that will override the internal
config file or an OmegaConf / DictConfig object representing the model config.
map_location: Optional torch.device() to map the instantiated model to a device.
By default (None), it will select a GPU if available, falling back to CPU otherwise.
strict: Passed to load_state_dict. By default True
return_config: If set to true, will return just the underlying config of the restored
model as an OmegaConf DictConfig object without instantiating the model.
Example:
```
model = nemo.collections.asr.models.EncDecCTCModel.restore_from('asr.nemo')
assert isinstance(model, nemo.collections.asr.models.EncDecCTCModel)
```
Returns:
An instance of type cls or its underlying config (if return_config is set).
"""
# Get path where the command is executed - the artifacts will be "retrieved" there
# (original .nemo behavior)
cwd = os.getcwd()
if map_location is None:
if torch.cuda.is_available():
map_location = torch.device('cuda')
else:
map_location = torch.device('cpu')
app_state = AppState()
with tempfile.TemporaryDirectory() as tmpdir:
try:
# Check if self.model_extracted_dir is set, and is a valid path
if self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir):
# Log that NeMo will use the provided `model_extracted_dir`
logging.info(
f"Restoration will occur within pre-extracted directory : " f"`{self.model_extracted_dir}`."
)
# Override `tmpdir` above with the pre-extracted `model_extracted_dir`
tmpdir = self.model_extracted_dir
else:
# Extract the nemo file into the temporary directory
self._unpack_nemo_file(path2file=restore_path, out_folder=tmpdir)
# Change current working directory to
os.chdir(tmpdir)
if override_config_path is None:
config_yaml = os.path.join(tmpdir, self.model_config_yaml)
else:
# can be str path or OmegaConf / DictConfig object
config_yaml = override_config_path
if not isinstance(config_yaml, (OmegaConf, DictConfig)):
conf = OmegaConf.load(config_yaml)
else:
conf = config_yaml
if override_config_path is not None:
# Resolve the override config
conf = OmegaConf.to_container(conf, resolve=True)
conf = OmegaConf.create(conf)
# If override is top level config, extract just `model` from it
if 'model' in conf:
conf = conf.model
if return_config:
instance = conf
return instance
else:
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
model_weights = self._inject_model_parallel_rank_for_ckpt(tmpdir, self.model_weights_ckpt)
else:
model_weights = os.path.join(tmpdir, self.model_weights_ckpt)
OmegaConf.set_struct(conf, True)
os.chdir(cwd)
# get the class
calling_cls._set_model_restore_state(is_being_restored=True, folder=tmpdir)
instance = calling_cls.from_config_dict(config=conf, trainer=trainer)
instance = instance.to(map_location)
# add load_state_dict override
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
model_weights = self._inject_model_parallel_rank_for_ckpt(tmpdir, self.model_weights_ckpt)
state_dict = self._load_state_dict_from_disk(model_weights, map_location=map_location)
finally:
os.chdir(cwd)
return (conf, instance, state_dict)
[docs] def modify_state_dict(self, conf, state_dict):
"""
Utility method that allows to modify the state dict before loading parameters into a model.
Args:
conf: A model level OmegaConf object.
state_dict: The state dict restored from the checkpoint.
Returns:
A potentially modified state dict.
"""
# NOTE and TODO (sandeepsub) This is duplicated across save_restore_connector and nlp_save_restore_connector. This shouldn't be here.
if conf.get('megatron_amp_O2', False):
new_state_dict = {}
for key in state_dict.keys():
new_key = key.replace('model.', 'model.module.', 1)
new_state_dict[new_key] = state_dict[key]
state_dict = new_state_dict
return state_dict
[docs] def load_instance_with_state_dict(self, instance, state_dict, strict):
"""
Utility method that loads a model instance with the (potentially modified) state dict.
Args:
instance: ModelPT subclass instance.
state_dict: The state dict (which may have been modified)
strict: Bool, whether to perform strict checks when loading the state dict.
"""
instance.load_state_dict(state_dict, strict=strict)
instance._set_model_restore_state(is_being_restored=False)
[docs] def restore_from(
self,
calling_cls,
restore_path: str,
override_config_path: Optional[Union[OmegaConf, str]] = None,
map_location: Optional[torch.device] = None,
strict: bool = True,
return_config: bool = False,
trainer: Trainer = None,
):
"""
Restores model instance (weights and configuration) into .nemo file
Args:
restore_path: path to .nemo file from which model should be instantiated
override_config_path: path to a yaml config that will override the internal
config file or an OmegaConf / DictConfig object representing the model config.
map_location: Optional torch.device() to map the instantiated model to a device.
By default (None), it will select a GPU if available, falling back to CPU otherwise.
strict: Passed to load_state_dict. By default True
return_config: If set to true, will return just the underlying config of the restored
model as an OmegaConf DictConfig object without instantiating the model.
trainer: An optional Trainer object, passed to the model constructor.
Example:
```
model = nemo.collections.asr.models.EncDecCTCModel.restore_from('asr.nemo')
assert isinstance(model, nemo.collections.asr.models.EncDecCTCModel)
```
Returns:
An instance of type cls or its underlying config (if return_config is set).
"""
# Get path where the command is executed - the artifacts will be "retrieved" there
# (original .nemo behavior)
loaded_params = self.load_config_and_state_dict(
calling_cls, restore_path, override_config_path, map_location, strict, return_config, trainer,
)
if not isinstance(loaded_params, tuple):
return loaded_params
conf, instance, state_dict = loaded_params
state_dict = self.modify_state_dict(conf, state_dict)
self.load_instance_with_state_dict(instance, state_dict, strict)
logging.info(f'Model {instance.__class__.__name__} was successfully restored from {restore_path}.')
return instance
[docs] def register_artifact(self, model, config_path: str, src: str, verify_src_exists: bool = True):
"""
Register model artifacts with this function. These artifacts (files) will be included inside .nemo file
when model.save_to("mymodel.nemo") is called.
How it works:
1. It always returns existing absolute path which can be used during Model constructor call
EXCEPTION: src is None or "" in which case nothing will be done and src will be returned
2. It will add (config_path, model_utils.ArtifactItem()) pair to self.artifacts
If "src" is local existing path, then it will be returned in absolute path form.
elif "src" starts with "nemo_file:unique_artifact_name":
.nemo will be untarred to a temporary folder location and an actual existing path will be returned
else an error will be raised.
WARNING: use .register_artifact calls in your models' constructors.
The returned path is not guaranteed to exist after you have exited your model's constructor.
Args:
model: ModelPT object to register artifact for.
config_path (str): Artifact key. Usually corresponds to the model config.
src (str): Path to artifact.
verify_src_exists (bool): If set to False, then the artifact is optional and register_artifact will return
None even if src is not found. Defaults to True.
Returns:
str: If src is not None or empty it always returns absolute path which is guaranteed to exists during model
instance life
"""
app_state = AppState()
artifact_item = model_utils.ArtifactItem()
# This is for backward compatibility, if the src objects exists simply inside of the tarfile
# without its key having been overriden, this pathway will be used.
src_obj_name = os.path.basename(src)
if app_state.nemo_file_folder is not None:
src_obj_path = os.path.abspath(os.path.join(app_state.nemo_file_folder, src_obj_name))
else:
src_obj_path = src_obj_name
# src is a local existing path - register artifact and return exact same path for usage by the model
if os.path.exists(os.path.abspath(src)):
return_path = os.path.abspath(src)
artifact_item.path_type = model_utils.ArtifactPathType.LOCAL_PATH
# this is the case when artifact must be retried from the nemo file
# we are assuming that the location of the right nemo file is available from _MODEL_RESTORE_PATH
elif src.startswith("nemo:"):
return_path = os.path.abspath(os.path.join(app_state.nemo_file_folder, src[5:]))
artifact_item.path_type = model_utils.ArtifactPathType.TAR_PATH
# backward compatibility implementation
elif os.path.exists(src_obj_path):
return_path = src_obj_path
artifact_item.path_type = model_utils.ArtifactPathType.TAR_PATH
else:
if verify_src_exists:
raise FileNotFoundError(
f"src path does not exist or it is not a path in nemo file. src value I got was: {src}. Absolute: {os.path.abspath(src)}"
)
else:
# artifact is optional and we simply return None
return None
assert os.path.exists(return_path)
artifact_item.path = os.path.abspath(src)
model.artifacts[config_path] = artifact_item
# we were called by ModelPT
if hasattr(model, "cfg"):
with open_dict(model._cfg):
OmegaConf.update(model.cfg, config_path, return_path)
return return_path
def _handle_artifacts(self, model, nemo_file_folder):
tarfile_artifacts = []
app_state = AppState()
for conf_path, artiitem in model.artifacts.items():
if artiitem.path_type == model_utils.ArtifactPathType.LOCAL_PATH:
if not os.path.exists(artiitem.path):
raise FileNotFoundError(f"Artifact {conf_path} not found at location: {artiitem.path}")
# Generate new uniq artifact name and copy it to nemo_file_folder
# Note uuid.uuid4().hex is guaranteed to be 32 character long
artifact_base_name = os.path.basename(artiitem.path)
artifact_uniq_name = f"{uuid.uuid4().hex}_{artifact_base_name}"
shutil.copy2(artiitem.path, os.path.join(nemo_file_folder, artifact_uniq_name))
# Update artifacts registry
artiitem.hashed_path = "nemo:" + artifact_uniq_name
model.artifacts[conf_path] = artiitem
elif artiitem.path_type == model_utils.ArtifactPathType.TAR_PATH:
# process all tarfile artifacts in one go, so preserve key-value pair
tarfile_artifacts.append((conf_path, artiitem))
else:
raise ValueError(f"Directly referencing artifacts from other nemo files isn't supported yet")
# Process current tarfile artifacts by unpacking the previous tarfile and extract the artifacts
# that are currently required.
model_metadata = app_state.get_model_metadata_from_guid(model.model_guid)
if len(tarfile_artifacts) > 0 and model_metadata.restoration_path is not None:
# Need to step into nemo archive to extract file
# Get path where the command is executed - the artifacts will be "retrieved" there
# (original .nemo behavior)
cwd = os.getcwd()
try:
# Step into the nemo archive to try and find the file
with tempfile.TemporaryDirectory() as archive_dir:
self._unpack_nemo_file(path2file=model_metadata.restoration_path, out_folder=archive_dir)
os.chdir(archive_dir)
for conf_path, artiitem in tarfile_artifacts:
# Get basename and copy it to nemo_file_folder
if 'nemo:' in artiitem.path:
artifact_base_name = artiitem.path.split('nemo:')[1]
else:
artifact_base_name = os.path.basename(artiitem.path)
# no need to hash here as we are in tarfile_artifacts which are already hashed
artifact_uniq_name = artifact_base_name
shutil.copy2(artifact_base_name, os.path.join(nemo_file_folder, artifact_uniq_name))
# Update artifacts registry
new_artiitem = model_utils.ArtifactItem()
new_artiitem.path = "nemo:" + artifact_uniq_name
new_artiitem.path_type = model_utils.ArtifactPathType.TAR_PATH
model.artifacts[conf_path] = new_artiitem
finally:
# change back working directory
os.chdir(cwd)
def _update_artifact_paths(self, model, path2yaml_file):
if model.artifacts is not None and len(model.artifacts) > 0:
conf = OmegaConf.load(path2yaml_file)
for conf_path, item in model.artifacts.items():
if item.hashed_path is None:
OmegaConf.update(conf, conf_path, item.path)
else:
OmegaConf.update(conf, conf_path, item.hashed_path)
with open(path2yaml_file, 'w', encoding='utf-8') as fout:
OmegaConf.save(config=conf, f=fout, resolve=True)
def _inject_model_parallel_rank_for_ckpt(self, dirname, basename):
model_weights = os.path.join(dirname, basename)
model_weights = inject_model_parallel_rank(model_weights)
return model_weights
@staticmethod
def _make_nemo_file_from_folder(filename, source_dir):
dirname = os.path.dirname(filename)
os.makedirs(dirname, exist_ok=True)
with tarfile.open(filename, "w:") as tar:
tar.add(source_dir, arcname=".")
@staticmethod
def _unpack_nemo_file(path2file: str, out_folder: str) -> str:
if not os.path.exists(path2file):
raise FileNotFoundError(f"{path2file} does not exist")
# we start with an assumption of uncompressed tar,
# which should be true for versions 1.7.0 and above
tar_header = "r:"
try:
tar_test = tarfile.open(path2file, tar_header)
tar_test.close()
except tarfile.ReadError:
# can be older checkpoint => try compressed tar
tar_header = "r:gz"
tar = tarfile.open(path2file, tar_header)
tar.extractall(path=out_folder)
tar.close()
return out_folder
@staticmethod
def _save_state_dict_to_disk(state_dict, filepath):
torch.save(state_dict, filepath)
@staticmethod
def _load_state_dict_from_disk(model_weights, map_location=None):
return torch.load(model_weights, map_location=map_location)
@property
def model_config_yaml(self) -> str:
return self._model_config_yaml
@model_config_yaml.setter
def model_config_yaml(self, path: str):
self._model_config_yaml = path
@property
def model_weights_ckpt(self) -> str:
return self._model_weights_ckpt
@model_weights_ckpt.setter
def model_weights_ckpt(self, path: str):
self._model_weights_ckpt = path
@property
def model_extracted_dir(self) -> Optional[str]:
return self._model_extracted_dir
@model_extracted_dir.setter
def model_extracted_dir(self, path: Optional[str]):
self._model_extracted_dir = path