TensorRT 8.6.0
NvInferRuntimePlugin.h
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 *
5 * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 * property and proprietary rights in and to this material, related
7 * documentation and any modifications thereto. Any use, reproduction,
8 * disclosure or distribution of this material and related documentation
9 * without an express license agreement from NVIDIA CORPORATION or
10 * its affiliates is strictly prohibited.
11 */
12
13#ifndef NV_INFER_RUNTIME_PLUGIN_H
14#define NV_INFER_RUNTIME_PLUGIN_H
15
16#include "NvInferRuntimeBase.h"
17
27
33namespace nvinfer1
34{
35
42
54{
62 float scale;
63};
64
71enum class PluginVersion : uint8_t
72{
74 kV2 = 0,
76 kV2_EXT = 1,
78 kV2_IOEXT = 2,
81};
82
98{
99public:
110 virtual int32_t getTensorRTVersion() const noexcept
111 {
112 return NV_TENSORRT_VERSION;
113 }
114
127 virtual AsciiChar const* getPluginType() const noexcept = 0;
128
141 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
142
156 virtual int32_t getNbOutputs() const noexcept = 0;
157
177 virtual Dims getOutputDimensions(int32_t index, Dims const* inputs, int32_t nbInputDims) noexcept = 0;
178
201 virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept = 0;
202
234 virtual void configureWithFormat(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
235 DataType type, PluginFormat format, int32_t maxBatchSize) noexcept
236 = 0;
237
249 virtual int32_t initialize() noexcept = 0;
250
263 virtual void terminate() noexcept = 0;
264
279 virtual size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept = 0;
280
297 virtual int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace,
298 cudaStream_t stream) noexcept
299 = 0;
300
311 virtual size_t getSerializationSize() const noexcept = 0;
312
326 virtual void serialize(void* buffer) const noexcept = 0;
327
336 virtual void destroy() noexcept = 0;
337
352 virtual IPluginV2* clone() const noexcept = 0;
353
368 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
369
378 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
379
380 // @cond SuppressDoxyWarnings
381 IPluginV2() = default;
382 virtual ~IPluginV2() noexcept = default;
383// @endcond
384
385protected:
386// @cond SuppressDoxyWarnings
387 IPluginV2(IPluginV2 const&) = default;
388 IPluginV2(IPluginV2&&) = default;
389 IPluginV2& operator=(IPluginV2 const&) & = default;
390 IPluginV2& operator=(IPluginV2&&) & = default;
391// @endcond
392};
393
408{
409public:
426 int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept
427 = 0;
428
445 int32_t outputIndex, bool const* inputIsBroadcasted, int32_t nbInputs) const noexcept
446 = 0;
447
466 virtual bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept = 0;
467
503 virtual void configurePlugin(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
504 DataType const* inputTypes, DataType const* outputTypes, bool const* inputIsBroadcast,
505 bool const* outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept
506 = 0;
507
508 IPluginV2Ext() = default;
509 ~IPluginV2Ext() override = default;
510
534 virtual void attachToContext(
535 cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, IGpuAllocator* /*allocator*/) noexcept
536 {
537 }
538
552 virtual void detachFromContext() noexcept {}
553
566 IPluginV2Ext* clone() const noexcept override = 0;
567
568protected:
569 // @cond SuppressDoxyWarnings
570 IPluginV2Ext(IPluginV2Ext const&) = default;
571 IPluginV2Ext(IPluginV2Ext&&) = default;
572 IPluginV2Ext& operator=(IPluginV2Ext const&) & = default;
573 IPluginV2Ext& operator=(IPluginV2Ext&&) & = default;
574// @endcond
575
587 int32_t getTensorRTVersion() const noexcept override
588 {
589 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_EXT) << 24U)
590 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
591 }
592
596 void configureWithFormat(Dims const* /*inputDims*/, int32_t /*nbInputs*/, Dims const* /*outputDims*/,
597 int32_t /*nbOutputs*/, DataType /*type*/, PluginFormat /*format*/, int32_t /*maxBatchSize*/) noexcept override
598 {
599 }
600};
601
612{
613public:
631 virtual void configurePlugin(
632 PluginTensorDesc const* in, int32_t nbInput, PluginTensorDesc const* out, int32_t nbOutput) noexcept
633 = 0;
634
673 int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept
674 = 0;
675
676 // @cond SuppressDoxyWarnings
677 IPluginV2IOExt() = default;
678 ~IPluginV2IOExt() override = default;
679// @endcond
680
681protected:
682// @cond SuppressDoxyWarnings
683 IPluginV2IOExt(IPluginV2IOExt const&) = default;
684 IPluginV2IOExt(IPluginV2IOExt&&) = default;
685 IPluginV2IOExt& operator=(IPluginV2IOExt const&) & = default;
686 IPluginV2IOExt& operator=(IPluginV2IOExt&&) & = default;
687// @endcond
688
700 int32_t getTensorRTVersion() const noexcept override
701 {
702 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_IOEXT) << 24U)
703 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
704 }
705
706private:
707 // Following are obsolete base class methods, and must not be implemented or used.
708
709 void configurePlugin(Dims const*, int32_t, Dims const*, int32_t, DataType const*, DataType const*, bool const*,
710 bool const*, PluginFormat, int32_t) noexcept final
711 {
712 }
713
714 bool supportsFormat(DataType, PluginFormat) const noexcept final
715 {
716 return false;
717 }
718};
719
724
725enum class PluginFieldType : int32_t
726{
728 kFLOAT16 = 0,
730 kFLOAT32 = 1,
732 kFLOAT64 = 2,
734 kINT8 = 3,
736 kINT16 = 4,
738 kINT32 = 5,
740 kCHAR = 6,
742 kDIMS = 7,
744 kUNKNOWN = 8
745};
746
755{
756public:
764 void const* data;
773 int32_t length;
774
775 PluginField(AsciiChar const* const name_ = nullptr, void const* const data_ = nullptr,
776 PluginFieldType const type_ = PluginFieldType::kUNKNOWN, int32_t const length_ = 0) noexcept
777 : name(name_)
778 , data(data_)
779 , type(type_)
780 , length(length_)
781 {
782 }
783};
784
787{
789 int32_t nbFields;
792};
793
801
803{
804public:
812 virtual int32_t getTensorRTVersion() const noexcept
813 {
814 return NV_TENSORRT_VERSION;
815 }
816
829 virtual AsciiChar const* getPluginName() const noexcept = 0;
830
843 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
844
855 virtual PluginFieldCollection const* getFieldNames() noexcept = 0;
856
866 virtual IPluginV2* createPlugin(AsciiChar const* name, PluginFieldCollection const* fc) noexcept = 0;
867
877 virtual IPluginV2* deserializePlugin(AsciiChar const* name, void const* serialData, size_t serialLength) noexcept
878 = 0;
879
892 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
893
906 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
907
908 IPluginCreator() = default;
909 virtual ~IPluginCreator() = default;
910
911protected:
912// @cond SuppressDoxyWarnings
913 IPluginCreator(IPluginCreator const&) = default;
914 IPluginCreator(IPluginCreator&&) = default;
915 IPluginCreator& operator=(IPluginCreator const&) & = default;
916 IPluginCreator& operator=(IPluginCreator&&) & = default;
917 // @endcond
918};
919
920} // namespace nvinfer1
921
922#endif // NV_INFER_RUNTIME_PLUGIN_H
#define NV_TENSORRT_VERSION
Definition: NvInferRuntimeBase.h:76
#define TRT_DEPRECATED
Definition: NvInferRuntimeBase.h:40
Definition: NvInferRuntimeBase.h:179
Application-implemented class for controlling allocation on the GPU.
Definition: NvInferRuntimeBase.h:367
Plugin creator class for user implemented layers.
Definition: NvInferRuntimePlugin.h:803
virtual int32_t getTensorRTVersion() const noexcept
Return the version of the API the plugin creator was compiled with.
Definition: NvInferRuntimePlugin.h:812
virtual AsciiChar const * getPluginName() const noexcept=0
Return the plugin name.
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:408
~IPluginV2Ext() override=default
virtual bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept=0
Return true if plugin can use input that is broadcast across batch without replication.
void configureWithFormat(Dims const *, int32_t, Dims const *, int32_t, DataType, PluginFormat, int32_t) noexcept override
Derived classes should not implement this. In a C++11 API it would be override final.
Definition: NvInferRuntimePlugin.h:596
virtual bool isOutputBroadcastAcrossBatch(int32_t outputIndex, bool const *inputIsBroadcasted, int32_t nbInputs) const noexcept=0
Return true if output tensor is broadcast across a batch.
IPluginV2Ext * clone() const noexcept override=0
Clone the plugin object. This copies over internal plugin parameters as well and returns a new plugin...
virtual void configurePlugin(Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType const *inputTypes, DataType const *outputTypes, bool const *inputIsBroadcast, bool const *outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept=0
Configure the layer with input and output data types.
virtual void detachFromContext() noexcept
Detach the plugin object from its execution context.
Definition: NvInferRuntimePlugin.h:552
virtual void attachToContext(cudnnContext *, cublasContext *, IGpuAllocator *) noexcept
Attach the plugin object to an execution context and grant the plugin the access to some context reso...
Definition: NvInferRuntimePlugin.h:534
virtual nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const *inputTypes, int32_t nbInputs) const noexcept=0
Return the DataType of the plugin output at the requested index.
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:98
virtual AsciiChar const * getPluginType() const noexcept=0
Return the plugin type. Should match the plugin name returned by the corresponding plugin creator.
virtual int32_t getTensorRTVersion() const noexcept
Return the API version with which this plugin was built.
Definition: NvInferRuntimePlugin.h:110
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:612
int32_t getTensorRTVersion() const noexcept override
Return the API version with which this plugin was built. The upper byte is reserved by TensorRT and i...
Definition: NvInferRuntimePlugin.h:700
virtual void configurePlugin(PluginTensorDesc const *in, int32_t nbInput, PluginTensorDesc const *out, int32_t nbOutput) noexcept=0
Configure the layer.
virtual bool supportsFormatCombination(int32_t pos, PluginTensorDesc const *inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept=0
Return true if plugin supports the format and datatype for the input/output indexed by pos.
Structure containing plugin attribute field names and associated data This information can be parsed ...
Definition: NvInferRuntimePlugin.h:755
AsciiChar const * name
Plugin field attribute name.
Definition: NvInferRuntimePlugin.h:760
PluginField(AsciiChar const *const name_=nullptr, void const *const data_=nullptr, PluginFieldType const type_=PluginFieldType::kUNKNOWN, int32_t const length_=0) noexcept
Definition: NvInferRuntimePlugin.h:775
void const * data
Plugin field attribute data.
Definition: NvInferRuntimePlugin.h:764
int32_t length
Number of data entries in the Plugin attribute.
Definition: NvInferRuntimePlugin.h:773
PluginFieldType type
Plugin field attribute type.
Definition: NvInferRuntimePlugin.h:769
The TensorRT API version 1 namespace.
PluginFieldType
The possible field types for custom layer.
Definition: NvInferRuntimePlugin.h:726
@ kUNKNOWN
Unknown field type.
@ kFLOAT32
FP32 field type.
@ kCHAR
char field type.
@ kINT16
INT16 field type.
@ kDIMS
nvinfer1::Dims field type.
@ kFLOAT64
FP64 field type.
@ kFLOAT16
FP16 field type.
char_t AsciiChar
Definition: NvInferRuntimeBase.h:94
@ kV2_DYNAMICEXT
IPluginV2DynamicExt.
@ kV2_IOEXT
IPluginV2IOExt.
@ kV2_EXT
IPluginV2Ext.
DataType
The type of weights and tensors.
Definition: NvInferRuntimeBase.h:120
@ kINT32
Signed 32-bit integer format.
TensorFormat PluginFormat
PluginFormat is reserved for backward compatibility.
Definition: NvInferRuntimePlugin.h:41
@ kINT8
Enable Int8 layer selection, with FP32 fallback with FP16 fallback if kFP16 also specified.
TensorFormat
Format of the input/output tensors.
Definition: NvInferRuntimeBase.h:209
Definition of plugin versions.
Plugin field collection struct.
Definition: NvInferRuntimePlugin.h:787
PluginField const * fields
Pointer to PluginField entries.
Definition: NvInferRuntimePlugin.h:791
int32_t nbFields
Number of PluginField entries.
Definition: NvInferRuntimePlugin.h:789
Fields that a plugin might see for an input or output.
Definition: NvInferRuntimePlugin.h:54
DataType type
Definition: NvInferRuntimePlugin.h:58
Dims dims
Dimensions.
Definition: NvInferRuntimePlugin.h:56
TensorFormat format
Tensor format.
Definition: NvInferRuntimePlugin.h:60
float scale
Scale for INT8 data type.
Definition: NvInferRuntimePlugin.h:62