Graph Surgeon
graphsurgeon allows you to transform TensorFlow graphs. Its capabilities are broadly divided into two categories: search and manipulation. Search functions allow you to find nodes in a TensorFlow graph. Manipulation functions allow you to modify, add, or remove nodes.
Node Creation
Allow you to create free standing TensorFlow nodes, which can be used as stand-ins for plugins.
- graphsurgeon.create_node(name, op=None, trt_plugin=False, **kwargs)
 Creates a free-standing TensorFlow NodeDef with the specified properties.
- Parameters
 name (str) – The name of the node.
op (str) – The node’s operation.
- Keyword Arguments
 dtype (tensorflow.DType) – TensorFlow dtype.
shape (tuple(int)) – Iterable container (usually a tuple) describing the shape of a tensor.
inputs (list(tensorflow.NodeDef) or str) – Iterable container (usually a tuple) of input nodes or input node names. Supports mixed-type lists.
**kwargs (AttrName=Value) – Any additional fields that should be present in the node. Currently supports int, float, bool, list(int), list(float), str and NumPy arrays. NumPy arrays will be inserted into the “value” attribute of the node - this can be useful for creating constant nodes equivalent to those created by tensorflow.constant.
- Returns
 tensorflow.NodeDef
- graphsurgeon.create_plugin_node(name, op=None, **kwargs)
 Creates a free-standing TensorFlow NodeDef with the specified properties. This is similar to create_node,
- Parameters
 name (str) – The name of the node.
op (str) – The node’s operation.
dtype (tensorflow.DType) – TensorFlow dtype.
shape (tuple(int)) – Iterable container (usually a tuple) describing the shape of a tensor.
inputs (list(tensorflow.NodeDef) or str) – Iterable container (usually a tuple) of input nodes or input node names. Supports mixed-type lists.
**kwargs (AttrName=Value) – Any additional fields that should be present in the node. Currently supports int, float, bool, list(int), list(float) and str.
- Returns
 tensorflow.NodeDef
Static Graph
- class graphsurgeon.StaticGraph(graphdef=None)
 Acts as a thin wrapper for a read-only TensorFlow GraphDef. Supports indexing based on node name/index as well as iteration over nodes using Python’s
for node in static_graphsyntax.- Parameters
 graphdef (tensorflow.GraphDef/tensorflow.Graph OR graphsurgeon.StaticGraph/graphsurgeon.DynamicGraph OR str) – A TensorFlow GraphDef/Graph or a StaticGraph from which to construct this graph, or a string containing a path to a frozen model.
- node_outputs
 A mapping of node names to their respective output nodes.
- Type
 dict(str, list(tensorflow.NodeDef))
- node_map
 A mapping of node names to their corresponding nodes.
- Type
 dict(str, tensorflow.NodeDef)
- graph_outputs
 A list of likely outputs of the graph.
- Type
 list(tensorflow.NodeDef)
- graph_inputs
 A list of likely inputs of the graph.
- Type
 list(tensorflow.NodeDef)
- as_graph_def()
 Returns this StaticGraph’s internal TensorFlow GraphDef.
- Parameters
 None –
- Returns
 tensorflow.GraphDef
- find_node_chains_by_op(chain)
 Finds groups of nodes in this graph that match the specified sequence of ops. Returns a list of matching chains of nodes, with ordering preserved.
- Parameters
 chain (list(str)) – The sequence of ops to look for. Should be ordered with the input of the chain as the first element, and the output as the last.
- Returns
 list(list(tensorflow.NodeDef))
- find_node_inputs(node)
 Finds input nodes of a given node.
- Parameters
 node (tensorflow.NodeDef) – The node in which to perform the search.
- Returns
 list(tensorflow.NodeDef)
- find_node_inputs_by_name(node, name)
 Finds input nodes of a given node based on their names.
- Parameters
 node (tensorflow.NodeDef) – The node in which to perform the search.
name (str OR list(str)) – The name to look for. Also accepts iterable containers (preferably a list) to search for multiple names in a single pass. Supports regular expressions.
- Returns
 list(tensorflow.NodeDef)
- find_node_inputs_by_op(node, op)
 Finds input nodes of a given node based on their ops.
- Parameters
 node (tensorflow.NodeDef) – The node in which to perform the search.
op (str OR list(str)) – The op to look for. Also accepts iterable containers (preferably a list) to search for multiple op in a single pass.
- Returns
 list(tensorflow.NodeDef)
- find_nodes_by_name(name)
 Finds nodes in this graph based on their names.
- Parameters
 name (str OR list(str)) – The name to look for. Also accepts iterable containers (preferably a list) to search for multiple names in a single pass of the graph. Supports regular expressions.
- Returns
 list(tensorflow.NodeDef)
- find_nodes_by_op(op)
 Finds nodes in this graph based on their ops.
- Parameters
 op (str OR set(str)) – The op to look for. Also accepts iterable containers (preferably hashsets) to search for multiple ops in a single pass of the graph.
- Returns
 list(tensorflow.NodeDef)
- find_nodes_by_path(path)
 Finds nodes in this graph based on their full paths. This will only match exact paths.
- Parameters
 path (str OR list(str)) – The path to look for. Also accepts iterable containers (preferably a list) to search for multiple paths in a single pass of the graph. Supports regular expressions.
- Returns
 list(tensorflow.NodeDef)
- read(filename)
 Reads a frozen protobuf file into this StaticGraph.
- Parameters
 filename (str) – Name of the protobuf file.
- Returns
 None
- write(filename)
 Writes the StaticGraph’s internal TensorFlow GraphDef into a frozen protobuf file.
- Parameters
 filename (str) – Name of the protobuf file to write.
- Returns
 None
- write_tensorboard(logdir)
 Writes the StaticGraph’s internal TensorFlow GraphDef into the specified directory, which can then be visualized in TensorBoard.
- Parameters
 logdir (str) – Name of the directory to write.
- Returns
 None
- Raises
 Warning – Passing a GraphDef to the SummaryWriter is deprecated. Pass a Graph object instead, such as sess.graph.
This is a known warning, but currently there is no alternative, since TensorFlow will not be able to convert invalid GraphDefs back to Graphs. –
Dynamic Graph (Inherits from StaticGraph)
- class graphsurgeon.DynamicGraph(graphdef=None)
 A sub-class of StaticGraph that can search and modify a TensorFlow GraphDef.
- Parameters
 graphdef (tensorflow.GraphDef/tensorflow.Graph OR graphsurgeon.StaticGraph/graphsurgeon.DynamicGraph OR str) – A TensorFlow GraphDef/Graph or a StaticGraph/DynamicGraph from which to construct this graph, or a string containing the path to a frozen model.
- append(node)
 Appends a node to this graph.
- Parameters
 node (tensorflow.NodeDef) – TensorFlow NodeDef to add to the graph.
- Returns
 None
- collapse_namespaces(namespace_map, exclude_nodes=[], unique_inputs=True)
 Collapses nodes in namespaces to single nodes specified by the user, except where those nodes are marked for exclusion.
- Parameters
 namespace_map (dict(str, tensorflow.NodeDef)) – A dictionary specifying namespaces and their corresponding plugin nodes. These plugin nodes are typically used to specify attributes of the custom plugin, while inputs and outputs are automatically deduced. Multiple namespaces can be collapsed into a single plugin node, and nested namespaces are collapsed into plugin nodes outside their parent namespaces.
exclude_nodes (list(tensorflow.NodeDef)) – Iterable container (usually a list) of nodes which should NOT be collapsed. These nodes will be present in the final graph as either inputs or outputs of the plugin nodes.
unique_inputs (bool) – Whether inputs to the collapsed node should be unique. If this is false, plugin nodes may have duplicate inputs.
- Returns
 None
- extend(node_list)
 Extends this graph’s nodes based on the provided list.
- Parameters
 node_list (list(tensorflow.NodeDef)) – List of TensorFlow NodeDefs to add to the graph.
- Returns
 None
- forward_inputs(nodes)
 Removes nodes from this graph. Recursively forwards inputs, such that paths in the graph are preserved.
Warning: Nodes with control inputs are not removed, so as not to break the structure of the graph. If you need to forward these, remove their control inputs first.
- Parameters
 nodes (list(tensorflow.NodeDef))) – Iterable container (usually a list) of nodes which should be removed and whose inputs forwarded.
- Returns
 None
- remove(nodes, remove_exclusive_dependencies=False)
 Removes nodes from this graph. Does not forward inputs, so paths in the graph could be broken.
- Parameters
 nodes (list(tensorflow.NodeDef))) – Iterable container (usually a list) of nodes which should be removed.
remove_exclusive_dependencies (bool) – Whether to also remove dependencies exclusive to the nodes about to be removed. When set to True, all exclusive dependencies will be removed recursively, and the number of hanging nodes in the graph will remain constant. Defaults to False.
- Returns
 None