tensorrt.plugin.impl#
- tensorrt.plugin.impl(plugin_id: str) Callable #
Wraps a function to define an implementation for a plugin already registered through trt.plugin.register.
This API is only intended to be used as a decorator. The decorated function is not required to have type hints for input arguments or return value; however, any type hints specified will be validated against the trt.plugin.register signature for consistency.
The schema for the function is as follows:
(inp0: Tensor, inp1: Tensor, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, outputs: Tuple[Tensor], stream: int, tactic: Optional[int]) -> None
Input tensors are passed first, each described by a Tensor.
- Plugin attributes are declared next.
Not all attributes included in trt.plugin.register must be specified here – they could be a subset.
Included attributes will be serialized to the TRT engine. Therefore, only attributes the plugin actually needs to perform inference (within the body of trt.plugin.impl) should be included.
tactic is an optional argument. If the plugin is using custom tactics, it must be specified to receive the tactic value to use for the current execution of the plugin.
- Parameters
plugin_id – The ID for the plugin in the form “{namespace}::{name}”, which must match that used during trt.plugin.register
1import tensorrt.plugin as trtp 2import triton 3import triton.language as tl 4 5@triton.jit 6def add_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 7 pid = tl.program_id(0) 8 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 9 mask = offsets < n_elements 10 x = tl.load(x_ptr + offsets, mask=mask) 11 tl.store(y_ptr + offsets, x + 1, mask=mask) 12 13@trtp.register("my::add_plugin") 14def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]: 15 return inp0.like() 16 17@trtp.impl("my::add_plugin") 18def add_plugin_impl(inp0: trtp.Tensor, block_size: int, outputs: Tuple[trtp.Tensor], stream: int) -> None: 19 20 n = inp0.numel() 21 inp0_t = torch.as_tensor(inp0, device="cuda") 22 out_t = torch.as_tensor(outputs[0], device="cuda") 23 24 add_kernel[(triton.cdiv(n, block_size),)](inp0_t, out_t, n, BLOCK_SIZE = block_size)