18#ifndef NV_INFER_RUNTIME_PLUGIN_H
19#define NV_INFER_RUNTIME_PLUGIN_H
21#define NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE 1
23#undef NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE
53static constexpr int32_t kPLUGIN_VERSION_PYTHON_BIT = 0x40;
175 virtual
AsciiChar const* getPluginVersion() const noexcept = 0;
190 virtual int32_t getNbOutputs() const noexcept = 0;
215 virtual
Dims getOutputDimensions(int32_t index,
Dims const* inputs, int32_t nbInputDims) noexcept = 0;
274 virtual
void configureWithFormat(
Dims const* inputDims, int32_t nbInputs,
Dims const* outputDims, int32_t nbOutputs,
289 virtual int32_t initialize() noexcept = 0;
304 virtual
void terminate() noexcept = 0;
323 virtual
size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept = 0;
346 virtual int32_t enqueue(int32_t batchSize,
void const* const* inputs,
void* const* outputs,
void* workspace,
347 cudaStream_t stream) noexcept
360 virtual
size_t getSerializationSize() const noexcept = 0;
375 virtual
void serialize(
void* buffer) const noexcept = 0;
385 virtual
void destroy() noexcept = 0;
420 virtual
void setPluginNamespace(
AsciiChar const* pluginNamespace) noexcept = 0;
433 virtual
AsciiChar const* getPluginNamespace() const noexcept = 0;
514 int32_t outputIndex,
bool const* inputIsBroadcasted, int32_t nbInputs)
const noexcept
582 DataType const* inputTypes,
DataType const* outputTypes,
bool const* inputIsBroadcast,
583 bool const* outputIsBroadcast,
PluginFormat floatFormat, int32_t maxBatchSize)
noexcept
682 int32_t getTensorRTVersion() const noexcept
override
774 int32_t pos,
PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs)
const noexcept
974 virtual
AsciiChar const* getPluginVersion() const noexcept = 0;
1018 virtual
IPluginV2* deserializePlugin(
AsciiChar const* name,
void const* serialData,
size_t serialLength) noexcept
1035 virtual
void setPluginNamespace(
AsciiChar const* pluginNamespace) noexcept = 0;
1049 virtual
AsciiChar const* getPluginNamespace() const noexcept = 0;
#define NV_TENSORRT_VERSION
Definition: NvInferRuntimeBase.h:93
#define TRT_DEPRECATED
Definition: NvInferRuntimeBase.h:45
Definition: NvInferRuntimeBase.h:201
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:464
virtual TRT_DEPRECATED bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept=0
Return true if the plugin can use an input tensor that is broadcast across batch without replication.
~IPluginV2Ext() override=default
void configureWithFormat(Dims const *, int32_t, Dims const *, int32_t, DataType, PluginFormat, int32_t) noexcept override
Derived classes must not implement this. In a C++11 API it would be override final.
Definition: NvInferRuntimePlugin.h:694
IPluginV2Ext * clone() const noexcept override=0
Clone the plugin object. This copies over internal plugin parameters as well and returns a new plugin...
virtual void configurePlugin(Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType const *inputTypes, DataType const *outputTypes, bool const *inputIsBroadcast, bool const *outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept=0
Configure the layer with input and output data types.
virtual void detachFromContext() noexcept
Detach the plugin object from its execution context.
Definition: NvInferRuntimePlugin.h:641
virtual TRT_DEPRECATED bool isOutputBroadcastAcrossBatch(int32_t outputIndex, bool const *inputIsBroadcasted, int32_t nbInputs) const noexcept=0
Return true if the output tensor is broadcast across a batch.
virtual void attachToContext(cudnnContext *, cublasContext *, IGpuAllocator *) noexcept
Attach the plugin object to an execution context and grant the plugin the access to some context reso...
Definition: NvInferRuntimePlugin.h:623
virtual nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const *inputTypes, int32_t nbInputs) const noexcept=0
Return the DataType of the plugin output at the requested index.
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:128
virtual AsciiChar const * getPluginType() const noexcept=0
Return the plugin type. Should match the plugin name returned by the corresponding plugin creator.
virtual int32_t getTensorRTVersion() const noexcept
Return the API version with which this plugin was built.
Definition: NvInferRuntimePlugin.h:142
Plugin class for user-implemented layers.
Definition: NvInferRuntimePlugin.h:713
int32_t getTensorRTVersion() const noexcept override
Return the API version with which this plugin was built. The upper byte is reserved by TensorRT and i...
Definition: NvInferRuntimePlugin.h:801
virtual void configurePlugin(PluginTensorDesc const *in, int32_t nbInput, PluginTensorDesc const *out, int32_t nbOutput) noexcept=0
Configure the layer.
virtual bool supportsFormatCombination(int32_t pos, PluginTensorDesc const *inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept=0
Return true if plugin supports the format and datatype for the input/output indexed by pos.
An Interface class for version control.
Definition: NvInferRuntimeBase.h:399
Version information associated with a TRT interface.
Definition: NvInferRuntimeBase.h:364
Structure containing plugin attribute field names and associated data This information can be parsed ...
Definition: NvInferRuntimePlugin.h:868
AsciiChar const * name
Plugin field attribute name.
Definition: NvInferRuntimePlugin.h:871
PluginField(AsciiChar const *const name_=nullptr, void const *const data_=nullptr, PluginFieldType const type_=PluginFieldType::kUNKNOWN, int32_t const length_=0) noexcept
Definition: NvInferRuntimePlugin.h:879
void const * data
Plugin field attribute data.
Definition: NvInferRuntimePlugin.h:873
int32_t length
Number of data entries in the Plugin attribute.
Definition: NvInferRuntimePlugin.h:877
PluginFieldType type
Plugin field attribute type.
Definition: NvInferRuntimePlugin.h:875
Definition: NvInferRuntimeBase.h:468
Definition: NvInferRuntimePlugin.h:946
virtual AsciiChar const * getPluginName() const noexcept=0
Return the plugin name.
Definition: NvInferRuntimePlugin.h:933
~IPluginCreatorInterface() noexcept override=default
The TensorRT API version 1 namespace.
PluginFieldType
The possible field types for custom layer.
Definition: NvInferRuntimePlugin.h:833
@ kUNKNOWN
Unknown field type.
@ kFLOAT32
FP32 field type.
@ kINT16
INT16 field type.
@ kDIMS
nvinfer1::Dims field type.
@ kFLOAT64
FP64 field type.
@ kFLOAT16
FP16 field type.
PluginCreatorVersion
Enum to identify version of the plugin creator.
Definition: NvInferRuntimePlugin.h:105
@ kV1_PYTHON
IPluginCreator-based Python plugin creators.
v_1_0::IPluginCreator IPluginCreator
Definition: NvInferRuntimePlugin.h:1091
PluginCapabilityType
Enumerates the different capability types a IPluginV3 object may have.
Definition: NvInferRuntimePlugin.h:908
@ kBUILD
Build capability. IPluginV3 objects provided to TensorRT build phase must have this.
@ kRUNTIME
Runtime capability. IPluginV3 objects provided to TensorRT build and execution phases must have this.
@ kCORE
Core capability. Every IPluginV3 object must have this.
char_t AsciiChar
Definition: NvInferRuntimeBase.h:107
TensorRTPhase
Indicates a phase of operation of TensorRT.
Definition: NvInferRuntimePlugin.h:923
@ kV2_DYNAMICEXT
IPluginV2DynamicExt.
@ kV2_IOEXT
IPluginV2IOExt.
@ kV2_DYNAMICEXT_PYTHON
IPluginV2DynamicExt-based Python plugins.
DataType
The type of weights and tensors.
Definition: NvInferRuntimeBase.h:135
@ kINT64
Signed 64-bit integer type.
@ kINT32
Signed 32-bit integer format.
TensorFormat PluginFormat
PluginFormat is reserved for backward compatibility.
Definition: NvInferRuntimePlugin.h:48
v_1_0::IPluginCreatorInterface IPluginCreatorInterface
Definition: NvInferRuntimePlugin.h:1079
@ kINT8
Enable Int8 layer selection, with FP32 fallback with FP16 fallback if kFP16 also specified.
TensorFormat
Format of the input/output tensors.
Definition: NvInferRuntimeBase.h:249
Definition of plugin versions.
Plugin field collection struct.
Definition: NvInferRuntimePlugin.h:895
PluginField const * fields
Pointer to PluginField entries.
Definition: NvInferRuntimePlugin.h:899
int32_t nbFields
Number of PluginField entries.
Definition: NvInferRuntimePlugin.h:897
Fields that a plugin might see for an input or output.
Definition: NvInferRuntimePlugin.h:67
DataType type
Definition: NvInferRuntimePlugin.h:71
Dims dims
Dimensions.
Definition: NvInferRuntimePlugin.h:69
TensorFormat format
Tensor format.
Definition: NvInferRuntimePlugin.h:73
float scale
Scale for INT8 data type.
Definition: NvInferRuntimePlugin.h:75