Source code for polygraphy.backend.pluginref.runner
## SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.# SPDX-License-Identifier: Apache-2.0## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importcopyimporttimefromcollectionsimportOrderedDictfrompolygraphyimportmod,utilfrompolygraphy.backend.baseimportBaseRunnerfrompolygraphy.backend.pluginref.referencesimportOP_REGISTRYfrompolygraphy.loggerimportG_LOGGERnp=mod.lazy_import("numpy")onnx_util=mod.lazy_import("polygraphy.backend.onnx.util")
[docs]@mod.export()classPluginRefRunner(BaseRunner):""" Runs inference using custom CPU reference implementations """def__init__(self,graph,name=None):""" Args: graph (Union[onnx_graphsurgeon.Graph, Callable() -> onnx_graphsurgeon.Graph]): An ONNX-GraphSurgeon graph or a callable that returns one. name (str): The human-readable name prefix to use for this runner. A runner count and timestamp will be appended to this prefix. """super().__init__(name=name,prefix="pluginref-runner")self._graph=graph@util.check_called_by("activate")defactivate_impl(self):self.graph,_=util.invoke_if_callable(self._graph)@util.check_called_by("get_input_metadata")defget_input_metadata_impl(self):returnonnx_util.meta_from_gs_tensors(self.graph.inputs)@util.check_called_by("infer")definfer_impl(self,feed_dict):start=time.time()intermediate_tensors=copy.copy(feed_dict)fornodeinself.graph.nodes:ifnode.opnotinOP_REGISTRY:G_LOGGER.critical(f"Op: {node.op} does not have a reference implementation registered!")intermediate_tensors.update(OP_REGISTRY[node.op](node,intermediate_tensors))outputs=OrderedDict()foroutinself.graph.outputs:outputs[out.name]=intermediate_tensors[out.name]end=time.time()self.inference_time=end-startreturnoutputs@util.check_called_by("deactivate")defdeactivate_impl(self):delself.graph