Refitting an Engine#
TensorRT can refit an engine with new weights without having to rebuild it. However, the option to do so must be specified when building:
...
config->setFlag(BuilderFlag::kREFIT)
builder->buildSerializedNetwork(network, config);
Later, you can create a Refitter object:
ICudaEngine* engine = ...;
IRefitter* refitter = createInferRefitter(*engine,gLogger)
Then, update the weights. For example, to update a set of weights named Conv Layer Kernel Weights:
Weights newWeights = ...;
refitter->setNamedWeights("Conv Layer Kernel Weight",
newWeights);
The new weights should have the same count as the original weights used to build the engine. setNamedWeights returns false if something goes wrong, such as a wrong weights name or a change in the weights count.
You can use INetworkDefinition::setWeightsName() to name weights at build time—the ONNX parser uses this API to associate the weights with the names used in the ONNX model. Otherwise, TensorRT will name the weights internally based on the related layer names and weight roles.
You can also pass GPU weights to the refitter using:
Weights newBiasWeights = ...;
refitter->setNamedWeights("Conv Layer Bias Weight", newBiasWeights, TensorLocation::kDEVICE);
Because of how the engine is optimized, if you change some weights, you might have to supply some other weights, too. The interface can tell you what additional weights must be supplied.
This typically requires two calls to IRefitter::getMissingWeights, first to get the number of weights objects that must be supplied, and second to get their layers and roles.
int32_t const n = refitter->getMissingWeights(0, nullptr);
std::vector<const char*> weightsNames(n);
refitter->getMissingWeights(n, weightslayerNames.data());
You can supply the missing weights in any order:
for (int32_t i = 0; i < n; ++i)
refitter->setNamedWeights(weightsNames[i], Weights{...});
The set of missing weights returned is complete because supplying only the missing weights does not require more.
After all the weights have been provided, you can update the engine:
bool success = refitter->refitCudaEngine();
assert(success);
If the refit returns false, check the log for a diagnostic; perhaps the issue is about weights that are still missing. There is also an async version, refitCudaEngineAsync, that can accept a stream parameter.
You can update the weights memory directly and then call refitCudaEngine/ refitCudaEngineAsync in another iteration. If weights pointers need to be changed, call setNamedWeights to override the previous setting. Call unsetNamedWeights to unset previously set weights so that they will not be used in later refitting, and it becomes safe to release these weights.
After refitting is done, you can then delete the refitter:
delete refitter;
The engine behaves like it was built from a network updated with the new weights. After refitting the engine, the previously created execution context can continue to be used.
To view all refittable weights in an engine, use refitter->getAllWeights(...), which is similar to how getMissingWeights was used above.
Weight-Stripping#
When refit is enabled, all the constant weights in the network can be updated after the engine is built. However, refitting the engine with new weights introduces a cost and a potential runtime impact. The inability to constant-fold weights can prevent the builder from performing some optimizations.
This cost is unavoidable when you do not know the refit weights at build time. However, in some scenarios, you know the weights in advance. For example, you might use TensorRT as one of multiple back ends to execute an ONNX model and want to avoid storing an additional copy of weights in the TensorRT plan.
The weight-stripping build configuration enables this scenario; when enabled, TensorRT enables refit only for constant weights that do not impact the builder’s ability to optimize and produce an engine with the same runtime performance as a non-fittable engine. Those weights are then omitted from the serialized engine, resulting in a small plan file that can be refitted at runtime using the weights from the ONNX model.
The trtexec tool provides the -stripWeights flags for building the weight-stripped engine. For more information, refer to the trtexec section.
The following steps show how to refit the weights for weight-stripped engines. When working with ONNX models, the ONNX parser library can perform the refit automatically. For more information, refer to the Refitting a Weight-Stripped Engine Directly from ONNX section.
Set the corresponding builder flag to enable the weight-stripped build. Here, the
kSTRIP_PLANflag works with eitherkREFITorkREFIT_IDENTICAL. It defaults to the latter. TheREFIT_IDENTICALflag instructs the TensorRT builder to optimize under the assumption that the engine will be refitted with weights identical to those provided at build time. ThekSTRIP_PLANflag minimizes plan size by stripping out the refittable weights.1... 2config->setFlag(BuilderFlag::kSTRIP_PLAN); 3config->setFlag(BuilderFlag::kREFIT_IDENTICAL); 4builder->buildSerializedNetwork(network, config);
1config.flags |= 1 << int(trt.BuilderFlag.STRIP_PLAN) 2config.flags |= 1 << int(trt.BuilderFlag.REFIT_IDENTICAL) 3builder.build_serialized_network(network, config)
After the engine is built, save the plan file and distribute it to the installer.
On the client side, when you launch the network for the first time, update all the weights in the engine. Since all the weights in the engine plan were removed, use the
getAllWeightsAPI.1int32_t const n = refitter->getAllWeights(0, nullptr);
1all_weights = refitter.get_all()
Update the weights one by one.
1for (int32_t i = 0; i < n; ++i) 2 refitter->setNamedWeights(weightsNames[i], Weights{...});
1for name in wts_list: 2 refitter.set_named_weights(name, weights[name])
Save the full engine plan file.
1auto serializationConfig = SampleUniquePtr<nvinfer1::ISerializationConfig>(cudaEngine->createSerializationConfig()); 2auto serializationFlag = serializationConfig->getFlags() 3serializationFlag &= ~(1<< static_cast<uint32_t>(nvinfer1::SerializationFlag::kEXCLUDE_WEIGHTS)); 4serializationConfig->setFlags(serializationFlag) 5auto hostMemory = SampleUniquePtr<nvinfer1::IHostMemory>(cudaEngine->serializeWithConfig(*serializationConfig));
1serialization_config = engine.create_serialization_config() 2serialization_config.flags &= ~(1 << int(trt.SerializationFlag.EXCLUDE_WEIGHTS)) 3binary = engine.serialize_with_config(serialization_config)
The application can now use the new full engine plan file for future inference. Note that when serializing a weight-stripping engine without kEXCLUDE_WEIGHTS flag, the resulting serialized engine is not refittable by default. Setting the kINCLUDE_REFIT flag can ensure that the serialized engine remains refittable.
Refitting a Weight-Stripped Engine Directly from ONNX#
When working with weight-stripped engines created from ONNX models, the refit process can be done automatically with the IParserRefitter class from the ONNX parser library. The following steps show how to create the class and run the refit process.
Create your engine as described in Weight-Stripping, and create an
IRefitterobject.1IRefitter* refitter = createInferRefitter(*engine, gLogger);
1refitter = trt.Refitter(engine, TRT_LOGGER)
Create an
IParserRefitterobject.1IParserRefitter* parserRefitter = createParserRefitter(*refitter, gLogger);
1parser_refitter = trt.OnnxParserRefitter(refitter, TRT_LOGGER)
Call the
refitFromFile()function of theIParserRefitter. Ensure that the ONNX model is identical to the one used to create the weight-stripped engine. This function will returntrueif all the stripped weights are found in the ONNX model; otherwise, it will returnfalse.1bool result = parserRefitter->refitFromFile("path_to_onnx_model");
1result = parser_refitter.refit_from_file("path_to_onnx_model")
Call the
refitfunction of the IRefitter to complete the refit process.1refitSuccess = refitter->refitCudaEngine();
1refit_success = refitter.refit_cuda_engine()
Weight-Stripping Work with Lean Runtime#
You can also leverage the lean runtime to further reduce the package size for the weight-stripped engine. The lean runtime is the same runtime used in version-compatible engines. The original purpose is to allow you to generate a TensorRT engine with version X and load it with an application built with version Y. The lean runtime library is relatively small, approximately 40 MiB. Therefore, software distributors on top of TensorRT only need to ship the weightless engine along with the 40 MiB lean runtime when the weights are already available on the target customer machine.
The recommended approach to build the engine is as follows:
1builderConfig.setFlag(BuilderFlag::kVERSION_COMPATIBLE);
2builderConfig.setFlag(BuilderFlag::kEXCLUDE_LEAN_RUNTIME);
3builderConfig.setFlag(BuilderFlag::kSTRIP_PLAN);
4IHostMemory* plan = builder->buildSerializedNetwork(network, config);
1builder_config.set_flag(tensorrt.BuilderFlag.VERSION_COMPATIBLE)
2builder_config.set_flag(tensorrt.BuilderFlag.EXCLUDE_LEAN_RUNTIME)
3builder_config.set_flag(tensorrt.BuilderFlag.STRIP_PLAN)
4
5plan = builder.build_serialized_network(network, config)
Load the engine with the shared lean runtime library path:
1runtime->loadRuntime("your_lean_runtime_full_path")
1runtime.load_runtime("your_lean_runtime_full_path")
For more information about the lean runtime, refer to the Version Compatibility section.
Fine-Grained Refit Build#
When using the kREFIT builder configuration, all weights are marked as refittable. This is useful when you cannot easily distinguish between trainable and untrainable weights. However, marking all weights as refittable can reduce performance because certain optimizations break when weights are marked as refittable.
For example, with the GELU expression, TensorRT can encode all GELU coefficients in a single CUDA kernel. However, if all coefficients are marked as refittable, TensorRT can no longer fuse the Conv-GELU operations into a single kernel.
To address this, TensorRT provides the fine-grained refit API. This API gives you precise control over which weights are marked as refittable, allowing for more efficient optimization.
Here is an example of marking weights as refittable in the INetworkDefinition:
1...
2network->setWeightsName(Weights(weights), "conv1_filter"));
3network->markWeightsRefittable("conv1_filter");
4assert(network->areWeightsMarkedRefittable("conv1_filter"));
1...
2network.set_weights_name(conv_filter, "conv1_filter")
3network.mark_weights_refittable("conv1_filter")
4assert network.are_weights_marked_refittable("conv1_filter")
Later, we need to update the builder configuration like this:
1...
2config->setFlag(BuilderFlag::kREFIT_INDIVIDUAL)
3builder->buildSerializedNetwork(network, config);
1...
2config.set_flag(trt.BuilderFlag.REFIT_INDIVIDUAL)
3builder.build_serialized_network(network, config)
The remaining refit code follows the same steps as refitting all weights workflow.
Stripping Weights with Fine-Grained Refit Build#
The fine-grained refit build also works with the weights stripping flag. To run this, we must enable both builder flags in the code after marking the necessary weights as refittable.
Here is an example:
1...
2config->setFlag(BuilderFlag::kSTRIP_PLAN);
3config->setFlag(BuilderFlag::kREFIT_INDIVIDUAL);
4builder->buildSerializedNetwork(network, config);
1config.flags |= 1 << int(trt.BuilderFlag.STRIP_PLAN)
2config.flags |= 1 << int(trt.BuilderFlag.REFIT_INDIVIDUAL)
3builder.build_serialized_network(network, config)
The remaining refit and inference codes are the same as the Weight-Stripping sections.