#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import contextlib
import importlib
import os
import subprocess as sp
import sys
from polygraphy.mod import util as mod_util
# Tracks all of Polygraphy's lazy imports, excluding internal ones.
_all_external_lazy_imports = set()
# Sometimes the Python package name differs from the module name.
_MODULE_TO_PKG_NAME = {
"tensorrt": "nvidia-tensorrt",
}
# Some packages need additional flags to install correctly.
_MODULE_EXTRA_FLAGS = {
"tensorrt": ["--extra-index-url=https://pypi.ngc.nvidia.com"],
"onnx_graphsurgeon": ["--extra-index-url=https://pypi.ngc.nvidia.com"],
}
LATEST_VERSION = "latest"
"""Indicates that the latest version of the package is preferred in lazy_import"""
def _version_ok(ver, preferred):
if preferred == LATEST_VERSION:
return False
pref_ver = preferred.lstrip("<=>").strip()
cond = preferred.rstrip(pref_ver).strip()
check = {
"==": lambda x, y: x == y,
">=": lambda x, y: x >= y,
">": lambda x, y: x > y,
"<=": lambda x, y: x <= y,
"<": lambda x, y: x < y,
}[cond]
return check(mod_util.version(ver), mod_util.version(pref_ver))
[docs]def lazy_import(name, log=True, version=None):
"""
Lazily import a module.
If the POLYGRAPHY_AUTOINSTALL_DEPS environment variable is set to 1,
missing modules are automatically installed, and existing modules may be
upgraded if newer versions are required.
Args:
name (str):
The name of the module.
log (bool):
Whether to log information about the module.
version (str):
The preferred version of the package, formatted as a version string.
For example, ``'>=0.5.0'`` or ``'==1.8.0'``. Use ``LATEST_VERSION`` to
indicate that the latest version of the package is preferred.
Returns:
LazyModule:
A lazily loaded module. When an attribute is first accessed,
the module will be imported.
"""
assert (
version is None or version == LATEST_VERSION or any(version.startswith(char) for char in ["=", ">", "<"])
), "version must be formatted as a version string!"
if "polygraphy" not in name:
_all_external_lazy_imports.add(name)
def import_mod():
from polygraphy import config
from polygraphy.logger import G_LOGGER, LogMode
def install_mod(raise_error=True):
modname = name.split(".")[0]
pkg = _MODULE_TO_PKG_NAME.get(modname, modname)
extra_flags = _MODULE_EXTRA_FLAGS.get(modname, [])
if version == LATEST_VERSION:
extra_flags.append("--upgrade")
elif version is not None:
pkg += version
cmd = config.INSTALL_CMD + [pkg] + extra_flags
G_LOGGER.info(
"{:} is required, but not installed. Attempting to install now.\n"
"Running: {:}".format(pkg, " ".join(cmd))
)
status = sp.run(cmd)
if status.returncode != 0:
log_func = G_LOGGER.critical if raise_error else G_LOGGER.warning
log_func(
"Could not automatically install required package: {:}. Please install it manually.".format(pkg)
)
mod = importlib.import_module(name)
return mod
mod = None
try:
mod = importlib.import_module(name)
except:
if config.AUTOINSTALL_DEPS:
mod = install_mod()
else:
G_LOGGER.critical(
"Module: {:} is required but could not be imported.\n"
"You can try setting POLYGRAPHY_AUTOINSTALL_DEPS=1 in your environment variables "
"to allow Polygraphy to automatically install missing packages.\n"
"Note that this may cause existing packages to be overwritten - hence, it may be "
"desirable to use a Python virtual environment or container. ".format(name)
)
# Auto-upgrade if necessary
if version is not None and hasattr(mod, "__version__") and not _version_ok(mod.__version__, version):
if config.AUTOINSTALL_DEPS:
G_LOGGER.info(
"Note: Package: '{name}' version {cur_ver} is installed, but version {rec_ver} is recommended.\n"
"Upgrading...".format(name=name, cur_ver=mod.__version__, rec_ver=version)
)
mod = install_mod(raise_error=False) # We can try to use the other version if install fails.
elif version != LATEST_VERSION:
G_LOGGER.warning(
"Package: '{name}' version {cur_ver} is installed, but version {rec_ver} is recommended.\n"
"Consider installing the recommended version or setting POLYGRAPHY_AUTOINSTALL_DEPS=1 in your "
"environment variables to do so automatically. ".format(
name=name, cur_ver=mod.__version__, rec_ver=version
),
mode=LogMode.ONCE,
)
if log:
G_LOGGER.module_info(mod)
return mod
class LazyModule(object):
def __getattr__(self, name):
self = import_mod()
return getattr(self, name)
def __setattr__(self, name, value):
self = import_mod()
return setattr(self, name, value)
return LazyModule()
[docs]def has_mod(lazy_mod, with_attr="__version__"):
"""
Checks whether a module is available.
Args:
lazy_mod (LazyModule):
A lazy module, like that returned by ``lazy_import``.
with_attr (str):
The name of an attribute to check for.
This helps distinguish mock modules from real ones.
"""
try:
getattr(lazy_mod, with_attr)
except:
return False
return True
[docs]def import_from_script(path, name):
"""
Imports a specified symbol from a Python script.
Args:
path (str): A path to the Python script. The path must include a '.py' extension.
name (str): The name of the symbol to import from the script.
Returns:
object: The loaded symbol.
"""
from polygraphy.logger import G_LOGGER
dir = os.path.dirname(path)
modname = os.path.splitext(os.path.basename(path))[0]
sys.path.insert(0, dir)
with contextlib.ExitStack() as stack:
def reset_sys_path():
del sys.path[0]
stack.callback(reset_sys_path)
try:
mod = importlib.import_module(modname)
return getattr(mod, name)
except Exception as err:
ext = os.path.splitext(path)[1]
err_msg = "Could not import symbol: {:} from script: {:}".format(name, path)
if ext != ".py":
err_msg += "\nThis could be because the extension of the file is not '.py'. Note: The extension is: {:}".format(
ext
)
err_msg += "\nNote: Error was: {:}".format(err)
G_LOGGER.critical(err_msg)