tensorrt.plugin.aot_impl¶
- tensorrt.plugin.aot_impl(plugin_id: str) Callable ¶
Wraps a function to define an Ahead-of-Time (AOT) implementation for a plugin already registered through trt.plugin.register.
This API is only intended to be used as a decorator. The decorated function is not required to have type hints for input arguments or return value; however, any type hints specified will be validated against the trt.plugin.register signature for consistency.
The schema for the function is as follows: .. code-block:: text
(inp0: TensorDesc, inp1: TensorDesc, …, attr0: SupportedAttrType, attr1: SupportedAttrType, outputs: Tuple[TensorDesc], tactic: Optional[int]) -> Tuple[str, str, KernelLaunchParams, SymExprs]
Input tensors are passed first, each described by a TensorDesc.
- Plugin attributes are declared next.
Not all attributes included in trt.plugin.register must be specified here – they could be a subset.
NOTE: Plugin attributes are not serialized into the engine when using an AOT implementation.
tactic is an optional argument. If the plugin is using custom tactics, it must be specified to receive the tactic value to use for the current execution of the plugin.
- Parameters:
plugin_id – The ID for the plugin in the form “{namespace}::{name}”, which must match that used during trt.plugin.register
- Returns:
kernel_name: The name of the kernel.
compiled_kernel: Compiled form of the kernel. Presently, only PTX is supported.
launch_params: The launch parameters for the kernel
extra_args: Symbolic expressions for scalar inputs to the kernel, located after the tensor inputs and before the tensor outputs
Implementation of an elementwise plugin with an OpenAI Triton kernel¶1import tensorrt.plugin as trtp 2import triton 3import triton.language as tl 4 5@triton.jit 6def add_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr): 7 pid = tl.program_id(0) 8 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 9 mask = offsets < n_elements 10 x = tl.load(x_ptr + offsets, mask=mask) 11 tl.store(y_ptr + offsets, x + 1, mask=mask) 12 13@trtp.register("my::add_plugin") 14def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]: 15 return inp0.like() 16 17@trtp.aot_impl("my::elemwise_add_plugin") 18def add_plugin_aot_impl( 19 inp0: trtp.TensorDesc, block_size: int, single_tactic: bool, outputs: Tuple[trtp.TensorDesc], tactic: int 20) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]: 21 22 type_str = "fp32" if inp0.dtype == trt.float32 else "fp16" 23 24 src = triton.compiler.ASTSource( 25 fn=add_kernel, 26 signature=f"*{type_str},i32,*{type_str}", 27 constants={ 28 "BLOCK_SIZE": block_size, 29 }, 30 ) 31 32 compiled_kernel = triton.compile(src) 33 34 N = inp0.shape_expr.numel() 35 launch_params = trtp.KernelLaunchParams() 36 37 # grid dims 38 launch_params.grid_x = trtp.cdiv(N, block_size) 39 # block dims 40 launch_params.block_x = compiled_kernel.metadata.num_warps * 32 41 # shared memory 42 launch_params.shared_mem = compiled_kernel.metadata.shared 43 44 extra_args = trtp.SymIntExprs(1) 45 extra_args[0] = trtp.SymInt32(N) 46 47 return compiled_kernel.metadata.name, compiled_kernel.asm["ptx"], launch_params, extra_args
See also: