TensorRT Operators#
Boilerplate#
This boilerplate code provides a framework to run all the operator examples. To make them runnable, copy and paste the specific example code between the designated ‘example begin’ and ‘example end’ comments.
Boilerplate for All Operator Examples
import numpy as np
import math # example_plugin_v2.py
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
class OutputAllocator(trt.IOutputAllocator):
def __init__(self, curr_size):
trt.IOutputAllocator.__init__(self)
self.curr_size = curr_size
self.allocated_mem = None
if curr_size > 0:
self.allocated_mem = cuda.mem_alloc(curr_size)
self.tensor_shape = None
def reallocate_output(self, tensor_name, memory, size, alignment):
assert size > 0
if size > self.curr_size:
self.allocated_mem = cuda.mem_alloc(size)
return int(self.allocated_mem)
def notify_shape(self, tensor_name, shape):
self.tensor_shape = shape
class Runner:
def __init__(self, logger=trt.Logger(min_severity=trt.ILogger.Severity.INFO)):
self.builder = trt.Builder(logger)
self.network = self.builder.create_network(flags=0)
self.config = self.builder.create_builder_config()
self.runtime = trt.Runtime(logger)
self.inputs = {}
self.outputs = {}
self.expected = {}
self.results = {}
self.logger = logger
def example(get_runner: Runner):
network = get_runner.network
inputs = get_runner.inputs
outputs = get_runner.outputs
expected = get_runner.expected
# -------------------- Example Begin --------------------
# Paste the code examples here
# e.g. for Activation
in1 = network.add_input("input1", dtype=trt.float32, shape=(2, 3))
layer = network.add_activation(in1, type=trt.ActivationType.RELU)
network.mark_output(layer.get_output(0))
inputs[in1.name] = np.array([[-3.0, -2.0, -1.0], [0.0, 1.0, 2.0]])
outputs[layer.get_output(0).name] = layer.get_output(0).shape
expected[layer.get_output(0).name] = np.array([[0.0, 0.0, 0.0], [0.0, 1.0, 2.0]])
# --------------------- Example End ---------------------
return get_runner
def run_example():
example_runner = Runner()
atol = 0.1
network = example_runner.network
inputs = example_runner.inputs
outputs = example_runner.outputs
expected = example_runner.expected
builder = example_runner.builder
config = example_runner.config
runtime = example_runner.runtime
results = example_runner.results
def log_info(info):
example_runner.logger.log(trt.ILogger.Severity.INFO, f"[Example] {info}")
def log_error(info):
example_runner.logger.log(trt.ILogger.Severity.ERROR, f"[Example] {info}")
example_runner = example(example_runner)
log_info("Building serialized network")
serialized_engine = builder.build_serialized_network(network, config)
assert serialized_engine is not None
log_info("Creating engine")
engine = runtime.deserialize_cuda_engine(serialized_engine)
context = engine.create_execution_context()
# Allocate host and device buffers
in_mem = []
out_mem = dict()
output_allocators = dict()
tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
for tensor in tensor_names:
dtype = trt.nptype(engine.get_tensor_dtype(tensor))
if engine.get_tensor_mode(tensor) == trt.TensorIOMode.INPUT:
if engine.is_shape_inference_io(tensor):
context.set_input_shape(tensor, inputs[tensor].shape)
# Get input memory address from the numpy object.
input_address = inputs[tensor].ctypes.data
context.set_tensor_address(tensor, input_address)
else:
# Handle input tensors
context.set_input_shape(tensor, inputs[tensor].shape)
input_buffer = np.ascontiguousarray(inputs[tensor], dtype=dtype)
input_memory = cuda.mem_alloc(input_buffer.nbytes)
context.set_tensor_address(tensor, int(input_memory))
in_mem.append((input_memory, input_buffer))
else: # Handle output tensors
# Check if output tensor contains unknown shape
if trt.volume(context.get_tensor_shape(tensor)) < 0:
# Set an output allocator for the output tensor with unknown shape.
# Initialize output allocator with 0 memory size, so reallocate always allocate.
output_allocator = OutputAllocator(0)
context.set_output_allocator(tensor, output_allocator)
output_allocators[tensor] = output_allocator
# No need to initialize output buffer and output memory here.
out_mem[tensor] = None
else:
size = trt.volume(context.get_tensor_shape(tensor))
output_buffer = cuda.pagelocked_empty(size, dtype)
output_memory = cuda.mem_alloc(output_buffer.nbytes)
context.set_tensor_address(tensor, int(output_memory))
out_mem[tensor] = (output_buffer, output_memory)
stream = cuda.Stream()
# Transfer input data to the GPU.
for input in in_mem:
cuda.memcpy_htod_async(input[0], input[1], stream)
# Run inference
log_info("Running example")
context.execute_async_v3(stream_handle=stream.handle)
# Transfer prediction output from the GPU.
for output in out_mem:
output_mem = out_mem[output]
if output_mem is None:
# Must have been allocated using OutputAllocator.reallocate.
assert output in output_allocators
assert output_allocators[output].allocated_mem
shape = output_allocators[output].tensor_shape
assert shape is not None
size = trt.volume(shape)
dtype = trt.nptype(engine.get_tensor_dtype(output))
output_buffer = cuda.pagelocked_empty(size, dtype)
output_memory = context.get_tensor_address(output)
output_mem = (output_buffer, output_memory)
# Store tensor to output buffer and output memory mappings.
out_mem[output] = output_mem
cuda.memcpy_dtoh_async(output_mem[0], output_mem[1], stream)
log_info("Synchronizing with cuda stream")
stream.synchronize()
log_info("Sync done")
for output in out_mem:
output_mem = out_mem[output][0]
shape = outputs[output]
if trt.volume(context.get_tensor_shape(tensor)) < 0:
# Get real output tensor size
shape = output_allocators[output].tensor_shape
assert shape is not None
size = trt.volume(shape)
output_mem = output_mem[:size]
output_mem = output_mem.reshape(shape)
results[output] = output_mem
log_info(f"Network inputs: {inputs}")
log_info(f"Inference results: {results}")
log_info(f"Expected results: {expected}")
# Check result
is_equal = {}
all_are_equal = True
for output in expected:
is_equal[output] = np.allclose(results[output], expected[output], atol=atol)
all_are_equal &= is_equal[output]
log_info(f"All results are expected: {all_are_equal}")
if all_are_equal is False:
for output in is_equal:
if is_equal[output] is False:
log_error(f"{output} mismatch:")
log_error(f"expected - content:{expected[output]}")
log_error(f"actual - content:{repr(results[output])}")
log_info("Example complete")
if __name__ == "__main__":
run_example()
- Activation
- Assertion
- Cast
- Concatenation
- Constant
- Convolution
- Deconvolution
- Dequantize
- Einsum
- ElementWise
- Fill
- Gather
- GridSample
- Identity
- If
- LRN
- Loop
- MatrixMultiply
- NMS
- NonZero
- Normalization
- OneHot
- Padding
- ParametricReLU
- PluginV2
- Pooling
- Quantize
- RaggedSoftMax
- Reduce
- Resize
- ReverseSequence
- Scale
- Scatter
- Select
- Shape
- Shuffle
- Slice
- SoftMax
- Squeeze
- TopK
- Unary
- Unsqueeze