L4T Multimedia API Reference

32.3.1 Release

 All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
trt_inference.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * * Redistributions of source code must retain the above copyright
8  * notice, this list of conditions and the following disclaimer.
9  * * Redistributions in binary form must reproduce the above copyright
10  * notice, this list of conditions and the following disclaimer in the
11  * documentation and/or other materials provided with the distribution.
12  * * Neither the name of NVIDIA CORPORATION nor the names of its
13  * contributors may be used to endorse or promote products derived
14  * from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17  * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24  * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  */
28 #ifndef TRT_INFERENCE_H_
29 #define TRT_INFERENCE_H_
30 
31 #include <fstream>
32 #include <queue>
33 #include "NvInfer.h"
34 #include "NvCaffeParser.h"
35 #include "opencv2/video/tracking.hpp"
36 #include "opencv2/imgproc/imgproc.hpp"
37 #include "opencv2/highgui/highgui.hpp"
38 #include <opencv2/objdetect/objdetect.hpp>
39 using namespace nvinfer1;
40 using namespace nvcaffeparser1;
41 using namespace std;
42 
43 // Model Index
44 #define GOOGLENET_SINGLE_CLASS 0
45 #define GOOGLENET_THREE_CLASS 1
46 #define RESNET_THREE_CLASS 2
47 
48 class Logger;
49 
50 class Profiler;
51 
53 {
54 public:
55  //net related parameter
56  int getNetWidth() const;
57 
58  int getNetHeight() const;
59 
60  uint32_t getBatchSize() const;
61 
62  int getChannel() const;
63 
64  int getModelClassCnt() const;
65 
66  void* getScales() const;
67 
68  void* getOffsets() const;
69 
70  // Buffer is allocated in TRT_Conxtex,
71  // Expose this interface for inputing data
72  void*& getBuffer(const int& index);
73 
74  float*& getInputBuf();
75 
76  uint32_t getNumTrtInstances() const;
77 
78  //0 fp16 1 fp32 2 int8
79  void setMode(const int& mode);
80 
81  void setBatchSize(const uint32_t& batchsize);
82 
83  void setDumpResult(const bool& dump_result);
84 
85  void setTrtProfilerEnabled(const bool& enable_trt_profiler);
86 
87  int getFilterNum() const;
88  void setFilterNum(const unsigned int& filter_num);
89 
90  TRT_Context();
91 
92  void setModelIndex(int modelIndex);
93 
94  void buildTrtContext(const string& deployfile,
95  const string& modelfile, bool bUseCPUBuf = false);
96 
97  void doInference(
98  queue< vector<cv::Rect> >* rectList_queue,
99  float *input = NULL);
100 
101  void destroyTrtContext(bool bUseCPUBuf = false);
102 
103  ~TRT_Context();
104 
105 private:
106  int net_width;
107  int net_height;
108  int filter_num;
109  void **buffers;
110  float *input_buf;
111  float *output_cov_buf;
112  float *output_bbox_buf;
113  void* offset_gpu;
114  void* scales_gpu;
115  float helnet_scale[4];
116  IRuntime *runtime;
117  ICudaEngine *engine;
118  IExecutionContext *context;
119  uint32_t *pResultArray;
120  int channel; //input file's channel
121  int num_bindings;
122  int trtinstance_num; //inference channel num
123  int batch_size;
124  int mode;
125  bool dump_result;
126  ofstream fstream;
127  bool enable_trt_profiler;
128  IHostMemory *trtModelStream{nullptr};
129  vector<string> outputs;
130  string result_file;
131  Logger *pLogger;
132  Profiler *pProfiler;
133  int frame_num;
134  uint64_t elapsed_frame_num;
135  uint64_t elapsed_time;
136  int inputIndex;
137  int outputIndex;
138  int outputIndexBBOX;
139  DimsCHW inputDims;
140  DimsCHW outputDims;
141  DimsCHW outputDimsBBOX;
142  size_t inputSize;
143  size_t outputSize;
144  size_t outputSizeBBOX;
145 
146  struct {
147  const int classCnt;
148  float THRESHOLD[3];
149  const char *INPUT_BLOB_NAME;
150  const char *OUTPUT_BLOB_NAME;
151  const char *OUTPUT_BBOX_NAME;
152  const int STRIDE;
153  const int WORKSPACE_SIZE;
154  int offsets[3];
155  float input_scale[3];
156  float bbox_output_scales[4];
157  const int ParseFunc_ID;
158  } *g_pModelNetAttr, gModelNetAttr[4] = {
159  {
160  // GOOGLENET_SINGLE_CLASS
161  1,
162  {0.8, 0, 0},
163  "data",
164  "coverage",
165  "bboxes",
166  4,
167  450 * 1024 * 1024,
168  {0, 0, 0},
169  {1.0f, 1.0f, 1.0f},
170  {1, 1, 1, 1},
171  0
172  },
173 
174  {
175  // GOOGLENET_THREE_CLASS
176  3,
177  {0.6, 0.6, 1.0}, //People, Motorbike, Car
178  "data",
179  "Layer16_cov",
180  "Layer16_bbox",
181  16,
182  110 * 1024 * 1024,
183  {124, 117, 104},
184  {1.0f, 1.0f, 1.0f},
185  {-640, -368, 640, 368},
186  0
187  },
188 
189  {
190  // RESNET_THREE_CLASS
191  4,
192  {0.1, 0.1, 0.1}, //People, Motorbike, Car
193  "data",
194  "Layer7_cov",
195  "Layer7_bbox",
196  16,
197  110 * 1024 * 1024,
198  {0, 0, 0},
199  {0.0039215697906911373, 0.0039215697906911373, 0.0039215697906911373},
200  {-640, -368, 640, 368},
201  1
202  },
203  };
204  enum Mode_type{
205  MODE_FP16 = 0,
206  MODE_FP32 = 1,
207  MODE_INT8 = 2
208  };
209  int parseNet(const string& deployfile);
210  void parseBbox(vector<cv::Rect>* rectList, int batch_th);
211  void ParseResnet10Bbox(vector<cv::Rect>* rectList, int batch_th);
212  void allocateMemory(bool bUseCPUBuf);
213  void releaseMemory(bool bUseCPUBuf);
214  void caffeToTRTModel(const string& deployfile, const string& modelfile);
215 };
216 
217 #endif
const int STRIDE
const char * INPUT_BLOB_NAME
const int ParseFunc_ID
const char * OUTPUT_BBOX_NAME
const char * OUTPUT_BLOB_NAME
const int WORKSPACE_SIZE
const int classCnt