API Migration Guide#

This section highlights the TensorRT API modifications. If you are unfamiliar with these changes, refer to our sample code for clarification.

Python#

Python API Changes#

Allocating Buffers and Using a Name-Based Engine API

 1def allocate_buffers(self, engine):
 2    '''
 3    Allocates all buffers required for an engine, i.e., host/device inputs/outputs.
 4    '''
 5    inputs = []
 6    outputs = []
 7    bindings = []
 8    stream = cuda.Stream()
 9
10    # binding is the name of input/output
11    for binding in the engine:
12        size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
13        dtype = trt.nptype(engine.get_binding_dtype(binding))
14
15        # Allocate host and device buffers
16        host_mem = cuda.pagelocked_empty(size, dtype) # page-locked memory buffer (won't swap to disk)
17        device_mem = cuda.mem_alloc(host_mem.nbytes)
18
19        # Append the device buffer address to device bindings.
20        # When cast to int, it's a linear index into the context's memory (like memory address).
21        bindings.append(int(device_mem))
22
23        # Append to the appropriate input/output list.
24        if engine.binding_is_input(binding):
25            inputs.append(self.HostDeviceMem(host_mem, device_mem))
26        else:
27            outputs.append(self.HostDeviceMem(host_mem, device_mem))
28
29    return inputs, outputs, bindings, stream
 1def allocate_buffers(self, engine):
 2    '''
 3    Allocates all buffers required for an engine, i.e., host/device inputs/outputs.
 4    '''
 5    inputs = []
 6    outputs = []
 7    bindings = []
 8    stream = cuda.Stream()
 9
10    for i in range(engine.num_io_tensors):
11        tensor_name = engine.get_tensor_name(i)
12        size = trt.volume(engine.get_tensor_shape(tensor_name))
13        dtype = trt.nptype(engine.get_tensor_dtype(tensor_name))
14
15        # Allocate host and device buffers
16        host_mem = cuda.pagelocked_empty(size, dtype) # page-locked memory buffer (won't swap to disk)
17        device_mem = cuda.mem_alloc(host_mem.nbytes)
18
19        # Append the device buffer address to device bindings.
20        # When cast to int, it's a linear index into the context's memory (like memory address).
21        bindings.append(int(device_mem))
22
23        # Append to the appropriate input/output list.
24        if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
25            inputs.append(self.HostDeviceMem(host_mem, device_mem))
26        else:
27            outputs.append(self.HostDeviceMem(host_mem, device_mem))
28
29    return inputs, outputs, bindings, stream

Transition from enqueueV2 to enqueueV3 for Python

 1# Allocate device memory for inputs.
 2d_inputs = [cuda.mem_alloc(input_nbytes) for binding in range(input_num)]
 3
 4# Allocate device memory for outputs.
 5h_output = cuda.pagelocked_empty(output_nbytes, dtype=np.float32)
 6d_output = cuda.mem_alloc(h_output.nbytes)
 7
 8# Transfer data from host to device.
 9cuda.memcpy_htod_async(d_inputs[0], input_a, stream)
10cuda.memcpy_htod_async(d_inputs[1], input_b, stream)
11cuda.memcpy_htod_async(d_inputs[2], input_c, stream)
12
13# Run inference
14context.execute_async_v2(bindings=[int(d_inp) for d_inp in d_inputs] + [int(d_output)], stream_handle=stream.handle)
15
16# Synchronize the stream
17stream.synchronize()
 1# Allocate device memory for inputs.
 2d_inputs = [cuda.mem_alloc(input_nbytes) for binding in range(input_num)]
 3
 4# Allocate device memory for outputs.
 5h_output = cuda.pagelocked_empty(output_nbytes, dtype=np.float32)
 6d_output = cuda.mem_alloc(h_output.nbytes)
 7
 8# Transfer data from host to device.
 9cuda.memcpy_htod_async(d_inputs[0], input_a, stream)
10cuda.memcpy_htod_async(d_inputs[1], input_b, stream)
11cuda.memcpy_htod_async(d_inputs[2], input_c, stream)
12
13# Setup tensor address
14bindings = [int(d_inputs[i]) for i in range(3)] + [int(d_output)]
15
16for i in range(engine.num_io_tensors):
17    context.set_tensor_address(engine.get_tensor_name(i), bindings[i])
18
19# Run inference
20context.execute_async_v3(stream_handle=stream.handle)
21
22# Synchronize the stream
23stream.synchronize()

Engine Building, use only build_serialized_network

1engine_bytes = None
2try:
3    engine_bytes = self.builder.build_serialized_network(self.network, self.config)
4except AttributeError:
5    engine = self.builder.build_engine(self.network, self.config)
6    engine_bytes = engine.serialize()
7    del engine
8assert engine_bytes
1engine_bytes = self.builder.build_serialized_network(self.network, self.config)
2if engine_bytes is None:
3    log.error("Failed to create engine")
4    sys.exit(1)

Added Python APIs#

  • APILanguage

  • ExecutionContextAllocationStrategy

  • IGpuAsyncAllocator

  • InterfaceInfo

  • IPluginResource

  • IPluginV3

  • IStreamReader

  • IVersionedInterface

  • ICudaEngine.is_debug_tensor()

  • ICudaEngine.minimum_weight_streaming_budget

  • ICudaEngine.streamable_weights_size

  • ICudaEngine.weight_streaming_budget

  • IExecutionContext.get_debug_listener()

  • IExecutionContext.get_debug_state()

  • IExecutionContext.set_all_tensors_debug_state()

  • IExecutionContext.set_debug_listener()

  • IExecutionContext.set_tensor_debug_state()

  • IExecutionContext.update_device_memory_size_for_shapes()

  • IGpuAllocator.allocate_async()

  • IGpuAllocator.deallocate_async()

  • INetworkDefinition.add_plugin_v3()

  • INetworkDefinition.is_debug_tensor()

  • INetworkDefinition.mark_debug()

  • INetworkDefinition.unmark_debug()

  • IPluginRegistry.acquire_plugin_resource()

  • IPluginRegistry.all_creators

  • IPluginRegistry.deregister_creator()

  • IPluginRegistry.get_creator()

  • IPluginRegistry.register_creator()

  • IPluginRegistry.release_plugin_resource()

Removed Python APIs#

Removed Python APIs and their Suggested Superseded API#

Python API

Superseded API

BuilderFlag.ENABLE_TACTIC_HEURISTIC

Builder optimization level 2

BuilderFlag.STRICT_TYPES

Use all three flags:

BuilderFlag.DIRECT_IO
BuilderFlag.PREFER_PRECISION_CONSTRAINTS
BuilderFlag.REJECT_EMPTY_ALGORITHMS
EngineCapability.DEFAULT
EngineCapability.kSAFE_DLA
EngineCapability.SAFE_GPU
EngineCapability.STANDARD
EngineCapability.DLA_STANDALONE
EngineCapability.SAFETY
IAlgorithmIOInfo.tensor_format

The strides, data type, and vectorization information are sufficient to identify tensor formats uniquely.

IBuilder.max_batch_size

Implicit batch is no longer supported.

IBuilderConfig.max_workspace_size
IBuilderConfig.set_memory_pool_limit() with MemoryPoolType.WORKSPACE
IBuilderConfig.get_memory_pool_limit() with MemoryPoolType.WORKSPACE
IBuilderConfig.min_timing_iterations
IBuilderConfig.avg_timing_iterations
 1ICudaEngine.binding_is_input()
 2ICudaEngine.get_binding_bytes_per_component()
 3ICudaEngine.get_binding_components_per_element()
 4ICudaEngine.get_binding_dtype()
 5ICudaEngine.get_binding_format()
 6ICudaEngine.get_binding_format_desc()
 7ICudaEngine.get_binding_index()
 8ICudaEngine.get_binding_name()
 9ICudaEngine.get_binding_shape()
10ICudaEngine.get_binding_vectorized_dim()
11ICudaEngine.get_location()
12ICudaEngine.get_profile_shape()
13ICudaEngine.get_profile_shape_input()
14ICudaEngine.has_implicit_batch_dimension()
15ICudaEngine.is_execution_binding()
16ICudaEngine.is_shape_binding()
17ICudaEngine.max_batch_size()
18ICudaEngine.num_bindings()
 1ICudaEngine.get_tensor_mode()
 2ICudaEngine.get_tensor_bytes_per_component()
 3ICudaEngine.get_tensor_components_per_element()
 4ICudaEngine.get_tensor_dtype()
 5ICudaEngine.get_tensor_format()
 6ICudaEngine.get_tensor_format_desc()
 7No name-based equivalent replacement
 8No name-based equivalent replacement
 9ICudaEngine.get_tensor_shape()
10ICudaEngine.get_tensor_vectorized_dim()
11ITensor.location
12ICudaEngine.get_tensor_profile_shape()
13ICudaEngine.get_tensor_profile_values()
14Implicit batch is no longer supported
15No name-based equivalent replacement
16ICudaEngine.is_shape_inference_io()
17Implicit batch is no longer supported
18ICudaEngine.num_io_tensors()
IExecutionContext.get_binding_shape()
IExecutionContext.get_strides()
IExecutionContext.set_binding_shape()
IExecutionContext.get_tensor_shape()
IExecutionContext.get_tensor_strides()
IExecutionContext.set_input_shape()
IFullyConnectedLayer
IMatrixMultiplyLayer
1INetworkDefinition.add_convolution()
2INetworkDefinition.add_deconvolution()
3INetworkDefinition.add_fully_connected()
4INetworkDefinition.add_padding()
5INetworkDefinition.add_pooling()
6INetworkDefinition.add_rnn_v2()
7INetworkDefinition.has_explicit_precision
8INetworkDefinition.has_implicit_batch_dimension
1INetworkDefinition.add_convolution_nd()
2INetworkDefinition.add_deconvolution_nd()
3INetworkDefinition.add_matrix_multiply()
4INetworkDefinition.add_padding_nd()
5INetworkDefinition.add_pooling_nd()
6INetworkDefinition.add_loop()
7Explicit precision support is removed in 10.0
8Implicit batch is no longer supported
IRNNv2Layer
ILoop
NetworkDefinitionCreationFlag.EXPLICIT_BATCH
NetworkDefinitionCreationFlag.EXPLICIT_PRECISION

Support is removed in 10.0

PaddingMode.CAFFE_ROUND_DOWN
PaddingMode.CAFFE_ROUND_UP

Caffe is not supported since 9.0

PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805
PreviewFeature.FASTER_DYNAMIC_SHAPES_0805
  • External tactics are always disabled for core code

  • This flag is on by default

ProfilingVerbosity.DEFAULT
ProfilingVerbosity.VERBOSE
ProfilingVerbosity.LAYER_NAMES_ONLY
ProfilingVerbosity.DETAILED
ResizeMode

Use InterpolationMode. Alias was removed.

SampleMode.DEFAULT
SampleMode.STRICT_BOUNDS
SliceMode

Use SampleMode. Alias was removed.

C++#

C++ API Changes#

Transition from enqueueV2 to enqueueV3 for C++

 1// Create RAII buffer manager object.
 2samplesCommon::BufferManager buffers(mEngine);
 3
 4auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
 5if (!context)
 6{
 7    return false;
 8}
 9
10// Pick a random digit to try to infer.
11srand(time(NULL));
12int32_t const digit = rand() % 10;
13
14// Read the input data into the managed buffers.
15// There should be just 1 input tensor.
16ASSERT(mParams.inputTensorNames.size() == 1);
17
18if (!processInput(buffers, mParams.inputTensorNames[0], digit))
19{
20    return false;
21}
22// Create a CUDA stream to execute this inference.
23cudaStream_t stream;
24CHECK(cudaStreamCreate(&stream));
25
26// Asynchronously copy data from host input buffers to device input
27buffers.copyInputToDeviceAsync(stream);
28
29// Asynchronously enqueue the inference work
30if (!context->enqueueV2(buffers.getDeviceBindings().data(), stream, nullptr))
31{
32    return false;
33}
34// Asynchronously copy data from device output buffers to host output buffers.
35buffers.copyOutputToHostAsync(stream);
36
37// Wait for the work in the stream to complete.
38CHECK(cudaStreamSynchronize(stream));
39
40// Release stream.
41CHECK(cudaStreamDestroy(stream));
 1// Create RAII buffer manager object.
 2samplesCommon::BufferManager buffers(mEngine);
 3
 4auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
 5if (!context)
 6{
 7    return false;
 8}
 9
10for (int32_t i = 0, e = mEngine->getNbIOTensors(); i < e; i++)
11{
12    auto const name = mEngine->getIOTensorName(i);
13    context->setTensorAddress(name, buffers.getDeviceBuffer(name));
14}
15
16// Pick a random digit to try to infer.
17srand(time(NULL));
18int32_t const digit = rand() % 10;
19
20// Read the input data into the managed buffers.
21// There should be just 1 input tensor.
22ASSERT(mParams.inputTensorNames.size() == 1);
23
24if (!processInput(buffers, mParams.inputTensorNames[0], digit))
25{
26    return false;
27}
28// Create a CUDA stream to execute this inference.
29cudaStream_t stream;
30CHECK(cudaStreamCreate(&stream));
31
32// Asynchronously copy data from host input buffers to device input
33buffers.copyInputToDeviceAsync(stream);
34
35// Asynchronously enqueue the inference work
36if (!context->enqueueV3(stream))
37{
38    return false;
39}
40
41// Asynchronously copy data from device output buffers to host output buffers.
42buffers.copyOutputToHostAsync(stream);
43
44// Wait for the work in the stream to complete.
45CHECK(cudaStreamSynchronize(stream));
46
47// Release stream.
48CHECK(cudaStreamDestroy(stream));

64-Bit Dimension Changes#

The dimensions held by Dims changed from int32_t to int64_t. However, in TensorRT 10.0, TensorRT will generally reject networks that use dimensions exceeding the range of int32_t. The tensor type returned by IShapeLayer is now DataType::kINT64. Use ICastLayer to cast the result to the tensor of type DataType::kINT32 if 32-bit dimensions are required.

Inspect code that bitwise copies to and from Dims to ensure it is correct for int64_t dimensions.

Added C++ APIs#

  • ActivationType::kGELU_ERF

  • ActivationType::kGELU_TANH

  • BuilderFlag::kREFIT_IDENTICAL

  • BuilderFlag::kSTRIP_PLAN

  • BuilderFlag::kWEIGHT_STREAMING

  • Datatype::kINT4

  • LayerType::kPLUGIN_V3

  • APILanguage

  • Dims64

  • ExecutionContextAllocationStrategy

  • IGpuAsyncAllocator

  • InterfaceInfo

  • IPluginResource

  • IPluginV3

  • IStreamReader

  • IVersionedInterface

  • getInferLibBuildVersion

  • getInferLibMajorVersion

  • getInferLibMinorVersion

  • getInferLibPatchVersion

  • ICudaEngine::createRefitter

  • IcudaEngine::getMinimumWeightStreamingBudget

  • IcudaEngine::getStreamableWeightsSize

  • ICudaEngine::getWeightStreamingBudget

  • IcudaEngine::isDebugTensor

  • ICudaEngine::setWeightStreamingBudget

  • IExecutionContext::getDebugListener

  • IExecutionContext::getTensorDebugState

  • IExecutionContext::setAllTensorsDebugState

  • IExecutionContext::setDebugListener

  • IExecutionContext::setOuputTensorAddress

  • IExecutionContext::setTensorDebugState

  • IExecutionContext::updateDeviceMemorySizeForShapes

  • IGpuAllocator::allocateAsync

  • IGpuAllocator::deallocateAsync

  • INetworkDefinition::addPluginV3

  • INetworkDefinition::isDebugTensor

  • INetworkDefinition::markDebug

  • INetworkDefinition::unmarkDebug

  • IPluginRegistry::acquirePluginResource

  • IPluginRegistry::deregisterCreator

  • IPluginRegistry::getAllCreators

  • IPluginRegistry::getCreator

  • IPluginRegistry::registerCreator

  • IPluginRegistry::releasePluginResource

Removed C++ APIs#

Removed C++ APIs and their Suggested Superseded API#

C++ API

Superseded API

BuilderFlag::kENABLE_TACTIC_HEURISTIC

Builder optimization level 2

BuilderFlag::kSTRICT_TYPES

Use all three flags:

kREJECT_EMPTY_ALGORITHMS
kDIRECT_IO
kPREFER_PRECISION_CONSTRAINTS

Note

When removing enum members (for all enums in this list), we will change enumeration to sequential numbers.

EngineCapability::kDEFAULT
EngineCapability::kSAFE_DLA
EngineCapability::kSAFE_GPU
EngineCapability::kSTANDARD
EngineCapability::kDLA_STANDALONE
EngineCapability::kSAFETY
IAlgorithm::getAlgorithmIOInfo()
IAlgorithm::getAlgorithmIOInfoByIndex()
IAlgorithmIOInfo::getTensorFormat()

The strides, data type, and vectorization information are sufficient to identify tensor formats uniquely.

IBuilder::buildEngineWithConfig()
IBuilder::destroy()
IBuilder::getMaxBatchSize()
IBuilder::setMaxBatchSize()
IBuilder::buildSerializedNetwork()
delete ObjectName
Implicit batch is no longer supported
Implicit batch is no longer supported
IBuilderConfig::destroy()
IBuilderConfig::getMaxWorkspaceSize()
IBuilderConfig::getMinTimingIterations()
IBuilderConfig::setMaxWorkspaceSize()
IBuilderConfig::setMinTimingIterations()
delete ObjectName
IBuilderConfig::getMemoryPoolLimit() with MemoryPoolType::kWORKSPACE
IBuilderConfig::getAvgTimingIterations()
IBuilderConfig::setMemoryPoolLimit() with MemoryPoolType::kWORKSPACE
IBuilderConfig::setAvgTimingIterations()
1IConvolutionLayer::getDilation()
2IConvolutionLayer::getKernelSize()
3IConvolutionLayer::getPadding()
4IConvolutionLayer::getStride()
5IConvolutionLayer::setDilation()
6IConvolutionLayer::setKernelSize()
7IConvolutionLayer::setPadding()
8IConvolutionLayer::setStride()
1IConvolutionLayer::getDilationNd()
2IConvolutionLayer::getKernelSizeNd()
3IConvolutionLayer::getPaddingNd()
4IConvolutionLayer::getStrideNd()
5IConvolutionLayer::setDilationNd()
6IConvolutionLayer::setKernelSizeNd()
7IConvolutionLayer::setPaddingNd()
8IConvolutionLayer::setStrideNd()
 1ICudaEngine::bindingIsInput()
 2ICudaEngine::destroy()
 3ICudaEngine::getBindingBytesPerComponent()
 4ICudaEngine::getBindingComponentsPerElement()
 5ICudaEngine::getBindingDataType()
 6ICudaEngine::getBindingDimensions()
 7ICudaEngine::getBindingFormat()
 8ICudaEngine::getBindingFormatDesc()
 9ICudaEngine::getBindingIndex()
10ICudaEngine::getBindingName()
11ICudaEngine::getBindingVectorizedDim()
12ICudaEngine::getLocation()
13ICudaEngine::getMaxBatchSize()
14ICudaEngine::getNbBindings()
15ICudaEngine::getProfileDimensions()
16ICudaEngine::getProfileShapeValues()
17ICudaEngine::hasImplicitBatchDimension()
18ICudaEngine::isExecutionBinding()
19ICudaEngine::isShapeBinding()
 1ICudaEngine::getTensorIOMode()
 2delete ObjectName
 3ICudaEngine::getTensorBytesPerComponent()
 4ICudaEngine::getTensorComponentsPerElement()
 5ICudaEngine::getTensorDataType()
 6ICudaEngine::getTensorShape()
 7ICudaEngine::getTensorFormat()
 8ICudaEngine::getTensorFormatDesc()
 9Name-based methods
10Name-based methods
11ICudaEngine::getTensorVectorizedDim()
12ITensor::getLocation()
13Implicit batch is no longer supported
14ICudaEngine::getNbIOTensors()
15ICudaEngine::getProfileShape()
16ICudaEngine::getShapeValues()
17Implicit batch is no longer supported
18No name-based equivalent replacement
19ICudaEngine::isShapeInferenceIO()
1IDeconvolutionLayer::getKernelSize()
2IDeconvolutionLayer::getPadding()
3IDeconvolutionLayer::getStride()
4IDeconvolutionLayer::setKernelSize()
5IDeconvolutionLayer::setPadding()
6IDeconvolutionLayer::setStride()
1IDeconvolutionLayer::getKernelSizeNd()
2IDeconvolutionLayer::getPaddingNd()
3IDeconvolutionLayer::getStrideNd()
4IDeconvolutionLayer::setKernelSizeNd()
5IDeconvolutionLayer::setPaddingNd()
6IDeconvolutionLayer::setStrideNd()
 1IExecutionContext::destroy()
 2IExecutionContext::enqueue()
 3IExecutionContext::enqueueV2()
 4IExecutionContext::execute()
 5IExecutionContext::getBindingDimensions()
 6IExecutionContext::getShapeBinding()
 7IExecutionContext::getStrides()
 8IExecutionContext::setBindingDimensions()
 9IExecutionContext::setInputShapeBinding()
10IExecutionContext::setOptimizationProfile()
 1delete ObjectName
 2IExecutionContext::enqueueV3()
 3IExecutionContext::enqueueV3()
 4IExecutionContext::executeV2()
 5IExecutionContext::getTensorShape()
 6IExecutionContext::getTensorAddress() or getOutputTensorAddress()
 7IExecutionContext::getTensorStrides()
 8IExecutionContext::setInputShape()
 9IExecutionContext::setInputTensorAddress() or setTensorAddress()
10IExecutionContext::setOptimizationProfileAsync()
IFullyConnectedLayer
IMatrixMultiplyLayer
IGpuAllocator::free()
IGpuAllocator::deallocate()
IHostMemory::destroy()
delete ObjectName
1INetworkDefinition::addConvolution()
2INetworkDefinition::addDeconvolution()
3INetworkDefinition::addFullyConnected()
4INetworkDefinition::addPadding()
5INetworkDefinition::addPooling()
6INetworkDefinition::addRNNv2()
7INetworkDefinition::destroy()
8INetworkDefinition::hasExplicitPrecision()
9INetworkDefinition::hasImplicitBatchDimension()
1INetworkDefinition::addConvolutionNd()
2INetworkDefinition::addDeconvolutionNd()
3INetworkDefinition::addMatrixMultiply()
4INetworkDefinition::addPaddingNd()
5INetworkDefinition::addPoolingNd()
6INetworkDefinition::addLoop()
7delete ObjectName
8Explicit precision support is removed in 10.0
9Implicit batch support is removed
IOnnxConfig::destroy()
delete ObjectName
IPaddingLayer::getPostPadding()
IPaddingLayer::getPrePadding()
IPaddingLayer::setPostPadding()
IPaddingLayer::setPrePadding()
IPaddingLayer::getPostPaddingNd()
IPaddingLayer::getPrePaddingNd()
IPaddingLayer::setPostPaddingNd()
IPaddingLayer::setPrePaddingNd()
1IPoolingLayer::getPadding()
2IPoolingLayer::getStride()
3IPoolingLayer::getWindowSize()
4IPoolingLayer::setPadding()
5IPoolingLayer::setStride()
6IPoolingLayer::setWindowSize()
1IPoolingLayer::getPaddingNd()
2IPoolingLayer::getStrideNd()
3IPoolingLayer::getWindowSizeNd()
4IPoolingLayer::setPaddingNd()
5IPoolingLayer::setStrideNd()
6IPoolingLayer::setWindowSizeNd()
IRefitter::destroy()
delete ObjectName
IResizeLayer::getAlignCorners()
IResizeLayer::setAlignCorners()
IResizeLayer::getAlignCornersNd()
IResizeLayer::setAlignCornersNd()
IRuntime::deserializeCudaEngine
    (void const* blob, std::size_t size,
    IPluginFactory* pluginFactory)
IRuntime::destroy()

Use deserializeCudaEngine with two parameters

delete ObjectName
IRNNv2Layer
ILoop
kNV_TENSORRT_VERSION_IMPL
#define NV_TENSORRT_VERSION_INT(major, minor, patch)
    ((major) *10000L + (minor) *100L + (patch) *1L)

Note

TensorRT version encoding was changed to accommodate a two-digit minor version.

NetworkDefinitionCreationFlag::kEXPLICIT_BATCH
NetworkDefinitionCreationFlag::kEXPLICIT_PRECISION

Support is removed in 10.0

NV_TENSORRT_SONAME_MAJOR
NV_TENSORRT_SONAME_MINOR
NV_TENSORRT_SONAME_PATCH
NV_TENSORRT_MAJOR
NV_TENSORRT_MINOR
NV_TENSORRT_PATCH
PaddingMode::kCAFFE_ROUND_DOWN
PaddingMode::kCAFFE_ROUND_UP

Caffe is not supported since 9.0

PreviewFeature::kDISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805
PreviewFeature::kFASTER_DYNAMIC_SHAPES_0805
  • External tactics are always disabled for core code

  • This flag is on by default

ProfilingVerbosity::kDEFAULT
ProfilingVerbosity::kVERBOSE
ProfilingVerbosity::kLAYER_NAMES_ONLY
ProfilingVerbosity::kDETAILED
ResizeMode

Use InterpolationMode. Alias was removed.

RNNDirection
RNNGateType
RNNInputMode
RNNOperation

RNN-related data structures are removed

SampleMode::kDEFAULT
SampleMode::kSTRICT_BOUNDS
SliceMode

Use SampleMode. Alias was removed.

Removed C++ Plugins#

Removed C++ Plugins and their Suggested Superseded Plugin#

C++ Plugin

Superseded Plugin

 1createAnchorGeneratorPlugin()
 2createBatchedNMSPlugin()
 3createInstanceNormalizationPlugin()
 4createNMSPlugin()
 5createNormalizePlugin()
 6createPriorBoxPlugin()
 7createRegionPlugin()
 8createReorgPlugin()
 9createRPNROIPlugin()
10createSplitPlugin()
 1GridAnchorPluginCreator::createPlugin()
 2BatchedNMSPluginCreator::createPlugin()
 3InstanceNormalizationPluginCreator::createPlugin()
 4NMSPluginCreator::createPlugin()
 5NormalizePluginCreator::createPlugin()
 6PriorBoxPluginCreator::createPlugin()
 7RegionPluginCreator::createPlugin()
 8ReorgPluginCreator::createPlugin()
 9RPROIPluginCreator::createPlugin()
10INetworkDefinition::addSlice()
struct Quadruple

Related plugins are removed

Removed Safety C++ APIs#

Removed Safety C++ APIs and their Suggested Superseded Safety API#

Safety C++ API

Superseded Safety API

 1safe::ICudaEngine::bindingIsInput()
 2safe::ICudaEngine::getBindingBytesPerComponent()
 3safe::ICudaEngine::getBindingComponentsPerElement()
 4safe::ICudaEngine::getBindingDataType()
 5safe::ICudaEngine::getBindingDimensions()
 6safe::ICudaEngine::getBindingIndex()
 7safe::ICudaEngine::getBindingName()
 8safe::ICudaEngine::getBindingVectorizedDim()
 9safe::ICudaEngine::getNbBindings()
10safe::ICudaEngine::getTensorFormat()
 1safe::ICudaEngine::tensorIOMode()
 2safe::ICudaEngine::getTensorBytesPerComponent()
 3safe::ICudaEngine::getTensorComponentsPerElement()
 4safe::ICudaEngine::getTensorDataType()
 5safe::ICudaEngine::getTensorShape()
 6safe::name-based methods
 7safe::name-based methods
 8safe::ICudaEngine::getTensorVectorizedDim()
 9safe::ICudaEngine::getNbIOTensors()
10safe::ICudaEngine::getBindingFormat()
safe::IExecutionContext::enqueueV2()
safe::IExecutionContext::getStrides()
safe::IExecutionContext::enqueueV3()
safe::IExecutionContext::getTensorStrides()

trtexec#

trtexec Flag Changes#

Changes to flag workspace and minTiming

1trtexec \
2    --onnx=/path/to/model.onnx \
3    --saveEngine=/path/to/engine.trt \
4    --optShapes=input:$INPUT_SHAPE \
5    --avgTiming=1 \
6    --workspace=1024 \
7    --minTiming=1
1trtexec \
2    --onnx=/path/to/model.onnx \
3    --saveEngine=/path/to/engine.trt \
4    --optShapes=input:$INPUT_SHAPE \
5    --avgTiming=1 \
6    --memPoolSize=workspace:1024

Removed trtexec Flags#

Removed trtexec Flags and their Suggested Superseded Flag#

trtexec Flag

Superseded Flag

--minTiming
avgTiming

--preview=features options:

  • disableExternalTacticSourcesForCore0805

  • fasterDynamicShapes0805

N/A

--workspace=N
--memPoolSize=poolspec

Deprecated trtexec Flags#

  • --buildOnly

  • --explicitPrecision

  • --heuristic

  • --nvtxMode