26 #ifndef NVNEURAL_SCRIPTENGINE_H
27 #define NVNEURAL_SCRIPTENGINE_H
40 class ScriptInterpretInterface
44 virtual ~ScriptInterpretInterface()
48 virtual void setReal(T& v,
const std::string& str) = 0;
49 virtual void setInteger(T& v,
const std::string& str) = 0;
50 virtual void setBufferPointer(T& v,
const std::string& str,
int argIndex) = 0;
52 virtual void neg(T& v) = 0;
53 virtual void add(T& v, T& op) = 0;
54 virtual void sub(T& v, T& op) = 0;
55 virtual void mul(T& v, T& op) = 0;
56 virtual void div(T& v, T& op) = 0;
58 virtual void setRealOption(T& v,
const std::string& str) = 0;
59 virtual void setIntegerOption(T& v,
const std::string& str) = 0;
61 virtual void transform(T& v,
const std::string& str) = 0;
63 virtual void setActType(T& v) = 0;
64 virtual void setActAlhpaLeft(T& v) = 0;
65 virtual void setActAlhpaRight(T& v) = 0;
67 virtual void setWeightsPointer(T& v,
const std::string& str) = 0;
69 virtual void setInputPointer(T& v,
int index) = 0;
70 virtual void setInputDimValue(T& v,
int index,
const std::string& str) = 0;
72 virtual void setOutputPointer(T& v,
const std::string& name) = 0;
73 virtual void setOutputDimValue(T& v,
const std::string& name,
const std::string& str) = 0;
76 struct ScriptEngineParameters
81 std::map<std::string, std::string> code;
83 std::string dimNCHWCalculation;
85 std::map<std::string, std::string> fwdEvalCall;
87 std::map<std::string, std::string> parameters;
88 std::vector<std::string> optionLegends;
89 std::vector<std::string> outputList;
92 std::map<std::string, std::string> weights;
93 std::vector<std::string> orderedWeightsNames;
96 class ScriptEngine :
public ScriptInterpretInterface<std::string>,
public ScriptInterpretInterface<AnyOperand>
99 explicit ScriptEngine(
const ScriptEngineParameters& engineParameters);
100 virtual ~ScriptEngine()
override
104 const std::map<std::string, std::string>& code()
const;
105 const std::vector<std::string>& outputList()
const;
106 const std::map<std::string, std::string>& weights()
const;
107 const std::vector<std::string>& orderedWeightsNames()
const;
109 const std::map<std::string, std::string>& parameters()
const;
110 const std::map<std::string, Lexem>& parametersParsed()
const;
111 const std::vector<std::string>& optionLegends()
const;
112 std::map<std::string, std::string>& optionValues();
113 const std::string& optionType(
const std::string& name)
const;
115 const SynParser::Node& fwdSyntree(
const std::string&
id)
const;
116 const std::string& entry(
const std::string&
id)
const;
117 const std::vector<AnyOperand>& ops(
const std::string&
id)
const;
118 const TensorDimension& stepping(
const std::string&
id)
const;
119 const std::string& fwdEvalCall(
const std::string&
id)
const;
121 const std::string& dimNchwCalculation()
const;
122 const std::vector<SynParser::Node>& dimSyntree()
const;
124 std::map<int, std::vector<uint8_t>>& bufferArgs();
126 bool cudaExists()
const;
127 bool cubinExists()
const;
128 bool dxmlExists()
const;
129 bool csExists()
const;
131 bool actInternal()
const;
134 T interpretParameter(ScriptInterpretInterface<T>* pInterpret,
const SynParser::Node& node,
bool call,
int argIndex = -1)
138 switch (node.lex.type)
143 pInterpret->setReal(rv, node.lex.value);
147 pInterpret->setInteger(rv, node.lex.value);
150 if (argIndex >= 0 && call)
152 pInterpret->setBufferPointer(rv, node.lex.value, argIndex);
156 if (node.lex.value ==
"UNAR")
158 rv = interpretParameter<T>(pInterpret, node.child[1], call);
161 else if (node.lex.value ==
"ADDSUB" || node.lex.value ==
"MULDIV")
163 rv = interpretParameter<T>(pInterpret, node.child[0], call);
164 for (
size_t i = 1; i < node.child.size(); i += 2)
166 T op = interpretParameter<T>(pInterpret, node.child[i + 1], call);
167 if (node.child[i].lex.value ==
"+")
168 pInterpret->add(rv, op);
169 else if (node.child[i].lex.value ==
"-")
170 pInterpret->sub(rv, op);
171 else if (node.child[i].lex.value ==
"*")
172 pInterpret->mul(rv, op);
173 else if (node.child[i].lex.value ==
"/")
174 pInterpret->div(rv, op);
177 else if (node.lex.value ==
"OPTION")
181 if (m_optionType[node.child[0].lex.value] ==
"float")
183 pInterpret->setRealOption(rv, node.child[0].lex.value);
187 pInterpret->setIntegerOption(rv, node.child[0].lex.value);
191 else if (node.lex.value ==
"CALL")
193 if (node.child.size() == 2 && call)
195 rv = interpretParameter<T>(pInterpret, node.child[1], call);
196 pInterpret->transform(rv, node.child[0].lex.value);
199 else if (node.lex.value ==
"VAR")
201 bool dual_id = node.child.size() > 1;
202 std::string
id = node.child[0].child[0].lex.value;
203 std::string second_id;
205 second_id = node.child[1].child[0].lex.value;
206 bool has_index = node.child[0].child.size() > 1;
207 int index = has_index ? std::atoi(node.child[0].child[1].lex.value.c_str()) : 0;
213 if ((second_id ==
"alpha" || second_id ==
"alpha_left") && call)
214 pInterpret->setActAlhpaLeft(rv);
215 else if (second_id ==
"alpha_right" && call)
216 pInterpret->setActAlhpaRight(rv);
220 pInterpret->setActType(rv);
222 m_actInternal =
true;
224 else if (
id ==
"weights" && dual_id)
226 pInterpret->setWeightsPointer(rv, second_id);
230 if (dual_id || has_index)
238 pInterpret->setInputPointer(rv, index);
243 pInterpret->setInputDimValue(rv, index, second_id);
248 pInterpret->setInputDimValue(rv, -1, second_id);
253 pInterpret->setInputPointer(rv, -1);
256 else if (indexOf(m_outputList,
id) >= 0 && call)
260 pInterpret->setOutputPointer(rv,
id);
264 pInterpret->setOutputDimValue(rv,
id, second_id);
267 else if (
id ==
"out" && call)
271 pInterpret->setOutputPointer(rv,
"");
275 if (index <
int(m_outputList.size()))
277 pInterpret->setOutputDimValue(rv, m_outputList[index], second_id);
282 pInterpret->setOutputDimValue(rv,
"", second_id);
285 else if (m_parameters.find(
id) != m_parameters.end())
287 if (m_parametersParsed[
id].type == Lexem::INTEGER)
289 pInterpret->setInteger(rv, m_parametersParsed[
id].value);
291 else if (m_parametersParsed[
id].type == Lexem::FLOAT && call)
293 pInterpret->setReal(rv, m_parametersParsed[
id].value);
295 else if (m_parametersParsed[
id].type == Lexem::DIMENSION && dual_id)
297 std::vector<std::string> dim =
split(m_parametersParsed[
id].value,
'x',
false);
298 for (
size_t i = dim.size(); i < 3; i++)
302 if (second_id ==
"x")
304 pInterpret->setInteger(rv, dim[0]);
306 else if (second_id ==
"y")
308 pInterpret->setInteger(rv, dim[1]);
310 else if (second_id ==
"z")
312 pInterpret->setInteger(rv, dim[2]);
325 void setReal(AnyOperand& v,
const std::string& )
override;
326 void setInteger(AnyOperand& v,
const std::string& )
override;
327 void setBufferPointer(AnyOperand& v,
const std::string& str,
int argIndex)
override;
328 void neg(AnyOperand& v)
override;
329 void add(AnyOperand& v, AnyOperand& op)
override;
330 void sub(AnyOperand& v, AnyOperand& op)
override;
331 void mul(AnyOperand& v, AnyOperand& op)
override;
332 void div(AnyOperand& v, AnyOperand& op)
override;
333 void setRealOption(AnyOperand& v,
const std::string& )
override;
334 void setIntegerOption(AnyOperand& v,
const std::string& )
override;
335 void transform(AnyOperand& v,
const std::string& str)
override;
336 void setActType(AnyOperand& v)
override;
337 void setActAlhpaLeft(AnyOperand& v)
override;
338 void setActAlhpaRight(AnyOperand& v)
override;
339 void setWeightsPointer(AnyOperand& v,
const std::string& str)
override;
340 void setInputPointer(AnyOperand& v,
int index)
override;
341 void setInputDimValue(AnyOperand& v,
int index,
const std::string& str)
override;
342 void setOutputPointer(AnyOperand& v,
const std::string& name)
override;
343 void setOutputDimValue(AnyOperand& v,
const std::string& name,
const std::string& str)
override;
345 std::string generateParameter(
const SynParser::Node& node,
bool call,
const std::string& pointerCast,
bool inPlugin,
int argIndex);
347 void setReal(std::string& v,
const std::string& str)
override;
348 void setInteger(std::string& v,
const std::string& str)
override;
349 void setBufferPointer(std::string& v,
const std::string& str,
int argIndex)
override;
350 void neg(std::string& v)
override;
351 void add(std::string& v, std::string& op)
override;
352 void sub(std::string& v, std::string& op)
override;
353 void mul(std::string& v, std::string& op)
override;
354 void div(std::string& v, std::string& op)
override;
355 void setRealOption(std::string& v,
const std::string& str)
override;
356 void setIntegerOption(std::string& v,
const std::string& str)
override;
357 void transform(std::string& v,
const std::string& str)
override;
358 void setActType(std::string& v)
override;
359 void setActAlhpaLeft(std::string& v)
override;
360 void setActAlhpaRight(std::string& v)
override;
361 void setWeightsPointer(std::string& v,
const std::string& str)
override;
362 void setInputPointer(std::string& v,
int index)
override;
363 void setInputDimValue(std::string& v,
int index,
const std::string& str)
override;
364 void setOutputPointer(std::string& v,
const std::string& name)
override;
365 void setOutputDimValue(std::string& v,
const std::string& name,
const std::string& str)
override;
372 std::string m_pointerCast;
377 bool m_actInternal =
false;
379 bool m_cudaCodeExists =
false;
380 bool m_cubinCodeExists =
false;
381 bool m_dxmlCodeExists =
false;
382 bool m_csCodeExists =
false;
384 int m_numberOfInputs;
386 std::map<std::string, std::string> m_code;
388 std::string m_dimNCHWCalculation;
389 std::map<std::string, std::string> m_fwdEvalCall;
391 std::map<std::string, std::string> m_parameters;
392 std::vector<std::string> m_optionLegends;
393 std::map<std::string, std::string> m_optionType;
394 std::map<std::string, std::string> m_optionValues;
396 std::map<std::string, Lexem> m_parametersParsed;
398 std::map<std::string, SynParser::Node> m_fwdSyntreeAll;
399 std::vector<SynParser::Node> m_dimSyntree;
402 bool m_genCheckN =
false;
403 bool m_genCheckC =
false;
404 bool m_genCheckH =
false;
405 bool m_genCheckW =
false;
407 std::vector<std::string> m_outputList;
409 std::map<std::string, std::vector<AnyOperand> > m_opsAll;
410 std::map<std::string, std::string> m_entryAll;
411 std::map<int, std::vector<uint8_t>> m_bufferArgs;
413 std::map<std::string, std::string> m_weights;
414 std::map<std::string, SynParser::Node> m_weightsSyntree;
415 std::vector<std::string> m_orderedWeightsNames;
417 std::map<std::string, TensorDimension> m_steppingValueAll;
std::vector< std::string > split(std::string strToSplit, char delimiter, bool skip_empty=false)
Splits a string by a delimiter, returns a vector of the split strings.
Definition: CoreHelpers.h:155
Fundamental NvNeural data types are declared here.
Internal helper classes for parsing script declarations.
Internal helper classes for parsing script declarations.
Internal helper classes for parsing script declarations.