Source code for polygraphy.backend.base.runner

#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import time
from collections import defaultdict

from polygraphy import config, func, mod, util
from polygraphy.logger import G_LOGGER, LogMode

np = mod.lazy_import("numpy")


@mod.export()
class BaseRunner(object):
    """
    Base class for Polygraphy runners. All runners should override the functions and attributes specified here.
    """

    RUNNER_COUNTS = defaultdict(int)

    def __init__(self, name=None, prefix=None):
        """
        Args:
            name (str):
                    The name to use for this runner.
            prefix (str):
                    The human-readable name prefix to use for this runner.
                    A runner count and timestamp will be appended to this prefix.
                    Only used if name is not provided.
        """
        prefix = util.default(prefix, "Runner")
        if name is None:
            count = BaseRunner.RUNNER_COUNTS[prefix]
            BaseRunner.RUNNER_COUNTS[prefix] += 1
            name = "{:}-N{:}-{:}-{:}".format(prefix, count, time.strftime("%x"), time.strftime("%X"))
        self.name = name
        self.inference_time = None

        self.is_active = False
        """bool: Whether this runner has been activated, either via context manager, or by calling ``activate()``."""

    @func.constantmethod
    def last_inference_time(self):
        """
        Returns the total inference time required during the last call to ``infer()``.

        Returns:
            float: The time in seconds, or None if runtime was not measured by the runner.
        """
        if self.inference_time is None:
            G_LOGGER.warning(
                "{:35} | inference_time was not set. Inference time will be incorrect!"
                "To correctly compare runtimes, please set the inference_time property in the"
                "infer() function".format(self.name),
                mode=LogMode.ONCE,
            )
            return None
        return self.inference_time

    def __enter__(self):
        """
        Activate the runner for inference. This may involve allocating GPU buffers, for example.
        """
        self.activate()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """
        Deactivate the runner.

        If the POLYGRAPHY_INTERNAL_CORRECTNESS_CHECKS environment variable is set to `1`, this
        will also check that the runner was reset to its state prior to activation.
        """
        self.deactivate()

    def activate_impl(self):
        """
        Implementation for runner activation. Derived classes should override this function
        rather than ``activate()``.
        """
        pass

    def activate(self):
        """
        Activate the runner for inference. This may involve allocating GPU buffers, for example.

        Generally, you should use a context manager instead of manually activating and deactivating.
        For example:
        ::

            with RunnerType(...) as runner:
                runner.infer(...)
        """
        if self.is_active:
            G_LOGGER.warning(
                "{:35} | Already active; will not activate again. If you really want to "
                "activate this runner again, call activate_impl() directly".format(self.name)
            )
            return

        if config.INTERNAL_CORRECTNESS_CHECKS:
            self._pre_activate_runner_state = copy.copy(vars(self))

        self.activate_impl()
        self.is_active = True

    def infer_impl(self, feed_dict):
        """
        Implementation for runner inference. Derived classes should override this function
        rather than ``infer()``
        """
        raise NotImplementedError("BaseRunner is an abstract class")

    def infer(self, feed_dict, check_inputs=True):
        """
        Runs inference using the provided feed_dict.

        Args:
            feed_dict (OrderedDict[str, numpy.ndarray]):
                    A mapping of input tensor names to corresponding input NumPy arrays.

            check_inputs (bool):
                    Whether to check that the provided ``feed_dict`` includes the expected inputs
                    with the expected data types and shapes.

        Returns:
            OrderedDict[str, numpy.ndarray]:
                    A mapping of output tensor names to their corresponding NumPy arrays.

                    IMPORTANT: Runners may reuse these output buffers. Thus, if you need to save
                    outputs from multiple inferences, you should make a copy with ``copy.deepcopy(outputs)``.
        """
        if not self.is_active:
            G_LOGGER.critical("{:35} | Must be activated prior to calling infer()".format(self.name))

        if check_inputs:
            input_metadata = self.get_input_metadata()
            G_LOGGER.verbose("Runner input metadata is: {:}".format(input_metadata))

            util.check_dict_contains(
                feed_dict, input_metadata.keys(), dict_name="feed_dict", log_func=G_LOGGER.critical
            )

            for name, inp in feed_dict.items():
                meta = input_metadata[name]
                if not np.issubdtype(inp.dtype, meta.dtype):
                    G_LOGGER.critical(
                        "Input tensor: {:} | Received unexpected dtype: {:}.\n"
                        "Note: Expected type: {:}".format(name, inp.dtype, meta.dtype)
                    )

                if not util.is_valid_shape_override(inp.shape, meta.shape):
                    G_LOGGER.critical(
                        "Input tensor: {:} | Received incompatible shape: {:}.\n"
                        "Note: Expected a shape compatible with: {:}".format(name, inp.shape, meta.shape)
                    )

        return self.infer_impl(feed_dict)

    @func.constantmethod
    def get_input_metadata_impl(self):
        """
        Implemenation for `get_input_metadata`. Derived classes should override this function
        rather than `get_input_metadata`.
        """
        raise NotImplementedError("BaseRunner is an abstract class")

    def get_input_metadata(self):
        """
        Returns information about the inputs of the model.
        Shapes here may include dynamic dimensions, represented by ``None``.
        Must be called only after activate() and before deactivate().

        Returns:
            TensorMetadata: Input names, shapes, and data types.
        """
        return self.get_input_metadata_impl()

    def deactivate_impl(self):
        """
        Implementation for runner deactivation. Derived classes should override this function
        rather than ``deactivate()``.
        """
        pass

    def deactivate(self):
        """
        Deactivate the runner.

        If the POLYGRAPHY_INTERNAL_CORRECTNESS_CHECKS environment variable is set to `1`, this
        will also check that the runner was reset to its state prior to activation.

        Generally, you should use a context manager instead of manually activating and deactivating.
        For example:
        ::

            with RunnerType(...) as runner:
                runner.infer(...)
        """
        if not self.is_active:
            G_LOGGER.warning(
                "{:35} | Not active; will not deactivate. If you really want to "
                "deactivate this runner, call deactivate_impl() directly".format(self.name)
            )
            return

        self.inference_time = None
        self.is_active = None

        try:
            self.deactivate_impl()
        except:
            raise  # Needed so we can have the else clause
        else:
            self.is_active = False
            if config.INTERNAL_CORRECTNESS_CHECKS:
                old_state = self._pre_activate_runner_state
                del self._pre_activate_runner_state
                if old_state != vars(self):
                    G_LOGGER.internal_error(
                        "Runner state was not reset after deactivation. "
                        "Note:\nOld state: {:}\nNew state: {:}".format(old_state, vars(self))
                    )

    def __del__(self):
        if self.is_active:
            # __del__ is not guaranteed to be called, but when it is, this could be a useful warning.
            print("[W] {:35} | Was activated but never deactivated. This could cause a memory leak!".format(self.name))