TensorRT 8.2.1
NvInferRuntimeCommon.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 1993-2021 NVIDIA Corporation. All rights reserved.
3 *
4 * NOTICE TO LICENSEE:
5 *
6 * This source code and/or documentation ("Licensed Deliverables") are
7 * subject to NVIDIA intellectual property rights under U.S. and
8 * international Copyright laws.
9 *
10 * These Licensed Deliverables contained herein is PROPRIETARY and
11 * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12 * conditions of a form of NVIDIA software license agreement by and
13 * between NVIDIA and Licensee ("License Agreement") or electronically
14 * accepted by Licensee. Notwithstanding any terms or conditions to
15 * the contrary in the License Agreement, reproduction or disclosure
16 * of the Licensed Deliverables to any third party without the express
17 * written consent of NVIDIA is prohibited.
18 *
19 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20 * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21 * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22 * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23 * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24 * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25 * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27 * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28 * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29 * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30 * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31 * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32 * OF THESE LICENSED DELIVERABLES.
33 *
34 * U.S. Government End Users. These Licensed Deliverables are a
35 * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36 * 1995), consisting of "commercial computer software" and "commercial
37 * computer software documentation" as such terms are used in 48
38 * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39 * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40 * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41 * U.S. Government End Users acquire the Licensed Deliverables with
42 * only those rights set forth herein.
43 *
44 * Any use of the Licensed Deliverables in individual and commercial
45 * software must include, in the user documentation and internal
46 * comments to the code, the above Disclaimer and U.S. Government End
47 * Users Notice.
48 */
49
50#ifndef NV_INFER_RUNTIME_COMMON_H
51#define NV_INFER_RUNTIME_COMMON_H
52
53#include "NvInferVersion.h"
54#include <cstddef>
55#include <cstdint>
56#include <cuda_runtime_api.h>
57
59#if __cplusplus >= 201402L
60#define TRT_DEPRECATED [[deprecated]]
61#if __GNUC__ < 6
62#define TRT_DEPRECATED_ENUM
63#else
64#define TRT_DEPRECATED_ENUM TRT_DEPRECATED
65#endif
66#ifdef _MSC_VER
67#define TRT_DEPRECATED_API __declspec(dllexport)
68#else
69#define TRT_DEPRECATED_API [[deprecated]] __attribute__((visibility("default")))
70#endif
71#else
72#ifdef _MSC_VER
73#define TRT_DEPRECATED
74#define TRT_DEPRECATED_ENUM
75#define TRT_DEPRECATED_API __declspec(dllexport)
76#else
77#define TRT_DEPRECATED __attribute__((deprecated))
78#define TRT_DEPRECATED_ENUM
79#define TRT_DEPRECATED_API __attribute__((deprecated, visibility("default")))
80#endif
81#endif
82
84#ifdef TENSORRT_BUILD_LIB
85#ifdef _MSC_VER
86#define TENSORRTAPI __declspec(dllexport)
87#else
88#define TENSORRTAPI __attribute__((visibility("default")))
89#endif
90#else
91#define TENSORRTAPI
92#endif
93#define TRTNOEXCEPT
99
100// forward declare some CUDA types to avoid an include dependency
101
102extern "C"
103{
105 struct cublasContext;
107 struct cudnnContext;
108}
109
110#define NV_TENSORRT_VERSION nvinfer1::kNV_TENSORRT_VERSION_IMPL
116namespace nvinfer1
117{
118
119static constexpr int32_t kNV_TENSORRT_VERSION_IMPL
120 = (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSORRT_PATCH; // major, minor, patch
121
123using char_t = char;
126
128class IErrorRecorder;
130class IGpuAllocator;
131
132namespace impl
133{
135template <typename T>
137} // namespace impl
138
140template <typename T>
141constexpr int32_t EnumMax() noexcept
142{
144}
145
150enum class DataType : int32_t
151{
153 kFLOAT = 0,
154
156 kHALF = 1,
157
159 kINT8 = 2,
160
162 kINT32 = 3,
163
165 kBOOL = 4
166};
167
168namespace impl
169{
171template <>
173{
174 // Declaration of kVALUE that represents maximum number of elements in DataType enum
175 static constexpr int32_t kVALUE = 5;
176};
177} // namespace impl
178
190{
191public:
193 static constexpr int32_t MAX_DIMS{8};
195 int32_t nbDims;
197 int32_t d[MAX_DIMS];
198};
199
205using Dims = Dims32;
206
220enum class TensorFormat : int32_t
221{
229 kLINEAR = 0,
230
237 kCHW2 = 1,
238
245 kHWC8 = 2,
246
262 kCHW4 = 3,
263
274 kCHW16 = 4,
275
285 kCHW32 = 5,
286
293 kDHWC8 = 6,
294
301 kCDHW32 = 7,
302
305 kHWC = 8,
306
315 kDLA_LINEAR = 9,
316
329 kDLA_HWC4 = 10,
330
337 kHWC16 = 11
338};
339
346
347namespace impl
348{
350template <>
352{
354 // coverity[autosar_cpp14_m0_1_4_violation] Approved RFD: https://jirasw.nvidia.com/browse/TID-489
355 static constexpr int32_t kVALUE = 12;
356};
357} // namespace impl
358
370{
378 float scale;
379};
380
387enum class PluginVersion : uint8_t
388{
390 kV2 = 0,
392 kV2_EXT = 1,
394 kV2_IOEXT = 2,
396 kV2_DYNAMICEXT = 3,
397};
398
411{
412public:
423 virtual int32_t getTensorRTVersion() const noexcept
424 {
425 return NV_TENSORRT_VERSION;
426 }
427
440 virtual AsciiChar const* getPluginType() const noexcept = 0;
441
454 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
455
469 virtual int32_t getNbOutputs() const noexcept = 0;
470
486 virtual Dims getOutputDimensions(int32_t index, Dims const* inputs, int32_t nbInputDims) noexcept = 0;
487
510 virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept = 0;
511
543 virtual void configureWithFormat(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
544 DataType type, PluginFormat format, int32_t maxBatchSize) noexcept
545 = 0;
546
558 virtual int32_t initialize() noexcept = 0;
559
572 virtual void terminate() noexcept = 0;
573
588 virtual size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept = 0;
589
606 virtual int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace,
607 cudaStream_t stream) noexcept
608 = 0;
609
620 virtual size_t getSerializationSize() const noexcept = 0;
621
635 virtual void serialize(void* buffer) const noexcept = 0;
636
645 virtual void destroy() noexcept = 0;
646
661 virtual IPluginV2* clone() const noexcept = 0;
662
677 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
678
687 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
688
689 // @cond SuppressDoxyWarnings
690 IPluginV2() = default;
691 virtual ~IPluginV2() noexcept = default;
692// @endcond
693
694protected:
695// @cond SuppressDoxyWarnings
696 IPluginV2(IPluginV2 const&) = default;
697 IPluginV2(IPluginV2&&) = default;
698 IPluginV2& operator=(IPluginV2 const&) & = default;
699 IPluginV2& operator=(IPluginV2&&) & = default;
700// @endcond
701};
702
714{
715public:
732 int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept
733 = 0;
734
751 int32_t outputIndex, bool const* inputIsBroadcasted, int32_t nbInputs) const noexcept
752 = 0;
753
772 virtual bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept = 0;
773
808 virtual void configurePlugin(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
809 DataType const* inputTypes, DataType const* outputTypes, bool const* inputIsBroadcast,
810 bool const* outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept
811 = 0;
812
813 IPluginV2Ext() = default;
814 ~IPluginV2Ext() override = default;
815
839 virtual void attachToContext(
840 cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, IGpuAllocator* /*allocator*/) noexcept
841 {
842 }
843
857 virtual void detachFromContext() noexcept {}
858
871 IPluginV2Ext* clone() const noexcept override = 0;
872
873protected:
874 // @cond SuppressDoxyWarnings
875 IPluginV2Ext(IPluginV2Ext const&) = default;
876 IPluginV2Ext(IPluginV2Ext&&) = default;
877 IPluginV2Ext& operator=(IPluginV2Ext const&) & = default;
878 IPluginV2Ext& operator=(IPluginV2Ext&&) & = default;
879// @endcond
880
892 int32_t getTensorRTVersion() const noexcept override
893 {
894 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_EXT) << 24U)
895 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
896 }
897
901 void configureWithFormat(Dims const* /*inputDims*/, int32_t /*nbInputs*/, Dims const* /*outputDims*/,
902 int32_t /*nbOutputs*/, DataType /*type*/, PluginFormat /*format*/, int32_t /*maxBatchSize*/) noexcept override
903 {
904 }
905};
906
917{
918public:
936 virtual void configurePlugin(
937 PluginTensorDesc const* in, int32_t nbInput, PluginTensorDesc const* out, int32_t nbOutput) noexcept
938 = 0;
939
978 int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept
979 = 0;
980
981 // @cond SuppressDoxyWarnings
982 IPluginV2IOExt() = default;
983 ~IPluginV2IOExt() override = default;
984// @endcond
985
986protected:
987// @cond SuppressDoxyWarnings
988 IPluginV2IOExt(IPluginV2IOExt const&) = default;
989 IPluginV2IOExt(IPluginV2IOExt&&) = default;
990 IPluginV2IOExt& operator=(IPluginV2IOExt const&) & = default;
991 IPluginV2IOExt& operator=(IPluginV2IOExt&&) & = default;
992// @endcond
993
1005 int32_t getTensorRTVersion() const noexcept override
1006 {
1007 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_IOEXT) << 24U)
1008 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
1009 }
1010
1011private:
1012 // Following are obsolete base class methods, and must not be implemented or used.
1013
1014 void configurePlugin(Dims const*, int32_t, Dims const*, int32_t, DataType const*, DataType const*, bool const*,
1015 bool const*, PluginFormat, int32_t) noexcept final
1016 {
1017 }
1018
1019 bool supportsFormat(DataType, PluginFormat) const noexcept final
1020 {
1021 return false;
1022 }
1023};
1024
1029
1030enum class PluginFieldType : int32_t
1031{
1033 kFLOAT16 = 0,
1035 kFLOAT32 = 1,
1037 kFLOAT64 = 2,
1039 kINT8 = 3,
1041 kINT16 = 4,
1043 kINT32 = 5,
1045 kCHAR = 6,
1047 kDIMS = 7,
1049 kUNKNOWN = 8
1050};
1051
1060{
1061public:
1069 void const* data;
1078 int32_t length;
1079
1080 PluginField(AsciiChar const* const name_ = nullptr, void const* const data_ = nullptr,
1081 PluginFieldType const type_ = PluginFieldType::kUNKNOWN, int32_t const length_ = 0) noexcept
1082 : name(name_)
1083 , data(data_)
1084 , type(type_)
1085 , length(length_)
1086 {
1087 }
1088};
1089
1092{
1094 int32_t nbFields;
1097};
1098
1106
1108{
1109public:
1117 virtual int32_t getTensorRTVersion() const noexcept
1118 {
1119 return NV_TENSORRT_VERSION;
1120 }
1121
1134 virtual AsciiChar const* getPluginName() const noexcept = 0;
1135
1148 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
1149
1160 virtual PluginFieldCollection const* getFieldNames() noexcept = 0;
1161
1171 virtual IPluginV2* createPlugin(AsciiChar const* name, PluginFieldCollection const* fc) noexcept = 0;
1172
1182 virtual IPluginV2* deserializePlugin(AsciiChar const* name, void const* serialData, size_t serialLength) noexcept
1183 = 0;
1184
1197 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
1198
1211 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
1212
1213 IPluginCreator() = default;
1214 virtual ~IPluginCreator() = default;
1215
1216protected:
1217// @cond SuppressDoxyWarnings
1218 IPluginCreator(IPluginCreator const&) = default;
1219 IPluginCreator(IPluginCreator&&) = default;
1220 IPluginCreator& operator=(IPluginCreator const&) & = default;
1221 IPluginCreator& operator=(IPluginCreator&&) & = default;
1222// @endcond
1223};
1224
1242
1244{
1245public:
1257 virtual bool registerCreator(IPluginCreator& creator, AsciiChar const* const pluginNamespace) noexcept = 0;
1258
1267 virtual IPluginCreator* const* getPluginCreatorList(int32_t* const numCreators) const noexcept = 0;
1268
1280 virtual IPluginCreator* getPluginCreator(AsciiChar const* const pluginName, AsciiChar const* const pluginVersion,
1281 AsciiChar const* const pluginNamespace = "") noexcept
1282 = 0;
1283
1284 // @cond SuppressDoxyWarnings
1285 IPluginRegistry() = default;
1286 IPluginRegistry(IPluginRegistry const&) = delete;
1287 IPluginRegistry(IPluginRegistry&&) = delete;
1288 IPluginRegistry& operator=(IPluginRegistry const&) & = delete;
1289 IPluginRegistry& operator=(IPluginRegistry&&) & = delete;
1290// @endcond
1291
1292protected:
1293 virtual ~IPluginRegistry() noexcept = default;
1294
1295public:
1305 //
1312 virtual void setErrorRecorder(IErrorRecorder* const recorder) noexcept = 0;
1313
1329 virtual IErrorRecorder* getErrorRecorder() const noexcept = 0;
1330
1346 virtual bool deregisterCreator(IPluginCreator const& creator) noexcept = 0;
1347};
1348
1349enum class AllocatorFlag : int32_t
1350{
1351 kRESIZABLE = 0,
1352};
1353
1354namespace impl
1355{
1357template <>
1359{
1360 static constexpr int32_t kVALUE = 1;
1361};
1362} // namespace impl
1363
1364using AllocatorFlags = uint32_t;
1365
1372{
1373public:
1395 virtual void* allocate(uint64_t const size, uint64_t const alignment, AllocatorFlags const flags) noexcept = 0;
1396
1415 TRT_DEPRECATED virtual void free(void* const memory) noexcept = 0;
1416
1421 virtual ~IGpuAllocator() = default;
1422 IGpuAllocator() = default;
1423
1457 virtual void* reallocate(void* /*baseAddr*/, uint64_t /*alignment*/, uint64_t /*newSize*/) noexcept
1458 {
1459 return nullptr;
1460 }
1461
1482 virtual bool deallocate(void* const memory) noexcept
1483 {
1484 this->free(memory);
1485 return true;
1486 }
1487
1488protected:
1489// @cond SuppressDoxyWarnings
1490 IGpuAllocator(IGpuAllocator const&) = default;
1491 IGpuAllocator(IGpuAllocator&&) = default;
1492 IGpuAllocator& operator=(IGpuAllocator const&) & = default;
1493 IGpuAllocator& operator=(IGpuAllocator&&) & = default;
1494// @endcond
1495};
1496
1506{
1507public:
1513 enum class Severity : int32_t
1514 {
1516 kINTERNAL_ERROR = 0,
1518 kERROR = 1,
1520 kWARNING = 2,
1522 kINFO = 3,
1524 kVERBOSE = 4,
1525 };
1526
1537 virtual void log(Severity severity, AsciiChar const* msg) noexcept = 0;
1538
1539 ILogger() = default;
1540 virtual ~ILogger() = default;
1541
1542protected:
1543// @cond SuppressDoxyWarnings
1544 ILogger(ILogger const&) = default;
1545 ILogger(ILogger&&) = default;
1546 ILogger& operator=(ILogger const&) & = default;
1547 ILogger& operator=(ILogger&&) & = default;
1548// @endcond
1549};
1550
1551namespace impl
1552{
1554template <>
1555struct EnumMaxImpl<ILogger::Severity>
1556{
1558 static constexpr int32_t kVALUE = 5;
1559};
1560} // namespace impl
1561
1567enum class ErrorCode : int32_t
1568{
1572 kSUCCESS = 0,
1573
1578
1583 kINTERNAL_ERROR = 2,
1584
1590
1598 kINVALID_CONFIG = 4,
1599
1606
1612
1620
1629
1642 kINVALID_STATE = 9,
1643
1654 kUNSUPPORTED_STATE = 10,
1655
1656};
1657
1658namespace impl
1659{
1661template <>
1663{
1665 static constexpr int32_t kVALUE = 11;
1666};
1667} // namespace impl
1668
1693{
1694public:
1698 using ErrorDesc = char const*;
1699
1703 // coverity[autosar_cpp14_m0_1_4_violation] Approved RFD: https://jirasw.nvidia.com/browse/TID-489
1704 static constexpr size_t kMAX_DESC_LENGTH{127U};
1705
1709 using RefCount = int32_t;
1710
1711 IErrorRecorder() = default;
1712 virtual ~IErrorRecorder() noexcept = default;
1713
1714 // Public API used to retrieve information from the error recorder.
1715
1734 virtual int32_t getNbErrors() const noexcept = 0;
1735
1753 virtual ErrorCode getErrorCode(int32_t errorIdx) const noexcept = 0;
1754
1774 virtual ErrorDesc getErrorDesc(int32_t errorIdx) const noexcept = 0;
1775
1790 virtual bool hasOverflowed() const noexcept = 0;
1791
1806 virtual void clear() noexcept = 0;
1807
1808 // API used by TensorRT to report Error information to the application.
1809
1830 virtual bool reportError(ErrorCode val, ErrorDesc desc) noexcept = 0;
1831
1848 virtual RefCount incRefCount() noexcept = 0;
1849
1866 virtual RefCount decRefCount() noexcept = 0;
1867
1868protected:
1869 // @cond SuppressDoxyWarnings
1870 IErrorRecorder(IErrorRecorder const&) = default;
1871 IErrorRecorder(IErrorRecorder&&) = default;
1872 IErrorRecorder& operator=(IErrorRecorder const&) & = default;
1873 IErrorRecorder& operator=(IErrorRecorder&&) & = default;
1874 // @endcond
1875}; // class IErrorRecorder
1876} // namespace nvinfer1
1877
1883extern "C" TENSORRTAPI int32_t getInferLibVersion() noexcept;
1884
1885#endif // NV_INFER_RUNTIME_COMMON_H
int32_t getInferLibVersion() noexcept
Return the library version number.
#define TRT_DEPRECATED
< Items that are marked as deprecated will be removed in a future release.
Definition: NvInferRuntimeCommon.h:77
#define NV_TENSORRT_MINOR
TensorRT minor version.
Definition: NvInferVersion.h:60
#define NV_TENSORRT_MAJOR
TensorRT major version.
Definition: NvInferVersion.h:59
#define NV_TENSORRT_PATCH
TensorRT patch version.
Definition: NvInferVersion.h:61
Definition: NvInferRuntimeCommon.h:190
int32_t nbDims
The rank (number of dimensions).
Definition: NvInferRuntimeCommon.h:195
static constexpr int32_t MAX_DIMS
The maximum rank (number of dimensions) supported for a tensor.
Definition: NvInferRuntimeCommon.h:193
int32_t d[MAX_DIMS]
The extent of each dimension.
Definition: NvInferRuntimeCommon.h:197
Reference counted application-implemented error reporting interface for TensorRT objects.
Definition: NvInferRuntimeCommon.h:1693
char const * ErrorDesc
Definition: NvInferRuntimeCommon.h:1698
int32_t RefCount
Definition: NvInferRuntimeCommon.h:1709
Application-implemented class for controlling allocation on the GPU.
Definition: NvInferRuntimeCommon.h:1372
virtual bool deallocate(void *const memory) noexcept
Definition: NvInferRuntimeCommon.h:1482
virtual void * reallocate(void *, uint64_t, uint64_t) noexcept
Definition: NvInferRuntimeCommon.h:1457
virtual ~IGpuAllocator()=default
virtual void * allocate(uint64_t const size, uint64_t const alignment, AllocatorFlags const flags) noexcept=0
virtual TRT_DEPRECATED void free(void *const memory) noexcept=0
Application-implemented logging interface for the builder, refitter and runtime.
Definition: NvInferRuntimeCommon.h:1506
Severity
Definition: NvInferRuntimeCommon.h:1514
virtual void log(Severity severity, AsciiChar const *msg) noexcept=0
Plugin creator class for user implemented layers.
Definition: NvInferRuntimeCommon.h:1108
virtual int32_t getTensorRTVersion() const noexcept
Return the version of the API the plugin creator was compiled with.
Definition: NvInferRuntimeCommon.h:1117
virtual AsciiChar const * getPluginName() const noexcept=0
Return the plugin name.
Single registration point for all plugins in an application. It is used to find plugin implementation...
Definition: NvInferRuntimeCommon.h:1244
virtual bool registerCreator(IPluginCreator &creator, AsciiChar const *const pluginNamespace) noexcept=0
Register a plugin creator. Returns false if one with same type is already registered.
virtual IPluginCreator * getPluginCreator(AsciiChar const *const pluginName, AsciiChar const *const pluginVersion, AsciiChar const *const pluginNamespace="") noexcept=0
Return plugin creator based on plugin name, version, and namespace associated with plugin during netw...
virtual IPluginCreator *const * getPluginCreatorList(int32_t *const numCreators) const noexcept=0
Return all the registered plugin creators and the number of registered plugin creators....
Plugin class for user-implemented layers.
Definition: NvInferRuntimeCommon.h:714
virtual bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept=0
Return true if plugin can use input that is broadcast across batch without replication.
void configureWithFormat(Dims const *, int32_t, Dims const *, int32_t, DataType, PluginFormat, int32_t) noexcept override
Derived classes should not implement this. In a C++11 API it would be override final.
Definition: NvInferRuntimeCommon.h:901
virtual bool isOutputBroadcastAcrossBatch(int32_t outputIndex, bool const *inputIsBroadcasted, int32_t nbInputs) const noexcept=0
Return true if output tensor is broadcast across a batch.
IPluginV2Ext * clone() const noexcept override=0
Clone the plugin object. This copies over internal plugin parameters as well and returns a new plugin...
virtual void configurePlugin(Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType const *inputTypes, DataType const *outputTypes, bool const *inputIsBroadcast, bool const *outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept=0
Configure the layer with input and output data types.
virtual void detachFromContext() noexcept
Detach the plugin object from its execution context.
Definition: NvInferRuntimeCommon.h:857
virtual void attachToContext(cudnnContext *, cublasContext *, IGpuAllocator *) noexcept
Attach the plugin object to an execution context and grant the plugin the access to some context reso...
Definition: NvInferRuntimeCommon.h:839
virtual nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const *inputTypes, int32_t nbInputs) const noexcept=0
Return the DataType of the plugin output at the requested index.
Plugin class for user-implemented layers.
Definition: NvInferRuntimeCommon.h:411
virtual AsciiChar const * getPluginType() const noexcept=0
Return the plugin type. Should match the plugin name returned by the corresponding plugin creator.
virtual void terminate() noexcept=0
Release resources acquired during plugin layer initialization. This is called when the engine is dest...
virtual int32_t getTensorRTVersion() const noexcept
Return the API version with which this plugin was built.
Definition: NvInferRuntimeCommon.h:423
virtual void setPluginNamespace(AsciiChar const *pluginNamespace) noexcept=0
Set the namespace that this plugin object belongs to. Ideally, all plugin objects from the same plugi...
virtual void configureWithFormat(Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType type, PluginFormat format, int32_t maxBatchSize) noexcept=0
Configure the layer.
virtual void serialize(void *buffer) const noexcept=0
Serialize the layer.
virtual void destroy() noexcept=0
Destroy the plugin object. This will be called when the network, builder or engine is destroyed.
virtual AsciiChar const * getPluginVersion() const noexcept=0
Return the plugin version. Should match the plugin version returned by the corresponding plugin creat...
virtual Dims getOutputDimensions(int32_t index, Dims const *inputs, int32_t nbInputDims) noexcept=0
Get the dimension of an output tensor.
virtual IPluginV2 * clone() const noexcept=0
Clone the plugin object. This copies over internal plugin parameters and returns a new plugin object ...
virtual int32_t enqueue(int32_t batchSize, void const *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept=0
Execute the layer.
virtual size_t getSerializationSize() const noexcept=0
Find the size of the serialization buffer required.
virtual size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept=0
Find the workspace size required by the layer.
virtual int32_t getNbOutputs() const noexcept=0
Get the number of outputs from the layer.
virtual AsciiChar const * getPluginNamespace() const noexcept=0
Return the namespace of the plugin object.
virtual int32_t initialize() noexcept=0
Initialize the layer for execution. This is called when the engine is created.
virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept=0
Check format support.
Plugin class for user-implemented layers.
Definition: NvInferRuntimeCommon.h:917
int32_t getTensorRTVersion() const noexcept override
Return the API version with which this plugin was built. The upper byte is reserved by TensorRT and i...
Definition: NvInferRuntimeCommon.h:1005
virtual void configurePlugin(PluginTensorDesc const *in, int32_t nbInput, PluginTensorDesc const *out, int32_t nbOutput) noexcept=0
Configure the layer.
virtual bool supportsFormatCombination(int32_t pos, PluginTensorDesc const *inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept=0
Return true if plugin supports the format and datatype for the input/output indexed by pos.
Structure containing plugin attribute field names and associated data This information can be parsed ...
Definition: NvInferRuntimeCommon.h:1060
AsciiChar const * name
Plugin field attribute name.
Definition: NvInferRuntimeCommon.h:1065
void const * data
Plugin field attribute data.
Definition: NvInferRuntimeCommon.h:1069
int32_t length
Number of data entries in the Plugin attribute.
Definition: NvInferRuntimeCommon.h:1078
PluginFieldType type
Plugin field attribute type.
Definition: NvInferRuntimeCommon.h:1074
The TensorRT API version 1 namespace.
ErrorCode
Error codes that can be returned by TensorRT during execution.
Definition: NvInferRuntimeCommon.h:1568
PluginFieldType
Definition: NvInferRuntimeCommon.h:1031
@ kUNKNOWN
Unknown field type.
@ kFLOAT32
FP32 field type.
@ kCHAR
char field type.
@ kINT16
INT16 field type.
@ kDIMS
nvinfer1::Dims field type.
@ kFLOAT64
FP64 field type.
@ kFLOAT16
FP16 field type.
char_t AsciiChar
AsciiChar is the type used by TensorRT to represent valid ASCII characters.
Definition: NvInferRuntimeCommon.h:125
char char_t
char_t is the type used by TensorRT to represent all valid characters.
Definition: NvInferRuntimeCommon.h:123
@ kV2_DYNAMICEXT
IPluginV2DynamicExt.
@ kV2_IOEXT
IPluginV2IOExt.
@ kV2_EXT
IPluginV2Ext.
DataType
The type of weights and tensors.
Definition: NvInferRuntimeCommon.h:151
@ kFLOAT
32-bit floating point format.
@ kBOOL
8-bit boolean. 0 = false, 1 = true, other values undefined.
@ kHALF
IEEE 16-bit floating-point format.
@ kINT8
8-bit integer representing a quantized floating-point value.
@ kINT32
Signed 32-bit integer format.
TensorFormat PluginFormat
PluginFormat is reserved for backward compatibility.
Definition: NvInferRuntimeCommon.h:345
@ kINT8
Enable Int8 layer selection, with FP32 fallback with FP16 fallback if kFP16 also specified.
TensorFormat
Format of the input/output tensors.
Definition: NvInferRuntimeCommon.h:221
constexpr int32_t EnumMax() noexcept
Maximum number of elements in an enumeration type.
Definition: NvInferRuntimeCommon.h:141
AllocatorFlag
Definition: NvInferRuntimeCommon.h:1350
@ kRESIZABLE
TensorRT may call realloc() on this allocation.
Definition of plugin versions.
Plugin field collection struct.
Definition: NvInferRuntimeCommon.h:1092
PluginField const * fields
Pointer to PluginField entries.
Definition: NvInferRuntimeCommon.h:1096
int32_t nbFields
Number of PluginField entries.
Definition: NvInferRuntimeCommon.h:1094
Fields that a plugin might see for an input or output.
Definition: NvInferRuntimeCommon.h:370
DataType type
Definition: NvInferRuntimeCommon.h:374
Dims dims
Dimensions.
Definition: NvInferRuntimeCommon.h:372
TensorFormat format
Tensor format.
Definition: NvInferRuntimeCommon.h:376
float scale
Scale for INT8 data type.
Definition: NvInferRuntimeCommon.h:378
Declaration of EnumMaxImpl struct to store maximum number of elements in an enumeration type.
Definition: NvInferRuntimeCommon.h:136