TensorRT  7.1.3.0
nvinfer1::IPluginV2DynamicExt Class Referenceabstract

#include <NvInferRuntime.h>

Inheritance diagram for nvinfer1::IPluginV2DynamicExt:
nvinfer1::IPluginV2Ext nvinfer1::IPluginV2

Public Member Functions

IPluginV2DynamicExtclone () const _TENSORRT_OVERRIDE=0
 Clone the plugin object. This copies over internal plugin parameters as well and returns a new plugin object with these parameters. If the source plugin is pre-configured with configurePlugin(), the returned object should also be pre-configured. The returned object should allow attachToContext() with a new execution context. Cloned plugin objects can share the same per-engine immutable resource (e.g. weights) with the source object (e.g. via ref-counting) to avoid duplication.
 
virtual DimsExprs getOutputDimensions (int outputIndex, const DimsExprs *inputs, int nbInputs, IExprBuilder &exprBuilder)=0
 Get expressions for computing dimensions of an output tensor from dimensions of the input tensors. More...
 
virtual bool supportsFormatCombination (int pos, const PluginTensorDesc *inOut, int nbInputs, int nbOutputs)=0
 Return true if plugin supports the format and datatype for the input/output indexed by pos. More...
 
virtual void configurePlugin (const DynamicPluginTensorDesc *in, int nbInputs, const DynamicPluginTensorDesc *out, int nbOutputs)=0
 Configure the layer. More...
 
virtual size_t getWorkspaceSize (const PluginTensorDesc *inputs, int nbInputs, const PluginTensorDesc *outputs, int nbOutputs) const =0
 Find the workspace size required by the layer. More...
 
virtual int enqueue (const PluginTensorDesc *inputDesc, const PluginTensorDesc *outputDesc, const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream)=0
 Execute the layer. More...
 
- Public Member Functions inherited from nvinfer1::IPluginV2Ext
virtual nvinfer1::DataType getOutputDataType (int index, const nvinfer1::DataType *inputTypes, int nbInputs) const =0
 Return the DataType of the plugin output at the requested index. The default behavior should be to return the type of the first input, or DataType::kFLOAT if the layer has no inputs. The returned data type must have a format that is supported by the plugin. More...
 
virtual void attachToContext (cudnnContext *, cublasContext *, IGpuAllocator *)
 Attach the plugin object to an execution context and grant the plugin the access to some context resource. More...
 
virtual void detachFromContext ()
 Detach the plugin object from its execution context. More...
 
- Public Member Functions inherited from nvinfer1::IPluginV2
virtual const char * getPluginType () const =0
 Return the plugin type. Should match the plugin name returned by the corresponding plugin creator.
 
virtual const char * getPluginVersion () const =0
 Return the plugin version. Should match the plugin version returned by the corresponding plugin creator.
 
virtual int getNbOutputs () const =0
 Get the number of outputs from the layer. More...
 
virtual int initialize ()=0
 Initialize the layer for execution. This is called when the engine is created. More...
 
virtual void terminate ()=0
 Release resources acquired during plugin layer initialization. This is called when the engine is destroyed. More...
 
virtual size_t getSerializationSize () const =0
 Find the size of the serialization buffer required. More...
 
virtual void serialize (void *buffer) const =0
 Serialize the layer. More...
 
virtual void destroy ()=0
 Destroy the plugin object. This will be called when the network, builder or engine is destroyed.
 
virtual void setPluginNamespace (const char *pluginNamespace)=0
 Set the namespace that this plugin object belongs to. Ideally, all plugin objects from the same plugin library should have the same namespace.
 
virtual const char * getPluginNamespace () const =0
 Return the namespace of the plugin object.
 

Static Public Attributes

static constexpr int kFORMAT_COMBINATION_LIMIT = 100
 

Protected Member Functions

int getTensorRTVersion () const _TENSORRT_OVERRIDE
 Return the API version with which this plugin was built. More...
 
TRT_DEPRECATED Dims getOutputDimensions (int, const Dims *, int) _TENSORRT_FINAL
 Derived classes should not implement this. In a C++11 API it would be override final. More...
 
TRT_DEPRECATED bool isOutputBroadcastAcrossBatch (int, const bool *, int) const _TENSORRT_FINAL
 Derived classes should not implement this. In a C++11 API it would be override final. More...
 
TRT_DEPRECATED bool canBroadcastInputAcrossBatch (int) const _TENSORRT_FINAL
 Derived classes should not implement this. In a C++11 API it would be override final. More...
 
TRT_DEPRECATED bool supportsFormat (DataType, PluginFormat) const _TENSORRT_FINAL
 Derived classes should not implement this. In a C++11 API it would be override final. More...
 
TRT_DEPRECATED void configurePlugin (const Dims *, int, const Dims *, int, const DataType *, const DataType *, const bool *, const bool *, PluginFormat, int) _TENSORRT_FINAL
 Derived classes should not implement this. In a C++11 API it would be override final. More...
 
TRT_DEPRECATED size_t getWorkspaceSize (int) const _TENSORRT_FINAL
 Derived classes should not implement this. In a C++11 API it would be override final. More...
 
TRT_DEPRECATED int enqueue (int, const void *const *, void **, void *, cudaStream_t) _TENSORRT_FINAL
 Derived classes should not implement this. In a C++11 API it would be override final. More...
 
- Protected Member Functions inherited from nvinfer1::IPluginV2Ext
int getTensorRTVersion () const _TENSORRT_OVERRIDE
 Return the API version with which this plugin was built. The upper byte reserved by TensorRT and is used to differentiate this from IPlguinV2. More...
 
void configureWithFormat (const Dims *, int, const Dims *, int, DataType, PluginFormat, int) _TENSORRT_OVERRIDE
 Derived classes should not implement this. In a C++11 API it would be override final.
 

Detailed Description

Similar to IPluginV2Ext, but with support for dynamic shapes.

Clients should override the public methods, including the following inherited methods:

virtual int getNbOutputs() const TRTNOEXCEPT = 0;
virtual nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRTNOEXCEPT = 0;
virtual size_t getSerializationSize() const TRTNOEXCEPT = 0;
virtual void serialize(void* buffer) const TRTNOEXCEPT = 0;
virtual void destroy() TRTNOEXCEPT = 0;
virtual void setPluginNamespace(const char* pluginNamespace) TRTNOEXCEPT = 0;
virtual const char* getPluginNamespace() const TRTNOEXCEPT = 0;

For getOutputDataType, the inputTypes will always be DataType::kFLOAT or DataType::kINT32, and the returned type is canonicalized to DataType::kFLOAT if it is DataType::kHALF or DataType:kINT8. Details about the floating-point precision are elicited later by method supportsFormatCombination.

Member Function Documentation

◆ canBroadcastInputAcrossBatch()

TRT_DEPRECATED bool nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch ( int  ) const
inlineprotectedvirtual

Derived classes should not implement this. In a C++11 API it would be override final.

With dynamic shapes, there is no implicit batch dimension to broadcast across.

Implements nvinfer1::IPluginV2Ext.

◆ configurePlugin() [1/2]

TRT_DEPRECATED void nvinfer1::IPluginV2DynamicExt::configurePlugin ( const Dims ,
int  ,
const Dims ,
int  ,
const DataType ,
const DataType ,
const bool *  ,
const bool *  ,
PluginFormat  ,
int   
)
inlineprotectedvirtual

Derived classes should not implement this. In a C++11 API it would be override final.

This method is not used because tensors with dynamic shapes do not have an implicit batch dimension, input dimensions might be variable, and outputs might have different floating-point formats..

Instead, derived classes should override the overload of configurePlugin that takes poiners to DynamicPluginTensorDesc.

Implements nvinfer1::IPluginV2Ext.

◆ configurePlugin() [2/2]

virtual void nvinfer1::IPluginV2DynamicExt::configurePlugin ( const DynamicPluginTensorDesc in,
int  nbInputs,
const DynamicPluginTensorDesc out,
int  nbOutputs 
)
pure virtual

Configure the layer.

This function is called by the builder prior to initialize(). It provides an opportunity for the layer to make algorithm choices on the basis of bounds on the input and output tensors, and the target value.

This function is also called once when the resource requirements are changed based on the optimization profiles.

Parameters
inThe input tensors attributes that are used for configuration.
nbInputsNumber of input tensors.
outThe output tensors attributes that are used for configuration.
nbOutputsNumber of output tensors.

◆ enqueue() [1/2]

virtual int nvinfer1::IPluginV2DynamicExt::enqueue ( const PluginTensorDesc inputDesc,
const PluginTensorDesc outputDesc,
const void *const *  inputs,
void *const *  outputs,
void *  workspace,
cudaStream_t  stream 
)
pure virtual

Execute the layer.

Parameters
inputDeschow to interpret the memory for the input tensors.
outputDeschow to interpret the memory for the output tensors.
inputsThe memory for the input tensors.
outputsThe memory for the output tensors.
workspaceWorkspace for execution.
streamThe stream in which to execute the kernels.
Returns
0 for success, else non-zero (which will cause engine termination).

◆ enqueue() [2/2]

TRT_DEPRECATED int nvinfer1::IPluginV2DynamicExt::enqueue ( int  ,
const void *const *  ,
void **  ,
void *  ,
cudaStream_t   
)
inlineprotectedvirtual

Derived classes should not implement this. In a C++11 API it would be override final.

This method is not used because tensors with dynamic shapes can have different sizes in different execution contexts.

Instead, derived classes should override the overload of enqueue that takes pointers to PluginTensorDesc.

Implements nvinfer1::IPluginV2.

◆ getOutputDimensions() [1/2]

virtual DimsExprs nvinfer1::IPluginV2DynamicExt::getOutputDimensions ( int  outputIndex,
const DimsExprs inputs,
int  nbInputs,
IExprBuilder exprBuilder 
)
pure virtual

Get expressions for computing dimensions of an output tensor from dimensions of the input tensors.

Parameters
outputIndexThe index of the output tensor
inputsExpressions for dimensions of the input tensors
nbInputDimsThe number of input tensors
exprBuilderObject for generating new expressions

This function is called by the implementations of IBuilder during analysis of the network.

Example #1: A plugin has a single output that transposes the last two dimensions of the plugin's single input. The body of the override of getOutputDimensions can be:

DimsExprs output(inputs[0]);
std::swap(output.d[output.nbDims-1], output.d[output.nbDims-2]);
return output;

Example #2: A plugin concatenates its two inputs along the first dimension. The body of the override of getOutputDimensions can be:

DimsExprs output(inputs[0]);
output.d[0] = exprBuilder.operation(DimensionOperation::kSUM, inputs[0].d[0], inputs[1].d[0]);
return output;

◆ getOutputDimensions() [2/2]

TRT_DEPRECATED Dims nvinfer1::IPluginV2DynamicExt::getOutputDimensions ( int  ,
const Dims ,
int   
)
inlineprotectedvirtual

Derived classes should not implement this. In a C++11 API it would be override final.

Instead, derived classes should override the overload of getOutputDimensions that returns DimsExprs.

Implements nvinfer1::IPluginV2.

◆ getTensorRTVersion()

int nvinfer1::IPluginV2DynamicExt::getTensorRTVersion ( ) const
inlineprotectedvirtual

Return the API version with which this plugin was built.

Do not override this method as it is used by the TensorRT library to maintain backwards-compatibility with plugins.

Reimplemented from nvinfer1::IPluginV2.

◆ getWorkspaceSize() [1/2]

virtual size_t nvinfer1::IPluginV2DynamicExt::getWorkspaceSize ( const PluginTensorDesc inputs,
int  nbInputs,
const PluginTensorDesc outputs,
int  nbOutputs 
) const
pure virtual

Find the workspace size required by the layer.

This function is called after the plugin is configured, and possibly during execution. The result should be a sufficient workspace size to deal with inputs and outputs of the given size or any smaller problem.

Returns
The workspace size.

◆ getWorkspaceSize() [2/2]

TRT_DEPRECATED size_t nvinfer1::IPluginV2DynamicExt::getWorkspaceSize ( int  ) const
inlineprotectedvirtual

Derived classes should not implement this. In a C++11 API it would be override final.

This method is not used because tensors with dynamic shapes do not have an implicit batch dimension, and the other dimensions might not be build-time constants.

Instead, derived classes should override the overload of getWorkspaceSize that takes pointers to PluginTensorDesc. The arguments to that overload provide maximum bounds on all dimensions.

Implements nvinfer1::IPluginV2.

◆ isOutputBroadcastAcrossBatch()

TRT_DEPRECATED bool nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch ( int  ,
const bool *  ,
int   
) const
inlineprotectedvirtual

Derived classes should not implement this. In a C++11 API it would be override final.

With dynamic shapes, there is no implicit batch dimension to broadcast across.

Implements nvinfer1::IPluginV2Ext.

◆ supportsFormat()

TRT_DEPRECATED bool nvinfer1::IPluginV2DynamicExt::supportsFormat ( DataType  ,
PluginFormat   
) const
inlineprotectedvirtual

Derived classes should not implement this. In a C++11 API it would be override final.

This method is not used because it does not allow a plugin to specify mixed formats.

Instead, derived classes should override supportsFormatCombination, which allows plugins to express mixed formats.

Implements nvinfer1::IPluginV2.

◆ supportsFormatCombination()

virtual bool nvinfer1::IPluginV2DynamicExt::supportsFormatCombination ( int  pos,
const PluginTensorDesc inOut,
int  nbInputs,
int  nbOutputs 
)
pure virtual

Return true if plugin supports the format and datatype for the input/output indexed by pos.

For this method inputs are numbered 0..(nbInputs-1) and outputs are numbered nbInputs..(nbInputs+nbOutputs-1). Using this numbering, pos is an index into InOut, where 0 <= pos < nbInputs+nbOutputs-1.

TensorRT invokes this method to ask if the input/output indexed by pos supports the format/datatype specified by inOut[pos].format and inOut[pos].type. The override should return true if that format/datatype at inOut[pos] are supported by the plugin. If support is conditional on other input/output formats/datatypes, the plugin can make its result conditional on the formats/datatypes in inOut[0..pos-1], which will be set to values that the plugin supports. The override should not inspect inOut[pos+1..nbInputs+nbOutputs-1], which will have invalid values. In other words, the decision for pos must be based on inOut[0..pos] only.

Some examples:

  • A definition for a plugin that supports only FP16 NCHW:
      return inOut.format[pos] == TensorFormat::kLINEAR && inOut.type[pos] == DataType::kHALF;
    
  • A definition for a plugin that supports only FP16 NCHW for its two inputs, and FP32 NCHW for its single output:
      return inOut.format[pos] == TensorFormat::kLINEAR && (inOut.type[pos] == pos < 2 ?  DataType::kHALF : DataType::kFLOAT);
    
  • A definition for a "polymorphic" plugin with two inputs and one output that supports any format or type, but the inputs and output must have the same format and type:
      return pos == 0 || (inOut.format[pos] == inOut.format[0] && inOut.type[pos] == inOut.type[0]);
    

Warning: TensorRT will stop asking for formats once it finds kFORMAT_COMBINATION_LIMIT on combinations.

Member Data Documentation

◆ kFORMAT_COMBINATION_LIMIT

constexpr int nvinfer1::IPluginV2DynamicExt::kFORMAT_COMBINATION_LIMIT = 100
staticconstexpr

Limit on number of format combinations accepted.


The documentation for this class was generated from the following file: