26 #ifndef NVNEURAL_MODELPREPROCESSOR_H
27 #define NVNEURAL_MODELPREPROCESSOR_H
30 #include <Pugixml/pugixml.hpp>
32 #include <unordered_map>
33 #include <unordered_set>
36 namespace nvneural {
namespace detail {
44 class XmlModelPreprocessor
47 explicit XmlModelPreprocessor(XmlModelPreprocessor* pParent =
nullptr,
bool inlineModels =
false);
48 ~XmlModelPreprocessor();
51 void addCustomLayer(
const std::string& layer);
53 enum class FusingRuleResult : std::uint16_t
59 FusingRuleResult addFusingScript(
const std::string& script);
62 FusingRuleResult addScopedFusingScript(
const TensorFormat& currentScope,
const std::string& script);
63 void setScriptDebug(
bool enable);
65 void addCutLayer(
const std::string name);
67 bool loadModelText(
const std::string& text,
bool autoUpdateModelVersion =
true);
68 bool loadModelFile(
const std::string& fileName,
bool autoUpdateModelVersion =
true);
69 bool generateLinked(pugi::xml_document& doc);
70 std::string serializeCurrent()
const;
72 bool supportedByEditor()
const;
74 static void pushIncludeDirectory(
const std::string& path);
82 const std::vector<ioDesc>& getOutputs()
const;
83 const std::vector<ioDesc>& getInputs()
const;
85 int searchOutputDescIndex(
const std::string& layerName)
const;
88 static void cloneXmlNode(pugi::xml_node& dest,
const pugi::xml_node& src);
89 static bool isCustomScripted(
const pugi::xml_node& layer);
90 static bool isInputTypeLayer(
const std::string& type);
92 const static std::vector<std::string> m_prototypeCode;
94 int currentModelVersion()
const;
95 static int actualModelVersion();
96 void updateModelVersion();
98 std::string getLastError()
const;
100 std::unordered_map<std::string, std::vector<std::string>> getInnerInlinedTemplates()
const;
101 std::unordered_map<std::string, std::string> getSynonyms()
const {
return m_synonyms; }
108 enum TypeCode : uint8_t
119 TypeCode type = FATAL_ERROR;
122 std::vector<rulex> lexParser(
const std::string& text);
124 struct ruleSyntaxTreeNode
141 TypeCode type = TypeCode::FATAL_ERROR;
144 std::vector<ruleSyntaxTreeNode> params;
146 ruleSyntaxTreeNode syntaxParser(std::vector<rulex>& lexemes);
147 ruleSyntaxTreeNode syntaxParseLayer(std::vector<rulex>& lexemes,
size_t& offset);
148 ruleSyntaxTreeNode syntaxParseConditionOr(std::vector<rulex>& lexemes,
size_t& offset);
149 ruleSyntaxTreeNode syntaxParseInputs(std::vector<rulex>& lexemes,
size_t& offset);
150 ruleSyntaxTreeNode syntaxParseSelection(std::vector<rulex>& lexemes,
size_t& offset);
151 ruleSyntaxTreeNode syntaxParseConditionAnd(std::vector<rulex>& lexemes,
size_t& offset);
152 ruleSyntaxTreeNode syntaxParseComparison(std::vector<rulex>& lexemes,
size_t& offset);
153 ruleSyntaxTreeNode syntaxParseOperand(std::vector<rulex>& lexemes,
size_t& offset);
154 ruleSyntaxTreeNode syntaxParseRange(std::vector<rulex>& lexemes,
size_t& offset);
155 void syntaxError(
const ruleSyntaxTreeNode& errorResult,
size_t offset)
const;
157 bool m_debugFusingScript =
false;
158 void printSyntaxTree(
const ruleSyntaxTreeNode& node,
int indent)
const;
160 std::vector<ruleSyntaxTreeNode> m_fusingScripts;
162 static std::vector<std::string> s_includeDirs;
163 static std::string findFile(
const std::string& fileName);
165 XmlModelPreprocessor* m_pParent;
167 std::string m_fileName;
168 bool checkExternalCollision(
const std::string& fileName)
const;
170 std::string m_lastError;
171 void setError(
const std::string& text);
172 void setError(
const std::string& text,
const pugi::xml_node& node);
173 void addError(
const std::string& text);
175 struct modelLayerNode
177 std::string finalName;
178 std::vector<std::string> inputList;
179 pugi::xml_node xmlNode;
180 pugi::xml_node templateNode;
183 std::string inputLayerName(std::string input);
184 std::string inputLayerName(modelLayerNode &node,
size_t index);
186 std::vector<std::string> m_inputs;
187 std::vector<ioDesc> m_inputDescriptors;
188 std::unordered_map<std::string, modelLayerNode> m_table;
189 std::vector<std::string> m_outputs;
190 std::unordered_map<std::string, std::string> m_outputLink;
191 std::vector<ioDesc> m_outputDescriptors;
192 std::unordered_map<std::string, std::string> m_customLayers;
193 std::unordered_set<std::string> m_cutLayers;
194 std::unordered_map<std::string, std::vector<std::string>> m_innerInlinedTemplateLayers;
195 std::unordered_map<std::string, std::string> m_synonyms;
197 std::vector<pugi::xml_node> m_extraNodes;
199 bool tryFusingLayer(modelLayerNode& layer, ruleSyntaxTreeNode& rule);
200 std::string collectFusingInputs(modelLayerNode& layer, std::vector<std::string>& fusingInputs);
201 bool interpretLayer(modelLayerNode& layer, ruleSyntaxTreeNode& rule, std::vector<std::string>& fused);
202 bool interpretCondition(modelLayerNode& layer, ruleSyntaxTreeNode& rule);
203 bool interpretInputs(modelLayerNode& layer, ruleSyntaxTreeNode& rule, std::vector<std::string>& fused,
size_t& in_index);
204 bool interpretSelection(modelLayerNode& layer, ruleSyntaxTreeNode& rule, std::vector<std::string>& fused,
size_t& in_index);
205 bool interpretParameter(modelLayerNode& layer, ruleSyntaxTreeNode& rule, rulex& op);
206 bool interpretCheckBreak(ruleSyntaxTreeNode& rule,
int count);
208 XmlModelPreprocessor* getSubgraphAsModel(
209 const std::vector<std::string>& inputs,
210 const std::vector<std::string>& outputs);
213 pugi::xml_node& network,
214 const XmlModelPreprocessor* pTempModel,
215 std::unordered_map<std::string, std::string> &synonym,
216 const std::string& name,
217 pugi::xml_node& layer,
218 const std::string& input,
219 bool isInnerTemplate);
221 pugi::xml_document m_model;
223 bool parseModel(pugi::xml_node networkModel,
int version);
224 bool backWayCollector(
225 const std::string& node,
226 std::unordered_set<std::string>& collection,
227 std::vector<std::string>& controlStack,
228 const std::vector<std::string>* pStopList);
230 std::unordered_map<std::string, std::unique_ptr<XmlModelPreprocessor>> m_innerTemplates;
231 std::unordered_map<std::string, std::unique_ptr<XmlModelPreprocessor>> m_externTemplates;
232 std::unordered_map<std::string, std::unique_ptr<XmlModelPreprocessor>> m_patternTemplates;
234 enum class FusingRuleLayoutScope
241 enum class FusingRuleDataTypeScope
248 bool preprocessFusingScriptScope(std::vector<rulex>& lexemes, FusingRuleLayoutScope& layoutScope, FusingRuleDataTypeScope& dataTypeScope)
const;
249 bool checkAndAddFusingScript(
const std::vector<rulex> lexemes);
250 bool checkFusingScript(
const std::vector<rulex>& lexemes);
Fundamental NvNeural data types are declared here.
@ Float
32-bit floating point elements (float)
@ Half
16-bit floating point elements (__half)
@ Success
Operation succeeded. Generic result.
@ Failure
Operation failed. Generic result.