TensorRT 10.0.0
NvInferRuntimePlugin.h
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef NV_INFER_RUNTIME_PLUGIN_H
19#define NV_INFER_RUNTIME_PLUGIN_H
20
21#include "NvInferRuntimeBase.h"
22
32
38namespace nvinfer1
39{
40
47
51static constexpr int32_t kPLUGIN_VERSION_PYTHON_BIT = 0x40;
52
65{
73 float scale;
74};
75
83enum class PluginVersion : uint8_t
84{
86 kV2 = 0,
88 kV2_EXT = 1,
90 kV2_IOEXT = 2,
94 kV2_DYNAMICEXT_PYTHON = kPLUGIN_VERSION_PYTHON_BIT | 3
95};
96
102enum class PluginCreatorVersion : int32_t
103{
105 kV1 = 0,
107 kV1_PYTHON = kPLUGIN_VERSION_PYTHON_BIT
108};
109
126{
127public:
140 virtual int32_t getTensorRTVersion() const noexcept
141 {
142 return NV_TENSORRT_VERSION;
143 }
144
158 virtual AsciiChar const* getPluginType() const noexcept = 0;
159
173 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
174
188 virtual int32_t getNbOutputs() const noexcept = 0;
189
213 virtual Dims getOutputDimensions(int32_t index, Dims const* inputs, int32_t nbInputDims) noexcept = 0;
214
238 virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept = 0;
239
272 virtual void configureWithFormat(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
273 DataType type, PluginFormat format, int32_t maxBatchSize) noexcept
274 = 0;
275
287 virtual int32_t initialize() noexcept = 0;
288
302 virtual void terminate() noexcept = 0;
303
321 virtual size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept = 0;
322
344 virtual int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace,
345 cudaStream_t stream) noexcept
346 = 0;
347
358 virtual size_t getSerializationSize() const noexcept = 0;
359
373 virtual void serialize(void* buffer) const noexcept = 0;
374
383 virtual void destroy() noexcept = 0;
384
402 virtual IPluginV2* clone() const noexcept = 0;
403
418 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
419
431 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
432
433 // @cond SuppressDoxyWarnings
434 IPluginV2() = default;
435 virtual ~IPluginV2() noexcept = default;
436// @endcond
437
438protected:
439// @cond SuppressDoxyWarnings
440 IPluginV2(IPluginV2 const&) = default;
441 IPluginV2(IPluginV2&&) = default;
442 IPluginV2& operator=(IPluginV2 const&) & = default;
443 IPluginV2& operator=(IPluginV2&&) & = default;
444// @endcond
445};
446
462{
463public:
488 int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept
489 = 0;
490
512 int32_t outputIndex, bool const* inputIsBroadcasted, int32_t nbInputs) const noexcept
513 = 0;
514
540 TRT_DEPRECATED virtual bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept = 0;
541
579 virtual void configurePlugin(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
580 DataType const* inputTypes, DataType const* outputTypes, bool const* inputIsBroadcast,
581 bool const* outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept
582 = 0;
583
584 IPluginV2Ext() = default;
585 ~IPluginV2Ext() override = default;
586
621 virtual void attachToContext(
622 cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, IGpuAllocator* /*allocator*/) noexcept
623 {
624 }
625
639 virtual void detachFromContext() noexcept {}
640
655 IPluginV2Ext* clone() const noexcept override = 0;
656
657protected:
658 // @cond SuppressDoxyWarnings
659 IPluginV2Ext(IPluginV2Ext const&) = default;
660 IPluginV2Ext(IPluginV2Ext&&) = default;
661 IPluginV2Ext& operator=(IPluginV2Ext const&) & = default;
662 IPluginV2Ext& operator=(IPluginV2Ext&&) & = default;
663// @endcond
664
680 int32_t getTensorRTVersion() const noexcept override
681 {
682 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_EXT) << 24U)
683 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
684 }
685
692 void configureWithFormat(Dims const* /*inputDims*/, int32_t /*nbInputs*/, Dims const* /*outputDims*/,
693 int32_t /*nbOutputs*/, DataType /*type*/, PluginFormat /*format*/, int32_t /*maxBatchSize*/) noexcept override
694 {
695 }
696};
697
709{
710public:
728 virtual void configurePlugin(
729 PluginTensorDesc const* in, int32_t nbInput, PluginTensorDesc const* out, int32_t nbOutput) noexcept
730 = 0;
731
770 int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept
771 = 0;
772
773 // @cond SuppressDoxyWarnings
774 IPluginV2IOExt() = default;
775 ~IPluginV2IOExt() override = default;
776// @endcond
777
778protected:
779// @cond SuppressDoxyWarnings
780 IPluginV2IOExt(IPluginV2IOExt const&) = default;
781 IPluginV2IOExt(IPluginV2IOExt&&) = default;
782 IPluginV2IOExt& operator=(IPluginV2IOExt const&) & = default;
783 IPluginV2IOExt& operator=(IPluginV2IOExt&&) & = default;
784// @endcond
785
797 int32_t getTensorRTVersion() const noexcept override
798 {
799 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_IOEXT) << 24U)
800 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
801 }
802
803private:
804 // Following are obsolete base class methods, and must not be implemented or used.
805
809 void configurePlugin(Dims const*, int32_t, Dims const*, int32_t, DataType const*, DataType const*, bool const*,
810 bool const*, PluginFormat, int32_t) noexcept final
811 {
812 }
813
817 bool supportsFormat(DataType, PluginFormat) const noexcept final
818 {
819 return false;
820 }
821};
822
828enum class PluginFieldType : int32_t
829{
831 kFLOAT16 = 0,
833 kFLOAT32 = 1,
835 kFLOAT64 = 2,
837 kINT8 = 3,
839 kINT16 = 4,
841 kINT32 = 5,
843 kCHAR = 6,
845 kDIMS = 7,
847 kUNKNOWN = 8,
849 kBF16 = 9,
851 kINT64 = 10,
853 kFP8 = 11,
854};
855
864{
865public:
869 void const* data;
873 int32_t length;
874
875 PluginField(AsciiChar const* const name_ = nullptr, void const* const data_ = nullptr,
876 PluginFieldType const type_ = PluginFieldType::kUNKNOWN, int32_t const length_ = 0) noexcept
877 : name(name_)
878 , data(data_)
879 , type(type_)
880 , length(length_)
881 {
882 }
883};
884
891{
893 int32_t nbFields;
896};
897
903enum class PluginCapabilityType : int32_t
904{
906 kCORE = 0,
908 kBUILD = 1,
910 kRUNTIME = 2
911};
912
918enum class TensorRTPhase : int32_t
919{
921 kBUILD = 0,
923 kRUNTIME = 1
924};
925
926namespace v_1_0
927{
929{
930public:
931 ~IPluginCreatorInterface() noexcept override = default;
932
933protected:
937 IPluginCreatorInterface& operator=(IPluginCreatorInterface const&) & = default;
939};
940
942{
943public:
956 virtual AsciiChar const* getPluginName() const noexcept = 0;
957
970 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
971
983 virtual PluginFieldCollection const* getFieldNames() noexcept = 0;
984
997 virtual IPluginV2* createPlugin(AsciiChar const* name, PluginFieldCollection const* fc) noexcept = 0;
998
1014 virtual IPluginV2* deserializePlugin(AsciiChar const* name, void const* serialData, size_t serialLength) noexcept
1015 = 0;
1016
1031 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
1032
1045 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
1046
1047 IPluginCreator() = default;
1048 ~IPluginCreator() override = default;
1049
1050protected:
1051 // @cond SuppressDoxyWarnings
1052 IPluginCreator(IPluginCreator const&) = default;
1053 IPluginCreator(IPluginCreator&&) = default;
1054 IPluginCreator& operator=(IPluginCreator const&) & = default;
1055 IPluginCreator& operator=(IPluginCreator&&) & = default;
1056 // @endcond
1057public:
1061 InterfaceInfo getInterfaceInfo() const noexcept override
1062 {
1063 return InterfaceInfo{"PLUGIN CREATOR_V1", 1, 0};
1064 }
1065};
1066} // namespace v_1_0
1067
1076
1085
1086} // namespace nvinfer1
1087
1088#endif // NV_INFER_RUNTIME_PLUGIN_H
#define NV_TENSORRT_VERSION
Definition: NvInferRuntimeBase.h:87
#define TRT_DEPRECATED
Definition: NvInferRuntimeBase.h:45
Definition: NvInferRuntimeBase.h:195
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:462
virtual TRT_DEPRECATED bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept=0
Return true if the plugin can use an input tensor that is broadcast across batch without replication.
~IPluginV2Ext() override=default
void configureWithFormat(Dims const *, int32_t, Dims const *, int32_t, DataType, PluginFormat, int32_t) noexcept override
Derived classes must not implement this. In a C++11 API it would be override final.
Definition: NvInferRuntimePlugin.h:692
IPluginV2Ext * clone() const noexcept override=0
Clone the plugin object. This copies over internal plugin parameters as well and returns a new plugin...
virtual void configurePlugin(Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType const *inputTypes, DataType const *outputTypes, bool const *inputIsBroadcast, bool const *outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept=0
Configure the layer with input and output data types.
virtual void detachFromContext() noexcept
Detach the plugin object from its execution context.
Definition: NvInferRuntimePlugin.h:639
virtual TRT_DEPRECATED bool isOutputBroadcastAcrossBatch(int32_t outputIndex, bool const *inputIsBroadcasted, int32_t nbInputs) const noexcept=0
Return true if the output tensor is broadcast across a batch.
virtual void attachToContext(cudnnContext *, cublasContext *, IGpuAllocator *) noexcept
Attach the plugin object to an execution context and grant the plugin the access to some context reso...
Definition: NvInferRuntimePlugin.h:621
virtual nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const *inputTypes, int32_t nbInputs) const noexcept=0
Return the DataType of the plugin output at the requested index.
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:126
virtual AsciiChar const * getPluginType() const noexcept=0
Return the plugin type. Should match the plugin name returned by the corresponding plugin creator.
virtual int32_t getTensorRTVersion() const noexcept
Return the API version with which this plugin was built.
Definition: NvInferRuntimePlugin.h:140
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:709
int32_t getTensorRTVersion() const noexcept override
Return the API version with which this plugin was built. The upper byte is reserved by TensorRT and i...
Definition: NvInferRuntimePlugin.h:797
virtual void configurePlugin(PluginTensorDesc const *in, int32_t nbInput, PluginTensorDesc const *out, int32_t nbOutput) noexcept=0
Configure the layer.
virtual bool supportsFormatCombination(int32_t pos, PluginTensorDesc const *inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept=0
Return true if plugin supports the format and datatype for the input/output indexed by pos.
An Interface class for version control.
Definition: NvInferRuntimeBase.h:393
Version information associated with a TRT interface.
Definition: NvInferRuntimeBase.h:358
Structure containing plugin attribute field names and associated data This information can be parsed ...
Definition: NvInferRuntimePlugin.h:864
AsciiChar const * name
Plugin field attribute name.
Definition: NvInferRuntimePlugin.h:867
PluginField(AsciiChar const *const name_=nullptr, void const *const data_=nullptr, PluginFieldType const type_=PluginFieldType::kUNKNOWN, int32_t const length_=0) noexcept
Definition: NvInferRuntimePlugin.h:875
void const * data
Plugin field attribute data.
Definition: NvInferRuntimePlugin.h:869
int32_t length
Number of data entries in the Plugin attribute.
Definition: NvInferRuntimePlugin.h:873
PluginFieldType type
Plugin field attribute type.
Definition: NvInferRuntimePlugin.h:871
Definition: NvInferRuntimeBase.h:462
Definition: NvInferRuntimePlugin.h:942
virtual AsciiChar const * getPluginName() const noexcept=0
Return the plugin name.
Definition: NvInferRuntimePlugin.h:929
~IPluginCreatorInterface() noexcept override=default
The TensorRT API version 1 namespace.
PluginFieldType
The possible field types for custom layer.
Definition: NvInferRuntimePlugin.h:829
@ kUNKNOWN
Unknown field type.
@ kFLOAT32
FP32 field type.
@ kCHAR
char field type.
@ kINT16
INT16 field type.
@ kDIMS
nvinfer1::Dims field type.
@ kFLOAT64
FP64 field type.
@ kFLOAT16
FP16 field type.
PluginCreatorVersion
Enum to identify version of the plugin creator.
Definition: NvInferRuntimePlugin.h:103
@ kV1_PYTHON
IPluginCreator-based Python plugin creators.
v_1_0::IPluginCreator IPluginCreator
Definition: NvInferRuntimePlugin.h:1084
PluginCapabilityType
Enumerates the different capability types a IPluginV3 object may have.
Definition: NvInferRuntimePlugin.h:904
@ kBUILD
Build capability. IPluginV3 objects provided to TensorRT build phase must have this.
@ kRUNTIME
Runtime capability. IPluginV3 objects provided to TensorRT build and execution phases must have this.
@ kCORE
Core capability. Every IPluginV3 object must have this.
char_t AsciiChar
Definition: NvInferRuntimeBase.h:101
TensorRTPhase
Indicates a phase of operation of TensorRT.
Definition: NvInferRuntimePlugin.h:919
@ kV2_DYNAMICEXT
IPluginV2DynamicExt.
@ kV2_IOEXT
IPluginV2IOExt.
@ kV2_EXT
IPluginV2Ext.
@ kV2_DYNAMICEXT_PYTHON
IPluginV2DynamicExt-based Python plugins.
DataType
The type of weights and tensors.
Definition: NvInferRuntimeBase.h:129
@ kINT64
Signed 64-bit integer type.
@ kINT32
Signed 32-bit integer format.
TensorFormat PluginFormat
PluginFormat is reserved for backward compatibility.
Definition: NvInferRuntimePlugin.h:46
v_1_0::IPluginCreatorInterface IPluginCreatorInterface
Definition: NvInferRuntimePlugin.h:1075
@ kINT8
Enable Int8 layer selection, with FP32 fallback with FP16 fallback if kFP16 also specified.
TensorFormat
Format of the input/output tensors.
Definition: NvInferRuntimeBase.h:243
Definition of plugin versions.
Plugin field collection struct.
Definition: NvInferRuntimePlugin.h:891
PluginField const * fields
Pointer to PluginField entries.
Definition: NvInferRuntimePlugin.h:895
int32_t nbFields
Number of PluginField entries.
Definition: NvInferRuntimePlugin.h:893
Fields that a plugin might see for an input or output.
Definition: NvInferRuntimePlugin.h:65
DataType type
Definition: NvInferRuntimePlugin.h:69
Dims dims
Dimensions.
Definition: NvInferRuntimePlugin.h:67
TensorFormat format
Tensor format.
Definition: NvInferRuntimePlugin.h:71
float scale
Scale for INT8 data type.
Definition: NvInferRuntimePlugin.h:73

  Copyright © 2024 NVIDIA Corporation
  Privacy Policy | Manage My Privacy | Do Not Sell or Share My Data | Terms of Service | Accessibility | Corporate Policies | Product Security | Contact