TensorRT 10.0.1
NvInferRuntimePlugin.h
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef NV_INFER_RUNTIME_PLUGIN_H
19#define NV_INFER_RUNTIME_PLUGIN_H
20
21#define NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE 1
22#include "NvInferRuntimeBase.h"
23#undef NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE
24
34
40namespace nvinfer1
41{
42
49
53static constexpr int32_t kPLUGIN_VERSION_PYTHON_BIT = 0x40;
54
67{
75 float scale;
76};
77
85enum class PluginVersion : uint8_t
86{
88 kV2 = 0,
90 kV2_EXT = 1,
92 kV2_IOEXT = 2,
96 kV2_DYNAMICEXT_PYTHON = kPLUGIN_VERSION_PYTHON_BIT | 3
97};
98
104enum class PluginCreatorVersion : int32_t
105{
107 kV1 = 0,
109 kV1_PYTHON = kPLUGIN_VERSION_PYTHON_BIT
110};
111
128{
129public:
142 virtual int32_t getTensorRTVersion() const noexcept
143 {
144 return NV_TENSORRT_VERSION;
145 }
146
160 virtual AsciiChar const* getPluginType() const noexcept = 0;
161
175 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
176
190 virtual int32_t getNbOutputs() const noexcept = 0;
191
215 virtual Dims getOutputDimensions(int32_t index, Dims const* inputs, int32_t nbInputDims) noexcept = 0;
216
240 virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept = 0;
241
274 virtual void configureWithFormat(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
275 DataType type, PluginFormat format, int32_t maxBatchSize) noexcept
276 = 0;
277
289 virtual int32_t initialize() noexcept = 0;
290
304 virtual void terminate() noexcept = 0;
305
323 virtual size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept = 0;
324
346 virtual int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace,
347 cudaStream_t stream) noexcept
348 = 0;
349
360 virtual size_t getSerializationSize() const noexcept = 0;
361
375 virtual void serialize(void* buffer) const noexcept = 0;
376
385 virtual void destroy() noexcept = 0;
386
404 virtual IPluginV2* clone() const noexcept = 0;
405
420 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
421
433 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
434
435 // @cond SuppressDoxyWarnings
436 IPluginV2() = default;
437 virtual ~IPluginV2() noexcept = default;
438// @endcond
439
440protected:
441// @cond SuppressDoxyWarnings
442 IPluginV2(IPluginV2 const&) = default;
443 IPluginV2(IPluginV2&&) = default;
444 IPluginV2& operator=(IPluginV2 const&) & = default;
445 IPluginV2& operator=(IPluginV2&&) & = default;
446// @endcond
447};
448
464{
465public:
490 int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept
491 = 0;
492
514 int32_t outputIndex, bool const* inputIsBroadcasted, int32_t nbInputs) const noexcept
515 = 0;
516
542 TRT_DEPRECATED virtual bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept = 0;
543
581 virtual void configurePlugin(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
582 DataType const* inputTypes, DataType const* outputTypes, bool const* inputIsBroadcast,
583 bool const* outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept
584 = 0;
585
586 IPluginV2Ext() = default;
587 ~IPluginV2Ext() override = default;
588
623 virtual void attachToContext(
624 cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, IGpuAllocator* /*allocator*/) noexcept
625 {
626 }
627
641 virtual void detachFromContext() noexcept {}
642
657 IPluginV2Ext* clone() const noexcept override = 0;
658
659protected:
660 // @cond SuppressDoxyWarnings
661 IPluginV2Ext(IPluginV2Ext const&) = default;
662 IPluginV2Ext(IPluginV2Ext&&) = default;
663 IPluginV2Ext& operator=(IPluginV2Ext const&) & = default;
664 IPluginV2Ext& operator=(IPluginV2Ext&&) & = default;
665// @endcond
666
682 int32_t getTensorRTVersion() const noexcept override
683 {
684 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_EXT) << 24U)
685 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
686 }
687
694 void configureWithFormat(Dims const* /*inputDims*/, int32_t /*nbInputs*/, Dims const* /*outputDims*/,
695 int32_t /*nbOutputs*/, DataType /*type*/, PluginFormat /*format*/, int32_t /*maxBatchSize*/) noexcept override
696 {
697 }
698};
699
713{
714public:
732 virtual void configurePlugin(
733 PluginTensorDesc const* in, int32_t nbInput, PluginTensorDesc const* out, int32_t nbOutput) noexcept
734 = 0;
735
774 int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept
775 = 0;
776
777 // @cond SuppressDoxyWarnings
778 IPluginV2IOExt() = default;
779 ~IPluginV2IOExt() override = default;
780// @endcond
781
782protected:
783// @cond SuppressDoxyWarnings
784 IPluginV2IOExt(IPluginV2IOExt const&) = default;
785 IPluginV2IOExt(IPluginV2IOExt&&) = default;
786 IPluginV2IOExt& operator=(IPluginV2IOExt const&) & = default;
787 IPluginV2IOExt& operator=(IPluginV2IOExt&&) & = default;
788// @endcond
789
801 int32_t getTensorRTVersion() const noexcept override
802 {
803 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_IOEXT) << 24U)
804 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
805 }
806
807private:
808 // Following are obsolete base class methods, and must not be implemented or used.
809
813 void configurePlugin(Dims const*, int32_t, Dims const*, int32_t, DataType const*, DataType const*, bool const*,
814 bool const*, PluginFormat, int32_t) noexcept final
815 {
816 }
817
821 bool supportsFormat(DataType, PluginFormat) const noexcept final
822 {
823 return false;
824 }
825};
826
832enum class PluginFieldType : int32_t
833{
835 kFLOAT16 = 0,
837 kFLOAT32 = 1,
839 kFLOAT64 = 2,
841 kINT8 = 3,
843 kINT16 = 4,
845 kINT32 = 5,
847 kCHAR = 6,
849 kDIMS = 7,
851 kUNKNOWN = 8,
853 kBF16 = 9,
855 kINT64 = 10,
857 kFP8 = 11,
858};
859
868{
869public:
873 void const* data;
877 int32_t length;
878
879 PluginField(AsciiChar const* const name_ = nullptr, void const* const data_ = nullptr,
880 PluginFieldType const type_ = PluginFieldType::kUNKNOWN, int32_t const length_ = 0) noexcept
881 : name(name_)
882 , data(data_)
883 , type(type_)
884 , length(length_)
885 {
886 }
887};
888
895{
897 int32_t nbFields;
900};
901
907enum class PluginCapabilityType : int32_t
908{
910 kCORE = 0,
912 kBUILD = 1,
914 kRUNTIME = 2
915};
916
922enum class TensorRTPhase : int32_t
923{
925 kBUILD = 0,
927 kRUNTIME = 1
928};
929
930namespace v_1_0
931{
933{
934public:
935 ~IPluginCreatorInterface() noexcept override = default;
936
937protected:
941 IPluginCreatorInterface& operator=(IPluginCreatorInterface const&) & = default;
943};
944
946{
947public:
960 virtual AsciiChar const* getPluginName() const noexcept = 0;
961
974 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
975
987 virtual PluginFieldCollection const* getFieldNames() noexcept = 0;
988
1001 virtual IPluginV2* createPlugin(AsciiChar const* name, PluginFieldCollection const* fc) noexcept = 0;
1002
1018 virtual IPluginV2* deserializePlugin(AsciiChar const* name, void const* serialData, size_t serialLength) noexcept
1019 = 0;
1020
1035 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
1036
1049 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
1050
1051 IPluginCreator() = default;
1052 ~IPluginCreator() override = default;
1053
1054protected:
1055 // @cond SuppressDoxyWarnings
1056 IPluginCreator(IPluginCreator const&) = default;
1057 IPluginCreator(IPluginCreator&&) = default;
1058 IPluginCreator& operator=(IPluginCreator const&) & = default;
1059 IPluginCreator& operator=(IPluginCreator&&) & = default;
1060 // @endcond
1061public:
1065 InterfaceInfo getInterfaceInfo() const noexcept override
1066 {
1067 return InterfaceInfo{"PLUGIN CREATOR_V1", 1, 0};
1068 }
1069};
1070} // namespace v_1_0
1071
1080
1092
1093} // namespace nvinfer1
1094
1095#endif // NV_INFER_RUNTIME_PLUGIN_H
#define NV_TENSORRT_VERSION
Definition: NvInferRuntimeBase.h:93
#define TRT_DEPRECATED
Definition: NvInferRuntimeBase.h:45
Definition: NvInferRuntimeBase.h:201
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:464
virtual TRT_DEPRECATED bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept=0
Return true if the plugin can use an input tensor that is broadcast across batch without replication.
~IPluginV2Ext() override=default
void configureWithFormat(Dims const *, int32_t, Dims const *, int32_t, DataType, PluginFormat, int32_t) noexcept override
Derived classes must not implement this. In a C++11 API it would be override final.
Definition: NvInferRuntimePlugin.h:694
IPluginV2Ext * clone() const noexcept override=0
Clone the plugin object. This copies over internal plugin parameters as well and returns a new plugin...
virtual void configurePlugin(Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType const *inputTypes, DataType const *outputTypes, bool const *inputIsBroadcast, bool const *outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept=0
Configure the layer with input and output data types.
virtual void detachFromContext() noexcept
Detach the plugin object from its execution context.
Definition: NvInferRuntimePlugin.h:641
virtual TRT_DEPRECATED bool isOutputBroadcastAcrossBatch(int32_t outputIndex, bool const *inputIsBroadcasted, int32_t nbInputs) const noexcept=0
Return true if the output tensor is broadcast across a batch.
virtual void attachToContext(cudnnContext *, cublasContext *, IGpuAllocator *) noexcept
Attach the plugin object to an execution context and grant the plugin the access to some context reso...
Definition: NvInferRuntimePlugin.h:623
virtual nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const *inputTypes, int32_t nbInputs) const noexcept=0
Return the DataType of the plugin output at the requested index.
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:128
virtual AsciiChar const * getPluginType() const noexcept=0
Return the plugin type. Should match the plugin name returned by the corresponding plugin creator.
virtual int32_t getTensorRTVersion() const noexcept
Return the API version with which this plugin was built.
Definition: NvInferRuntimePlugin.h:142
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:713
int32_t getTensorRTVersion() const noexcept override
Return the API version with which this plugin was built. The upper byte is reserved by TensorRT and i...
Definition: NvInferRuntimePlugin.h:801
virtual void configurePlugin(PluginTensorDesc const *in, int32_t nbInput, PluginTensorDesc const *out, int32_t nbOutput) noexcept=0
Configure the layer.
virtual bool supportsFormatCombination(int32_t pos, PluginTensorDesc const *inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept=0
Return true if plugin supports the format and datatype for the input/output indexed by pos.
An Interface class for version control.
Definition: NvInferRuntimeBase.h:399
Version information associated with a TRT interface.
Definition: NvInferRuntimeBase.h:364
Structure containing plugin attribute field names and associated data This information can be parsed ...
Definition: NvInferRuntimePlugin.h:868
AsciiChar const * name
Plugin field attribute name.
Definition: NvInferRuntimePlugin.h:871
PluginField(AsciiChar const *const name_=nullptr, void const *const data_=nullptr, PluginFieldType const type_=PluginFieldType::kUNKNOWN, int32_t const length_=0) noexcept
Definition: NvInferRuntimePlugin.h:879
void const * data
Plugin field attribute data.
Definition: NvInferRuntimePlugin.h:873
int32_t length
Number of data entries in the Plugin attribute.
Definition: NvInferRuntimePlugin.h:877
PluginFieldType type
Plugin field attribute type.
Definition: NvInferRuntimePlugin.h:875
Definition: NvInferRuntimeBase.h:468
Definition: NvInferRuntimePlugin.h:946
virtual AsciiChar const * getPluginName() const noexcept=0
Return the plugin name.
Definition: NvInferRuntimePlugin.h:933
~IPluginCreatorInterface() noexcept override=default
The TensorRT API version 1 namespace.
PluginFieldType
The possible field types for custom layer.
Definition: NvInferRuntimePlugin.h:833
@ kUNKNOWN
Unknown field type.
@ kFLOAT32
FP32 field type.
@ kCHAR
char field type.
@ kINT16
INT16 field type.
@ kDIMS
nvinfer1::Dims field type.
@ kFLOAT64
FP64 field type.
@ kFLOAT16
FP16 field type.
PluginCreatorVersion
Enum to identify version of the plugin creator.
Definition: NvInferRuntimePlugin.h:105
@ kV1_PYTHON
IPluginCreator-based Python plugin creators.
v_1_0::IPluginCreator IPluginCreator
Definition: NvInferRuntimePlugin.h:1091
PluginCapabilityType
Enumerates the different capability types a IPluginV3 object may have.
Definition: NvInferRuntimePlugin.h:908
@ kBUILD
Build capability. IPluginV3 objects provided to TensorRT build phase must have this.
@ kRUNTIME
Runtime capability. IPluginV3 objects provided to TensorRT build and execution phases must have this.
@ kCORE
Core capability. Every IPluginV3 object must have this.
char_t AsciiChar
Definition: NvInferRuntimeBase.h:107
TensorRTPhase
Indicates a phase of operation of TensorRT.
Definition: NvInferRuntimePlugin.h:923
@ kV2_DYNAMICEXT
IPluginV2DynamicExt.
@ kV2_IOEXT
IPluginV2IOExt.
@ kV2_EXT
IPluginV2Ext.
@ kV2_DYNAMICEXT_PYTHON
IPluginV2DynamicExt-based Python plugins.
DataType
The type of weights and tensors.
Definition: NvInferRuntimeBase.h:135
@ kINT64
Signed 64-bit integer type.
@ kINT32
Signed 32-bit integer format.
TensorFormat PluginFormat
PluginFormat is reserved for backward compatibility.
Definition: NvInferRuntimePlugin.h:48
v_1_0::IPluginCreatorInterface IPluginCreatorInterface
Definition: NvInferRuntimePlugin.h:1079
@ kINT8
Enable Int8 layer selection, with FP32 fallback with FP16 fallback if kFP16 also specified.
TensorFormat
Format of the input/output tensors.
Definition: NvInferRuntimeBase.h:249
Definition of plugin versions.
Plugin field collection struct.
Definition: NvInferRuntimePlugin.h:895
PluginField const * fields
Pointer to PluginField entries.
Definition: NvInferRuntimePlugin.h:899
int32_t nbFields
Number of PluginField entries.
Definition: NvInferRuntimePlugin.h:897
Fields that a plugin might see for an input or output.
Definition: NvInferRuntimePlugin.h:67
DataType type
Definition: NvInferRuntimePlugin.h:71
Dims dims
Dimensions.
Definition: NvInferRuntimePlugin.h:69
TensorFormat format
Tensor format.
Definition: NvInferRuntimePlugin.h:73
float scale
Scale for INT8 data type.
Definition: NvInferRuntimePlugin.h:75

  Copyright © 2024 NVIDIA Corporation
  Privacy Policy | Manage My Privacy | Do Not Sell or Share My Data | Terms of Service | Accessibility | Corporate Policies | Product Security | Contact