26 #ifndef NVNEURAL_MODELPREPROCESSOR_H
27 #define NVNEURAL_MODELPREPROCESSOR_H
30 #include <Pugixml/pugixml.hpp>
32 #include <unordered_map>
33 #include <unordered_set>
36 namespace nvneural {
namespace detail {
44 class XmlModelPreprocessor
47 explicit XmlModelPreprocessor(XmlModelPreprocessor* pParent =
nullptr,
bool inlineModels =
false);
48 ~XmlModelPreprocessor();
51 void addCustomLayer(
const std::string& layer);
53 enum class FusingRuleResult : std::uint16_t
59 FusingRuleResult addFusingScript(
const std::string& script);
62 FusingRuleResult addScopedFusingScript(
const TensorFormat& currentScope,
const std::string& script);
63 void setScriptDebug(
bool enable);
65 void addCutLayer(
const std::string name);
67 bool loadModelText(
const std::string& text,
bool autoUpdateModelVersion =
true);
68 bool loadModelFile(
const std::string& fileName,
bool autoUpdateModelVersion =
true);
69 bool generateLinked(pugi::xml_document& doc);
70 std::string serializeCurrent()
const;
72 bool supportedByEditor()
const;
74 static void pushIncludeDirectory(
const std::string& path);
82 const std::vector<ioDesc>& getOutputs()
const;
83 const std::vector<ioDesc>& getInputs()
const;
85 int searchOutputDescIndex(
const std::string& layerName)
const;
88 static void cloneXmlNode(pugi::xml_node& dest,
const pugi::xml_node& src);
89 static bool isCustomScripted(
const pugi::xml_node& layer);
90 static bool isInputTypeLayer(
const std::string& type);
92 const static std::vector<std::string> m_prototypeCode;
94 int currentModelVersion()
const;
95 static int actualModelVersion();
96 void updateModelVersion();
98 std::string getLastError()
const;
105 enum TypeCode : uint8_t
116 TypeCode type = FATAL_ERROR;
119 std::vector<rulex> lexParser(
const std::string& text);
121 struct ruleSyntaxTreeNode
138 TypeCode type = TypeCode::FATAL_ERROR;
141 std::vector<ruleSyntaxTreeNode> params;
143 ruleSyntaxTreeNode syntaxParser(std::vector<rulex>& lexemes);
144 ruleSyntaxTreeNode syntaxParseLayer(std::vector<rulex>& lexemes,
size_t& offset);
145 ruleSyntaxTreeNode syntaxParseConditionOr(std::vector<rulex>& lexemes,
size_t& offset);
146 ruleSyntaxTreeNode syntaxParseInputs(std::vector<rulex>& lexemes,
size_t& offset);
147 ruleSyntaxTreeNode syntaxParseSelection(std::vector<rulex>& lexemes,
size_t& offset);
148 ruleSyntaxTreeNode syntaxParseConditionAnd(std::vector<rulex>& lexemes,
size_t& offset);
149 ruleSyntaxTreeNode syntaxParseComparison(std::vector<rulex>& lexemes,
size_t& offset);
150 ruleSyntaxTreeNode syntaxParseOperand(std::vector<rulex>& lexemes,
size_t& offset);
151 ruleSyntaxTreeNode syntaxParseRange(std::vector<rulex>& lexemes,
size_t& offset);
152 void syntaxError(
const ruleSyntaxTreeNode& errorResult,
size_t offset)
const;
154 bool m_debugFusingScript =
false;
155 void printSyntaxTree(
const ruleSyntaxTreeNode& node,
int indent)
const;
157 std::vector<ruleSyntaxTreeNode> m_fusingScripts;
159 static std::vector<std::string> s_includeDirs;
160 static std::string findFile(
const std::string& fileName);
162 XmlModelPreprocessor* m_pParent;
164 std::string m_fileName;
165 bool checkExternalCollision(
const std::string& fileName)
const;
167 std::string m_lastError;
168 void setError(
const std::string& text);
169 void setError(
const std::string& text,
const pugi::xml_node& node);
170 void addError(
const std::string& text);
172 struct modelLayerNode
174 std::string finalName;
175 std::vector<std::string> inputList;
176 pugi::xml_node xmlNode;
177 pugi::xml_node templateNode;
180 std::string inputLayerName(std::string input);
181 std::string inputLayerName(modelLayerNode &node,
size_t index);
183 std::vector<std::string> m_inputs;
184 std::vector<ioDesc> m_inputDescriptors;
185 std::unordered_map<std::string, modelLayerNode> m_table;
186 std::vector<std::string> m_outputs;
187 std::unordered_map<std::string, std::string> m_outputLink;
188 std::vector<ioDesc> m_outputDescriptors;
189 std::unordered_map<std::string, std::string> m_customLayers;
190 std::unordered_set<std::string> m_cutLayers;
192 std::vector<pugi::xml_node> m_extraNodes;
194 bool tryFusingLayer(modelLayerNode& layer, ruleSyntaxTreeNode& rule);
195 std::string collectFusingInputs(modelLayerNode& layer, std::vector<std::string>& fusingInputs);
196 bool interpretLayer(modelLayerNode& layer, ruleSyntaxTreeNode& rule, std::vector<std::string>& fused);
197 bool interpretCondition(modelLayerNode& layer, ruleSyntaxTreeNode& rule);
198 bool interpretInputs(modelLayerNode& layer, ruleSyntaxTreeNode& rule, std::vector<std::string>& fused,
size_t& in_index);
199 bool interpretSelection(modelLayerNode& layer, ruleSyntaxTreeNode& rule, std::vector<std::string>& fused,
size_t& in_index);
200 bool interpretParameter(modelLayerNode& layer, ruleSyntaxTreeNode& rule, rulex& op);
201 bool interpretCheckBreak(ruleSyntaxTreeNode& rule,
int count);
203 XmlModelPreprocessor* getSubgraphAsModel(
204 const std::vector<std::string>& inputs,
205 const std::vector<std::string>& outputs);
208 const XmlModelPreprocessor* pTempModel,
209 std::unordered_map<std::string, std::string> &synonym,
210 const std::string& name,
211 pugi::xml_node& layer,
212 const std::string& input);
214 pugi::xml_document m_model;
216 bool parseModel(pugi::xml_node networkModel,
int version);
217 bool backWayCollector(
218 const std::string& node,
219 std::unordered_set<std::string>& collection,
220 std::vector<std::string>& controlStack,
221 const std::vector<std::string>* pStopList);
223 std::unordered_map<std::string, std::unique_ptr<XmlModelPreprocessor>> m_innerTemplates;
224 std::unordered_map<std::string, std::unique_ptr<XmlModelPreprocessor>> m_externTemplates;
225 std::unordered_map<std::string, std::unique_ptr<XmlModelPreprocessor>> m_patternTemplates;
227 enum class FusingRuleLayoutScope
234 enum class FusingRuleDataTypeScope
241 bool preprocessFusingScriptScope(std::vector<rulex>& lexemes, FusingRuleLayoutScope& layoutScope, FusingRuleDataTypeScope& dataTypeScope)
const;
242 bool checkAndAddFusingScript(
const std::vector<rulex> lexemes);
243 bool checkFusingScript(
const std::vector<rulex>& lexemes);
Fundamental NvNeural data types are declared here.
@ Float
32-bit floating point elements (float)
@ Half
16-bit floating point elements (__half)
@ Success
Operation succeeded. Generic result.
@ Failure
Operation failed. Generic result.