# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import importlib
import importlib.util
import os
import subprocess as sp
import sys
from typing import List

    # Available in Python 3.8+
    import importlib.metadata
except ModuleNotFoundError:
    import pkg_resources

from polygraphy import constants
from polygraphy.mod import util as mod_util

# Tracks all of Polygraphy's lazy imports, excluding internal ones.
_all_external_lazy_imports = set()

# Sometimes the Python package name differs from the module name.

# Some packages need additional flags to install correctly.
    "onnx_graphsurgeon": ["--extra-index-url="],

LATEST_VERSION = "==latest"
"""Indicates that the latest version of the package is preferred in lazy_import"""

def _version_ok(ver, preferred):
    if preferred == LATEST_VERSION:
        return False

    pref_ver = preferred.lstrip("<=>").strip()
    cond = preferred.rstrip(pref_ver).strip()
    check = {
        "==": lambda x, y: x == y,
        ">=": lambda x, y: x >= y,
        ">": lambda x, y: x > y,
        "<=": lambda x, y: x <= y,
        "<": lambda x, y: x < y,
    return check(mod_util.version(ver), mod_util.version(pref_ver))

[docs]def lazy_import( name: str, log: bool = None, pkg_name: str = None, install_flags: List[str] = None, requires: List[str] = None, ): """ Lazily import a module. If config.AUTOINSTALL_DEPS is set to 1, missing modules are automatically installed, and existing modules may be upgraded if newer versions are required. Args: name (str): The name of the module and optionally the preferred version of the package, formatted as a version string. For example, ``'example_module>=0.5.0'`` or ``'example_module==1.8.0'``. log (bool): Whether to log information about the module. Defaults to True. pkg_name (str): The name of the package that provides this module, if it is different from the module name. Used only if automatic installation of dependencies is enabled. install_flags (List[str]): Additional flags to provide to the installation command. Used only if automatic installation of dependencies is enabled. requires (List[str]): Additional dependencies required by the module which are *not* specified as dependencies. This parameter should only be required when a module does not correctly specify dependencies. Defaults to []. Returns: LazyModule: A lazily loaded module. When an attribute is first accessed, the module will be imported. """ VERSION_CHARS = ["=", ">", "<"] log = True if log is None else log requires = [] if requires is None else requires def split_name_version(inp): version_char_indices = [inp.index(char) for char in VERSION_CHARS if char in inp] if not version_char_indices: return inp, None min_index = min(version_char_indices) return inp[:min_index], inp[min_index:] name, version = split_name_version(name) all_required_mods = list(map(split_name_version, requires)) + [(name, version)] if "polygraphy" not in name: _all_external_lazy_imports.add(name) def import_mod(): from polygraphy import config from polygraphy.logger import G_LOGGER, LogMode def install_mod(install_name, install_version, raise_error=True): modname = install_name.split(".")[0] pkg = pkg_name if pkg_name is not None else _PKG_NAME_FROM_MODULE.get(modname, modname) extra_flags = install_flags if install_flags is not None else _EXTRA_FLAGS_FOR_MODULE.get(modname, []) def fail(): log_func = G_LOGGER.critical if raise_error else G_LOGGER.warning log_func(f"Could not automatically install required module: {pkg}. Please install it manually.") if config.ASK_BEFORE_INSTALL: res = None while res not in ["y", "n"]: res = input(f"Automatically install '{pkg}' (version: {install_version or 'any'}) ([Y]/n)? ") res = res.strip()[:1].lower() or "y" if res == "n": fail() if install_version == LATEST_VERSION: extra_flags.append("--upgrade") elif install_version is not None: pkg += install_version cmd = config.INSTALL_CMD + [pkg] + extra_flags"Running installation command: {' '.join(cmd)}") status =, stdout=sp.PIPE, stderr=sp.PIPE) if status.returncode != 0: G_LOGGER.error(f"Error during installation:\n{constants.TAB}{status.stderr.decode()}") fail() mod = importlib.import_module(install_name) return mod mod = None try: mod = importlib.import_module(name) except ImportError as err: if config.AUTOINSTALL_DEPS: for install_name, install_version in all_required_mods: f"Module: '{install_name}' is required, but not installed. Attempting to install now." ) mod = install_mod(install_name, install_version) else: G_LOGGER.critical( f"Module: '{name}' is required but could not be imported.\nNote: Error was: {err}\n" f"You can set POLYGRAPHY_AUTOINSTALL_DEPS=1 in your environment variables to allow " f"Polygraphy to automatically install missing modules.\n" ) # Auto-upgrade if necessary for install_name, install_version in all_required_mods: installed_mod = importlib.import_module(install_name) if ( install_version is not None and hasattr(installed_mod, "__version__") and not _version_ok(installed_mod.__version__, install_version) ): if config.AUTOINSTALL_DEPS: f"Note: Module: '{install_name}' version '{installed_mod.__version__}' is installed, but version '{install_version}' is required.\n" f"Attempting to upgrade now." ) # We can try to use the other version if install fails, so this is non-fatal. installed_mod = install_mod(install_name, install_version, raise_error=False) if install_name == name: mod = installed_mod elif install_version != LATEST_VERSION: G_LOGGER.error( f"Module: '{install_name}' version '{installed_mod.__version__}' is installed, but version '{install_version}' is required.\n" f"Please install the required version or set POLYGRAPHY_AUTOINSTALL_DEPS=1 in your environment variables " f"to allow Polygraphy to do so automatically.\n" f"Attempting to continue with the currently installed version of this module, but note that this may cause errors!", mode=LogMode.ONCE, ) if log: G_LOGGER.module_info(mod) return mod MODULE_VAR_NAME = "module" class LazyModule: def __init__(self): super().__setattr__(MODULE_VAR_NAME, None) def __polygraphy_import_mod(self): if self.module is None: super().__setattr__(MODULE_VAR_NAME, import_mod()) return self.module def __getattr__(self, name): module = self.__polygraphy_import_mod() return getattr(module, name) def __setattr__(self, name, value): module = self.__polygraphy_import_mod() return setattr(module, name, value) def is_installed(self): """ Checks whether this module is installed and has the correct version. The module will not be imported by this method. Returns: bool: Whether the module is installed and has the correct version. """ global importlib try: has_mod = name in sys.modules or (importlib.util.find_spec(name) is not None) except ValueError: return False if not has_mod: return False if version is None: return True try: # For newer versions of Python installed_version = importlib.metadata.version(name) except NameError: try: # This may fail if setuptools is not installed. # In that case, assume that we have the correct version of the package - if we don't, # a warning will be issued when we try to import. installed_version = pkg_resources.get_distribution(name).version except: return True return _version_ok(installed_version, version) def is_importable(self): """ Checks whether this module is importable. Note that a module may be installed but not importable. Returns: bool: Whether the module is importable. """ try: importlib.import_module(name) return True except: return False return LazyModule()
[docs]def has_mod(modname): """ Checks whether a module is installed without importing the module. Args: modname (str): The name of the module to check. Returns: bool: Whether the module is installed. """ try: return modname in sys.modules or (importlib.util.find_spec(modname) is not None) except ValueError: return False
[docs]def autoinstall(lazy_mod): """ If the config.AUTOINSTALL_DEPS is set to 1, automatically install or upgrade a module. Does nothing if autoinstallation is disabled. Args: lazy_mod (LazyModule): A lazy module, like that returned by ``lazy_import``. """ from polygraphy import config if not config.AUTOINSTALL_DEPS: return try: # It doesn't matter which attribute we try to get as any call to `__getattr__` will # trigger the automatic installation. getattr(lazy_mod, "__fake_polygraphy_autoinstall_attr") except: pass
[docs]def import_from_script(path, name): """ Imports a specified symbol from a Python script. Args: path (str): A path to the Python script. The path must include a '.py' extension. name (str): The name of the symbol to import from the script. Returns: object: The loaded symbol. """ from polygraphy.logger import G_LOGGER dir = os.path.dirname(path) modname = os.path.splitext(os.path.basename(path))[0] sys.path.insert(0, dir) with contextlib.ExitStack() as stack: def reset_sys_path(): del sys.path[0] stack.callback(reset_sys_path) try: mod = importlib.import_module(modname) return getattr(mod, name) except Exception as err: ext = os.path.splitext(path)[1] err_msg = f"Could not import symbol: {name} from script: {path}" if ext != ".py": err_msg += ( f"\nThis could be because the extension of the file is not '.py'. Note: The extension is: {ext}" ) err_msg += f"\nNote: Error was: {err}" err_msg += f"\nNote: sys.path was: {sys.path}" G_LOGGER.critical(err_msg)