L4T Multimedia API Reference

28.1 Release

 All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
trt_inference.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * * Redistributions of source code must retain the above copyright
8  * notice, this list of conditions and the following disclaimer.
9  * * Redistributions in binary form must reproduce the above copyright
10  * notice, this list of conditions and the following disclaimer in the
11  * documentation and/or other materials provided with the distribution.
12  * * Neither the name of NVIDIA CORPORATION nor the names of its
13  * contributors may be used to endorse or promote products derived
14  * from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17  * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24  * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  */
28 #ifndef TRT_INFERENCE_H_
29 #define TRT_INFERENCE_H_
30 
31 #include <fstream>
32 #include <queue>
33 #include "NvInfer.h"
34 #include "NvCaffeParser.h"
35 #include "opencv2/video/tracking.hpp"
36 #include "opencv2/imgproc/imgproc.hpp"
37 #include "opencv2/highgui/highgui.hpp"
38 #include <opencv2/objdetect/objdetect.hpp>
39 using namespace nvinfer1;
40 using namespace nvcaffeparser1;
41 using namespace std;
42 
43 // Model Index
44 #define GOOGLENET_SINGLE_CLASS 0
45 #define GOOGLENET_THREE_CLASS 1
46 #define HELNET_THREE_CLASS 2
47 
48 class Logger;
49 
50 class Profiler;
51 
53 {
54 public:
55  //net related parameter
56  int getNetWidth() const;
57 
58  int getNetHeight() const;
59 
60  uint32_t getBatchSize() const;
61 
62  int getChannel() const;
63 
64  int getModelClassCnt() const;
65 
66  // Buffer is allocated in TRT_Conxtex,
67  // Expose this interface for inputing data
68  void*& getBuffer(const int& index);
69 
70  float*& getInputBuf();
71 
72  uint32_t getNumTrtInstances() const;
73 
74  void setForcedFp32(const bool& forced_fp32);
75 
76  void setDumpResult(const bool& dump_result);
77 
78  void setTrtProfilerEnabled(const bool& enable_trt_profiler);
79 
80  int getFilterNum() const;
81  void setFilterNum(const unsigned int& filter_num);
82 
83  TRT_Context();
84 
85  void setModelIndex(int modelIndex);
86 
87  void buildTrtContext(const string& deployfile,
88  const string& modelfile, bool bUseCPUBuf = false);
89 
90  void doInference(
91  queue< vector<cv::Rect> >* rectList_queue,
92  float *input = NULL);
93 
94  void destroyTrtContext(bool bUseCPUBuf = false);
95 
96  ~TRT_Context();
97 
98 private:
99  int net_width;
100  int net_height;
101  int filter_num;
102  void **buffers;
103  float *input_buf;
104  float *output_cov_buf;
105  float *output_bbox_buf;
106  float helnet_scale[4];
107  IRuntime *runtime;
108  ICudaEngine *engine;
109  IExecutionContext *context;
110  uint32_t *pResultArray;
111  int channel; //input file's channel
112  int num_bindings;
113  int trtinstance_num; //inference channel num
114  int batch_size;
115  bool forced_fp32;
116  bool dump_result;
117  ofstream fstream;
118  bool enable_trt_profiler;
119  IHostMemory *trtModelStream{nullptr};
120  vector<string> outputs;
121  string result_file;
122  Logger *pLogger;
123  Profiler *pProfiler;
124  int frame_num;
125  uint64_t elapsed_frame_num;
126  uint64_t elapsed_time;
127  int inputIndex;
128  int outputIndex;
129  int outputIndexBBOX;
130  DimsCHW inputDims;
131  DimsCHW outputDims;
132  DimsCHW outputDimsBBOX;
133  size_t inputSize;
134  size_t outputSize;
135  size_t outputSizeBBOX;
136 
137  int parseNet(const string& deployfile);
138  void parseBbox(vector<cv::Rect>* rectList, int batch_th);
139  void allocateMemory(bool bUseCPUBuf);
140  void releaseMemory(bool bUseCPUBuf);
141  void caffeToTRTModel(const string& deployfile, const string& modelfile);
142 };
143 
144 #endif