Source code for polygraphy.tools.args.backend.onnxrt.loader
## SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.# SPDX-License-Identifier: Apache-2.0## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#frompolygraphyimportmodfrompolygraphy.tools.argsimportutilasargs_utilfrompolygraphy.tools.args.baseimportBaseArgsfrompolygraphy.tools.args.modelimportModelArgsfrompolygraphy.tools.args.backend.onnx.loaderimportOnnxLoadArgsfrompolygraphy.tools.scriptimportmake_invocable
[docs]@mod.export()classOnnxrtSessionArgs(BaseArgs):""" ONNX-Runtime Session Creation: creating an ONNX-Runtime Inference Session Depends on: - OnnxLoadArgs - ModelArgs """defadd_parser_args_impl(self):self.group.add_argument("--providers","--execution-providers",dest="providers",help="A list of execution providers to use in order of priority. ""Each provider may be either an exact match or a case-insensitive partial match ""for the execution providers available in ONNX-Runtime. For example, a value of 'cpu' would ""match the 'CPUExecutionProvider'",nargs="+",default=None,)
[docs]defparse_impl(self,args):""" Parses command-line arguments and populates the following attributes: Attributes: providers (List[str]): A list of execution providers. """self.providers=args_util.get(args,"providers")
defadd_to_script_impl(self,script,onnx_name=None):ifonnx_nameisNone:# default behavior according to self.arg_groupsifself.arg_groups[OnnxLoadArgs].must_use_onnx_loader():onnx_name=self.arg_groups[OnnxLoadArgs].add_to_script(script,serialize_model=True)else:onnx_name=self.arg_groups[ModelArgs].pathscript.add_import(imports=["SessionFromOnnx"],frm="polygraphy.backend.onnxrt")loader_name=script.add_loader(make_invocable("SessionFromOnnx",onnx_name,providers=self.providers),"build_onnxrt_session",)returnloader_name
[docs]defload_onnxrt_session(self,model=None):""" Loads an ONNX-Runtime Inference Session according to arguments provided on the command-line. Args: model (Union[bytes, str]): The model bytes or path to a model. Defaults to None, in which case, the model specified on the command-line is used. Returns: onnxruntime.InferenceSession """loader=args_util.run_script(self.add_to_script,model)returnloader()