Graph Surgeon

graphsurgeon allows you to transform TensorFlow graphs. Its capabilities are broadly divided into two categories: search and manipulation. Search functions allow you to find nodes in a TensorFlow graph. Manipulation functions allow you to modify, add, or remove nodes.

Node Creation

Allow you to create free standing TensorFlow nodes, which can be used as stand-ins for plugins.

graphsurgeon.create_node(name, op=None, trt_plugin=False, **kwargs)

Creates a free-standing TensorFlow NodeDef with the specified properties.

Parameters
  • name (str) – The name of the node.

  • op (str) – The node’s operation.

Keyword Arguments
  • dtype (tensorflow.DType) – TensorFlow dtype.

  • shape (tuple(int)) – Iterable container (usually a tuple) describing the shape of a tensor.

  • inputs (list(tensorflow.NodeDef) or str) – Iterable container (usually a tuple) of input nodes or input node names. Supports mixed-type lists.

  • **kwargs (AttrName=Value) – Any additional fields that should be present in the node. Currently supports int, float, bool, list(int), list(float), str and NumPy arrays. NumPy arrays will be inserted into the “value” attribute of the node - this can be useful for creating constant nodes equivalent to those created by tensorflow.constant.

Returns

tensorflow.NodeDef

graphsurgeon.create_plugin_node(name, op=None, **kwargs)

Creates a free-standing TensorFlow NodeDef with the specified properties. This is similar to create_node,

Parameters
  • name (str) – The name of the node.

  • op (str) – The node’s operation.

  • dtype (tensorflow.DType) – TensorFlow dtype.

  • shape (tuple(int)) – Iterable container (usually a tuple) describing the shape of a tensor.

  • inputs (list(tensorflow.NodeDef) or str) – Iterable container (usually a tuple) of input nodes or input node names. Supports mixed-type lists.

  • **kwargs (AttrName=Value) – Any additional fields that should be present in the node. Currently supports int, float, bool, list(int), list(float) and str.

Returns

tensorflow.NodeDef

Static Graph

class graphsurgeon.StaticGraph(graphdef=None)

Acts as a thin wrapper for a read-only TensorFlow GraphDef. Supports indexing based on node name/index as well as iteration over nodes using Python’s for node in static_graph syntax.

Parameters

graphdef (tensorflow.GraphDef/tensorflow.Graph OR graphsurgeon.StaticGraph/graphsurgeon.DynamicGraph OR str) – A TensorFlow GraphDef/Graph or a StaticGraph from which to construct this graph, or a string containing a path to a frozen model.

node_outputs

A mapping of node names to their respective output nodes.

Type

dict(str, list(tensorflow.NodeDef))

node_map

A mapping of node names to their corresponding nodes.

Type

dict(str, tensorflow.NodeDef)

graph_outputs

A list of likely outputs of the graph.

Type

list(tensorflow.NodeDef)

graph_inputs

A list of likely inputs of the graph.

Type

list(tensorflow.NodeDef)

as_graph_def()

Returns this StaticGraph’s internal TensorFlow GraphDef.

Parameters

None

Returns

tensorflow.GraphDef

find_node_chains_by_op(chain)

Finds groups of nodes in this graph that match the specified sequence of ops. Returns a list of matching chains of nodes, with ordering preserved.

Parameters

chain (list(str)) – The sequence of ops to look for. Should be ordered with the input of the chain as the first element, and the output as the last.

Returns

list(list(tensorflow.NodeDef))

find_node_inputs(node)

Finds input nodes of a given node.

Parameters

node (tensorflow.NodeDef) – The node in which to perform the search.

Returns

list(tensorflow.NodeDef)

find_node_inputs_by_name(node, name)

Finds input nodes of a given node based on their names.

Parameters
  • node (tensorflow.NodeDef) – The node in which to perform the search.

  • name (str OR list(str)) – The name to look for. Also accepts iterable containers (preferably a list) to search for multiple names in a single pass. Supports regular expressions.

Returns

list(tensorflow.NodeDef)

find_node_inputs_by_op(node, op)

Finds input nodes of a given node based on their ops.

Parameters
  • node (tensorflow.NodeDef) – The node in which to perform the search.

  • op (str OR list(str)) – The op to look for. Also accepts iterable containers (preferably a list) to search for multiple op in a single pass.

Returns

list(tensorflow.NodeDef)

find_nodes_by_name(name)

Finds nodes in this graph based on their names.

Parameters

name (str OR list(str)) – The name to look for. Also accepts iterable containers (preferably a list) to search for multiple names in a single pass of the graph. Supports regular expressions.

Returns

list(tensorflow.NodeDef)

find_nodes_by_op(op)

Finds nodes in this graph based on their ops.

Parameters

op (str OR set(str)) – The op to look for. Also accepts iterable containers (preferably hashsets) to search for multiple ops in a single pass of the graph.

Returns

list(tensorflow.NodeDef)

find_nodes_by_path(path)

Finds nodes in this graph based on their full paths. This will only match exact paths.

Parameters

path (str OR list(str)) – The path to look for. Also accepts iterable containers (preferably a list) to search for multiple paths in a single pass of the graph. Supports regular expressions.

Returns

list(tensorflow.NodeDef)

read(filename)

Reads a frozen protobuf file into this StaticGraph.

Parameters

filename (str) – Name of the protobuf file.

Returns

None

write(filename)

Writes the StaticGraph’s internal TensorFlow GraphDef into a frozen protobuf file.

Parameters

filename (str) – Name of the protobuf file to write.

Returns

None

write_tensorboard(logdir)

Writes the StaticGraph’s internal TensorFlow GraphDef into the specified directory, which can then be visualized in TensorBoard.

Parameters

logdir (str) – Name of the directory to write.

Returns

None

Raises
  • Warning – Passing a GraphDef to the SummaryWriter is deprecated. Pass a Graph object instead, such as sess.graph.

  • This is a known warning, but currently there is no alternative, since TensorFlow will not be able to convert invalid GraphDefs back to Graphs.

Dynamic Graph (Inherits from StaticGraph)

class graphsurgeon.DynamicGraph(graphdef=None)

A sub-class of StaticGraph that can search and modify a TensorFlow GraphDef.

Parameters

graphdef (tensorflow.GraphDef/tensorflow.Graph OR graphsurgeon.StaticGraph/graphsurgeon.DynamicGraph OR str) – A TensorFlow GraphDef/Graph or a StaticGraph/DynamicGraph from which to construct this graph, or a string containing the path to a frozen model.

append(node)

Appends a node to this graph.

Parameters

node (tensorflow.NodeDef) – TensorFlow NodeDef to add to the graph.

Returns

None

collapse_namespaces(namespace_map, exclude_nodes=[], unique_inputs=True)

Collapses nodes in namespaces to single nodes specified by the user, except where those nodes are marked for exclusion.

Parameters
  • namespace_map (dict(str, tensorflow.NodeDef)) – A dictionary specifying namespaces and their corresponding plugin nodes. These plugin nodes are typically used to specify attributes of the custom plugin, while inputs and outputs are automatically deduced. Multiple namespaces can be collapsed into a single plugin node, and nested namespaces are collapsed into plugin nodes outside their parent namespaces.

  • exclude_nodes (list(tensorflow.NodeDef)) – Iterable container (usually a list) of nodes which should NOT be collapsed. These nodes will be present in the final graph as either inputs or outputs of the plugin nodes.

  • unique_inputs (bool) – Whether inputs to the collapsed node should be unique. If this is false, plugin nodes may have duplicate inputs.

Returns

None

extend(node_list)

Extends this graph’s nodes based on the provided list.

Parameters

node_list (list(tensorflow.NodeDef)) – List of TensorFlow NodeDefs to add to the graph.

Returns

None

forward_inputs(nodes)

Removes nodes from this graph. Recursively forwards inputs, such that paths in the graph are preserved.

Warning: Nodes with control inputs are not removed, so as not to break the structure of the graph. If you need to forward these, remove their control inputs first.

Parameters

nodes (list(tensorflow.NodeDef))) – Iterable container (usually a list) of nodes which should be removed and whose inputs forwarded.

Returns

None

remove(nodes, remove_exclusive_dependencies=False)

Removes nodes from this graph. Does not forward inputs, so paths in the graph could be broken.

Parameters
  • nodes (list(tensorflow.NodeDef))) – Iterable container (usually a list) of nodes which should be removed.

  • remove_exclusive_dependencies (bool) – Whether to also remove dependencies exclusive to the nodes about to be removed. When set to True, all exclusive dependencies will be removed recursively, and the number of hanging nodes in the graph will remain constant. Defaults to False.

Returns

None