#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from polygraphy import mod, util
from polygraphy.logger.logger import G_LOGGER
from polygraphy.tools.args import util as args_util
from polygraphy.tools.args.base import BaseArgs
from polygraphy.tools.args.model import ModelArgs
from polygraphy.tools.args.backend.onnx.loader import OnnxLoadArgs
from polygraphy.tools.args.backend.trt.config import TrtConfigArgs
from polygraphy.tools.script import make_invocable
[docs]@mod.export()
class TrtLoadPluginsArgs(BaseArgs):
"""
TensorRT Plugin Loading: loading TensorRT plugins.
"""
def add_parser_args_impl(self):
self.group.add_argument("--plugins", help="Path(s) of plugin libraries to load", nargs="+", default=None)
[docs] def parse_impl(self, args):
"""
Parses command-line arguments and populates the following attributes:
Attributes:
plugins (List[str]): Path(s) to plugin libraries.
"""
self.plugins = args_util.get(args, "plugins")
# If plugins are present, wrap the provided loader/object with LoadPlugins
[docs] def add_to_script_impl(self, script, loader_name: str):
"""
Args:
loader_name (str):
The name of the loader which should be consumed by the ``LoadPlugins`` loader.
"""
if self.plugins:
script.add_import(imports=["LoadPlugins"], frm="polygraphy.backend.trt")
loader_str = make_invocable("LoadPlugins", plugins=self.plugins, obj=loader_name)
loader_name = script.add_loader(loader_str, "load_plugins")
return loader_name
[docs]@mod.export()
class TrtLoadNetworkArgs(BaseArgs):
"""
TensorRT Network Loading: loading TensorRT networks.
Depends on:
- ModelArgs
- TrtLoadPluginsArgs
- OnnxLoadArgs: if allow_onnx_loading == True
"""
def __init__(self, allow_custom_outputs: bool = None, allow_onnx_loading: bool = None):
"""
Args:
allow_custom_outputs (bool):
Whether to allow marking custom output tensors.
Defaults to True.
allow_onnx_loading (bool):
Whether to allow parsing networks from an ONNX model.
Defaults to True.
"""
super().__init__()
self._allow_custom_outputs = util.default(allow_custom_outputs, True)
self._allow_onnx_loading = util.default(allow_onnx_loading, True)
def add_parser_args_impl(self):
self.group.add_argument(
"--explicit-precision",
help="[DEPRECATED] Enable explicit precision mode",
action="store_true",
default=None,
)
if self._allow_custom_outputs:
self.group.add_argument(
"--trt-outputs",
help="Name(s) of TensorRT output(s). "
"Using '--trt-outputs mark all' indicates that all tensors should be used as outputs",
nargs="+",
default=None,
)
self.group.add_argument(
"--trt-exclude-outputs",
help="[EXPERIMENTAL] Name(s) of TensorRT output(s) to unmark as outputs.",
nargs="+",
default=None,
)
self.group.add_argument(
"--trt-network-func-name",
help="When using a trt-network-script instead of other model types, this specifies the name "
"of the function that loads the network. Defaults to `load_network`.",
default="load_network",
)
[docs] def parse_impl(self, args):
"""
Parses command-line arguments and populates the following attributes:
Attributes:
outputs (List[str]): Names of output tensors.
exclude_outputs (List[str]): Names of tensors which should be unmarked as outputs.
trt_network_func_name (str): The name of the function in a custom network script that creates the network.
"""
self.outputs = args_util.get_outputs(args, "trt_outputs")
self.explicit_precision = args_util.get(args, "explicit_precision")
if self.explicit_precision is not None:
mod.warn_deprecated("--explicit-precision", use_instead=None, remove_in="0.42.0", always_show_warning=True)
self.exclude_outputs = args_util.get(args, "trt_exclude_outputs")
self.trt_network_func_name = args_util.get(args, "trt_network_func_name")
def add_to_script_impl(self, script):
model_file = self.arg_groups[ModelArgs].path
model_type = self.arg_groups[ModelArgs].model_type
outputs = args_util.get_outputs_for_script(script, self.outputs)
if model_type == "trt-network-script":
script.add_import(imports=["InvokeFromScript"], frm="polygraphy.backend.common")
loader_str = make_invocable("InvokeFromScript", model_file, name=self.trt_network_func_name)
loader_name = script.add_loader(loader_str, "load_network")
elif self._allow_onnx_loading:
if self.arg_groups[OnnxLoadArgs].must_use_onnx_loader(disable_custom_outputs=True):
# When loading from ONNX, we need to disable custom outputs since TRT requires dtypes on outputs,
# which our marking function doesn't guarantee.
script.add_import(imports=["NetworkFromOnnxBytes"], frm="polygraphy.backend.trt")
onnx_loader = self.arg_groups[OnnxLoadArgs].add_to_script(
script, disable_custom_outputs=True, serialize_model=True
)
loader_str = make_invocable(
"NetworkFromOnnxBytes",
self.arg_groups[TrtLoadPluginsArgs].add_to_script(script, onnx_loader),
explicit_precision=self.explicit_precision,
)
loader_name = script.add_loader(loader_str, "parse_network_from_onnx")
else:
script.add_import(imports=["NetworkFromOnnxPath"], frm="polygraphy.backend.trt")
loader_str = make_invocable(
"NetworkFromOnnxPath",
self.arg_groups[TrtLoadPluginsArgs].add_to_script(script, model_file),
explicit_precision=self.explicit_precision,
)
loader_name = script.add_loader(loader_str, "parse_network_from_onnx")
else:
G_LOGGER.internal_error("Loading from ONNX is not enabled and a network script was not provided!")
MODIFY_NETWORK = "ModifyNetworkOutputs"
modify_network_str = make_invocable(
MODIFY_NETWORK, loader_name, outputs=outputs, exclude_outputs=self.exclude_outputs
)
if str(modify_network_str) != str(make_invocable(MODIFY_NETWORK, loader_name)):
script.add_import(imports=[MODIFY_NETWORK], frm="polygraphy.backend.trt")
loader_name = script.add_loader(modify_network_str, "modify_network")
return loader_name
[docs] def load_network(self):
"""
Loads a TensorRT Network model according to arguments provided on the command-line.
Returns:
tensorrt.INetworkDefinition
"""
loader = args_util.run_script(self.add_to_script)
return loader()
[docs]@mod.export()
class TrtSaveEngineArgs(BaseArgs):
"""
TensorRT Engine Saving: saving TensorRT engines.
"""
def __init__(self, output_opt: str = None, output_short_opt: str = None):
"""
Args:
output_opt (str):
The name of the output path option.
Defaults to "output".
Use a value of ``False`` to disable the option.
output_short_opt (str):
The short option to use for the output path.
Defaults to "-o".
Use a value of ``False`` to disable the short option.
"""
super().__init__()
self._output_opt = util.default(output_opt, "output")
self._output_short_opt = util.default(output_short_opt, "-o")
def add_parser_args_impl(self):
if self._output_opt:
params = ([self._output_short_opt] if self._output_short_opt else []) + [f"--{self._output_opt}"]
self.group.add_argument(*params, help="Path to save the TensorRT Engine", dest="save_engine", default=None)
[docs] def parse_impl(self, args):
"""
Parses command-line arguments and populates the following attributes:
Attributes:
path (str): The path at which to save the TensorRT engine.
"""
self.path = args_util.get(args, "save_engine")
[docs] def add_to_script_impl(self, script, loader_name):
"""
Args:
loader_name (str):
The name of the loader which should be consumed by the ``SaveEngine`` loader.
Returns:
str: The name of the ``SaveEngine`` loader added to the script.
"""
if self.path is None:
return loader_name
script.add_import(imports=["SaveEngine"], frm="polygraphy.backend.trt")
return script.add_loader(make_invocable("SaveEngine", loader_name, path=self.path), "save_engine")
[docs] def save_engine(self, engine, path=None):
"""
Saves a TensorRT engine according to arguments provided on the command-line.
Args:
model (onnx.ModelProto): The TensorRT engine to save.
path (str):
The path at which to save the engine.
If no path is provided, it is determined from command-line arguments.
Returns:
tensorrt.ICudaEngine: The engine that was saved.
"""
with util.TempAttrChange(self, {"path": path}):
loader = args_util.run_script(self.add_to_script, engine)
return loader()
[docs]@mod.export()
class TrtLoadEngineArgs(BaseArgs):
"""
TensorRT Engine: loading TensorRT engines.
Depends on:
- ModelArgs
- TrtLoadPluginsArgs
- TrtLoadNetworkArgs: if support for building engines is required
- TrtConfigArgs: if support for building engines is required
- TrtSaveEngineArgs: if allow_saving == True
"""
def __init__(self, allow_saving: bool = None):
"""
Args:
allow_saving (bool):
Whether to allow loaded models to be saved.
Defaults to False.
"""
super().__init__()
self._allow_saving = util.default(allow_saving, False)
def add_parser_args_impl(self):
self.group.add_argument(
"--save-timing-cache",
help="Path to save tactic timing cache if building an engine. "
"Existing caches will be appended to with any new timing information gathered. ",
default=None,
)
[docs] def parse_impl(self, args):
"""
Parses command-line arguments and populates the following attributes:
Attributes:
save_timing_cache (str): Path at which to save the tactic timing cache.
"""
self.save_timing_cache = args_util.get(args, "save_timing_cache")
[docs] def add_to_script_impl(self, script, network_name=None):
"""
Args:
network_name (str): The name of a variable in the script pointing to a network loader.
"""
if self.arg_groups[ModelArgs].model_type == "engine":
script.add_import(imports=["EngineFromBytes"], frm="polygraphy.backend.trt")
script.add_import(imports=["BytesFromPath"], frm="polygraphy.backend.common")
load_engine = script.add_loader(
make_invocable("BytesFromPath", self.arg_groups[ModelArgs].path), "load_engine_bytes"
)
return script.add_loader(
make_invocable(
"EngineFromBytes", self.arg_groups[TrtLoadPluginsArgs].add_to_script(script, load_engine)
),
"deserialize_engine",
)
network_loader_name = network_name
if network_loader_name is None:
if TrtLoadNetworkArgs not in self.arg_groups:
G_LOGGER.internal_error("TrtNetworkLoaderArgs is required for engine building!")
network_loader_name = self.arg_groups[TrtLoadNetworkArgs].add_to_script(script)
if TrtConfigArgs not in self.arg_groups:
G_LOGGER.internal_error("TrtConfigArgs is required for engine building!")
script.add_import(imports=["EngineFromNetwork"], frm="polygraphy.backend.trt")
config_loader_name = self.arg_groups[TrtConfigArgs].add_to_script(script)
loader_str = make_invocable(
"EngineFromNetwork",
self.arg_groups[TrtLoadPluginsArgs].add_to_script(script, network_loader_name),
config=config_loader_name,
# Needed to support legacy --timing-cache argument
save_timing_cache=self.save_timing_cache or self.arg_groups[TrtConfigArgs].timing_cache,
)
loader_name = script.add_loader(loader_str, "build_engine")
if self._allow_saving:
loader_name = self.arg_groups[TrtSaveEngineArgs].add_to_script(script, loader_name)
return loader_name
[docs] def load_engine(self, network=None):
"""
Loads a TensorRT engine according to arguments provided on the command-line.
Args:
network (Tuple[trt.Builder, trt.INetworkDefinition, Optional[parser]]):
A tuple containing a TensorRT builder, network and optionally parser.
Returns:
tensorrt.ICudaEngine: The engine.
"""
loader = args_util.run_script(self.add_to_script, network)
return loader()