28 #ifndef TRT_INFERENCE_H_
29 #define TRT_INFERENCE_H_
34 #include "NvCaffeParser.h"
35 #include "opencv2/video/tracking.hpp"
36 #include "opencv2/imgproc/imgproc.hpp"
37 #include "opencv2/highgui/highgui.hpp"
38 #include <opencv2/objdetect/objdetect.hpp>
39 using namespace nvinfer1;
40 using namespace nvcaffeparser1;
44 #define GOOGLENET_SINGLE_CLASS 0
45 #define GOOGLENET_THREE_CLASS 1
46 #define HELNET_THREE_CLASS 2
56 int getNetWidth()
const;
58 int getNetHeight()
const;
60 uint32_t getBatchSize()
const;
62 int getChannel()
const;
64 int getModelClassCnt()
const;
68 void*& getBuffer(
const int& index);
70 float*& getInputBuf();
72 uint32_t getNumTrtInstances()
const;
74 void setForcedFp32(
const bool& forced_fp32);
76 void setDumpResult(
const bool& dump_result);
78 void setTrtProfilerEnabled(
const bool& enable_trt_profiler);
80 int getFilterNum()
const;
81 void setFilterNum(
const unsigned int& filter_num);
85 void setModelIndex(
int modelIndex);
87 void buildTrtContext(
const string& deployfile,
88 const string& modelfile,
bool bUseCPUBuf =
false);
91 queue< vector<cv::Rect> >* rectList_queue,
94 void destroyTrtContext(
bool bUseCPUBuf =
false);
104 float *output_cov_buf;
105 float *output_bbox_buf;
106 float helnet_scale[4];
109 IExecutionContext *context;
110 uint32_t *pResultArray;
118 bool enable_trt_profiler;
119 IHostMemory *trtModelStream{
nullptr};
120 vector<string> outputs;
125 uint64_t elapsed_frame_num;
126 uint64_t elapsed_time;
132 DimsCHW outputDimsBBOX;
135 size_t outputSizeBBOX;
137 int parseNet(
const string& deployfile);
138 void parseBbox(vector<cv::Rect>* rectList,
int batch_th);
139 void allocateMemory(
bool bUseCPUBuf);
140 void releaseMemory(
bool bUseCPUBuf);
141 void caffeToTRTModel(
const string& deployfile,
const string& modelfile);