Source code for polygraphy.mod.exporter

# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import sys
import warnings
from textwrap import dedent

import polygraphy
from polygraphy import config
from polygraphy.logger import G_LOGGER
from polygraphy.mod.util import version

def _add_to_all(symbol, module):
    if hasattr(module, "__all__"):
        module.__all__ = [symbol]

def _define_in_module(name, symbol, module):
    assert name not in vars(module), "This symbol is already defined!"
    vars(module)[name] = symbol
    _add_to_all(name, module)

def export(funcify=False, func_name=None):
    Decorator that exports a symbol into the ``__all__`` attribute of
    the caller's module. This makes the symbol visible in a ``*`` import
    (e.g. ``from module import *``) and hides other symbols unless they are
    also present in ``__all__``.

        funcify (bool):
                Whether to create and export a function that will call a decorated Polygraphy loader.
                The decorated type *must* be a subclass of ``BaseLoader`` if ``funcify=True``.

                This is useful to provide convenient short-hands to immediately evaluate loaders.
                For example:

                    class SuperCoolModelFromPath(BaseLoader):
                        def __init__(self, init_params):

                        def call_impl(self, call_params):

                    # We can now magically access an immediately evaluated functional
                    # variant of the loader:
                    model = super_cool_model_from_path(init_params, call_params)

                    # Which is equivalent to:
                    load_model = SuperCoolModelFromPath(init_params)
                    model = load_model(call_params)

                The signature of the generated function is a combination of the signatures
                of ``__init__`` and ``call_impl``. Specifically, parameters without defaults will
                precede those with defaults, and ``__init__`` parameters will precede ``call_impl``
                parameters. Special parameters like ``*args`` and ``**kwargs`` will always be the last
                parameters in the generated signature if they are present in the loader method signatures.
                The return value(s) will always come from ``call_impl``.

                For example:

                    # With __init__ signature:
                    def __init__(a, b=0) -> None:

                    # And call_impl signature:
                    def call_impl(c, d=0) -> z:

                    # The generated function will have a signature:
                    def generated(a, c, b=0, d=0) -> z:

        func_name (str):
                If funcify is True, this controls the name of the generated function.
                By default, the exported function will use the same name as the loader, but
                ``snake_case`` instead of ``PascalCase``.
    module = inspect.getmodule(sys._getframe(1))

    # Find a method by wallking the inheritance hierarchy of a type:
    def find_method(symbol, method):
        hierarchy = inspect.getmro(symbol)
        for ancestor in hierarchy:
            if method in vars(ancestor):
                return vars(ancestor)[method]

        assert (
        ), f"Could not find method: {method} in the inheritance hierarcy of: {symbol}"

    def export_impl(func_or_cls):
        _add_to_all(func_or_cls.__name__, module)

        if funcify:
            # We only support funcify-ing BaseLoaders, and only if __init__ and call_impl
            # have no overlapping parameters.
            from polygraphy.backend.base import BaseLoader

            assert inspect.isclass(
            ), "Decorated type must be a loader to use funcify=True"
            assert BaseLoader in inspect.getmro(
            ), "Decorated type must derive from BaseLoader to use funcify=True"

            def get_params(method):
                return list(
                        find_method(func_or_cls, method)

            def is_variadic(param):
                return param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]

            def has_default(param):
                return param.default != param.empty

            def get_param_name(p):
                # For variadic arguments, will drop the *, **
                return str(p) if is_variadic(p) else

            def param_names(params):
                return [get_param_name(p) for p in params]

            loader = func_or_cls

            init_params = get_params("__init__")
            call_impl_params = get_params("call_impl")

            assert (
                set(param_names(call_impl_params)) - set(param_names(init_params))
            ) == set(
            ), "Cannot funcify a type where call_impl and __init__ have the same argument names!"

            # Dynamically generate a function with the right signature.

            # To generate the signature, we use the init and call_impl arguments,
            # but move required arguments (i.e. without default values) to the front.

            def build_arg_list(should_include):
                def str_from_param(p):
                    return get_param_name(p) + (
                        f"={p.default}" if has_default(p) else ""

                arg_list = [str_from_param(p) for p in init_params if should_include(p)]
                arg_list += [
                    str_from_param(p) for p in call_impl_params if should_include(p)
                return arg_list

            non_default_args = build_arg_list(
                should_include=lambda p: not is_variadic(p) and not has_default(p)
            default_args = build_arg_list(
                should_include=lambda p: not is_variadic(p) and has_default(p)
            special_args = build_arg_list(should_include=is_variadic)

            signature = ", ".join(non_default_args + default_args + special_args)

            init_args = ", ".join(param_names(init_params))
            call_impl_args = ", ".join(param_names(call_impl_params))

            def pascal_to_snake(name):
                return "".join(
                    f"_{c.lower()}" if c.isupper() else c for c in name

            nonlocal func_name
            func_name = func_name or pascal_to_snake(loader.__name__)

            func_code = dedent(
                def {func_name}({signature}):
                    return loader_binding({init_args})({call_impl_args})

                func_var = {func_name}

                func_code, {"loader_binding": loader}, locals()
            )  # Need to bind the loader this way, or it won't be accesible from func_code.
            func = locals()["func_var"]

            # Next we setup the docstring so that it is a combination of the __init__
            # and call_impl docstrings.
            func.__doc__ = f"Immediately evaluated functional variant of :class:`{loader.__name__}` .\n"

            def try_add_method_doc(method):
                call_impl = find_method(loader, method)
                if call_impl.__doc__:
                    func.__doc__ += dedent(call_impl.__doc__)


            # Now that the function has been defined, we just need to add it into the module's
            # __dict__ so it is accessible like a normal symbol.
            _define_in_module(func_name, func, module)

        # We don't actually want to modify the decorated object.
        return func_or_cls

    return export_impl

def warn_deprecated(
    name, use_instead, remove_in, module_name=None, always_show_warning=False

    if version(polygraphy.__version__) >= version(remove_in):
            f"{name} should have been removed in version: {remove_in}"

    full_obj_name = f"{module_name}.{name}" if module_name else name
    msg = (
        f"{full_obj_name} is deprecated and will be removed in Polygraphy {remove_in}."
    if use_instead is not None:
        msg += f" Use {use_instead} instead."

    warnings.warn(msg, DeprecationWarning, stacklevel=3)
    if always_show_warning:

def deprecate(remove_in, use_instead, module_name=None, name=None):
    Decorator that marks a function or class as deprecated.
    When the function or class is used, a warning will be issued.

        remove_in (str):
                The version in which the decorated type will be removed.
        use_instead (str):
                The function or class to use instead.
        module_name (str):
                The name of the containing module. This will be used to
                generate more informative warnings.
                Defaults to None.
        name (str):
                The name of the object being deprecated.
                If not provided, this is automatically determined based on the decorated type.
                Defaults to None.

    def deprecate_impl(obj):
        if config.INTERNAL_CORRECTNESS_CHECKS and version(
        ) >= version(remove_in):
                f"{obj} should have been removed in version: {remove_in}"

        nonlocal name
        name = name or obj.__name__

        if inspect.ismodule(obj):

            class DeprecatedModule:
                def __getattr__(self, attr_name):
                    warn_deprecated(name, use_instead, remove_in, module_name)
                    self = obj
                    return getattr(self, attr_name)

                def __setattr__(self, attr_name, value):
                    warn_deprecated(name, use_instead, remove_in, module_name)
                    self = obj
                    return setattr(self, attr_name, value)

            DeprecatedModule.__doc__ = f"Deprecated: Use {use_instead} instead"
            return DeprecatedModule()
        elif inspect.isclass(obj):

            class Deprecated(obj):
                def __init__(self, *args, **kwargs):
                    warn_deprecated(name, use_instead, remove_in, module_name)
                    super().__init__(*args, **kwargs)

            Deprecated.__doc__ = f"Deprecated: Use {use_instead} instead"
            return Deprecated
        elif inspect.isfunction(obj):

            def wrapped(*args, **kwargs):
                warn_deprecated(name, use_instead, remove_in, module_name)
                return obj(*args, **kwargs)

            wrapped.__doc__ = f"Deprecated: Use {use_instead} instead"
            return wrapped
            G_LOGGER.internal_error(f"deprecate is not implemented for: {obj}")

    return deprecate_impl

def export_deprecated_alias(name, remove_in, use_instead=None):
    Decorator that creates and exports a deprecated alias for
    the decorated class or function.

    The alias will behave like the decorated type, except it will
    issue a deprecation warning when used.

    To create a deprecated alias for an entire module, invoke the
    function manually within the module like so:

        mod.export_deprecated_alias("old_mod_name", remove_in="0.0.0")(sys.modules[__name__])

        name (str):
                The name of the deprecated alias.
        remove_in (str):
                The version, as a string, in which the deprecated alias will be removed.
        use_instead (str):
                The name of the function, class, or module to use instead.
                If this is ``None``, the new name will be automatically determined.
                Defaults to None.
    module = inspect.getmodule(sys._getframe(1))

    def export_deprecated_alias_impl(obj):
        new_obj = deprecate(
            use_instead=use_instead or obj.__name__,
        _define_in_module(name, new_obj, module)
        return obj

    return export_deprecated_alias_impl