TensorRT Plugins#

Custom TensorRT plugins for specialized operations.

Overview#

The TensorRT Plugins module provides custom TensorRT plugins for operations not natively supported by TensorRT:

  • DMRS - DMRS (Demodulation Reference Signal) generation and extraction

  • FFT - Fast Fourier Transform operations

  • Cholesky - Cholesky factorization and inversion

  • Plugin Management - Global plugin registry and management

Usage#

The TensorRT plugins are automatically registered and available for use in MLIR-TRT lowering. See the PUSCH channel estimation lowering tutorial for examples of using these plugins in practice.

Python API#

TensorRT plugin manager for RAN package.

class ran.trt_plugins.manager.TensorRTPluginManager[source]#

Bases: object

Manager for RAN TensorRT plugins.

__init__() None[source]#

Initialize the plugin manager.

create_plugin(
plugin_name: str,
fields: dict[str, Any] | None = None,
) object | None[source]#

Create a plugin instance.

Parameters:
  • plugin_name – Name of the plugin to create (C++ TensorRT plugin name).

  • fields – Dictionary of plugin fields to pass to the plugin creator.

Return type:

Plugin instance if successful, None otherwise.

load_plugin_library() bool[source]#

Load the TensorRT plugin library from RAN_TRT_PLUGIN_DSO_PATH.

The library path is obtained from the RAN_TRT_PLUGIN_DSO_PATH environment variable via config.get_ran_trt_plugin_dso_path(). When running tests via CMake targets, this environment variable is automatically set by CMake.

Return type:

True if library loaded successfully.

Raises:
ran.trt_plugins.manager.get_ran_trt_plugin_dso_path() str[source]#

Get the RAN TensorRT plugin DSO path from environment variable.

This function returns the path to libran_trt_plugin.so from the RAN_TRT_PLUGIN_DSO_PATH environment variable. When running tests via CMake targets (py_ran_test, py_ran_wheel_test), this environment variable is automatically set by CMake.

For interactive/manual usage, attempts to auto-load from .env.python file.

Return type:

Absolute path to libran_trt_plugin.so

Raises:
ran.trt_plugins.manager.global_trt_plugin_manager_create_plugin(
name: str,
fields: dict[str, Any] | None = None,
) object | None[source]#

Create a TensorRT plugin.

Parameters:

name – Name of the plugin to create.

Return type:

Plugin instance if successful, None otherwise.

ran.trt_plugins.manager.inspect_engine(
engine_path: Path,
verbose: bool = False,
) None[source]#

Inspect a TensorRT engine file and display information.

C++ API Reference#

static constexpr int32_t ran::trt_plugin::DEFAULT_SEQUENCE_LENGTH = 42#
static constexpr int32_t ran::trt_plugin::DEFAULT_N_T = 14#
template<typename T, typename Default = std::nullopt_t>
T ran::trt_plugin::get_plugin_field(
const nvinfer1::PluginFieldCollection *fc,
const std::string_view field_name,
const nvinfer1::PluginFieldType expected_type,
Default &&default_value = std::nullopt,
)#

Extracts a field value from a PluginFieldCollection with optional default

This template function provides a clean interface for parsing plugin fields with automatic error logging when required fields are missing.

Usage examples:

// With default value - no error if missing
auto size = get_plugin_field<std::int32_t>(fc, "fft_size", PluginFieldType::kINT32, 128);

// Without default - logs error if missing
auto size = get_plugin_field<std::int32_t>(fc, "fft_size", PluginFieldType::kINT32);

Note

If default_value is std::nullopt and field is not found, logs error and returns T{}

Note

If field data is nullptr or type mismatches, logs error and returns default/T{}

Template Parameters:
  • T – The type of the field value to extract

  • Default – The type of the default value (deduced, typically T or std::nullopt_t)

Parameters:
  • fc[in] Plugin field collection to search

  • field_name[in] Name of the field to find

  • expected_type[in] Expected TensorRT plugin field type

  • default_value[in] Optional default value; if std::nullopt, logs error when field is missing

Returns:

The field value if found, otherwise the default value or default-constructed T

void ran::trt_plugin::launch_cholesky_factor_inv_kernel(
const float *input_real,
const float *input_imag,
std::int32_t matrix_size,
std::int32_t batch_size,
float *output_real,
float *output_imag,
void *workspace,
cudaStream_t stream,
bool is_complex = false,
)#

CUDA kernel launcher for Cholesky factorization and inversion (batched)

This function launches the CUDA kernel that performs Cholesky decomposition followed by matrix inversion using cuSOLVERDx library for multiple matrices in parallel.

Supports both real and complex data types:

  • For REAL data (is_complex=false):

    • input_real = the real input data

    • input_imag = nullptr (unused)

    • output_real = the real output data

    • output_imag = nullptr (unused)

  • For COMPLEX data (is_complex=true):

    • input_real = real part of complex input

    • input_imag = imaginary part of complex input

    • output_real = real part of complex output

    • output_imag = imaginary part of complex output

TensorRT doesn’t support complex types, so complex data is split into separate real and imaginary arrays at the interface level.

See also

cholesky_factor_inv_kernel for the actual CUDA kernel implementation

Parameters:
  • input_real[in] Real data (if is_complex=false) or real part (if is_complex=true)

  • input_imag[in] Imaginary part (if is_complex=true) or nullptr (if is_complex=false)

  • matrix_size[in] Size of each square matrix (n_ant)

  • batch_size[in] Total number of matrices (batch_size * n_prb)

  • output_real[out] Real output (if is_complex=false) or real part (if is_complex=true)

  • output_imag[out] Imaginary part (if is_complex=true) or nullptr (if is_complex=false)

  • workspace[in] Workspace memory for cuSOLVERDx computation

  • stream[in] CUDA stream for asynchronous execution

  • is_complex[in] false for real data, true for complex data

void ran::trt_plugin::launch_dmrs_kernel(
const std::int32_t *input_params,
std::int32_t sequence_length,
std::int32_t n_t,
float *r_dmrs_ri_sym_cdm_sc,
std::int32_t *scr_seq_sym_ri_sc,
cudaStream_t stream,
)#

Launches DMRS sequence generation kernel

This function launches the CUDA kernel that generates DMRS sequences for all n_t symbols and both n_scid ports (0, 1) using scalar inputs. The kernel internally loops over n_t symbols and n_scid ports to produce two output tensors: complex DMRS values and binary gold sequence.

See also

dmrs_kernel for the actual CUDA kernel implementation

Note

Complex output shape: (2, n_t, 2, sequence_length/2) where dim0: [0]=real, [1]=imag

Note

Binary output shape: (n_t, 2, sequence_length)

Note

input_params is GPU memory containing [slot_number, n_dmrs_id]

Parameters:
  • input_params[in] GPU pointer to [slot_number, n_dmrs_id] array

  • sequence_length[in] Sequence length per port (compile-time constant)

  • n_t[in] Number of OFDM symbols per slot (compile-time constant)

  • r_dmrs_ri_sym_cdm_sc[out] Complex DMRS output (GPU memory, 2 x n_t x 2 x sequence_length/2)

  • scr_seq_sym_ri_sc[out] Binary gold sequence output (GPU memory, n_t x 2 x sequence_length)

  • stream[in] CUDA stream for asynchronous execution

void ran::trt_plugin::launch_fft_kernel(
const float *input_real,
const float *input_imag,
std::int32_t fft_size,
std::int32_t batch_size,
float *output_real,
float *output_imag,
void *workspace,
cudaStream_t stream,
std::int32_t precision = 0,
std::int32_t fft_type = 0,
std::int32_t direction = 0,
std::int32_t ffts_per_block = 1,
std::int32_t elements_per_thread = FftTrtPluginParams::DEFAULT_ELEMENTS_PER_THREAD,
)#

CUDA kernel launcher for FFT computation (batched)

This function launches the CUDA kernel that performs FFT computation using cuFFTDx library for multiple input signals in parallel.

See also

cufft_kernel for the actual CUDA kernel implementation

Parameters:
  • input_real[in] Pointer to input real component array (GPU memory, batch_size * fft_size elements)

  • input_imag[in] Pointer to input imaginary component array (GPU memory, batch_size * fft_size elements)

  • fft_size[in] Size of the FFT to compute

  • batch_size[in] Number of input signals to process in parallel

  • output_real[out] Pointer to output real component buffer (GPU memory, batch_size * fft_size)

  • output_imag[out] Pointer to output imaginary component buffer (GPU memory, batch_size * fft_size)

  • workspace[in] Workspace memory for FFT computation

  • stream[in] CUDA stream for asynchronous execution

  • precision[in] Precision mode (0=float, 1=double)

  • fft_type[in] FFT type (0=C2C, 1=R2C, 2=C2R)

  • direction[in] FFT direction (0=forward, 1=inverse)

  • ffts_per_block[in] Number of FFTs per block

  • elements_per_thread[in] Number of elements processed per thread

void ran::trt_plugin::launch_sequential_sum_kernel(
const float *input,
float *output,
int64_t size,
cudaStream_t stream,
)#

Launches CUDA kernel for sequential sum computation

Parameters:
  • input[in] Input array on device

  • output[out] Output array on device

  • size[in] Number of elements in arrays

  • stream[in] CUDA stream for kernel execution

class CholeskyFactorInvPlugin : public ran::trt_plugin::TrtPluginBase<CholeskyFactorInvPlugin>#
#include <cholesky_factor_inv_trt_plugin.hpp>

TensorRT plugin for Cholesky decomposition and matrix inversion.

This plugin implements Cholesky decomposition followed by matrix inversion using NVIDIA’s cuSOLVERDx library for high-performance execution during TensorRT inference.

The plugin computes: A^{-1} where A is a positive definite matrix Method: Cholesky decomposition A = L*L^H, then solve L*L^H*X = I for X

Public Functions

explicit CholeskyFactorInvPlugin(
std::string_view name,
std::string_view name_space = "",
std::int32_t matrix_size = DEFAULT_MATRIX_SIZE,
bool is_complex = false,
)#

Constructor with explicit matrix size and complex flag.

Creates a plugin instance with a specific matrix size and data type. This constructor is primarily used by the clone() method to correctly initialize new instances.

Parameters:
  • name[in] Plugin name identifier

  • name_space[in] Plugin namespace (defaults to empty)

  • matrix_size[in] Size of the square matrix (N for NxN matrix)

  • is_complex[in] Whether data is complex (true) or real (false)

~CholeskyFactorInvPlugin() override = default#

Virtual destructor.

CholeskyFactorInvPlugin(const CholeskyFactorInvPlugin&) = delete#
CholeskyFactorInvPlugin &operator=(
const CholeskyFactorInvPlugin&,
) = delete#
CholeskyFactorInvPlugin(CholeskyFactorInvPlugin&&) = delete#
CholeskyFactorInvPlugin &operator=(
CholeskyFactorInvPlugin&&,
) = delete#
nvinfer1::IPluginV3 *clone() noexcept override#

Creates a deep copy of the plugin instance.

Returns a new plugin instance with identical configuration and state. This is used by TensorRT for plugin cloning and resource management.

Note

The returned plugin is owned by the caller and must be deleted

Returns:

Pointer to a new plugin instance

std::int32_t getNbOutputs() const noexcept override#

Returns the number of output tensors.

Returns:

Number of output tensors produced by the plugin

std::int32_t getOutputDataTypes(
nvinfer1::DataType *output_types,
std::int32_t nb_outputs,
nvinfer1::DataType const *input_types,
std::int32_t nb_inputs,
) const noexcept override#

Determines the output data types based on input types.

Sets the output tensor data types. For this plugin, outputs use float type for inverted matrices.

Parameters:
  • output_types[out] Array to store output data types

  • nb_outputs[in] Number of output tensors

  • input_types[in] Array of input data types

  • nb_inputs[in] Number of input tensors

Returns:

0 on success, non-zero on failure

std::int32_t getOutputShapes(
nvinfer1::DimsExprs const *inputs,
std::int32_t nb_inputs,
nvinfer1::DimsExprs const *shape_inputs,
std::int32_t nb_shape_inputs,
nvinfer1::DimsExprs *outputs,
std::int32_t nb_outputs,
nvinfer1::IExprBuilder &expr_builder,
) noexcept override#

Computes output tensor shapes based on input shapes (batched)

Determines the output tensor dimensions using the expression builder for dynamic shape support. For batched processing, the output shape is [batch_size, n_prb, n_ant, n_ant] where batch_size and n_prb come from the input.

Note

Input shape: [batch_size, n_prb, n_ant, n_ant] - covariance matrices

Note

Output shape: [batch_size, n_prb, n_ant, n_ant] - inverted matrices

Parameters:
  • inputs[in] Array of input tensor shapes

  • nb_inputs[in] Number of input tensors

  • shape_inputs[in] Array of shape input tensors (unused)

  • nb_shape_inputs[in] Number of shape input tensors

  • outputs[out] Array to store output tensor shapes

  • nb_outputs[in] Number of output tensors

  • expr_builder[in] Expression builder for dynamic shape computation

Returns:

0 on success, non-zero on failure

bool supportsFormatCombination(
std::int32_t pos,
nvinfer1::DynamicPluginTensorDesc const *in_out,
std::int32_t nb_inputs,
std::int32_t nb_outputs,
) noexcept override#

Checks if the plugin supports the given format combination.

Validates that the input/output format combination is supported by the plugin implementation.

Parameters:
  • pos[in] Position in the input/output array to check

  • in_out[in] Array of input/output tensor descriptions

  • nb_inputs[in] Number of input tensors

  • nb_outputs[in] Number of output tensors

Returns:

true if the format combination is supported, false otherwise

std::int32_t enqueue(
nvinfer1::PluginTensorDesc const *input_desc,
nvinfer1::PluginTensorDesc const *output_desc,
void const *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream,
) noexcept override#

Executes the cuSOLVERDx Cholesky inversion kernel (batched)

This is the main execution method called by TensorRT during inference. It launches a CUDA kernel that performs Cholesky decomposition followed by matrix inversion using cuSOLVERDx for multiple matrices in parallel.

The kernel uses cuSOLVERDx’s POTRF and TRSM implementations:

  1. Takes covariance matrices of shape [batch_size, n_prb, n_ant, n_ant]

  2. Performs Cholesky decomposition: A = L*L^H

  3. Solves L*L^H*X = I to compute X = A^{-1}

  4. Outputs inverted matrices in the same shape

See also

launch_cholesky_factor_inv_kernel for the underlying CUDA implementation

Note

Input shape: [batch_size, n_prb, n_ant, n_ant] - covariance matrices

Note

Output shape: [batch_size, n_prb, n_ant, n_ant] - inverted matrices

Note

All data pointers must be valid GPU memory addresses

Parameters:
  • input_desc[in] Array of input tensor descriptions

  • output_desc[in] Array of output tensor descriptions

  • inputs[in] Array of input data pointers (GPU memory)

  • outputs[out] Array of output data pointers (GPU memory)

  • workspace[in] Workspace memory for cuSOLVERDx computation

  • stream[in] CUDA stream for asynchronous execution

Returns:

0 on success, -1 on failure

nvinfer1::PluginFieldCollection const *getFieldsToSerialize(
) noexcept override#

Returns the fields that should be serialized.

Specifies which plugin parameters should be saved when serializing the model. This plugin has no serializable fields.

Returns:

Pointer to an empty field collection

Public Static Attributes

static constexpr const char *PLUGIN_TYPE = "CholeskyFactorInv"#

Plugin type identifier.

static constexpr const char *PLUGIN_VERSION = "1"#

Plugin version string.

static constexpr std::int32_t DEFAULT_MATRIX_SIZE = 2#

Default matrix size.

class CholeskyFactorInvPluginCreator : public ran::trt_plugin::TrtPluginCreatorBase<CholeskyFactorInvPlugin>#
#include <cholesky_factor_inv_trt_plugin.hpp>

Plugin Creator for CholeskyFactorInvPlugin.

This class handles the creation and configuration of CholeskyFactorInvPlugin instances. It implements the TensorRT plugin creator interface and manages plugin field collection for parameter configuration.

The creator extracts matrix size parameters from the field collection and ensures proper plugin initialization with the correct configuration.

See also

CholeskyFactorInvPlugin for the main plugin implementation

See also

IPluginCreatorV3One for the base creator interface

Public Functions

explicit CholeskyFactorInvPluginCreator(std::string_view name_space)#

Constructor with required namespace.

Initializes the creator with the specified namespace and prepares the plugin field collection.

Parameters:

name_space[in] Plugin namespace

~CholeskyFactorInvPluginCreator() override = default#

Virtual destructor.

CholeskyFactorInvPluginCreator(
const CholeskyFactorInvPluginCreator&,
) = delete#
CholeskyFactorInvPluginCreator &operator=(
const CholeskyFactorInvPluginCreator&,
) = delete#
CholeskyFactorInvPluginCreator(
CholeskyFactorInvPluginCreator&&,
) = delete#
CholeskyFactorInvPluginCreator &operator=(
CholeskyFactorInvPluginCreator&&,
) = delete#
nvinfer1::IPluginV3 *createPlugin(
nvinfer1::AsciiChar const *name,
nvinfer1::PluginFieldCollection const *fc,
nvinfer1::TensorRTPhase phase,
) noexcept override#

Creates a new plugin instance.

Instantiates a CholeskyFactorInvPlugin with the specified name and configuration from the field collection.

Note

The returned plugin is owned by the caller and must be deleted

Parameters:
  • name[in] Name for the new plugin instance

  • fc[in] Field collection containing configuration parameters

  • phase[in] TensorRT phase (build or runtime)

Returns:

Pointer to the created plugin instance

class DMRSTrtPlugin : public ran::trt_plugin::TrtPluginBase<DMRSTrtPlugin>#
#include <dmrs_trt_plugin.hpp>

TensorRT plugin for DMRS generation.

This plugin implements the 3GPP DMRS generation algorithm using CUDA kernels for TensorRT inference. It takes slot number and DMRS ID as inputs and produces two outputs: complex DMRS values and binary gold sequences for all OFDM symbols and ports.

Input: [2] containing [slot_number, n_dmrs_id] Output 0: (2, n_t, 2, sequence_length/2) - Complex DMRS values Output 1: (n_t, 2, sequence_length) - Binary gold sequence

Public Functions

explicit DMRSTrtPlugin(
std::string_view name,
std::string_view name_space = "",
std::int32_t sequence_length = DEFAULT_SEQUENCE_LENGTH,
std::int32_t n_t = DEFAULT_N_T,
)#

Constructor with explicit sequence length and n_t.

Creates a plugin instance with a specific sequence length and number of OFDM symbols. This constructor is primarily used by the clone() method to correctly initialize new instances.

Parameters:
  • name[in] Plugin name identifier

  • name_space[in] Plugin namespace (defaults to empty)

  • sequence_length[in] Length of the DMRS sequence to generate

  • n_t[in] Number of OFDM symbols per slot

~DMRSTrtPlugin() override = default#

Virtual destructor.

DMRSTrtPlugin(const DMRSTrtPlugin&) = delete#
DMRSTrtPlugin &operator=(const DMRSTrtPlugin&) = delete#
DMRSTrtPlugin(DMRSTrtPlugin&&) = delete#
DMRSTrtPlugin &operator=(DMRSTrtPlugin&&) = delete#
nvinfer1::IPluginV3 *clone() noexcept override#

Creates a deep copy of the plugin instance.

Returns a new plugin instance with identical configuration and state. This is used by TensorRT for plugin cloning and resource management.

Note

The returned plugin is owned by the caller and must be deleted

Returns:

Pointer to a new plugin instance

std::int32_t getNbOutputs() const noexcept override#

Returns the number of output tensors.

This plugin produces two outputs: complex DMRS values and binary gold sequence.

Returns:

Number of output tensors produced by the plugin (always 2)

std::int32_t getOutputDataTypes(
nvinfer1::DataType *output_types,
std::int32_t nb_outputs,
nvinfer1::DataType const *input_types,
std::int32_t nb_inputs,
) const noexcept override#

Determines the output data types based on input types.

Sets the output tensor data types. This plugin produces two outputs: Output 0 is FLOAT32 (complex DMRS values), Output 1 is INT32 (binary sequence).

Parameters:
  • output_types[out] Array to store output data types

  • nb_outputs[in] Number of output tensors

  • input_types[in] Array of input data types

  • nb_inputs[in] Number of input tensors

Returns:

0 on success, non-zero on failure

std::int32_t getOutputShapes(
nvinfer1::DimsExprs const *inputs,
std::int32_t nb_inputs,
nvinfer1::DimsExprs const *shape_inputs,
std::int32_t nb_shape_inputs,
nvinfer1::DimsExprs *outputs,
std::int32_t nb_outputs,
nvinfer1::IExprBuilder &expr_builder,
) noexcept override#

Computes output tensor shapes based on input shapes.

Determines the output tensor dimensions using the expression builder for dynamic shape support. This plugin produces two output tensors from scalar input parameters.

Note

Input shape: [2] - contains [slot_number, n_dmrs_id] parameters

Note

Output 0 shape: (2, n_t, 2, sequence_length/2) - Complex DMRS values [real/imag, symbols, ports, subcarriers]

Note

Output 1 shape: (n_t, 2, sequence_length) - Binary gold sequence [symbols, ports, subcarriers]

Parameters:
  • inputs[in] Array of input tensor shapes

  • nb_inputs[in] Number of input tensors

  • shape_inputs[in] Array of shape input tensors (unused)

  • nb_shape_inputs[in] Number of shape input tensors

  • outputs[out] Array to store output tensor shapes

  • nb_outputs[in] Number of output tensors

  • expr_builder[in] Expression builder for dynamic shape computation

Returns:

0 on success, non-zero on failure

bool supportsFormatCombination(
std::int32_t pos,
nvinfer1::DynamicPluginTensorDesc const *in_out,
std::int32_t nb_inputs,
std::int32_t nb_outputs,
) noexcept override#

Checks if the plugin supports the given format combination.

Validates that the input/output format combination is supported by the plugin implementation.

Parameters:
  • pos[in] Position in the input/output array to check

  • in_out[in] Array of input/output tensor descriptions

  • nb_inputs[in] Number of input tensors

  • nb_outputs[in] Number of output tensors

Returns:

true if the format combination is supported, false otherwise

std::int32_t enqueue(
nvinfer1::PluginTensorDesc const *input_desc,
nvinfer1::PluginTensorDesc const *output_desc,
void const *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream,
) noexcept override#

Executes the DMRS sequence generation kernel.

This is the main execution method called by TensorRT during inference. It launches a CUDA kernel that generates pseudo-random DMRS sequences for all n_t OFDM symbols and both n_scid ports (0, 1) based on the input slot number and DMRS ID parameters.

The kernel implements the 3GPP DMRS sequence algorithm:

  1. Generates two M-sequences (x1 and x2) using linear feedback shift registers based on the computed c_init value

  2. Combines them using modulo-2 addition with a 1600-bit offset

  3. Outputs complex DMRS values and binary gold sequences for all symbols and ports

See also

launch_dmrs_kernel for the underlying CUDA implementation

Note

Input shape: [2] - contains [slot_number, n_dmrs_id] parameters

Note

Output 0 shape: (2, n_t, 2, sequence_length/2) - Complex DMRS values

Note

Output 1 shape: (n_t, 2, sequence_length) - Binary gold sequence

Note

All data pointers must be valid GPU memory addresses

Parameters:
  • input_desc[in] Array of input tensor descriptions

  • output_desc[in] Array of output tensor descriptions

  • inputs[in] Array of input data pointers (GPU memory)

  • outputs[out] Array of output data pointers (GPU memory)

  • workspace[in] Workspace memory (unused)

  • stream[in] CUDA stream for asynchronous execution

Returns:

0 on success, -1 on failure

nvinfer1::PluginFieldCollection const *getFieldsToSerialize(
) noexcept override#

Returns the fields that should be serialized.

Specifies which plugin parameters should be saved when serializing the model. This plugin serializes:

  • sequence_length: Length of DMRS sequences to generate

  • n_t: Number of OFDM symbols per slot

Returns:

Pointer to field collection containing sequence_length and n_t

Public Static Attributes

static constexpr const char *PLUGIN_TYPE = "DmrsTrt"#

Plugin type identifier.

static constexpr const char *PLUGIN_VERSION = "1"#

Plugin version string.

static constexpr std::int32_t DEFAULT_SEQUENCE_LENGTH = 42#

Default DMRS sequence length.

static constexpr std::int32_t DEFAULT_N_T = 14#

Default OFDM symbols per slot.

class DMRSTrtPluginCreator : public ran::trt_plugin::TrtPluginCreatorBase<DMRSTrtPlugin>#
#include <dmrs_trt_plugin.hpp>

Plugin Creator for DMRSTrtPlugin.

This class handles the creation and configuration of DMRSTrtPlugin instances. It implements the TensorRT plugin creator interface and manages plugin field collection for parameter configuration.

The creator extracts sequence length parameters from the field collection and ensures proper plugin initialization with the correct configuration.

See also

DMRSTrtPlugin for the main plugin implementation

See also

TrtPluginCreatorBase for the base creator interface

Public Functions

explicit DMRSTrtPluginCreator(std::string_view name_space)#

Constructor with required namespace.

Initializes the creator with the specified namespace and prepares the plugin field collection.

Parameters:

name_space[in] Plugin namespace

~DMRSTrtPluginCreator() override = default#

Virtual destructor.

DMRSTrtPluginCreator(const DMRSTrtPluginCreator&) = delete#
DMRSTrtPluginCreator &operator=(const DMRSTrtPluginCreator&) = delete#
DMRSTrtPluginCreator(DMRSTrtPluginCreator&&) = delete#
DMRSTrtPluginCreator &operator=(DMRSTrtPluginCreator&&) = delete#
nvinfer1::IPluginV3 *createPlugin(
nvinfer1::AsciiChar const *name,
nvinfer1::PluginFieldCollection const *fc,
nvinfer1::TensorRTPhase phase,
) noexcept override#

Creates a new plugin instance.

Instantiates a DMRSTrtPlugin with the specified name and configuration from the field collection.

Note

The returned plugin is owned by the caller and must be deleted

Parameters:
  • name[in] Name for the new plugin instance

  • fc[in] Field collection containing configuration parameters

  • phase[in] TensorRT phase (build or runtime)

Returns:

Pointer to the created plugin instance

class FftTrtPlugin : public ran::trt_plugin::TrtPluginBase<FftTrtPlugin>#
#include <fft_trt_plugin.hpp>

TensorRT plugin for FFT computation (batched)

This plugin implements FFT computation using NVIDIA’s cuFFTDx library for high-performance execution during TensorRT inference.

Public Functions

explicit FftTrtPlugin(
std::string_view name,
std::string_view name_space = "",
const FftTrtPluginParams &params = {},
)#

Constructor with FFT configuration parameters.

Creates a plugin instance with specified FFT configuration. Use designated initializers to override default values.

// All defaults
auto plugin = new FftTrtPlugin("fft1");

// Override FFT size only
auto plugin = new FftTrtPlugin("fft2", "", FftTrtPluginParams{.fft_size = 256});

// Override size and direction
auto plugin = new FftTrtPlugin("fft3", "", FftTrtPluginParams{
    .fft_size = 512,
    .direction = "inverse"
});

Parameters:
  • name[in] Plugin name identifier

  • name_space[in] Plugin namespace (defaults to empty)

  • params[in] FFT configuration parameters

~FftTrtPlugin() override = default#

Virtual destructor.

FftTrtPlugin(const FftTrtPlugin&) = delete#
FftTrtPlugin &operator=(const FftTrtPlugin&) = delete#
FftTrtPlugin(FftTrtPlugin&&) = delete#
FftTrtPlugin &operator=(FftTrtPlugin&&) = delete#
nvinfer1::IPluginV3 *clone() noexcept override#

Creates a deep copy of the plugin instance.

Returns a new plugin instance with identical configuration and state. This is used by TensorRT for plugin cloning and resource management.

Note

The returned plugin is owned by the caller and must be deleted

Returns:

Pointer to a new plugin instance

std::int32_t getNbOutputs() const noexcept override#

Returns the number of output tensors.

Returns:

Number of output tensors produced by the plugin

std::int32_t getOutputDataTypes(
nvinfer1::DataType *output_types,
std::int32_t nb_outputs,
nvinfer1::DataType const *input_types,
std::int32_t nb_inputs,
) const noexcept override#

Determines the output data types based on input types.

Sets the output tensor data types. For this plugin, outputs use complex float type for FFT results.

Parameters:
  • output_types[out] Array to store output data types

  • nb_outputs[in] Number of output tensors

  • input_types[in] Array of input data types

  • nb_inputs[in] Number of input tensors

Returns:

0 on success, non-zero on failure

std::int32_t getOutputShapes(
nvinfer1::DimsExprs const *inputs,
std::int32_t nb_inputs,
nvinfer1::DimsExprs const *shape_inputs,
std::int32_t nb_shape_inputs,
nvinfer1::DimsExprs *outputs,
std::int32_t nb_outputs,
nvinfer1::IExprBuilder &expr_builder,
) noexcept override#

Computes output tensor shapes based on input shapes (batched)

Determines the output tensor dimensions using the expression builder for dynamic shape support. For batched processing, the output shape is [batch_size, fft_size] where batch_size comes from the input.

Note

Input shape: [batch_size, fft_size] - complex input data

Note

Output shape: [batch_size, fft_size] - complex FFT results

Parameters:
  • inputs[in] Array of input tensor shapes

  • nb_inputs[in] Number of input tensors

  • shape_inputs[in] Array of shape input tensors (unused)

  • nb_shape_inputs[in] Number of shape input tensors

  • outputs[out] Array to store output tensor shapes

  • nb_outputs[in] Number of output tensors

  • expr_builder[in] Expression builder for dynamic shape computation

Returns:

0 on success, non-zero on failure

bool supportsFormatCombination(
std::int32_t pos,
nvinfer1::DynamicPluginTensorDesc const *in_out,
std::int32_t nb_inputs,
std::int32_t nb_outputs,
) noexcept override#

Checks if the plugin supports the given format combination.

Validates that the input/output format combination is supported by the plugin implementation.

Parameters:
  • pos[in] Position in the input/output array to check

  • in_out[in] Array of input/output tensor descriptions

  • nb_inputs[in] Number of input tensors

  • nb_outputs[in] Number of output tensors

Returns:

true if the format combination is supported, false otherwise

std::int32_t enqueue(
nvinfer1::PluginTensorDesc const *input_desc,
nvinfer1::PluginTensorDesc const *output_desc,
void const *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream,
) noexcept override#

Executes the MatX FFT computation kernel (batched)

This is the main execution method called by TensorRT during inference. It launches a CUDA kernel that performs FFT computation using MatX for multiple input signals in parallel.

The kernel uses MatX’s FFT implementation:

  1. Takes complex input data of shape [batch_size, fft_size]

  2. Performs FFT computation using MatX’s optimized FFT routines

  3. Outputs complex FFT results in the same shape

See also

launch_cufft_fft_kernel for the underlying CUDA implementation

Note

Input shape: [batch_size, fft_size] - complex input data

Note

Output shape: [batch_size, fft_size] - complex FFT results

Note

All data pointers must be valid GPU memory addresses

Parameters:
  • input_desc[in] Array of input tensor descriptions

  • output_desc[in] Array of output tensor descriptions

  • inputs[in] Array of input data pointers (GPU memory)

  • outputs[out] Array of output data pointers (GPU memory)

  • workspace[in] Workspace memory for FFT computation

  • stream[in] CUDA stream for asynchronous execution

Returns:

0 on success, -1 on failure

nvinfer1::PluginFieldCollection const *getFieldsToSerialize(
) noexcept override#

Returns the fields that should be serialized.

Specifies which plugin parameters should be saved when serializing the model. Serializes FFT size and direction.

Returns:

Pointer to field collection containing m_fft_size_ and m_direction_

Public Static Attributes

static constexpr const char *PLUGIN_TYPE = "FftTrt"#

Plugin type identifier.

static constexpr const char *PLUGIN_VERSION = "1"#

Plugin version string.

class FftTrtPluginCreator : public ran::trt_plugin::TrtPluginCreatorBase<FftTrtPlugin>#
#include <fft_trt_plugin.hpp>

Plugin Creator for FftTrtPlugin.

This class handles the creation and configuration of FftTrtPlugin instances. It implements the TensorRT plugin creator interface and manages plugin field collection for parameter configuration.

The creator extracts FFT size parameters from the field collection and ensures proper plugin initialization with the correct configuration.

See also

FftTrtPlugin for the main plugin implementation

See also

IPluginCreatorV3One for the base creator interface

Public Functions

explicit FftTrtPluginCreator(std::string_view name_space)#

Constructor with required namespace.

Initializes the creator with the specified namespace and prepares the plugin field collection.

Parameters:

name_space[in] Plugin namespace

~FftTrtPluginCreator() override = default#

Virtual destructor.

FftTrtPluginCreator(const FftTrtPluginCreator&) = delete#
FftTrtPluginCreator &operator=(const FftTrtPluginCreator&) = delete#
FftTrtPluginCreator(FftTrtPluginCreator&&) = delete#
FftTrtPluginCreator &operator=(FftTrtPluginCreator&&) = delete#
nvinfer1::IPluginV3 *createPlugin(
nvinfer1::AsciiChar const *name,
nvinfer1::PluginFieldCollection const *fc,
nvinfer1::TensorRTPhase phase,
) noexcept override#

Creates a new plugin instance.

Instantiates a FftTrtPlugin with the specified name and configuration from the field collection.

Note

The returned plugin is owned by the caller and must be deleted

Parameters:
  • name[in] Name for the new plugin instance

  • fc[in] Field collection containing configuration parameters

  • phase[in] TensorRT phase (build or runtime)

Returns:

Pointer to the created plugin instance

struct FftTrtPluginParams#
#include <fft_trt_plugin.hpp>

FFT plugin configuration parameters.

Configuration struct for FftTrtPlugin initialization. Use designated initializers to specify non-default values.

Public Members

std::int32_t fft_size = {DEFAULT_FFT_SIZE}#

Size of the FFT to compute.

std::string precision = {"float"}#

FFT precision (float, double)

std::string fft_type = {"c2c"}#

FFT type (c2c, r2c, c2r)

std::string direction = {"forward"}#

FFT direction (forward, inverse)

std::int32_t ffts_per_block = {DEFAULT_FFTS_PER_BLOCK}#

Number of FFTs per CUDA block.

std::int32_t elements_per_thread = {DEFAULT_ELEMENTS_PER_THREAD}#

Number of elements per thread.

Public Static Attributes

static constexpr std::int32_t DEFAULT_FFT_SIZE = 128#

Default FFT size.

static constexpr std::int32_t DEFAULT_FFTS_PER_BLOCK = 1#

Default FFTs per block.

static constexpr std::int32_t DEFAULT_ELEMENTS_PER_THREAD = 8#

Default elements per thread.

class SequentialSumPlugin : public ran::trt_plugin::TrtPluginBase<SequentialSumPlugin>#
#include <sequential_sum_plugin.hpp>

Sequential Sum Plugin - deliberately non-parallelizable operation

This plugin computes a sequential sum where each element depends on the previous: output[i] = input[i] + output[i-1]

This is intentionally designed to be difficult to parallelize to demonstrate a custom kernel that doesn’t scale well with parallelization.

Public Functions

explicit SequentialSumPlugin(
std::string_view name,
std::string_view name_space = "",
)#

Constructor for plugin creation

Parameters:
  • name[in] Plugin instance name

  • name_space[in] Plugin namespace (defaults to empty)

~SequentialSumPlugin() override = default#

Destructor

SequentialSumPlugin(const SequentialSumPlugin&) = delete#
SequentialSumPlugin &operator=(const SequentialSumPlugin&) = delete#
SequentialSumPlugin(SequentialSumPlugin&&) = delete#
SequentialSumPlugin &operator=(SequentialSumPlugin&&) = delete#
nvinfer1::IPluginV3 *clone() noexcept override#

Creates a copy of the plugin

Returns:

New plugin instance

std::int32_t getOutputDataTypes(
nvinfer1::DataType *output_types,
std::int32_t nb_outputs,
nvinfer1::DataType const *input_types,
std::int32_t nb_inputs,
) const noexcept override#

Determines output data types based on input types

Parameters:
  • output_types[out] Array to store output data types

  • nb_outputs[in] Number of outputs

  • input_types[in] Array of input data types

  • nb_inputs[in] Number of inputs

Returns:

0 on success

std::int32_t getOutputShapes(
nvinfer1::DimsExprs const *inputs,
std::int32_t nb_inputs,
nvinfer1::DimsExprs const *shape_inputs,
std::int32_t nb_shape_inputs,
nvinfer1::DimsExprs *outputs,
std::int32_t nb_outputs,
nvinfer1::IExprBuilder &expr_builder,
) noexcept override#

Computes output shapes based on input shapes

Parameters:
  • inputs[in] Array of input tensor shapes

  • nb_inputs[in] Number of inputs

  • shape_inputs[in] Array of shape input tensors

  • nb_shape_inputs[in] Number of shape inputs

  • outputs[out] Array to store output shapes

  • nb_outputs[in] Number of outputs

  • expr_builder[inout] Expression builder for shape calculations

Returns:

0 on success

bool supportsFormatCombination(
std::int32_t pos,
nvinfer1::DynamicPluginTensorDesc const *in_out,
std::int32_t nb_inputs,
std::int32_t nb_outputs,
) noexcept override#

Checks if format combination is supported

Parameters:
  • pos[in] Position in input/output array to check

  • in_out[in] Array of input and output tensor descriptors

  • nb_inputs[in] Number of inputs

  • nb_outputs[in] Number of outputs

Returns:

true if format combination is supported

std::int32_t getNbOutputs() const noexcept override#

Returns the number of outputs

Returns:

Number of output tensors

std::int32_t enqueue(
nvinfer1::PluginTensorDesc const *input_desc,
nvinfer1::PluginTensorDesc const *output_desc,
void const *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream,
) noexcept override#

Executes the plugin operation

Parameters:
  • input_desc[in] Array of input tensor descriptors

  • output_desc[in] Array of output tensor descriptors

  • inputs[in] Array of input device buffers

  • outputs[out] Array of output device buffers

  • workspace[inout] Workspace memory pointer

  • stream[in] CUDA stream for kernel execution

Returns:

0 on success

nvinfer1::PluginFieldCollection const *getFieldsToSerialize(
) noexcept override#

Returns fields to be serialized

Returns:

Pointer to field collection for serialization

Public Static Attributes

static constexpr const char *PLUGIN_TYPE = "SequentialSum"#

Plugin type identifier.

static constexpr const char *PLUGIN_VERSION = "1"#

Plugin version string.

class SequentialSumPluginCreator : public ran::trt_plugin::TrtPluginCreatorBase<SequentialSumPlugin>#
#include <sequential_sum_plugin.hpp>

Plugin Creator for SequentialSumPlugin

Public Functions

explicit SequentialSumPluginCreator(std::string_view name_space)#

Constructor with required namespace.

Parameters:

name_space[in] Plugin namespace

~SequentialSumPluginCreator() override = default#
SequentialSumPluginCreator(const SequentialSumPluginCreator&) = delete#
SequentialSumPluginCreator &operator=(
const SequentialSumPluginCreator&,
) = delete#
SequentialSumPluginCreator(SequentialSumPluginCreator&&) = delete#
SequentialSumPluginCreator &operator=(
SequentialSumPluginCreator&&,
) = delete#
nvinfer1::IPluginV3 *createPlugin(
nvinfer1::AsciiChar const *name,
nvinfer1::PluginFieldCollection const *fc,
nvinfer1::TensorRTPhase phase,
) noexcept override#

Creates a new plugin instance

Parameters:
  • name[in] Plugin instance name

  • fc[in] Field collection containing plugin parameters

  • phase[in] TensorRT phase (build or runtime)

Returns:

New plugin instance or nullptr on failure

template<typename Derived>
class TrtPluginBase : public nvinfer1::IPluginV3, public nvinfer1::IPluginV3OneCore, public nvinfer1::IPluginV3OneBuild, public nvinfer1::IPluginV3OneRuntime#
#include <trt_plugin_base.hpp>

CRTP base class for TensorRT V3 plugins

Provides common implementations for IPluginV3 interface methods that are identical across all plugins. Uses CRTP pattern for zero runtime overhead.

Derived class requirements:

  • Must inherit from TrtPluginBase<DerivedClass>

  • Must define static constexpr members: PLUGIN_TYPE and PLUGIN_VERSION

  • Must implement plugin-specific methods (enqueue, getOutputShapes, etc.)

Template Parameters:

Derived – The derived plugin class (CRTP pattern)

Public Functions

TrtPluginBase(const TrtPluginBase&) = delete#
TrtPluginBase &operator=(const TrtPluginBase&) = delete#
TrtPluginBase(TrtPluginBase&&) = delete#
TrtPluginBase &operator=(TrtPluginBase&&) = delete#
~TrtPluginBase() override = default#

Virtual destructor for proper cleanup

inline nvinfer1::IPluginCapability *getCapabilityInterface(
nvinfer1::PluginCapabilityType type,
) noexcept override#

Returns capability interface for the requested type

This implementation is identical for all plugins and routes to the appropriate interface based on the capability type.

Parameters:

type[in] The capability type being requested

Returns:

Pointer to the capability interface, or nullptr if unsupported

inline nvinfer1::AsciiChar const *getPluginName(
) const noexcept override#

Returns plugin type name

Uses CRTP to access the derived class’s static PLUGIN_TYPE member.

Returns:

C-string containing the plugin type name

inline nvinfer1::AsciiChar const *getPluginVersion(
) const noexcept override#

Returns plugin version string

Uses CRTP to access the derived class’s static PLUGIN_VERSION member.

Returns:

C-string containing the plugin version

inline nvinfer1::AsciiChar const *getPluginNamespace(
) const noexcept override#

Returns plugin namespace

Returns:

C-string containing the plugin namespace

inline int32_t configurePlugin(
nvinfer1::DynamicPluginTensorDesc const *in,
int32_t nb_inputs,
nvinfer1::DynamicPluginTensorDesc const *out,
int32_t nb_outputs,
) noexcept override#

Configures the plugin for the given input/output configuration

Default implementation performs no configuration. Derived classes can override if they need custom configuration logic.

Parameters:
  • in[in] Array of input tensor descriptions

  • nb_inputs[in] Number of input tensors

  • out[in] Array of output tensor descriptions

  • nb_outputs[in] Number of output tensors

Returns:

0 on success

inline size_t getWorkspaceSize(
nvinfer1::DynamicPluginTensorDesc const *inputs,
int32_t nb_inputs,
nvinfer1::DynamicPluginTensorDesc const *outputs,
int32_t nb_outputs,
) const noexcept override#

Returns the workspace size required by the plugin

Default implementation returns 0 (no workspace needed). Derived classes can override if they require workspace memory.

Parameters:
  • inputs[in] Array of input tensor descriptions

  • nb_inputs[in] Number of input tensors

  • outputs[in] Array of output tensor descriptions

  • nb_outputs[in] Number of output tensors

Returns:

Required workspace size in bytes (default: 0)

inline int32_t onShapeChange(
nvinfer1::PluginTensorDesc const *in,
int32_t nb_inputs,
nvinfer1::PluginTensorDesc const *out,
int32_t nb_outputs,
) noexcept override#

Handles dynamic shape changes during runtime

Default implementation performs no special handling. Derived classes can override if they need custom shape change logic.

Parameters:
  • in[in] Array of input tensor descriptions

  • nb_inputs[in] Number of input tensors

  • out[in] Array of output tensor descriptions

  • nb_outputs[in] Number of output tensors

Returns:

0 on success

inline nvinfer1::IPluginV3 *attachToContext(
nvinfer1::IPluginResourceContext *context,
) noexcept override#

Attaches the plugin to a resource context

Default implementation creates a clone of the plugin for the new context. Derived classes can override if they need custom context handling.

Parameters:

context[in] Resource context provided by TensorRT

Returns:

Pointer to a new plugin instance for the context

template<typename PluginType>
class TrtPluginCreatorBase : public nvinfer1::IPluginCreatorV3One#
#include <trt_plugin_creator_base.hpp>

CRTP base class for TensorRT V3 plugin creators

Provides common implementations for IPluginCreatorV3One interface methods that are identical across all plugin creators.

Derived class requirements:

  • Must inherit from TrtPluginCreatorBase<PluginType>

  • PluginType must have static members: PLUGIN_TYPE and PLUGIN_VERSION

  • Must implement createPlugin() method

Template Parameters:

PluginType – The plugin class this creator creates

Public Functions

inline explicit TrtPluginCreatorBase(
const std::string_view name_space,
)#

Constructor with required namespace

Parameters:

name_space[in] Plugin namespace

~TrtPluginCreatorBase() override = default#

Virtual destructor for proper cleanup

TrtPluginCreatorBase(const TrtPluginCreatorBase&) = delete#
TrtPluginCreatorBase &operator=(const TrtPluginCreatorBase&) = delete#
TrtPluginCreatorBase(TrtPluginCreatorBase&&) = delete#
TrtPluginCreatorBase &operator=(TrtPluginCreatorBase&&) = delete#
inline nvinfer1::AsciiChar const *getPluginName(
) const noexcept override#

Returns plugin type name

Uses the PluginType’s static PLUGIN_TYPE member.

Returns:

C-string containing the plugin type name

inline nvinfer1::AsciiChar const *getPluginVersion(
) const noexcept override#

Returns plugin version string

Uses the PluginType’s static PLUGIN_VERSION member.

Returns:

C-string containing the plugin version

inline nvinfer1::AsciiChar const *getPluginNamespace(
) const noexcept override#

Returns plugin namespace

Returns:

C-string containing the plugin namespace

inline nvinfer1::PluginFieldCollection const *getFieldNames(
) noexcept override#

Returns the field collection

Returns:

Pointer to the plugin field collection