TensorRT 8.4.3
NvInferRuntimeCommon.h
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 *
5 * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 * property and proprietary rights in and to this material, related
7 * documentation and any modifications thereto. Any use, reproduction,
8 * disclosure or distribution of this material and related documentation
9 * without an express license agreement from NVIDIA CORPORATION or
10 * its affiliates is strictly prohibited.
11 */
12
13#ifndef NV_INFER_RUNTIME_COMMON_H
14#define NV_INFER_RUNTIME_COMMON_H
15
16#include "NvInferVersion.h"
17#include <cstddef>
18#include <cstdint>
19#include <cuda_runtime_api.h>
20
21// Items that are marked as deprecated will be removed in a future release.
22#if __cplusplus >= 201402L
23#define TRT_DEPRECATED [[deprecated]]
24#if __GNUC__ < 6
25#define TRT_DEPRECATED_ENUM
26#else
27#define TRT_DEPRECATED_ENUM TRT_DEPRECATED
28#endif
29#ifdef _MSC_VER
30#define TRT_DEPRECATED_API __declspec(dllexport)
31#else
32#define TRT_DEPRECATED_API [[deprecated]] __attribute__((visibility("default")))
33#endif
34#else
35#ifdef _MSC_VER
36#define TRT_DEPRECATED
37#define TRT_DEPRECATED_ENUM
38#define TRT_DEPRECATED_API __declspec(dllexport)
39#else
40#define TRT_DEPRECATED __attribute__((deprecated))
41#define TRT_DEPRECATED_ENUM
42#define TRT_DEPRECATED_API __attribute__((deprecated, visibility("default")))
43#endif
44#endif
45
46// Defines which symbols are exported
47#ifdef TENSORRT_BUILD_LIB
48#ifdef _MSC_VER
49#define TENSORRTAPI __declspec(dllexport)
50#else
51#define TENSORRTAPI __attribute__((visibility("default")))
52#endif
53#else
54#define TENSORRTAPI
55#endif
56#define TRTNOEXCEPT
62
63// forward declare some CUDA types to avoid an include dependency
64
65extern "C"
66{
68 struct cublasContext;
70 struct cudnnContext;
71}
72
73#define NV_TENSORRT_VERSION nvinfer1::kNV_TENSORRT_VERSION_IMPL
79namespace nvinfer1
80{
81
82static constexpr int32_t kNV_TENSORRT_VERSION_IMPL
83 = (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSORRT_PATCH; // major, minor, patch
84
86using char_t = char;
89
91class IErrorRecorder;
93class IGpuAllocator;
94
95namespace impl
96{
98template <typename T>
100} // namespace impl
101
103template <typename T>
104constexpr int32_t EnumMax() noexcept
105{
107}
108
113enum class DataType : int32_t
114{
116 kFLOAT = 0,
117
119 kHALF = 1,
120
122 kINT8 = 2,
123
125 kINT32 = 3,
126
128 kBOOL = 4
129};
130
131namespace impl
132{
134template <>
136{
137 // Declaration of kVALUE that represents maximum number of elements in DataType enum
138 static constexpr int32_t kVALUE = 5;
139};
140} // namespace impl
141
153{
154public:
156 static constexpr int32_t MAX_DIMS{8};
158 int32_t nbDims;
160 int32_t d[MAX_DIMS];
161};
162
168using Dims = Dims32;
169
182enum class TensorFormat : int32_t
183{
191 kLINEAR = 0,
192
199 kCHW2 = 1,
200
207 kHWC8 = 2,
208
224 kCHW4 = 3,
225
236 kCHW16 = 4,
237
247 kCHW32 = 5,
248
255 kDHWC8 = 6,
256
263 kCDHW32 = 7,
264
267 kHWC = 8,
268
277 kDLA_LINEAR = 9,
278
292 kDLA_HWC4 = 10,
293
300 kHWC16 = 11
301};
302
309
310namespace impl
311{
313template <>
315{
317 static constexpr int32_t kVALUE = 12;
318};
319} // namespace impl
320
332{
340 float scale;
341};
342
349enum class PluginVersion : uint8_t
350{
352 kV2 = 0,
354 kV2_EXT = 1,
356 kV2_IOEXT = 2,
358 kV2_DYNAMICEXT = 3,
359};
360
373{
374public:
385 virtual int32_t getTensorRTVersion() const noexcept
386 {
387 return NV_TENSORRT_VERSION;
388 }
389
402 virtual AsciiChar const* getPluginType() const noexcept = 0;
403
416 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
417
431 virtual int32_t getNbOutputs() const noexcept = 0;
432
452 virtual Dims getOutputDimensions(int32_t index, Dims const* inputs, int32_t nbInputDims) noexcept = 0;
453
476 virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept = 0;
477
509 virtual void configureWithFormat(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
510 DataType type, PluginFormat format, int32_t maxBatchSize) noexcept
511 = 0;
512
524 virtual int32_t initialize() noexcept = 0;
525
538 virtual void terminate() noexcept = 0;
539
554 virtual size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept = 0;
555
572 virtual int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace,
573 cudaStream_t stream) noexcept
574 = 0;
575
586 virtual size_t getSerializationSize() const noexcept = 0;
587
601 virtual void serialize(void* buffer) const noexcept = 0;
602
611 virtual void destroy() noexcept = 0;
612
627 virtual IPluginV2* clone() const noexcept = 0;
628
643 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
644
653 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
654
655 // @cond SuppressDoxyWarnings
656 IPluginV2() = default;
657 virtual ~IPluginV2() noexcept = default;
658// @endcond
659
660protected:
661// @cond SuppressDoxyWarnings
662 IPluginV2(IPluginV2 const&) = default;
663 IPluginV2(IPluginV2&&) = default;
664 IPluginV2& operator=(IPluginV2 const&) & = default;
665 IPluginV2& operator=(IPluginV2&&) & = default;
666// @endcond
667};
668
680{
681public:
698 int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept
699 = 0;
700
717 int32_t outputIndex, bool const* inputIsBroadcasted, int32_t nbInputs) const noexcept
718 = 0;
719
738 virtual bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept = 0;
739
774 virtual void configurePlugin(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
775 DataType const* inputTypes, DataType const* outputTypes, bool const* inputIsBroadcast,
776 bool const* outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept
777 = 0;
778
779 IPluginV2Ext() = default;
780 ~IPluginV2Ext() override = default;
781
805 virtual void attachToContext(
806 cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, IGpuAllocator* /*allocator*/) noexcept
807 {
808 }
809
823 virtual void detachFromContext() noexcept {}
824
837 IPluginV2Ext* clone() const noexcept override = 0;
838
839protected:
840 // @cond SuppressDoxyWarnings
841 IPluginV2Ext(IPluginV2Ext const&) = default;
842 IPluginV2Ext(IPluginV2Ext&&) = default;
843 IPluginV2Ext& operator=(IPluginV2Ext const&) & = default;
844 IPluginV2Ext& operator=(IPluginV2Ext&&) & = default;
845// @endcond
846
858 int32_t getTensorRTVersion() const noexcept override
859 {
860 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_EXT) << 24U)
861 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
862 }
863
867 void configureWithFormat(Dims const* /*inputDims*/, int32_t /*nbInputs*/, Dims const* /*outputDims*/,
868 int32_t /*nbOutputs*/, DataType /*type*/, PluginFormat /*format*/, int32_t /*maxBatchSize*/) noexcept override
869 {
870 }
871};
872
883{
884public:
902 virtual void configurePlugin(
903 PluginTensorDesc const* in, int32_t nbInput, PluginTensorDesc const* out, int32_t nbOutput) noexcept
904 = 0;
905
944 int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept
945 = 0;
946
947 // @cond SuppressDoxyWarnings
948 IPluginV2IOExt() = default;
949 ~IPluginV2IOExt() override = default;
950// @endcond
951
952protected:
953// @cond SuppressDoxyWarnings
954 IPluginV2IOExt(IPluginV2IOExt const&) = default;
955 IPluginV2IOExt(IPluginV2IOExt&&) = default;
956 IPluginV2IOExt& operator=(IPluginV2IOExt const&) & = default;
957 IPluginV2IOExt& operator=(IPluginV2IOExt&&) & = default;
958// @endcond
959
971 int32_t getTensorRTVersion() const noexcept override
972 {
973 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_IOEXT) << 24U)
974 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
975 }
976
977private:
978 // Following are obsolete base class methods, and must not be implemented or used.
979
980 void configurePlugin(Dims const*, int32_t, Dims const*, int32_t, DataType const*, DataType const*, bool const*,
981 bool const*, PluginFormat, int32_t) noexcept final
982 {
983 }
984
985 bool supportsFormat(DataType, PluginFormat) const noexcept final
986 {
987 return false;
988 }
989};
990
995
996enum class PluginFieldType : int32_t
997{
999 kFLOAT16 = 0,
1001 kFLOAT32 = 1,
1003 kFLOAT64 = 2,
1005 kINT8 = 3,
1007 kINT16 = 4,
1009 kINT32 = 5,
1011 kCHAR = 6,
1013 kDIMS = 7,
1015 kUNKNOWN = 8
1016};
1017
1026{
1027public:
1035 void const* data;
1044 int32_t length;
1045
1046 PluginField(AsciiChar const* const name_ = nullptr, void const* const data_ = nullptr,
1047 PluginFieldType const type_ = PluginFieldType::kUNKNOWN, int32_t const length_ = 0) noexcept
1048 : name(name_)
1049 , data(data_)
1050 , type(type_)
1051 , length(length_)
1052 {
1053 }
1054};
1055
1058{
1060 int32_t nbFields;
1063};
1064
1072
1074{
1075public:
1083 virtual int32_t getTensorRTVersion() const noexcept
1084 {
1085 return NV_TENSORRT_VERSION;
1086 }
1087
1100 virtual AsciiChar const* getPluginName() const noexcept = 0;
1101
1114 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
1115
1126 virtual PluginFieldCollection const* getFieldNames() noexcept = 0;
1127
1137 virtual IPluginV2* createPlugin(AsciiChar const* name, PluginFieldCollection const* fc) noexcept = 0;
1138
1148 virtual IPluginV2* deserializePlugin(AsciiChar const* name, void const* serialData, size_t serialLength) noexcept
1149 = 0;
1150
1163 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
1164
1177 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
1178
1179 IPluginCreator() = default;
1180 virtual ~IPluginCreator() = default;
1181
1182protected:
1183// @cond SuppressDoxyWarnings
1184 IPluginCreator(IPluginCreator const&) = default;
1185 IPluginCreator(IPluginCreator&&) = default;
1186 IPluginCreator& operator=(IPluginCreator const&) & = default;
1187 IPluginCreator& operator=(IPluginCreator&&) & = default;
1188// @endcond
1189};
1190
1208
1210{
1211public:
1223 virtual bool registerCreator(IPluginCreator& creator, AsciiChar const* const pluginNamespace) noexcept = 0;
1224
1233 virtual IPluginCreator* const* getPluginCreatorList(int32_t* const numCreators) const noexcept = 0;
1234
1246 virtual IPluginCreator* getPluginCreator(AsciiChar const* const pluginName, AsciiChar const* const pluginVersion,
1247 AsciiChar const* const pluginNamespace = "") noexcept
1248 = 0;
1249
1250 // @cond SuppressDoxyWarnings
1251 IPluginRegistry() = default;
1252 IPluginRegistry(IPluginRegistry const&) = delete;
1253 IPluginRegistry(IPluginRegistry&&) = delete;
1254 IPluginRegistry& operator=(IPluginRegistry const&) & = delete;
1255 IPluginRegistry& operator=(IPluginRegistry&&) & = delete;
1256// @endcond
1257
1258protected:
1259 virtual ~IPluginRegistry() noexcept = default;
1260
1261public:
1271 //
1278 virtual void setErrorRecorder(IErrorRecorder* const recorder) noexcept = 0;
1279
1295 virtual IErrorRecorder* getErrorRecorder() const noexcept = 0;
1296
1312 virtual bool deregisterCreator(IPluginCreator const& creator) noexcept = 0;
1313};
1314
1315enum class AllocatorFlag : int32_t
1316{
1317 kRESIZABLE = 0,
1318};
1319
1320namespace impl
1321{
1323template <>
1325{
1326 static constexpr int32_t kVALUE = 1;
1327};
1328} // namespace impl
1329
1330using AllocatorFlags = uint32_t;
1331
1338{
1339public:
1361 virtual void* allocate(uint64_t const size, uint64_t const alignment, AllocatorFlags const flags) noexcept = 0;
1362
1381 TRT_DEPRECATED virtual void free(void* const memory) noexcept = 0;
1382
1387 virtual ~IGpuAllocator() = default;
1388 IGpuAllocator() = default;
1389
1423 virtual void* reallocate(void* /*baseAddr*/, uint64_t /*alignment*/, uint64_t /*newSize*/) noexcept
1424 {
1425 return nullptr;
1426 }
1427
1448 virtual bool deallocate(void* const memory) noexcept
1449 {
1450 this->free(memory);
1451 return true;
1452 }
1453
1454protected:
1455// @cond SuppressDoxyWarnings
1456 IGpuAllocator(IGpuAllocator const&) = default;
1457 IGpuAllocator(IGpuAllocator&&) = default;
1458 IGpuAllocator& operator=(IGpuAllocator const&) & = default;
1459 IGpuAllocator& operator=(IGpuAllocator&&) & = default;
1460// @endcond
1461};
1462
1476{
1477public:
1483 enum class Severity : int32_t
1484 {
1486 kINTERNAL_ERROR = 0,
1488 kERROR = 1,
1490 kWARNING = 2,
1492 kINFO = 3,
1494 kVERBOSE = 4,
1495 };
1496
1509 virtual void log(Severity severity, AsciiChar const* msg) noexcept = 0;
1510
1511 ILogger() = default;
1512 virtual ~ILogger() = default;
1513
1514protected:
1515// @cond SuppressDoxyWarnings
1516 ILogger(ILogger const&) = default;
1517 ILogger(ILogger&&) = default;
1518 ILogger& operator=(ILogger const&) & = default;
1519 ILogger& operator=(ILogger&&) & = default;
1520// @endcond
1521};
1522
1523namespace impl
1524{
1526template <>
1527struct EnumMaxImpl<ILogger::Severity>
1528{
1530 static constexpr int32_t kVALUE = 5;
1531};
1532} // namespace impl
1533
1539enum class ErrorCode : int32_t
1540{
1544 kSUCCESS = 0,
1545
1550
1555 kINTERNAL_ERROR = 2,
1556
1562
1570 kINVALID_CONFIG = 4,
1571
1578
1584
1592
1601
1614 kINVALID_STATE = 9,
1615
1626 kUNSUPPORTED_STATE = 10,
1627
1628};
1629
1630namespace impl
1631{
1633template <>
1635{
1637 static constexpr int32_t kVALUE = 11;
1638};
1639} // namespace impl
1640
1665{
1666public:
1670 using ErrorDesc = char const*;
1671
1675 static constexpr size_t kMAX_DESC_LENGTH{127U};
1676
1680 using RefCount = int32_t;
1681
1682 IErrorRecorder() = default;
1683 virtual ~IErrorRecorder() noexcept = default;
1684
1685 // Public API used to retrieve information from the error recorder.
1686
1705 virtual int32_t getNbErrors() const noexcept = 0;
1706
1724 virtual ErrorCode getErrorCode(int32_t errorIdx) const noexcept = 0;
1725
1745 virtual ErrorDesc getErrorDesc(int32_t errorIdx) const noexcept = 0;
1746
1761 virtual bool hasOverflowed() const noexcept = 0;
1762
1777 virtual void clear() noexcept = 0;
1778
1779 // API used by TensorRT to report Error information to the application.
1780
1801 virtual bool reportError(ErrorCode val, ErrorDesc desc) noexcept = 0;
1802
1819 virtual RefCount incRefCount() noexcept = 0;
1820
1837 virtual RefCount decRefCount() noexcept = 0;
1838
1839protected:
1840 // @cond SuppressDoxyWarnings
1841 IErrorRecorder(IErrorRecorder const&) = default;
1842 IErrorRecorder(IErrorRecorder&&) = default;
1843 IErrorRecorder& operator=(IErrorRecorder const&) & = default;
1844 IErrorRecorder& operator=(IErrorRecorder&&) & = default;
1845 // @endcond
1846}; // class IErrorRecorder
1847} // namespace nvinfer1
1848
1854extern "C" TENSORRTAPI int32_t getInferLibVersion() noexcept;
1855
1856#endif // NV_INFER_RUNTIME_COMMON_H
#define TENSORRTAPI
Definition: NvInferRuntimeCommon.h:54
int32_t getInferLibVersion() noexcept
Return the library version number.
#define NV_TENSORRT_VERSION
Definition: NvInferRuntimeCommon.h:73
#define TRT_DEPRECATED
Definition: NvInferRuntimeCommon.h:40
#define NV_TENSORRT_MINOR
TensorRT minor version.
Definition: NvInferVersion.h:23
#define NV_TENSORRT_MAJOR
TensorRT major version.
Definition: NvInferVersion.h:22
#define NV_TENSORRT_PATCH
TensorRT patch version.
Definition: NvInferVersion.h:24
Definition: NvInferRuntimeCommon.h:153
int32_t nbDims
The rank (number of dimensions).
Definition: NvInferRuntimeCommon.h:158
static constexpr int32_t MAX_DIMS
The maximum rank (number of dimensions) supported for a tensor.
Definition: NvInferRuntimeCommon.h:156
int32_t d[MAX_DIMS]
The extent of each dimension.
Definition: NvInferRuntimeCommon.h:160
Reference counted application-implemented error reporting interface for TensorRT objects.
Definition: NvInferRuntimeCommon.h:1665
virtual ~IErrorRecorder() noexcept=default
char const * ErrorDesc
Definition: NvInferRuntimeCommon.h:1670
int32_t RefCount
Definition: NvInferRuntimeCommon.h:1680
Application-implemented class for controlling allocation on the GPU.
Definition: NvInferRuntimeCommon.h:1338
virtual bool deallocate(void *const memory) noexcept
Definition: NvInferRuntimeCommon.h:1448
virtual void * reallocate(void *, uint64_t, uint64_t) noexcept
Definition: NvInferRuntimeCommon.h:1423
virtual ~IGpuAllocator()=default
virtual void * allocate(uint64_t const size, uint64_t const alignment, AllocatorFlags const flags) noexcept=0
virtual TRT_DEPRECATED void free(void *const memory) noexcept=0
Application-implemented logging interface for the builder, refitter and runtime.
Definition: NvInferRuntimeCommon.h:1476
virtual ~ILogger()=default
Severity
Definition: NvInferRuntimeCommon.h:1484
virtual void log(Severity severity, AsciiChar const *msg) noexcept=0
Plugin creator class for user implemented layers.
Definition: NvInferRuntimeCommon.h:1074
virtual int32_t getTensorRTVersion() const noexcept
Return the version of the API the plugin creator was compiled with.
Definition: NvInferRuntimeCommon.h:1083
virtual AsciiChar const * getPluginName() const noexcept=0
Return the plugin name.
Single registration point for all plugins in an application. It is used to find plugin implementation...
Definition: NvInferRuntimeCommon.h:1210
virtual bool registerCreator(IPluginCreator &creator, AsciiChar const *const pluginNamespace) noexcept=0
Register a plugin creator. Returns false if one with same type is already registered.
virtual IPluginCreator * getPluginCreator(AsciiChar const *const pluginName, AsciiChar const *const pluginVersion, AsciiChar const *const pluginNamespace="") noexcept=0
Return plugin creator based on plugin name, version, and namespace associated with plugin during netw...
virtual IPluginCreator *const * getPluginCreatorList(int32_t *const numCreators) const noexcept=0
Return all the registered plugin creators and the number of registered plugin creators....
Plugin class for user-implemented layers.
Definition: NvInferRuntimeCommon.h:680
~IPluginV2Ext() override=default
virtual bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept=0
Return true if plugin can use input that is broadcast across batch without replication.
void configureWithFormat(Dims const *, int32_t, Dims const *, int32_t, DataType, PluginFormat, int32_t) noexcept override
Derived classes should not implement this. In a C++11 API it would be override final.
Definition: NvInferRuntimeCommon.h:867
virtual bool isOutputBroadcastAcrossBatch(int32_t outputIndex, bool const *inputIsBroadcasted, int32_t nbInputs) const noexcept=0
Return true if output tensor is broadcast across a batch.
IPluginV2Ext * clone() const noexcept override=0
Clone the plugin object. This copies over internal plugin parameters as well and returns a new plugin...
virtual void configurePlugin(Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType const *inputTypes, DataType const *outputTypes, bool const *inputIsBroadcast, bool const *outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept=0
Configure the layer with input and output data types.
virtual void detachFromContext() noexcept
Detach the plugin object from its execution context.
Definition: NvInferRuntimeCommon.h:823
virtual void attachToContext(cudnnContext *, cublasContext *, IGpuAllocator *) noexcept
Attach the plugin object to an execution context and grant the plugin the access to some context reso...
Definition: NvInferRuntimeCommon.h:805
virtual nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const *inputTypes, int32_t nbInputs) const noexcept=0
Return the DataType of the plugin output at the requested index.
Plugin class for user-implemented layers.
Definition: NvInferRuntimeCommon.h:373
virtual AsciiChar const * getPluginType() const noexcept=0
Return the plugin type. Should match the plugin name returned by the corresponding plugin creator.
virtual void terminate() noexcept=0
Release resources acquired during plugin layer initialization. This is called when the engine is dest...
virtual int32_t getTensorRTVersion() const noexcept
Return the API version with which this plugin was built.
Definition: NvInferRuntimeCommon.h:385
virtual void setPluginNamespace(AsciiChar const *pluginNamespace) noexcept=0
Set the namespace that this plugin object belongs to. Ideally, all plugin objects from the same plugi...
virtual void configureWithFormat(Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType type, PluginFormat format, int32_t maxBatchSize) noexcept=0
Configure the layer.
virtual void serialize(void *buffer) const noexcept=0
Serialize the layer.
virtual void destroy() noexcept=0
Destroy the plugin object. This will be called when the network, builder or engine is destroyed.
virtual AsciiChar const * getPluginVersion() const noexcept=0
Return the plugin version. Should match the plugin version returned by the corresponding plugin creat...
virtual Dims getOutputDimensions(int32_t index, Dims const *inputs, int32_t nbInputDims) noexcept=0
Get the dimension of an output tensor.
virtual IPluginV2 * clone() const noexcept=0
Clone the plugin object. This copies over internal plugin parameters and returns a new plugin object ...
virtual int32_t enqueue(int32_t batchSize, void const *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept=0
Execute the layer.
virtual size_t getSerializationSize() const noexcept=0
Find the size of the serialization buffer required.
virtual size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept=0
Find the workspace size required by the layer.
virtual int32_t getNbOutputs() const noexcept=0
Get the number of outputs from the layer.
virtual AsciiChar const * getPluginNamespace() const noexcept=0
Return the namespace of the plugin object.
virtual int32_t initialize() noexcept=0
Initialize the layer for execution. This is called when the engine is created.
virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept=0
Check format support.
Plugin class for user-implemented layers.
Definition: NvInferRuntimeCommon.h:883
int32_t getTensorRTVersion() const noexcept override
Return the API version with which this plugin was built. The upper byte is reserved by TensorRT and i...
Definition: NvInferRuntimeCommon.h:971
virtual void configurePlugin(PluginTensorDesc const *in, int32_t nbInput, PluginTensorDesc const *out, int32_t nbOutput) noexcept=0
Configure the layer.
virtual bool supportsFormatCombination(int32_t pos, PluginTensorDesc const *inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept=0
Return true if plugin supports the format and datatype for the input/output indexed by pos.
Structure containing plugin attribute field names and associated data This information can be parsed ...
Definition: NvInferRuntimeCommon.h:1026
AsciiChar const * name
Plugin field attribute name.
Definition: NvInferRuntimeCommon.h:1031
PluginField(AsciiChar const *const name_=nullptr, void const *const data_=nullptr, PluginFieldType const type_=PluginFieldType::kUNKNOWN, int32_t const length_=0) noexcept
Definition: NvInferRuntimeCommon.h:1046
void const * data
Plugin field attribute data.
Definition: NvInferRuntimeCommon.h:1035
int32_t length
Number of data entries in the Plugin attribute.
Definition: NvInferRuntimeCommon.h:1044
PluginFieldType type
Plugin field attribute type.
Definition: NvInferRuntimeCommon.h:1040
The TensorRT API version 1 namespace.
ErrorCode
Error codes that can be returned by TensorRT during execution.
Definition: NvInferRuntimeCommon.h:1540
PluginFieldType
Definition: NvInferRuntimeCommon.h:997
@ kUNKNOWN
Unknown field type.
@ kFLOAT32
FP32 field type.
@ kCHAR
char field type.
@ kINT16
INT16 field type.
@ kDIMS
nvinfer1::Dims field type.
@ kFLOAT64
FP64 field type.
@ kFLOAT16
FP16 field type.
char_t AsciiChar
AsciiChar is the type used by TensorRT to represent valid ASCII characters.
Definition: NvInferRuntimeCommon.h:88
char char_t
char_t is the type used by TensorRT to represent all valid characters.
Definition: NvInferRuntimeCommon.h:86
@ kV2_DYNAMICEXT
IPluginV2DynamicExt.
@ kV2_IOEXT
IPluginV2IOExt.
@ kV2_EXT
IPluginV2Ext.
DataType
The type of weights and tensors.
Definition: NvInferRuntimeCommon.h:114
@ kFLOAT
32-bit floating point format.
@ kBOOL
8-bit boolean. 0 = false, 1 = true, other values undefined.
@ kHALF
IEEE 16-bit floating-point format.
@ kINT8
8-bit integer representing a quantized floating-point value.
@ kINT32
Signed 32-bit integer format.
TensorFormat PluginFormat
PluginFormat is reserved for backward compatibility.
Definition: NvInferRuntimeCommon.h:308
@ kINT8
Enable Int8 layer selection, with FP32 fallback with FP16 fallback if kFP16 also specified.
TensorFormat
Format of the input/output tensors.
Definition: NvInferRuntimeCommon.h:183
constexpr int32_t EnumMax() noexcept
Maximum number of elements in an enumeration type.
Definition: NvInferRuntimeCommon.h:104
AllocatorFlag
Definition: NvInferRuntimeCommon.h:1316
@ kRESIZABLE
TensorRT may call realloc() on this allocation.
uint32_t AllocatorFlags
Definition: NvInferRuntimeCommon.h:1330
Definition of plugin versions.
Plugin field collection struct.
Definition: NvInferRuntimeCommon.h:1058
PluginField const * fields
Pointer to PluginField entries.
Definition: NvInferRuntimeCommon.h:1062
int32_t nbFields
Number of PluginField entries.
Definition: NvInferRuntimeCommon.h:1060
Fields that a plugin might see for an input or output.
Definition: NvInferRuntimeCommon.h:332
DataType type
Definition: NvInferRuntimeCommon.h:336
Dims dims
Dimensions.
Definition: NvInferRuntimeCommon.h:334
TensorFormat format
Tensor format.
Definition: NvInferRuntimeCommon.h:338
float scale
Scale for INT8 data type.
Definition: NvInferRuntimeCommon.h:340
Declaration of EnumMaxImpl struct to store maximum number of elements in an enumeration type.
Definition: NvInferRuntimeCommon.h:99

  Copyright © 2024 NVIDIA Corporation
  Privacy Policy | Manage My Privacy | Do Not Sell or Share My Data | Terms of Service | Accessibility | Corporate Policies | Product Security | Contact