Go to the documentation of this file.
50 #ifndef NV_INFER_RUNTIME_COMMON_H
51 #define NV_INFER_RUNTIME_COMMON_H
56 #include <cuda_runtime_api.h>
59 #if __cplusplus >= 201402L
60 #define TRT_DEPRECATED [[deprecated]]
62 #define TRT_DEPRECATED_ENUM
64 #define TRT_DEPRECATED_ENUM TRT_DEPRECATED
67 #define TRT_DEPRECATED_API __declspec(dllexport)
69 #define TRT_DEPRECATED_API [[deprecated]] __attribute__((visibility("default")))
73 #define TRT_DEPRECATED
74 #define TRT_DEPRECATED_ENUM
75 #define TRT_DEPRECATED_API __declspec(dllexport)
77 #define TRT_DEPRECATED __attribute__((deprecated))
78 #define TRT_DEPRECATED_ENUM
79 #define TRT_DEPRECATED_API __attribute__((deprecated, visibility("default")))
84 #ifdef TENSORRT_BUILD_LIB
86 #define TENSORRTAPI __declspec(dllexport)
88 #define TENSORRTAPI __attribute__((visibility("default")))
105 struct cublasContext;
110 #define NV_TENSORRT_VERSION nvinfer1::kNV_TENSORRT_VERSION_IMPL
119 static constexpr int32_t kNV_TENSORRT_VERSION_IMPL
135 template <
typename T>
140 template <
typename T>
143 return impl::EnumMaxImpl<T>::kVALUE;
175 static constexpr int32_t kVALUE = 5;
354 static constexpr int32_t kVALUE = 12;
419 return NV_TENSORRT_VERSION;
542 virtual int32_t
enqueue(int32_t batchSize,
void const* const* inputs,
void* const* outputs,
void* workspace,
543 cudaStream_t stream) noexcept
561 virtual
void serialize(
void* buffer) const noexcept = 0;
566 virtual
void destroy() noexcept = 0;
630 int32_t index,
nvinfer1::DataType const* inputTypes, int32_t nbInputs)
const noexcept = 0;
642 virtual bool isOutputBroadcastAcrossBatch(
643 int32_t outputIndex,
bool const* inputIsBroadcasted, int32_t nbInputs)
const noexcept = 0;
658 virtual bool canBroadcastInputAcrossBatch(int32_t inputIndex)
const noexcept = 0;
689 virtual void configurePlugin(
Dims const* inputDims, int32_t nbInputs,
Dims const* outputDims, int32_t nbOutputs,
690 DataType const* inputTypes,
DataType const* outputTypes,
bool const* inputIsBroadcast,
691 bool const* outputIsBroadcast,
PluginFormat floatFormat, int32_t maxBatchSize) noexcept = 0;
750 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
785 virtual void configurePlugin(
820 virtual bool supportsFormatCombination(
821 int32_t pos,
PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs)
const noexcept = 0;
846 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
953 return NV_TENSORRT_VERSION;
962 virtual AsciiChar const* getPluginName() const noexcept = 0;
986 virtual
IPluginV2* deserializePlugin(
AsciiChar const* name,
void const* serialData,
size_t serialLength) noexcept = 0;
1050 virtual IPluginCreator*
const* getPluginCreatorList(int32_t*
const numCreators)
const noexcept = 0;
1060 AsciiChar const*
const pluginNamespace =
"") noexcept
1087 virtual void setErrorRecorder(
IErrorRecorder*
const recorder) noexcept = 0;
1113 virtual bool deregisterCreator(
IPluginCreator const& creator) noexcept = 0;
1127 static constexpr int32_t kVALUE = 1;
1131 using AllocatorFlags = uint32_t;
1158 virtual void* allocate(uint64_t
const size, uint64_t
const alignment, AllocatorFlags
const flags) noexcept = 0;
1174 TRT_DEPRECATED virtual void free(
void*
const memory) noexcept = 0;
1284 virtual void log(Severity severity,
AsciiChar const* msg) noexcept = 0;
1305 static constexpr int32_t kVALUE = 5;
1412 static constexpr int32_t kVALUE = 11;
1450 static constexpr
size_t kMAX_DESC_LENGTH = 127U;
1475 virtual int32_t getNbErrors() const noexcept = 0;
1489 virtual
ErrorCode getErrorCode(int32_t errorIdx) const noexcept = 0;
1505 virtual
ErrorDesc getErrorDesc(int32_t errorIdx) const noexcept = 0;
1516 virtual
bool hasOverflowed() const noexcept = 0;
1527 virtual
void clear() noexcept = 0;
1559 virtual
RefCount incRefCount() noexcept = 0;
1572 virtual
RefCount decRefCount() noexcept = 0;
1591 #endif // NV_INFER_RUNTIME_COMMON_H
char_t AsciiChar
AsciiChar is the type used by TensorRT to represent valid ASCII characters.
Definition: NvInferRuntimeCommon.h:125
char char_t
char_t is the type used by TensorRT to represent all valid characters.
Definition: NvInferRuntimeCommon.h:123
Signed 32-bit integer format.
virtual AsciiChar const * getPluginType() const noexcept=0
Return the plugin type. Should match the plugin name returned by the corresponding plugin creator.
virtual Dims getOutputDimensions(int32_t index, Dims const *inputs, int32_t nbInputDims) noexcept=0
Get the dimension of an output tensor.
ErrorCode
Error codes that can be returned by TensorRT during execution.
Definition: NvInferRuntimeCommon.h:1314
float scale
Scale for INT8 data type.
Definition: NvInferRuntimeCommon.h:377
nvinfer1::Dims field type.
8-bit boolean. 0 = false, 1 = true, other values undefined.
Declaration of EnumMaxImpl struct to store maximum number of elements in an enumeration type.
Definition: NvInferRuntimeCommon.h:136
virtual void attachToContext(cudnnContext *, cublasContext *, IGpuAllocator *) noexcept
Attach the plugin object to an execution context and grant the plugin the access to some context reso...
Definition: NvInferRuntimeCommon.h:713
Plugin class for user-implemented layers.
Definition: NvInferRuntimeCommon.h:409
Definition: NvInferRuntimeCommon.h:189
virtual IPluginV2 * clone() const noexcept=0
Clone the plugin object. This copies over internal plugin parameters and returns a new plugin object ...
virtual AsciiChar const * getPluginVersion() const noexcept=0
Return the plugin version. Should match the plugin version returned by the corresponding plugin creat...
virtual void setPluginNamespace(AsciiChar const *pluginNamespace) noexcept=0
Set the namespace that this plugin object belongs to. Ideally, all plugin objects from the same plugi...
int32_t RefCount
Definition: NvInferRuntimeCommon.h:1455
IEEE 16-bit floating-point format.
Enable Int8 layer selection, with FP32 fallback with FP16 fallback if kFP16 also specified.
TensorFormat
Format of the input/output tensors.
Definition: NvInferRuntimeCommon.h:220
Plugin creator class for user implemented layers.
Definition: NvInferRuntimeCommon.h:945
Single registration point for all plugins in an application. It is used to find plugin implementation...
Definition: NvInferRuntimeCommon.h:1034
PluginFieldType
Definition: NvInferRuntimeCommon.h:868
Application-implemented logging interface for the builder, refitter and runtime.
Definition: NvInferRuntimeCommon.h:1256
virtual size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept=0
Find the workspace size required by the layer.
virtual void * reallocate(void *, uint64_t, uint64_t) noexcept
Definition: NvInferRuntimeCommon.h:1212
int32_t d[MAX_DIMS]
The extent of each dimension.
Definition: NvInferRuntimeCommon.h:197
virtual AsciiChar const * getPluginNamespace() const noexcept=0
Return the namespace of the plugin object.
int32_t length
Number of data entries in the Plugin attribute.
Definition: NvInferRuntimeCommon.h:916
PluginFieldType type
Plugin field attribute type.
Definition: NvInferRuntimeCommon.h:912
DataType type
Definition: NvInferRuntimeCommon.h:373
void configureWithFormat(Dims const *, int32_t, Dims const *, int32_t, DataType, PluginFormat, int32_t) noexcept override
Derived classes should not implement this. In a C++11 API it would be override final.
Definition: NvInferRuntimeCommon.h:756
The TensorRT API version 1 namespace.
TensorFormat format
Tensor format.
Definition: NvInferRuntimeCommon.h:375
#define NV_TENSORRT_MINOR
TensorRT minor version.
Definition: NvInferVersion.h:60
int32_t nbDims
The rank (number of dimensions).
Definition: NvInferRuntimeCommon.h:195
char const * ErrorDesc
Definition: NvInferRuntimeCommon.h:1445
virtual void detachFromContext() noexcept
Detach the plugin object from its execution context.
Definition: NvInferRuntimeCommon.h:723
Definition of plugin versions.
Perform the normal matrix multiplication in the first recurrent layer.
int32_t nbFields
Number of PluginField entries.
Definition: NvInferRuntimeCommon.h:932
virtual int32_t getNbOutputs() const noexcept=0
Get the number of outputs from the layer.
DataType
The type of weights and tensors.
Definition: NvInferRuntimeCommon.h:150
int32_t getTensorRTVersion() const noexcept override
Return the API version with which this plugin was built. The upper byte is reserved by TensorRT and i...
Definition: NvInferRuntimeCommon.h:843
Severity
Definition: NvInferRuntimeCommon.h:1264
constexpr int32_t EnumMax() noexcept
Maximum number of elements in an enumeration type.
Definition: NvInferRuntimeCommon.h:141
virtual int32_t getTensorRTVersion() const noexcept
Return the API version with which this plugin was built.
Definition: NvInferRuntimeCommon.h:417
Plugin class for user-implemented layers.
Definition: NvInferRuntimeCommon.h:771
virtual int32_t enqueue(int32_t batchSize, void const *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept=0
Execute the layer.
Reference counted application-implemented error reporting interface for TensorRT objects.
Definition: NvInferRuntimeCommon.h:1439
TensorRT may call realloc() on this allocation.
virtual void configureWithFormat(Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType type, PluginFormat format, int32_t maxBatchSize) noexcept=0
Configure the layer.
virtual size_t getSerializationSize() const noexcept=0
Find the size of the serialization buffer required.
virtual void serialize(void *buffer) const noexcept=0
Serialize the layer.
Plugin class for user-implemented layers.
Definition: NvInferRuntimeCommon.h:618
static constexpr int32_t MAX_DIMS
The maximum rank (number of dimensions) supported for a tensor.
Definition: NvInferRuntimeCommon.h:193
32-bit floating point format.
AsciiChar const * name
Plugin field attribute name.
Definition: NvInferRuntimeCommon.h:903
PluginField const * fields
Pointer to PluginField entries.
Definition: NvInferRuntimeCommon.h:934
Dims dims
Dimensions.
Definition: NvInferRuntimeCommon.h:371
Fields that a plugin might see for an input or output.
Definition: NvInferRuntimeCommon.h:368
virtual int32_t initialize() noexcept=0
Initialize the layer for execution. This is called when the engine is created.
#define NV_TENSORRT_MAJOR
TensorRT major version.
Definition: NvInferVersion.h:59
TensorFormat PluginFormat
PluginFormat is reserved for backward compatibility.
Definition: NvInferRuntimeCommon.h:345
virtual void destroy() noexcept=0
Destroy the plugin object. This will be called when the network, builder or engine is destroyed.
virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept=0
Check format support.
virtual void terminate() noexcept=0
Release resources acquired during plugin layer initialization. This is called when the engine is dest...
AllocatorFlag
Definition: NvInferRuntimeCommon.h:1116
Application-implemented class for controlling allocation on the GPU.
Definition: NvInferRuntimeCommon.h:1138
virtual bool deallocate(void *const memory) noexcept
Definition: NvInferRuntimeCommon.h:1233
#define NV_TENSORRT_PATCH
TensorRT patch version.
Definition: NvInferVersion.h:61
Plugin field collection struct.
Definition: NvInferRuntimeCommon.h:929
Structure containing plugin attribute field names and associated data This information can be parsed ...
Definition: NvInferRuntimeCommon.h:897
virtual int32_t getTensorRTVersion() const noexcept
Return the version of the API the plugin creator was compiled with.
Definition: NvInferRuntimeCommon.h:951
#define TRT_DEPRECATED
< Items that are marked as deprecated will be removed in a future release.
Definition: NvInferRuntimeCommon.h:77
int32_t getInferLibVersion() noexcept
Return the library version number.
void const * data
Plugin field attribute data.
Definition: NvInferRuntimeCommon.h:907