26 #ifndef NVNEURAL_MODELPREPROCESSOR_H
27 #define NVNEURAL_MODELPREPROCESSOR_H
30 #include <Pugixml/pugixml.hpp>
32 #include <unordered_map>
33 #include <unordered_set>
36 namespace nvneural {
namespace detail {
44 class XmlModelPreprocessor
47 explicit XmlModelPreprocessor(XmlModelPreprocessor* pParent =
nullptr,
bool inlineModels =
false);
48 ~XmlModelPreprocessor();
51 void addCustomLayer(
const std::string& layer);
53 enum class FusingRuleResult : std::uint16_t
59 FusingRuleResult addFusingScript(
const std::string& script);
62 FusingRuleResult addScopedFusingScript(
const TensorFormat& currentScope,
const std::string& script);
63 void setScriptDebug(
bool enable);
65 void addCutLayer(
const std::string name);
67 bool loadModelText(
const std::string& text,
bool autoUpdateModelVersion =
true);
68 bool loadModelFile(
const std::string& fileName,
bool autoUpdateModelVersion =
true);
69 bool generateLinked(pugi::xml_document& doc);
70 std::string serializeCurrent()
const;
72 bool supportedByEditor()
const;
74 static void pushIncludeDirectory(
const std::string& path);
83 const std::vector<ioDesc>& getOutputs()
const;
84 const std::vector<ioDesc>& getInputs()
const;
86 int searchOutputDescIndex(
const std::string& layerName)
const;
89 static void cloneXmlNode(pugi::xml_node& dest,
const pugi::xml_node& src);
90 static bool isCustomScripted(
const pugi::xml_node& layer);
91 static bool isInputTypeLayer(
const std::string& type);
93 const static std::vector<std::string> m_prototypeCode;
95 int currentModelVersion()
const;
96 static int actualModelVersion();
97 void updateModelVersion();
99 std::string getLastError()
const;
101 std::unordered_map<std::string, std::vector<std::string>> getInnerInlinedTemplates()
const;
102 std::unordered_map<std::string, std::string> getSynonyms()
const {
return m_synonyms; }
104 std::unordered_map<std::string, std::shared_ptr<XmlModelPreprocessor>> getInnerTemplates()
const;
111 enum TypeCode : uint8_t
122 TypeCode type = FATAL_ERROR;
125 std::vector<rulex> lexParser(
const std::string& text);
127 struct ruleSyntaxTreeNode
144 TypeCode type = TypeCode::FATAL_ERROR;
147 std::vector<ruleSyntaxTreeNode> params;
149 ruleSyntaxTreeNode syntaxParser(std::vector<rulex>& lexemes);
150 ruleSyntaxTreeNode syntaxParseLayer(std::vector<rulex>& lexemes,
size_t& offset);
151 ruleSyntaxTreeNode syntaxParseConditionOr(std::vector<rulex>& lexemes,
size_t& offset);
152 ruleSyntaxTreeNode syntaxParseInputs(std::vector<rulex>& lexemes,
size_t& offset);
153 ruleSyntaxTreeNode syntaxParseSelection(std::vector<rulex>& lexemes,
size_t& offset);
154 ruleSyntaxTreeNode syntaxParseConditionAnd(std::vector<rulex>& lexemes,
size_t& offset);
155 ruleSyntaxTreeNode syntaxParseComparison(std::vector<rulex>& lexemes,
size_t& offset);
156 ruleSyntaxTreeNode syntaxParseOperand(std::vector<rulex>& lexemes,
size_t& offset);
157 ruleSyntaxTreeNode syntaxParseRange(std::vector<rulex>& lexemes,
size_t& offset);
158 void syntaxError(
const ruleSyntaxTreeNode& errorResult,
size_t offset)
const;
160 bool m_debugFusingScript =
false;
161 void printSyntaxTree(
const ruleSyntaxTreeNode& node,
int indent)
const;
163 std::vector<ruleSyntaxTreeNode> m_fusingScripts;
165 static std::vector<std::string> s_includeDirs;
166 static std::string findFile(
const std::string& fileName);
168 XmlModelPreprocessor* m_pParent;
170 std::string m_fileName;
171 bool checkExternalCollision(
const std::string& fileName)
const;
173 std::string m_lastError;
174 void setError(
const std::string& text);
175 void setError(
const std::string& text,
const pugi::xml_node& node);
176 void addError(
const std::string& text);
178 struct fusedTemplateLayerNode
181 pugi::xml_node templateNode;
182 std::string templateInstanceScope;
185 std::size_t fuseIndex;
188 struct modelLayerNode
190 std::string finalName;
191 std::vector<std::string> inputList;
192 pugi::xml_node xmlNode;
193 pugi::xml_node templateNode;
197 std::string templateInstanceScope;
198 bool isFused =
false;
200 std::vector<fusedTemplateLayerNode> fusedTemplateLayerNodes;
203 std::string inputLayerName(std::string input);
204 std::string inputLayerName(modelLayerNode &node,
size_t index);
206 std::vector<std::string> m_inputs;
207 std::vector<ioDesc> m_inputDescriptors;
208 std::unordered_map<std::string, modelLayerNode> m_table;
209 std::vector<std::string> m_outputs;
210 std::unordered_map<std::string, std::string> m_outputLink;
211 std::vector<ioDesc> m_outputDescriptors;
212 std::unordered_map<std::string, std::string> m_customLayers;
213 std::unordered_set<std::string> m_cutLayers;
214 std::unordered_map<std::string, std::vector<std::string>> m_innerInlinedTemplateLayers;
215 std::unordered_map<std::string, std::string> m_synonyms;
217 std::vector<pugi::xml_node> m_extraNodes;
219 bool tryFusingLayer(modelLayerNode& layer, ruleSyntaxTreeNode& rule);
220 std::string collectFusingInputs(modelLayerNode& layer, std::vector<std::string>& fusingInputs);
221 bool interpretLayer(modelLayerNode& layer, ruleSyntaxTreeNode& rule, std::vector<std::string>& fused);
222 bool interpretCondition(modelLayerNode& layer, ruleSyntaxTreeNode& rule);
223 bool interpretInputs(modelLayerNode& layer, ruleSyntaxTreeNode& rule, std::vector<std::string>& fused,
size_t& in_index);
224 bool interpretSelection(modelLayerNode& layer, ruleSyntaxTreeNode& rule, std::vector<std::string>& fused,
size_t& in_index);
225 bool interpretParameter(modelLayerNode& layer, ruleSyntaxTreeNode& rule, rulex& op);
226 bool interpretCheckBreak(ruleSyntaxTreeNode& rule,
int count);
228 XmlModelPreprocessor* getSubgraphAsModel(
229 const std::vector<std::string>& inputs,
230 const std::vector<std::string>& outputs);
233 pugi::xml_node& network,
234 const XmlModelPreprocessor* pTempModel,
235 std::unordered_map<std::string, std::string> &synonym,
236 const std::string& name,
237 pugi::xml_node& layer,
238 const std::string& input,
239 bool isInnerTemplate);
241 pugi::xml_document m_model;
243 bool parseModel(pugi::xml_node networkModel,
int version);
244 bool backWayCollector(
245 const std::string& node,
246 std::unordered_set<std::string>& collection,
247 std::vector<std::string>& controlStack,
248 const std::vector<std::string>* pStopList);
250 std::unordered_map<std::string, std::shared_ptr<XmlModelPreprocessor>> m_innerTemplates;
251 std::unordered_map<std::string, std::unique_ptr<XmlModelPreprocessor>> m_externTemplates;
252 std::unordered_map<std::string, std::unique_ptr<XmlModelPreprocessor>> m_patternTemplates;
254 enum class FusingRuleLayoutScope
261 enum class FusingRuleDataTypeScope
268 bool preprocessFusingScriptScope(std::vector<rulex>& lexemes, FusingRuleLayoutScope& layoutScope, FusingRuleDataTypeScope& dataTypeScope)
const;
269 bool checkAndAddFusingScript(
const std::vector<rulex> lexemes);
270 bool checkFusingScript(
const std::vector<rulex>& lexemes);
Fundamental NvNeural data types are declared here.
@ Float
32-bit floating point elements (float)
@ Half
16-bit floating point elements (__half)
@ Success
Operation succeeded. Generic result.
@ Failure
Operation failed. Generic result.