Source code for nemo_export.utils.utils
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from collections import Counter
from pathlib import Path
from typing import Dict, Optional, Union
import torch
from transformers import PreTrainedTokenizerBase
[docs]
def is_nemo2_checkpoint(checkpoint_path: str) -> bool:
"""Checks if the checkpoint is in NeMo 2.0 format.
Args:
checkpoint_path (str): Path to a checkpoint.
Returns:
bool: True if the path points to a NeMo 2.0 checkpoint; otherwise false.
"""
ckpt_path = Path(checkpoint_path)
return (ckpt_path / "context").is_dir()
[docs]
def prepare_directory_for_export(
model_dir: Union[str, Path],
delete_existing_files: bool,
subdir: Optional[str] = None,
) -> None:
"""Prepares model_dir path for the TensorRTT-LLM / vLLM export.
Makes sure that the model_dir directory exists and is empty.
Args:
model_dir (str): Path to the target directory for the export.
delete_existing_files (bool): Attempt to delete existing files if they exist.
subdir (Optional[str]): Subdirectory to create inside the model_dir.
Returns:
None
"""
model_path = Path(model_dir)
if model_path.exists():
if delete_existing_files:
shutil.rmtree(model_path)
elif any(model_path.iterdir()):
raise RuntimeError(f"There are files in {model_path} folder: try setting delete_existing_files=True.")
if subdir is not None:
model_path /= subdir
model_path.mkdir(parents=True, exist_ok=True)
[docs]
def is_nemo_tarfile(path: str) -> bool:
"""Checks if the path exists and points to packed NeMo 1 checkpoint.
Args:
path (str): Path to possible checkpoint.
Returns:
bool: NeMo 1 checkpoint exists and is in '.nemo' format.
"""
checkpoint_path = Path(path)
return checkpoint_path.exists() and checkpoint_path.suffix == ".nemo"
# Copied from nemo.collections.nlp.parts.utils_funcs to avoid introducing extra NeMo dependencies:
[docs]
def torch_dtype_from_precision(precision: Union[int, str], megatron_amp_O2: bool = True) -> torch.dtype:
"""Mapping from PyTorch Lighthing (PTL) precision types to corresponding PyTorch parameter data type.
Args:
precision (Union[int, str]): The PTL precision type used.
megatron_amp_O2 (bool): A flag indicating if Megatron AMP O2 is enabled.
Returns:
torch.dtype: The corresponding PyTorch data type based on the provided precision.
"""
if not megatron_amp_O2:
return torch.float32
if precision in ["bf16", "bf16-mixed"]:
return torch.bfloat16
elif precision in [16, "16", "16-mixed"]:
return torch.float16
elif precision in [32, "32", "32-true"]:
return torch.float32
else:
raise ValueError(f"Could not parse the precision of '{precision}' to a valid torch.dtype")
[docs]
def get_model_device_type(module: torch.nn.Module) -> str:
"""Find the device type the model is assigned to and ensure consistency."""
# Collect device types of all parameters and buffers
param_device_types = {param.device.type for param in module.parameters()}
buffer_device_types = {buffer.device.type for buffer in module.buffers()}
all_device_types = param_device_types.union(buffer_device_types)
if len(all_device_types) > 1:
raise ValueError(
f"Model parameters and buffers are on multiple device types: {all_device_types}. "
"Ensure all parameters and buffers are on the same device type."
)
# Return the single device type, or default to 'cpu' if no parameters or buffers
return all_device_types.pop() if all_device_types else "cpu"
[docs]
def validate_fp8_network(network) -> None:
"""Checks the network to ensure it's compatible with fp8 precison.
Raises:
ValueError if netowrk doesn't container Q/DQ FP8 layers
"""
import tensorrt as trt
quantize_dequantize_layers = []
for layer in network:
if layer.type in {trt.LayerType.QUANTIZE, trt.LayerType.DEQUANTIZE}:
quantize_dequantize_layers.append(layer)
if not quantize_dequantize_layers:
error_msg = "No Quantize/Dequantize layers found"
raise ValueError(error_msg)
quantize_dequantize_layer_dtypes = Counter(layer.precision for layer in quantize_dequantize_layers)
if trt.DataType.FP8 not in quantize_dequantize_layer_dtypes:
error_msg = "Found Quantize/Dequantize layers. But none with FP8 precision."
raise ValueError(error_msg)