tensorrt.plugin.aot_impl

tensorrt.plugin.aot_impl(plugin_id: str) Callable

Wraps a function to define an Ahead-of-Time (AOT) implementation for a plugin already registered through trt.plugin.register.

This API is only intended to be used as a decorator. The decorated function is not required to have type hints for input arguments or return value; however, any type hints specified will be validated against the trt.plugin.register signature for consistency.

The schema for the function is as follows: .. code-block:: text

(inp0: TensorDesc, inp1: TensorDesc, …, attr0: SupportedAttrType, attr1: SupportedAttrType, outputs: Tuple[TensorDesc], tactic: Optional[int]) -> Tuple[str, str, KernelLaunchParams, SymExprs]

  • Input tensors are passed first, each described by a TensorDesc.

  • Plugin attributes are declared next.
    • Not all attributes included in trt.plugin.register must be specified here – they could be a subset.

    • NOTE: Plugin attributes are not serialized into the engine when using an AOT implementation.

  • tactic is an optional argument. If the plugin is using custom tactics, it must be specified to receive the tactic value to use for the current execution of the plugin.

Parameters:

plugin_id – The ID for the plugin in the form “{namespace}::{name}”, which must match that used during trt.plugin.register

Returns:

  • kernel_name: The name of the kernel.

  • compiled_kernel: Compiled form of the kernel. Presently, only PTX is supported.

  • launch_params: The launch parameters for the kernel

  • extra_args: Symbolic expressions for scalar inputs to the kernel, located after the tensor inputs and before the tensor outputs

Implementation of an elementwise plugin with an OpenAI Triton kernel
 1import tensorrt.plugin as trtp
 2import triton
 3import triton.language as tl
 4
 5@triton.jit
 6def add_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
 7    pid = tl.program_id(0)
 8    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
 9    mask = offsets < n_elements
10    x = tl.load(x_ptr + offsets, mask=mask)
11    tl.store(y_ptr + offsets, x + 1, mask=mask)
12
13@trtp.register("my::add_plugin")
14def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
15    return inp0.like()
16
17@trtp.aot_impl("my::elemwise_add_plugin")
18def add_plugin_aot_impl(
19    inp0: trtp.TensorDesc, block_size: int, single_tactic: bool, outputs: Tuple[trtp.TensorDesc], tactic: int
20) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
21
22    type_str = "fp32" if inp0.dtype == trt.float32 else "fp16"
23
24    src = triton.compiler.ASTSource(
25        fn=add_kernel,
26        signature=f"*{type_str},i32,*{type_str}",
27        constants={
28            "BLOCK_SIZE": block_size,
29        },
30    )
31
32    compiled_kernel = triton.compile(src)
33
34    N = inp0.shape_expr.numel()
35    launch_params = trtp.KernelLaunchParams()
36
37    # grid dims
38    launch_params.grid_x = trtp.cdiv(N, block_size)
39    # block dims
40    launch_params.block_x = compiled_kernel.metadata.num_warps * 32
41    # shared memory
42    launch_params.shared_mem = compiled_kernel.metadata.shared
43
44    extra_args = trtp.SymIntExprs(1)
45    extra_args[0] = trtp.SymInt32(N)
46
47    return compiled_kernel.metadata.name, compiled_kernel.asm["ptx"], launch_params, extra_args

See also: