Adding Custom Layers Using the C++ API#
There are four steps to ensure that TensorRT properly recognizes your plugin:
Implement a plugin class from one of TensorRT’s plugin base classes. Currently, the only recommended one is
IPluginV3.Implement a plugin creator class tied to your class by deriving from one of TensorRT’s plugin creator-based classes. Currently, the only recommended one is
IPluginCreatorV3One.Register an instance of the plugin creator class with TensorRT’s plugin registry.
Add an instance of the plugin class to a TensorRT network by directly using TensorRT’s network APIs or loading an ONNX model using the TensorRT ONNX parser APIs.
The following sections explore each of these steps in detail.
Implementing a Plugin Class#
You can implement a custom layer by deriving from one of TensorRT’s plugin base classes. Starting in TensorRT 10.0, the only plugin interface recommended is IPluginV3, as others are deprecated. Therefore, this section mostly describes plugin implementation using IPluginV3. Refer to the Migrating V2 Plugins to IPluginV3 section for how plugins implementing V2 plugin interfaces can be migrated to IPluginV3.
IPluginV3 is a wrapper for a set of capability interfaces that define three capabilities: core, build, and runtime.
Core capability: Refers to plugin attributes and behaviors common to both the build and runtime phases of a plugin’s lifetime.
Build capability: Refers to plugin attributes and behaviors that the plugin must exhibit for the TensorRT builder.
Runtime capability: Refers to plugin attributes and behaviors that the plugin must exhibit for it to be executable, either during auto-tuning in the TensorRT build phase or inference in the TensorRT runtime phase.
IPluginV3OneCore (C++, Python), IPluginV3OneBuild (C++, Python), and IPluginV3OneRuntime (C++, Python) are the base classes that an IPluginV3 plugin must implement to display the core, build, and runtime capabilities, respectively. If I/O aliasing is required, IPluginV3OneBuildV2 (C++, Python) can be used as the build capability, which contains a superset of the functionalities in IPluginV3OneBuild.
Implementing a Plugin Creator Class#
To use a plugin in a network, you must first register it with TensorRT’s PluginRegistry (C++, Python). Rather than registering the plugin directly, you register an instance of a factory class for the plugin, derived from a child class of IPluginCreatorInterface (C++, Python). The plugin creator class also provides other information about the plugin: its name, version, and plugin field parameters.
IPluginCreatorV3One is the factory class for IPluginV3. IPluginCreatorV3One::createPlugin(), which has the signature below.
1IPluginV3* createPlugin(AsciiChar const *name, PluginFieldCollection const *fc, TensorRTPhase phase)
1create_plugin(self: trt.IPluginCreatorV3, name: str, field_collection: trt.PluginFieldCollection, phase: trt.TensorRTPhase) -> trt.IPluginV3
IPluginCreatorV3One::createPlugin() can be called to create a plugin instance in either the build phase of TensorRT or the runtime phase of TensorRT, which is communicated by the phase argument of type TensorRTPhase (C++, Python).
The returned
IPluginV3object must have a valid core capability in both phases.In the build phase, the returned
IPluginV3object must have both a build and runtime capability.In the runtime phase, the returned
IPluginV3object must have a runtime capability. A build capability is not required and is ignored.
Registering a Plugin Creator with the Plugin Registry#
There are two ways that you can register plugin creators with the registry:
Statically register by calling
REGISTER_TENSORRT_PLUGIN.REGISTER_TENSORRT_PLUGINalways registers the creator under the default namespace (“”).Dynamically register by creating an entry point similar to
initLibNvInferPluginsand callingregisterCreatoron the plugin registry. This is preferred over static registration as it allows plugins to be registered under a unique namespace. This ensures no name collisions during build time across different plugin libraries.
During serialization, the TensorRT engine internally stores the plugin name, plugin version, and namespace (if it exists) for all plugins, along with any plugin fields in the PluginFieldCollection returned by IPluginV3OneRuntime::getFieldsToSerialize(). During deserialization, TensorRT looks up a plugin creator with the same plugin name, version, and namespace from the plugin registry and invokes IPluginCreatorV3One:::createPlugin() on it—the PluginFieldCollection that was serialized is passed back as the fc argument.
Adding a Plugin Instance to a TensorRT Network#
You can add a plugin to the TensorRT network using addPluginV3(), which creates a network layer with the given plugin.
For example, you can add a plugin layer to your network as follows:
// Look up the plugin in the registry
// Cast to appropriate child class of IPluginCreatorInterface
auto creator = static_cast<IPluginCreatorV3One*>(getPluginRegistry()->getCreator(pluginName, pluginVersion, pluginNamespace));
PluginFieldCollection const* pluginFC = creator->getFieldNames();
// Populate the field parameters for the plugin layer
PluginFieldCollection *pluginData = parseAndFillFields(pluginFC, layerFields);
// Create the plugin object using the layerName and the plugin metadata for use by the TensorRT builder
IPluginV3 *pluginObj = creator->createPlugin(layerName, pluginData, TensorRTPhase::kBUILD);
// Add the plugin to the TensorRT network
auto layer = network.addPluginV3(inputs.data(), int(inputs.size()), shapeInputs.data(), int(shapeInputs.size()), pluginObj);
//... (build rest of the network and serialize engine)
// Delete the plugin object
delete pluginObj;d
// ... (free allocated pluginData)
The createPlugin method described previously creates a new plugin object on the heap and returns a pointer. As shown previously, ensure you delete the pluginObj to avoid a memory leak.
When the engine is deleted, the engine destroys any clones of the plugin object created during the build. You are responsible for ensuring the plugin object you created is freed after it is added to the network.
Note
Do not serialize all plugin parameters, only those required to function correctly at runtime. Build time parameters can be omitted.
If you are an automotive safety user, you must call
getSafePluginRegistry()instead ofgetPluginRegistry(). You must also use the macroREGISTER_SAFE_TENSORRT_PLUGINinstead ofREGISTER_TENSORRT_PLUGIN. Refer to the NVIDIA TensorRT Safety Production Guide for DriveOS for any safety-related activities.
Example: Adding a Custom Layer with Dynamic Shapes Using C++#
Imagine that a custom layer is needed for a padding-like operation where each image in an input batch must be reshaped to 32 x 32. The input tensor X would be of shape (B, C, H, W), and the output Y would be of shape (B, C, 32, 32). To accomplish this, a TensorRT plugin can be written using the IPluginV3 interface; let us call it PadPlugin.
Since an IPluginV3 plugin must possess multiple capabilities, each defined by a separate interface, you could implement a plugin using the principle of composition or multiple inheritance. However, a multiple inheritance approach is easier for most use cases, particularly when coupling build and runtime capabilities in a single class is tolerable.
Using multiple inheritance, PadPlugin can be implemented as follows:
class PadPlugin : public IPluginV3, public IPluginV3OneCore, public IPluginV3OneBuild, public IPluginV3OneRuntime
{
...override inherited virtual methods.
};
The override of IPluginV3::getCapabilityInterface must return pointers to the individual capability interfaces. For each PluginCapabilityType, it is imperative to cast through the corresponding capability interface to remove ambiguity for the compiler.
IPluginCapability* PadPlugin::getCapabilityInterface(PluginCapabilityType type) noexcept override
{
// All plugin interface methods are noexcept and care should be
// taken not to throw exceptions across the API boundary. It is
// recommended to catch any exceptions and return a value that
// appropriately represents the error status.
try
{
if (type == PluginCapabilityType::kBUILD)
{
return static_cast<IPluginV3OneBuild*>(this);
}
if (type == PluginCapabilityType::kRUNTIME)
{
return static_cast<IPluginV3OneRuntime*>(this);
}
ASSERT(type == PluginCapabilityType::kCORE);
return static_cast<IPluginV3OneCore*>(this);
}
catch(...)
{
// log error
}
return nullptr;
}
The methods that are of importance in this particular example are:
INetworkDefinition::addPluginV3IPluginV3OneBuild::getNbOutputsIPluginV3OneBuild::getOutputDataTypesIPluginV3OneBuild::getOutputShapesIPluginV3OneBuild::supportsFormatCombinationIPluginV3OneBuild::configurePluginIPluginV3OneRuntime::onShapeChangeIPluginV3OneRuntime::enqueue
INetworkDefinition::addPluginV3 (C++, Python) can add the plugin to the network.
std::vector<ITensor*> inputs{X};
auto pluginLayer = network->addPluginV3(inputs.data(), inputs.size(), nullptr, 0, *plugin);
You can communicate that there is a single plugin output by overriding IPluginV3OneBuild::getNbOutputs.
int32_t PadPlugin::getNbOutputs() const noexcept override
{
return 1;
}
The output will have the same data type as the input, which can be communicated in the override of IPluginV3OneBuild::getOutputDataTypes.
int32_t PadPlugin::getOutputDataTypes(
DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override
{
outputTypes[0] = inputTypes[0];
return 0;
}
The override for getOutputShapes returns symbolic expressions for the output dimensions in terms of the input dimensions, except in the case of data-dependent output shapes, which will be covered later in Example: Adding a Custom Layer with a Data-Dependent and Shape Input-Dependent Shapes Using C++. In the current example, the first two dimensions of the output will equal the first two dimensions of the input, respectively, and the last two dimensions will be constants, each equal to 32. The IExprBuilder passed into getOutputShapes can be used to define constant symbolic expressions.
int32_t PadPlugin::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs, int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept
{
outputs[0].nbDims = 4;
// first two output dims are equal to the first two input dims
outputs[0].d[0] = inputs[0].d[0];
outputs[0].d[1] = inputs[0].d[1];
// The last two output dims are equal to 32
outputs[0].d[2] = exprBuilder.constant(32);
outputs[0].d[3] = exprBuilder.constant(32);
return 0;
}
TensorRT uses supportsFormatCombination to ask whether the plugin accepts a given type and format combination for a connection at a given position pos and given formats/types for lesser-indexed connections. The interface indexes the inputs/outputs uniformly as connections, starting at 0 for the first input, then the rest of the inputs in order, followed by numbering the outputs. In the example, the input is connection 0, and the output is connection 1.
For the sake of simplicity, the example supports only linear formats and FP32 types.
bool PadPlugin::supportsFormatCombination(
int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override
{
assert(0 <= pos && pos < 2);
return inOut[pos].desc.format == PluginFormat::kLINEAR && inOut[pos].desc.type == DataType::kFLOAT;
}
TensorRT invokes two methods to allow the plugin to make any configuration choices before enqueue(), both during auto-tuning (in the engine build phase) and when the engine is being executed (in the runtime phase).
IPluginV3OneBuild::configurePlugin: Called when a plugin is being prepared for profiling (auto-tuning) but not for any specific input size. Themin,max, andoptvalues of theDynamicPluginTensorDesccorrespond to the bounds on the tensor shape and its shape for auto-tuning. Thedesc.dimsfield corresponds to the dimensions of the plugin specified at network creation, including any wildcards (-1) for dynamic dimensions.IPluginV3OneRuntime::onShapeChange: Called during both the build-phase and runtime phase beforeenqueue()to communicate the input and output shapes for the subsequentenqueue(). The outputPluginTensorDescwill contain wildcards (-1) for any data-dependent dimensions specified throughgetOutputShapes().
This plugin does not need configurePlugin and onShapeChange to do anything, so they are no-ops:
int32_t PadPlugin::configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override
{
return 0;
}
int32_t PadPlugin::onShapeChange(PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override
{
return 0;
}
Finally, the override PadPlugin::enqueue has to do the work. Since shapes are dynamic, enqueue is handed a PluginTensorDesc that describes each input and output’s dimensions, type, and format.
int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) noexcept override
{
// populate outputs and return status code
}
Example: Adding a Custom Layer with Data-Dependent and Shape Input-Dependent Shapes Using C++#
This section shows an example of a plugin with data-dependent and shape-input-dependent shapes. Note that data-dependent output shapes and adding shape inputs to a plugin are new features not present in V2 plugins.
Data-dependent Shapes (DDS): The shape of a plugin output could depend on the values of the input tensors.
Shape inputs: A plugin could accept shape and device tensor inputs. These inputs are only visible to the plugin as arguments to
IPluginV3OneBuild::getOutputShapes(). Therefore, their sole purpose is to aid the plugin in performing output shape calculations.
For example, BarPlugin is a plugin with one device input X, one shape input S, and an output Y, where:
The first dimension of
Ydepends on the value ofS.The second dimension of
Yis static.The third dimension of
Ydepends on the shape ofX.The fourth dimension of
Yis data-dependent.
Similar to PadPlugin in the prior example, BarPlugin uses multiple inheritance.
To add the plugin to the network, INetworkDefinition::addPluginV3 (C++, Python) can be used similarly. After the device tensor inputs, addPluginV3 takes two additional arguments to specify the shape tensor inputs.
std::vector<ITensor*> inputs{X};
std::vector<ITensor*> shapeInputs{S};
auto pluginLayer = network->addPluginV3(inputs.data(), inputs.size(), shapeInputs.data(), shapeInputs.size(), *plugin);
Note
The TensorRT ONNX parser provides an inbuilt feature to pass shape inputs to custom ops supported by IPluginV3-based plugins. The indices of the inputs to be interpreted as shape inputs must be indicated by a node attribute named tensorrt_plugin_shape_input_indices as a list of integers. For example, if the custom op has four inputs and the second and fourth inputs should be passed as shape inputs to the plugin, add a node attribute named tensorrt_plugin_shape_input_indices of type onnx.AttributeProto.ints containing the value [1, 3].
In the override for getOutputShapes, plugins must declare both the position and the bounds of each data-dependent dimension of each output tensor. The bounds can be expressed using a special output called a size tensor.
A size tensor is a scalar of either INT32 or INT64 data type, expressed through a value for auto-tuning and an upper bound; these values can either be constants or computed in terms of device input shapes or shape input values using IExprBuilder.
In this case, there is a singular data-dependent dimension, which we can represent using one size tensor. Note that any size tensor needed to express a data-dependent dimension counts as an output of the plugin; therefore, the plugin will have two outputs in total.
int32_t getNbOutputs() const noexcept override
{
return 2;
}
Assume output Y is the same type as the device input X and that the data-dependent dimension size fits INT32 (the size tensor has type r). Then BarPlugin expresses the output data types like this:
int32_t getOutputDataTypes(
DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override
{
outputTypes[0] = inputTypes[0];
outputTypes[1] = DataType::kINT32;
return 0;
}
The method getOutputShapes can build symbolic output shape expressions using the IExprBuilder passed to it. In what follows, note that size tensors must be explicitly declared 0D.
int32_t BarPlugin::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs, int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept
{
outputs[0].nbDims = 4;
// The first output dimension depends on the value of S.
// The value of S is encoded as fictitious dimensions.
outputs[0].d[0] = shapeInputs[0].d[0];
// The third output dimension depends on the shape of X
outputs[0].d[2] = inputs[0].d[0];
// The second output dimension is static
outputs[0].d[1] = exprBuilder.constant(3);
auto upperBound = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[2], *inputs[0].d[3]);
auto optValue = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *upperBound, *exprBuilder.constant(2));
// output at index 1 is a size tensor
outputs[1].nbDims = 0; // size tensors must be declared as 0-D
auto sizeTensor = exprBuilder.declareSizeTensor(1, *optValue, *upperBound);
// The fourth output dimension is data-dependent
outputs[0].d[3] = sizeTensor;
return 0;
}
The override of supportsFormatCombination imposes the following conditions:
The device input
Xmust haveDataType::kFLOATorDataType::kHALF.The output
Ymust have the same type asX.The size tensor output has the type
DataType::kINT32.
Note
Shape inputs passed to the plugin through addPluginV3 (C++, Python) only appear as arguments to getOutputShapes() and are not counted or included among plugin inputs in any other plugin interface method.
bool BarPlugin::supportsFormatCombination(
int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override
{
assert(0 <= pos && pos < 3);
auto const* in = inOut;
auto const* out = inOut + nbInputs;
bool typeOk{false};
switch (pos)
{
case 0: typeOk = in[0].desc.type == DataType::kFLOAT || in[0].desc.type == DataType::kHALF; break;
case 1: typeOk = out[0].desc.type == in[0].desc.type; break;
case 2: typeOk = out[1].desc.type == DataType::kINT32; break;
}
return inOut[pos].desc.format == PluginFormat::kLINEAR && typeOk;
}
The local variables in and out here allow inspecting inOut by input or output number instead of connection number.
Important
The override inspects the format/type for a connection with an index less than pos but must never inspect the format/type for a connection with an index greater than pos. The example uses case 1 to check connection 1 against connection 0 and not case 0 to check connection 0 against connection 1.
configurePlugin and onShapeChange would be no-ops here, too; one thing to note is that in onShapeChange, the output’s PluginTensorDesc will contain a wildcard (-1) for the data-dependent dimension.
Implementing enqueue with data-dependent output shapes differs greatly from the static or dynamic shape cases. As with any other output, for an output with a data-dependent dimension, the output buffer passed to enqueue is guaranteed large enough to hold the corresponding output tensor (based on the upper bound specified through getOutputShapes).
Example: Adding a Custom Layer with INT8 I/O Support Using C++#
PoolPlugin is a plugin demonstrating how to addINT8 I/O for a custom pooling layer using IPluginV3. PoolPlugin uses multiple inheritance to derive from IPluginV3, IPluginV3OneCore, IPluginV3OneBuild, and IPluginV3OneRuntime, similar to the PadPlugin and BarPlugin examples above.
The main methods that affect INT8 I/O are:
supportsFormatCombinationconfigurePlugin
The override for supportsFormatCombination must indicate which INT8 I/O combination is allowed. This interface is similar to Example: Adding a Custom Layer with Dynamic Shapes using C++. In this example, the supported I/O tensor format is linear CHW with FP32, FP16, BF16, FP8, or INT8 data type, but the I/O tensor must have the same data type.
bool PoolPlugin::supportsFormatCombination(
int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override
{
assert(nbInputs == 1 && nbOutputs == 1 && pos < nbInputs + nbOutputs);
bool condition = inOut[pos].desc.format == PluginFormat::kLINEAR;
condition &= (inOut[pos].desc.type == DataType::kFLOAT ||
inOut[pos].desc.type == DataType::kHALF ||
inOut[pos].desc.type == DataType::kBF16 ||
inOut[pos].desc.type == DataType::kFP8 ||
inOut[pos].desc.type == DataType::kINT8);
condition &= inOut[pos].desc.type == inOut[0].desc.type;
return condition;
}
Important
If INT8 calibration must be used with a network with INT8 I/O plugins, the plugin must support FP32 I/O, as TensorRT uses FP32 to calibrate the graph.
If the FP32 I/O variant is not supported or INT8 calibration is not used, all required INT8 I/O tensor scales must be set explicitly.
Calibration cannot determine the dynamic range of a plugin’s internal tensors. Plugins that operate on quantized data must calculate their dynamic range for internal tensors.
A plugin can be designed to accept FP8 and INT8 I/O types, although note that in TensorRT 9.0, the builder does not allow networks that mix INT8 and FP8.
Information communicated by TensorRT through configurePlugin or onShapeChange can be used to obtain information about the pooling parameters and the input and output scales. These can be stored as member variables, serialized, and then deserialized to be used during inference.
int32_t PoolPlugin::configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override
{
...
mPoolingParams.mC = in.desc.d[1];
mPoolingParams.mH = in.desc.d[2];
mPoolingParams.mW = in.desc.d[3];
mPoolingParams.mP = out.desc.d[2];
mPoolingParams.mQ = ou.desc.d[3];
mInHostScale = in[0].desc.scale >= 0.0F ? in[0].desc.scale : -1.0F;
mOutHostScale = out[0].desc.scale >= 0.0F ? out[0].desc.scale : -1.0F;
}
INT8 I/O scales on a per-tensor basis have been obtained from PluginTensorDesc::scale.