Source code for nemo_export.utils.model_loader
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os.path
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Union
import numpy
# tenosrstore is needed to register 'bfloat16' dtype with numpy for zarr compatibility
import tensorstore # noqa: F401 pylint: disable=unused-import
import torch
from torch.distributed.checkpoint import FileSystemReader, load
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
TensorStorageMetadata,
)
from nemo_export.tarutils import TarPath, ZarrPathStore
from nemo_export.utils._mock_import import _mock_import
LOGGER = logging.getLogger("NeMo")
[docs]
def nemo_to_path(nemo_checkpoint: Union[Path, str]) -> Union[Path, TarPath]:
"""Creates Path / TarPath object suitable for navigating inside the nemo checkpoint.
Args:
nemo_checkpoint (Path, str): Path to the NeMo checkpoint.
Returns:
Path | TarPath: Suitable Path object for navigating through the checkpoint.
"""
string_path = str(nemo_checkpoint)
if os.path.isdir(string_path):
return Path(string_path)
return TarPath(string_path)
[docs]
class TarFileSystemReader(FileSystemReader):
"""Reader that accepts both Path and TarPath checkpoint directory.
The FileSystemReader works with TarPath, but expects a pure Path.
It's enough to skip the Path check in __init__.
"""
def __init__(self, path: Union[Path, TarPath]) -> None:
"""Makes sure that super().__init__ gets a pure path as expected."""
super_path = str(path) if isinstance(path, TarPath) else path
super().__init__(super_path)
if isinstance(path, TarPath):
self.path = path # overwrites path set in super().__init__ call
[docs]
def nemo_weights_directory(nemo_path: Union[Path, TarPath]) -> Union[Path, TarPath]:
"""Returns a Path pointing to the weights directory inside the NeMo checkpoint.
Args:
nemo_path (Path | TarPath): Path to the nemo checkpoint.
Returns:
Path | TarPath: Path to the weights directory inside the model checkpoint.
"""
if (nemo_path / "model_weights").exists():
return nemo_path / "model_weights"
if (nemo_path / "weights").exists():
return nemo_path / "weights"
return nemo_path
[docs]
def load_model_weights(checkpoint_path: Union[str, Path], load_extra_states: bool = False) -> Dict[str, Any]:
"""Loads NeMo state dictionary.
Weights are stored in torch.Tensor
Args:
checkpoint_path (str | Path): Path to the NeMo checkpoint.
load_extra_states (bool): If True, loads BytesIO objects, corresponding to the extra states.
Returns:
dict: Model state dictionary.
"""
nemo_path = nemo_to_path(checkpoint_path)
nemo_weights = nemo_weights_directory(nemo_path)
with (nemo_weights / "metadata.json").open(mode="r") as f:
config_dict = json.load(f)
if config_dict["sharded_backend"] == "zarr":
return load_sharded_metadata_zarr(nemo_weights, load_extra_states=load_extra_states)
elif config_dict["sharded_backend"] == "torch_dist":
# TODO: Remove mocking imports once MCore is available in NIM containers
with _mock_import("megatron.core.dist_checkpointing.strategies.torch"):
return load_sharded_metadata_torch_dist(nemo_weights, load_extra_states=load_extra_states)
raise NotImplementedError(f"Distributed checkpoint backend {config_dict['sharded_backend']} not supported")