Graph Surgeon
graphsurgeon
allows you to transform TensorFlow graphs. Its capabilities are broadly divided into two categories: search and manipulation. Search functions allow you to find nodes in a TensorFlow graph. Manipulation functions allow you to modify, add, or remove nodes.
Node Creation
Allow you to create free standing TensorFlow nodes, which can be used as stand-ins for plugins.
- graphsurgeon.create_node(name, op=None, trt_plugin=False, **kwargs)
Creates a free-standing TensorFlow NodeDef with the specified properties.
- Parameters
name (str) – The name of the node.
op (str) – The node’s operation.
- Keyword Arguments
dtype (tensorflow.DType) – TensorFlow dtype.
shape (tuple(int)) – Iterable container (usually a tuple) describing the shape of a tensor.
inputs (list(tensorflow.NodeDef) or str) – Iterable container (usually a tuple) of input nodes or input node names. Supports mixed-type lists.
**kwargs (AttrName=Value) – Any additional fields that should be present in the node. Currently supports int, float, bool, list(int), list(float), str and NumPy arrays. NumPy arrays will be inserted into the “value” attribute of the node - this can be useful for creating constant nodes equivalent to those created by tensorflow.constant.
- Returns
tensorflow.NodeDef
- graphsurgeon.create_plugin_node(name, op=None, **kwargs)
Creates a free-standing TensorFlow NodeDef with the specified properties. This is similar to create_node,
- Parameters
name (str) – The name of the node.
op (str) – The node’s operation.
dtype (tensorflow.DType) – TensorFlow dtype.
shape (tuple(int)) – Iterable container (usually a tuple) describing the shape of a tensor.
inputs (list(tensorflow.NodeDef) or str) – Iterable container (usually a tuple) of input nodes or input node names. Supports mixed-type lists.
**kwargs (AttrName=Value) – Any additional fields that should be present in the node. Currently supports int, float, bool, list(int), list(float) and str.
- Returns
tensorflow.NodeDef
Static Graph
- class graphsurgeon.StaticGraph(graphdef=None)
Acts as a thin wrapper for a read-only TensorFlow GraphDef. Supports indexing based on node name/index as well as iteration over nodes using Python’s
for node in static_graph
syntax.- Parameters
graphdef (tensorflow.GraphDef/tensorflow.Graph OR graphsurgeon.StaticGraph/graphsurgeon.DynamicGraph OR str) – A TensorFlow GraphDef/Graph or a StaticGraph from which to construct this graph, or a string containing a path to a frozen model.
- node_outputs
A mapping of node names to their respective output nodes.
- Type
dict(str, list(tensorflow.NodeDef))
- node_map
A mapping of node names to their corresponding nodes.
- Type
dict(str, tensorflow.NodeDef)
- graph_outputs
A list of likely outputs of the graph.
- Type
list(tensorflow.NodeDef)
- graph_inputs
A list of likely inputs of the graph.
- Type
list(tensorflow.NodeDef)
- as_graph_def()
Returns this StaticGraph’s internal TensorFlow GraphDef.
- Parameters
None –
- Returns
tensorflow.GraphDef
- find_node_chains_by_op(chain)
Finds groups of nodes in this graph that match the specified sequence of ops. Returns a list of matching chains of nodes, with ordering preserved.
- Parameters
chain (list(str)) – The sequence of ops to look for. Should be ordered with the input of the chain as the first element, and the output as the last.
- Returns
list(list(tensorflow.NodeDef))
- find_node_inputs(node)
Finds input nodes of a given node.
- Parameters
node (tensorflow.NodeDef) – The node in which to perform the search.
- Returns
list(tensorflow.NodeDef)
- find_node_inputs_by_name(node, name)
Finds input nodes of a given node based on their names.
- Parameters
node (tensorflow.NodeDef) – The node in which to perform the search.
name (str OR list(str)) – The name to look for. Also accepts iterable containers (preferably a list) to search for multiple names in a single pass. Supports regular expressions.
- Returns
list(tensorflow.NodeDef)
- find_node_inputs_by_op(node, op)
Finds input nodes of a given node based on their ops.
- Parameters
node (tensorflow.NodeDef) – The node in which to perform the search.
op (str OR list(str)) – The op to look for. Also accepts iterable containers (preferably a list) to search for multiple op in a single pass.
- Returns
list(tensorflow.NodeDef)
- find_nodes_by_name(name)
Finds nodes in this graph based on their names.
- Parameters
name (str OR list(str)) – The name to look for. Also accepts iterable containers (preferably a list) to search for multiple names in a single pass of the graph. Supports regular expressions.
- Returns
list(tensorflow.NodeDef)
- find_nodes_by_op(op)
Finds nodes in this graph based on their ops.
- Parameters
op (str OR set(str)) – The op to look for. Also accepts iterable containers (preferably hashsets) to search for multiple ops in a single pass of the graph.
- Returns
list(tensorflow.NodeDef)
- find_nodes_by_path(path)
Finds nodes in this graph based on their full paths. This will only match exact paths.
- Parameters
path (str OR list(str)) – The path to look for. Also accepts iterable containers (preferably a list) to search for multiple paths in a single pass of the graph. Supports regular expressions.
- Returns
list(tensorflow.NodeDef)
- read(filename)
Reads a frozen protobuf file into this StaticGraph.
- Parameters
filename (str) – Name of the protobuf file.
- Returns
None
- write(filename)
Writes the StaticGraph’s internal TensorFlow GraphDef into a frozen protobuf file.
- Parameters
filename (str) – Name of the protobuf file to write.
- Returns
None
- write_tensorboard(logdir)
Writes the StaticGraph’s internal TensorFlow GraphDef into the specified directory, which can then be visualized in TensorBoard.
- Parameters
logdir (str) – Name of the directory to write.
- Returns
None
- Raises
Warning – Passing a GraphDef to the SummaryWriter is deprecated. Pass a Graph object instead, such as sess.graph.
This is a known warning, but currently there is no alternative, since TensorFlow will not be able to convert invalid GraphDefs back to Graphs. –
Dynamic Graph (Inherits from StaticGraph)
- class graphsurgeon.DynamicGraph(graphdef=None)
A sub-class of StaticGraph that can search and modify a TensorFlow GraphDef.
- Parameters
graphdef (tensorflow.GraphDef/tensorflow.Graph OR graphsurgeon.StaticGraph/graphsurgeon.DynamicGraph OR str) – A TensorFlow GraphDef/Graph or a StaticGraph/DynamicGraph from which to construct this graph, or a string containing the path to a frozen model.
- append(node)
Appends a node to this graph.
- Parameters
node (tensorflow.NodeDef) – TensorFlow NodeDef to add to the graph.
- Returns
None
- collapse_namespaces(namespace_map, exclude_nodes=[], unique_inputs=True)
Collapses nodes in namespaces to single nodes specified by the user, except where those nodes are marked for exclusion.
- Parameters
namespace_map (dict(str, tensorflow.NodeDef)) – A dictionary specifying namespaces and their corresponding plugin nodes. These plugin nodes are typically used to specify attributes of the custom plugin, while inputs and outputs are automatically deduced. Multiple namespaces can be collapsed into a single plugin node, and nested namespaces are collapsed into plugin nodes outside their parent namespaces.
exclude_nodes (list(tensorflow.NodeDef)) – Iterable container (usually a list) of nodes which should NOT be collapsed. These nodes will be present in the final graph as either inputs or outputs of the plugin nodes.
unique_inputs (bool) – Whether inputs to the collapsed node should be unique. If this is false, plugin nodes may have duplicate inputs.
- Returns
None
- extend(node_list)
Extends this graph’s nodes based on the provided list.
- Parameters
node_list (list(tensorflow.NodeDef)) – List of TensorFlow NodeDefs to add to the graph.
- Returns
None
- forward_inputs(nodes)
Removes nodes from this graph. Recursively forwards inputs, such that paths in the graph are preserved.
Warning: Nodes with control inputs are not removed, so as not to break the structure of the graph. If you need to forward these, remove their control inputs first.
- Parameters
nodes (list(tensorflow.NodeDef))) – Iterable container (usually a list) of nodes which should be removed and whose inputs forwarded.
- Returns
None
- remove(nodes, remove_exclusive_dependencies=False)
Removes nodes from this graph. Does not forward inputs, so paths in the graph could be broken.
- Parameters
nodes (list(tensorflow.NodeDef))) – Iterable container (usually a list) of nodes which should be removed.
remove_exclusive_dependencies (bool) – Whether to also remove dependencies exclusive to the nodes about to be removed. When set to True, all exclusive dependencies will be removed recursively, and the number of hanging nodes in the graph will remain constant. Defaults to False.
- Returns
None