TensorRT 10.6.0
NvInferRuntimePlugin.h
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef NV_INFER_RUNTIME_PLUGIN_H
19#define NV_INFER_RUNTIME_PLUGIN_H
20
21#define NV_INFER_INTERNAL_INCLUDE 1
22#include "NvInferPluginBase.h"
23#undef NV_INFER_INTERNAL_INCLUDE
24
33
39namespace nvinfer1
40{
41
42enum class TensorFormat : int32_t;
43namespace v_1_0
44{
45class IGpuAllocator;
46}
48
55
59static constexpr int32_t kPLUGIN_VERSION_PYTHON_BIT = 0x40;
60
73{
81 float scale;
82};
83
91enum class PluginVersion : uint8_t
92{
94 kV2 = 0,
96 kV2_EXT = 1,
98 kV2_IOEXT = 2,
100 kV2_DYNAMICEXT = 3,
102 kV2_DYNAMICEXT_PYTHON = kPLUGIN_VERSION_PYTHON_BIT | 3
103};
104
110enum class PluginCreatorVersion : int32_t
111{
113 kV1 = 0,
115 kV1_PYTHON = kPLUGIN_VERSION_PYTHON_BIT
116};
117
134{
135public:
148 virtual int32_t getTensorRTVersion() const noexcept
149 {
150 return NV_TENSORRT_VERSION;
151 }
152
166 virtual AsciiChar const* getPluginType() const noexcept = 0;
167
181 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
182
196 virtual int32_t getNbOutputs() const noexcept = 0;
197
221 virtual Dims getOutputDimensions(int32_t index, Dims const* inputs, int32_t nbInputDims) noexcept = 0;
222
246 virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept = 0;
247
280 virtual void configureWithFormat(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
281 DataType type, PluginFormat format, int32_t maxBatchSize) noexcept
282 = 0;
283
295 virtual int32_t initialize() noexcept = 0;
296
310 virtual void terminate() noexcept = 0;
311
329 virtual size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept = 0;
330
352 virtual int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace,
353 cudaStream_t stream) noexcept
354 = 0;
355
366 virtual size_t getSerializationSize() const noexcept = 0;
367
381 virtual void serialize(void* buffer) const noexcept = 0;
382
391 virtual void destroy() noexcept = 0;
392
410 virtual IPluginV2* clone() const noexcept = 0;
411
426 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
427
439 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
440
441 // @cond SuppressDoxyWarnings
442 IPluginV2() = default;
443 virtual ~IPluginV2() noexcept = default;
444// @endcond
445
446protected:
447// @cond SuppressDoxyWarnings
448 IPluginV2(IPluginV2 const&) = default;
449 IPluginV2(IPluginV2&&) = default;
450 IPluginV2& operator=(IPluginV2 const&) & = default;
451 IPluginV2& operator=(IPluginV2&&) & = default;
452// @endcond
453};
454
470{
471public:
496 int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept
497 = 0;
498
520 int32_t outputIndex, bool const* inputIsBroadcasted, int32_t nbInputs) const noexcept
521 = 0;
522
548 TRT_DEPRECATED virtual bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept = 0;
549
587 virtual void configurePlugin(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
588 DataType const* inputTypes, DataType const* outputTypes, bool const* inputIsBroadcast,
589 bool const* outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept
590 = 0;
591
592 IPluginV2Ext() = default;
593 ~IPluginV2Ext() override = default;
594
629 virtual void attachToContext(
630 cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, IGpuAllocator* /*allocator*/) noexcept
631 {
632 }
633
647 virtual void detachFromContext() noexcept {}
648
663 IPluginV2Ext* clone() const noexcept override = 0;
664
665protected:
666 // @cond SuppressDoxyWarnings
667 IPluginV2Ext(IPluginV2Ext const&) = default;
668 IPluginV2Ext(IPluginV2Ext&&) = default;
669 IPluginV2Ext& operator=(IPluginV2Ext const&) & = default;
670 IPluginV2Ext& operator=(IPluginV2Ext&&) & = default;
671// @endcond
672
688 int32_t getTensorRTVersion() const noexcept override
689 {
690 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_EXT) << 24U)
691 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
692 }
693
700 void configureWithFormat(Dims const* /*inputDims*/, int32_t /*nbInputs*/, Dims const* /*outputDims*/,
701 int32_t /*nbOutputs*/, DataType /*type*/, PluginFormat /*format*/, int32_t /*maxBatchSize*/) noexcept override
702 {
703 }
704};
705
719{
720public:
738 virtual void configurePlugin(
739 PluginTensorDesc const* in, int32_t nbInput, PluginTensorDesc const* out, int32_t nbOutput) noexcept
740 = 0;
741
780 int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept
781 = 0;
782
783 // @cond SuppressDoxyWarnings
784 IPluginV2IOExt() = default;
785 ~IPluginV2IOExt() override = default;
786// @endcond
787
788protected:
789// @cond SuppressDoxyWarnings
790 IPluginV2IOExt(IPluginV2IOExt const&) = default;
791 IPluginV2IOExt(IPluginV2IOExt&&) = default;
792 IPluginV2IOExt& operator=(IPluginV2IOExt const&) & = default;
793 IPluginV2IOExt& operator=(IPluginV2IOExt&&) & = default;
794// @endcond
795
807 int32_t getTensorRTVersion() const noexcept override
808 {
809 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_IOEXT) << 24U)
810 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
811 }
812
813private:
814 // Following are obsolete base class methods, and must not be implemented or used.
815
819 void configurePlugin(Dims const*, int32_t, Dims const*, int32_t, DataType const*, DataType const*, bool const*,
820 bool const*, PluginFormat, int32_t) noexcept final
821 {
822 }
823
827 bool supportsFormat(DataType, PluginFormat) const noexcept final
828 {
829 return false;
830 }
831};
832
833namespace v_1_0
834{
836{
837public:
850 virtual AsciiChar const* getPluginName() const noexcept = 0;
851
864 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
865
877 virtual PluginFieldCollection const* getFieldNames() noexcept = 0;
878
891 virtual IPluginV2* createPlugin(AsciiChar const* name, PluginFieldCollection const* fc) noexcept = 0;
892
908 virtual IPluginV2* deserializePlugin(AsciiChar const* name, void const* serialData, size_t serialLength) noexcept
909 = 0;
910
925 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
926
939 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
940
941 IPluginCreator() = default;
942 ~IPluginCreator() override = default;
943
944protected:
945 // @cond SuppressDoxyWarnings
946 IPluginCreator(IPluginCreator const&) = default;
947 IPluginCreator(IPluginCreator&&) = default;
948 IPluginCreator& operator=(IPluginCreator const&) & = default;
949 IPluginCreator& operator=(IPluginCreator&&) & = default;
950 // @endcond
951public:
955 InterfaceInfo getInterfaceInfo() const noexcept override
956 {
957 return InterfaceInfo{"PLUGIN CREATOR_V1", 1, 0};
958 }
959};
960} // namespace v_1_0
961
973
974} // namespace nvinfer1
975
976#endif // NV_INFER_RUNTIME_PLUGIN_H
#define NV_TENSORRT_VERSION
Definition: NvInferRuntimeBase.h:91
#define TRT_DEPRECATED
Definition: NvInferRuntimeBase.h:45
Application-implemented class for controlling allocation on the GPU.
Definition: NvInferRuntimeBase.h:200
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:470
virtual TRT_DEPRECATED bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept=0
Return true if the plugin can use an input tensor that is broadcast across batch without replication.
~IPluginV2Ext() override=default
void configureWithFormat(Dims const *, int32_t, Dims const *, int32_t, DataType, PluginFormat, int32_t) noexcept override
Derived classes must not implement this. In a C++11 API it would be override final.
Definition: NvInferRuntimePlugin.h:700
IPluginV2Ext * clone() const noexcept override=0
Clone the plugin object. This copies over internal plugin parameters as well and returns a new plugin...
virtual void configurePlugin(Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType const *inputTypes, DataType const *outputTypes, bool const *inputIsBroadcast, bool const *outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept=0
Configure the layer with input and output data types.
virtual void detachFromContext() noexcept
Detach the plugin object from its execution context.
Definition: NvInferRuntimePlugin.h:647
virtual TRT_DEPRECATED bool isOutputBroadcastAcrossBatch(int32_t outputIndex, bool const *inputIsBroadcasted, int32_t nbInputs) const noexcept=0
Return true if the output tensor is broadcast across a batch.
virtual void attachToContext(cudnnContext *, cublasContext *, IGpuAllocator *) noexcept
Attach the plugin object to an execution context and grant the plugin the access to some context reso...
Definition: NvInferRuntimePlugin.h:629
virtual nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const *inputTypes, int32_t nbInputs) const noexcept=0
Return the DataType of the plugin output at the requested index.
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:134
virtual AsciiChar const * getPluginType() const noexcept=0
Return the plugin type. Should match the plugin name returned by the corresponding plugin creator.
virtual int32_t getTensorRTVersion() const noexcept
Return the API version with which this plugin was built.
Definition: NvInferRuntimePlugin.h:148
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:719
int32_t getTensorRTVersion() const noexcept override
Return the API version with which this plugin was built. The upper byte is reserved by TensorRT and i...
Definition: NvInferRuntimePlugin.h:807
virtual void configurePlugin(PluginTensorDesc const *in, int32_t nbInput, PluginTensorDesc const *out, int32_t nbOutput) noexcept=0
Configure the layer.
virtual bool supportsFormatCombination(int32_t pos, PluginTensorDesc const *inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept=0
Return true if plugin supports the format and datatype for the input/output indexed by pos.
Version information associated with a TRT interface.
Definition: NvInferRuntimeBase.h:225
Definition: NvInferRuntime.h:1532
Definition: NvInferRuntimePlugin.h:836
virtual AsciiChar const * getPluginName() const noexcept=0
Return the plugin name.
Definition: NvInferPluginBase.h:186
The TensorRT API version 1 namespace.
PluginCreatorVersion
Enum to identify version of the plugin creator.
Definition: NvInferRuntimePlugin.h:111
@ kV1_PYTHON
IPluginCreator-based Python plugin creators.
v_1_0::IPluginCreator IPluginCreator
Definition: NvInferRuntimePlugin.h:972
v_1_0::IGpuAllocator IGpuAllocator
Definition: NvInferRuntime.h:1731
char_t AsciiChar
Definition: NvInferRuntimeBase.h:105
@ kV2_DYNAMICEXT
IPluginV2DynamicExt.
@ kV2_IOEXT
IPluginV2IOExt.
@ kV2_EXT
IPluginV2Ext.
@ kV2_DYNAMICEXT_PYTHON
IPluginV2DynamicExt-based Python plugins.
DataType
The type of weights and tensors.
Definition: NvInferRuntimeBase.h:133
TensorFormat PluginFormat
PluginFormat is reserved for backward compatibility.
Definition: NvInferRuntimePlugin.h:54
TensorFormat
Format of the input/output tensors.
Definition: NvInferRuntime.h:1306
Definition of plugin versions.
Plugin field collection struct.
Definition: NvInferPluginBase.h:96
Fields that a plugin might see for an input or output.
Definition: NvInferRuntimePlugin.h:73
DataType type
Definition: NvInferRuntimePlugin.h:77
Dims dims
Dimensions.
Definition: NvInferRuntimePlugin.h:75
TensorFormat format
Tensor format.
Definition: NvInferRuntimePlugin.h:79
float scale
Scale for INT8 data type.
Definition: NvInferRuntimePlugin.h:81

  Copyright © 2024 NVIDIA Corporation
  Privacy Policy | Manage My Privacy | Do Not Sell or Share My Data | Terms of Service | Accessibility | Corporate Policies | Product Security | Contact