Writing custom operators with TensorRT Python plugins#

TensorRT (TRT) offers a wide array of built-in operators that fit most use cases. However, you may wish to introduce a custom operator for a variety of reasons, including:

  • Supporting an entirely new operator that TRT does not support out-of-the-box

  • Adding an operator with a slightly different behavior to an operator that is already supported

TRT provides support for custom operators through plugins. This guide shows how Python functions that define the behavior of a plugin can be implemented and wrapped such that they can be added to a network as a custom operator.

Composition of a plugin#

Primarily, a plugin definition needs two functions, wrapped with decorators provided by the tensorrt.plugin module:

  1. tensorrt.plugin.register(): Returns shape and type characteristics of output tensors. The function signature also defines the input tensors and any attributes the plugin needs to function.

  2. tensorrt.plugin.impl(): Performs the plugin computation

Optionally, if a plugin is able to support multiple data type/tensor layout combinations for its I/O, or is able to support multiple backends (e.g. kernels), it can also take advantage of TensorRT’s auototuning capabilities to let it pick the most performant configuration of the plugin on the target platform. To enable autotuning, simply define a function wrapped with tensorrt.plugin.autotune().

Let us consider a few examples to understand the specifics of these functions.

Example: Circular padding plugin#

Circular padding is useful for ops like circular convolution. The following image denotes how the original image (red) is circular padded once (green) and twice (blue):

alt text

The plugin shall have the following characteristics:

  • Input: \(N\)-dimensional input

  • Attribute(s): \(m\)-dimensional parameter pads where \(m\) is even and \(m/2 \le N\). pads denotes the amount of padding to apply before and after each of the \(m/2\) last dimensions of the input tensor.

  • Output: Padded tensor, which has the same data type as the input. The i-th to-last dimension of the output, \(d^{\text{out}}_{N-i-1}\) can be expressed in terms of the corresponding input dimension, \(d^{\text{in}}_{N-i-1}\), as:

\[ d^{\text{out}}_{N-i-1} = d^{\text{in}}_{N-i-1} + \text{pads}_{2i} + \text{pads}_{2i + 1} \]

Let’s capture this information in the tensorrt.plugin.register function:

import tensorrt.plugin as trtp
import numpy.typing as npt

@trtp.register("example::circ_pad_plugin")
def circ_pad_plugin_desc(
    inp0: trtp.TensorDesc, pads: npt.NDArray[np.int32]
) -> trtp.TensorDesc:
    ndim = inp0.ndim
    out_desc = inp0.like()

    for i in range(np.size(pads) // 2):
        out_desc.shape_expr[ndim - i - 1] += int(
            pads[i * 2] + pads[i * 2 + 1]
        )

    return out_desc

The argument example::circ_pad_plugin defines the namespace (“example”) and name (“circ_pad_plugin”) of the plugin. Input arguments annotated with tensorrt.plugin.TensorDesc denote the input tensors; all others are interpreted as plugin attributes. Supported attribute types are:

  • int, float, str, bool, bytes. Lists/tuples of these types are not supported.

  • 1-D Numpy arrays of the following types: int8, int16, int32, int64, float16, float32, float64, bool. These must be annotated with numpy.typing.NDArray[dtype], where dtype is the expected Numpy data type.

The output signature is a trt.plugin.TensorDesc describing the output. To construct the output tensor descriptor, we start with inp0.like(), which returns a tensor descriptor with identical shape and type characteristics to inp0. Since the output shape is in fact different from the input, symbolic expressions for the output shape is accessed through tensorrt.plugin.TensorDesc.shape_expr, to which the expected shape is written.

Now let’s define the computation function, decorated with tensorrt.plugin.impl(). For simplicity, let’s leverage PyTorch’s torch.nn.functional.pad to do the computation.

import tensorrt.plugin as trtp

@trtp.impl("example::circ_pad_plugin")
def circ_pad_plugin_impl(
    inp0: trtp.Tensor,
    pads: npt.NDArray[np.int32],
    outputs: Tuple[trtp.Tensor],
    stream: int
) -> None:
    inp_t = torch.as_tensor(inp0, device="cuda")
    out_t = torch.as_tensor(outputs[0], device="cuda")

    out = torch.nn.functional.pad(inp_t, pads.tolist(), mode="circular")
    out_t.copy_(out)

Note that the decorated function receives tensorrt.plugin.Tensors for each input and output. In contrast to TensorDescs, a Tensor references an underlying data buffer, directly accessible through tensorrt.plugin.Tensor.data_ptr(). When working with Torch and OpenAI Triton kernels, it is easier to use torch.as_tensor() to zero-copy construct a torch.Tensor corresponding to the tensorrt.plugin.Tensor.

Picking the most performant plugin configuration: Autotuning#

Let’s assume the plugin is able to support both FP32 and FP16 I/O, in linear tensor layouts (tensor formats). If performance is key, and it is uncertain whether FP32 or FP16 would perform better on the target platform, we can define a function decorated with tensorrt.plugin.autotune().

@trtp.autotune("example::circ_pad_plugin")
def circ_pad_plugin_autotune(
    inp0: trtp.TensorDesc,
    pads: npt.NDArray[np.int32],
    outputs: Tuple[trtp.TensorDesc],
) -> List[trtp.AutoTuneCombination]:
    return [trtp.AutoTuneCombination("FP32|FP16, FP32|FP16", "LINEAR")]

The decorated function must return a list of tensorrt.plugin.AutoTuneCombinations. In this case, we define a single combination AutoTuneCombination("FP32|FP16, FP32|FP16", "LINEAR"); this indicates that the input and output must be either both FP32 or both FP16, and each have linear formats. When an autotune function is defined, during the engine build, TRT may execute the plugin (invoke its impl function) with random inputs, for each of the type/format combinations.

Adding the plugin to a TensorRT network#

Now that we have defined the plugin, it can be added to a TensorRT network. Depending on your workflow, you may need to add the plugin directly to a TRT tensorrt.INetworkDefinition through the TRT Python API, or you may want the plugin to be added by the TRT ONNX parser when loading an ONNX graph containing the custom operator.

Adding the plugin using TRT Python APIs#

The tensorrt.INetworkDefinition.add_plugin() API can be used to add a plugin to a network definition (an instance of tensorrt.INetworkDefinition):

input_tensor = network.add_input(name="x", dtype=trt.DataType.FLOAT, shape=x.shape)
plugin_layer = network.add_plugin(trt.plugin.op.example.circ_pad_plugin(input_tensor))

Note that the registered plugin is findable in tensorrt.plugin.op under its registered namespace and name.

Loading an ONNX model with the custom operator#

It is possible to load an ONNX model with a custom op node which you need to run through the TRT plugin. To allow the TRT ONNX parser to correctly recognize your plugin as being mapped to the ONNX node(s) of interest, ensure that:

  • The op property of the ONNX node is exactly the same as your plugin name.

  • The node contains a string attribute called plugin_namespace with the namespace of your plugin.

For example, if using ONNX Graphsurgeon, the custom op node can be constructed as follows:

import onnx_graphsurgeon as gs

var_x = gs.Variable(name="x", shape=inp_shape, dtype=np.float32)
var_y = gs.Variable(name="y", dtype=np.float32)

circ_pad_node = gs.Node(
    name="circ_pad_plugin",
    op="circ_pad_plugin",
    inputs=[var_x],
    outputs=[var_y],
    attrs={"pads": pads, "plugin_namespace": "example"},
)

Advanced usage#

Example: Operators with data-dependent output shapes - Non-zero#

Non-zero is an operation where the indices of the non-zero elements of the input tensor is found. Therefore, it has data-dependent output shapes (DDS), for which typical shape calculations cannot be done with input shapes.

To handle DDS, the extent of each data-dependent output dimension must be expressed in terms of a size tensor, which is a scalar that communicates to TRT an upper-bound and an autotune value for that dimension, in terms of the input shapes. The TRT engine build may be optimized for the autotune value, but the extent of that dimension may stretch up to the upper-bound at runtime.

In this example, we consider a 2D input tensor inp0; the output will be an \(N \times 2\) tensor (a set of \(N\) 2D indices), where \(N\) is the number of non-zero indices. At maximum, all elements could be non-zero, and so the upper-bound could be expressed as upper_bound = inp0.shape_expr[0] * inp0.shape_expr[1].

On average, we can expect half of the input to be filled with zero, so a size tensor can be constructed with that as the autotune value:

st = trt.plugin.size_tensor(opt = upper_bound // 2, upper_bound = upper_bound)

Now we’re ready to construct the output shape. st.expr() returns a shape expression for the size tensor, so a tensor descriptor for the output shape can be constructed as trt.plugin.from_shape_expr((st.expr(), 2), dtype=trt.int32). TRT requires that any size tensors also be made outputs of the plugin. Putting things together, we arrive at the following:

import tensorrt.plugin as trtp

@trtp.register("example::non_zero_plugin")
def non_zero_plugin_desc(
    inp0: trtp.TensorDesc,
) -> Tuple[trtp.TensorDesc, trtp.TensorDesc]:
    upper_bound = inp0.shape_expr[0] * inp0.shape_expr[1]
    st = trtp.size_tensor(upper_bound // 2, upper_bound)
    return trtp.from_shape_expr((st.expr(), 2), dtype=trt.int32), st

Example: Implementing in-place custom ops with I/O aliasing#

In-place computations can be accomplished with TRT plugins via aliased I/O. i.e. An input that needs to be modified in-place can be represented by an input-output pair, where the output is aliased to the input. For example, consider a simple elementwise addition plugin. If an in-place addition is needed, that can be achieved as below:

import tensorrt.plugin as trtp

@trtp.register("sample::elemwise_add_plugin_")
def add_plugin_desc_(inp0: trtp.TensorDesc) -> trtp.TensorDesc:
    return inp0.aliased()

Note that inp0.aliased() produces an output TensorDesc that is aliased to inp0.

Example: Plugins with multiple backends: Using custom tactics#

It may be the case that there are multiple kernels or libraries available to perform the plugin computation (for the same IO data type/tensor layout combination) but it is not predeterminable which backend would be the fastest on the target platform. Such alternate backends are termed tactics of the plugin. To let TRT figure out the fastest tactic, it is possible to leverage the autotune function.

Let’s assume that we have an OpenAI Triton kernel circ_pad_kernel which can perform the circular padding operation in the above example. We can ask TRT to pick the fastest between the OpenAI Triton kernel and torch.nn.functional.pad as below:

import tensorrt.plugin as trtp
from enum import IntEnum

class Tactic(IntEnum):
    TORCH = 1
    TRITON = 2

@trt.plugin.autotune("sample::circ_pad_plugin")
def circ_pad_plugin_autotune(inp0: trtp.TensorDesc, pads: npt.NDArray[np.int32], outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]:
    c = trtp.AutoTuneCombination()
    c.pos([0, 1], "FP32|FP16")
    c.tactics([int(Tactic.TORCH), int(Tactic.TRITON)])
    return [c]

Note that we’re using another way of constructing a trt.plugin.AutoTuneCombination here – namely, through pos(...) to populate the type/format information and tactics(...) to specify the tactics.

Now, the impl function can take an additional int argument called tactic to use the appropriate backend:

@trtp.impl("example::circ_pad_plugin")
def circ_pad_plugin_impl(
    inp0: trtp.Tensor,
    pads: npt.NDArray[np.int32],
    outputs: Tuple[trtp.Tensor],
    stream: int,
    tactic: int
) -> None:
    inp_t = torch.as_tensor(inp0, device="cuda")
    out_t = torch.as_tensor(outputs[0], device="cuda")

    if tactic == Tactic.TORCH:
        out = torch.nn.functional.pad(inp_t, pads.tolist(), mode="circular")
        out_t.copy_(out)
    elif tactic == Tactic.TRITON:
        N = inp0.ndim
        all_pads = np.zeros((N * 2,), dtype=np.int32)
        out_dims = trtp.Shape(tuple(inp0.shape))

        block_size = 256
        num_blocks = tuple(
            [int((np.prod(out_dims) + block_size - 1) // block_size)]
        )

        circ_pad_kernel[num_blocks](inp_t, ..., BLOCK_SIZE=block_size)

FAQ#

  1. What happened to plugin creators, plugin registration and plugin libraries?

You may have this question if you have previously dealt with class-based TRT plugins. For decorator based plugins, registration is automatically handled by the tensorrt.plugin module as long as the tensorrt.plugin.register function is defined.

To simulate the experience of plugin libraries, you may define plugins under a common namespace on a separate Python module, and then load the module when that plugin “library” needs to be loaded.

  1. How is the serialization/deserialization of plugin attributes handled?

The tensorrt.plugin module automatically serializes any plugin attributes included in the tensorrt.plugin.impl function signature into the TRT engine. When that engine is loaded, those attributes are deserialized and passed back to the tensorrt.plugin.impl function.

Note

If only a subset of the plugin attributes included in tensorrt.plugin.register is required for the plugin computation, only include those in tensorrt.plugin.impl function signature. This should avoid the unnecessary serialization of data, and therefore produce leaner TRT engines.

Limitations#

  • TRT plugins defined in Python result in TRT engines that, when executed, rely on the availability of a Python environment with those plugin definitions. To build Python-independent TRT engines, it is recommended to use the TRT C++ plugin interfaces.

  • This guide describes plugin implementations through the tensorrt.plugin module, which supports some of the most common use cases for plugin usage. For more advanced use cases, such as passing shape inputs, it is recommended to define plugins as direct implementations of the tensorrt.IPluginV3 interface.