Python tensorflow.python.framework.graph_util.extract_sub_graph() Examples
The following are 14
code examples of tensorflow.python.framework.graph_util.extract_sub_graph().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.framework.graph_util
, or try the search function
.
Example #1
Source File: strip_unused_lib.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def strip_unused(input_graph_def, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a GraphDef. Args: input_graph_def: A graph with nodes we want to prune. input_node_names: A list of the nodes we use as inputs. output_node_names: A list of the output nodes. placeholder_type_enum: The AttrValue enum for the placeholder data type, or a list that specifies one value per input node name. Returns: A GraphDef with all unnecessary ops removed. """ # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). inputs_replaced_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node.name in input_node_names: placeholder_node = node_def_pb2.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name if isinstance(placeholder_type_enum, list): input_node_index = input_node_names.index(node.name) placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue(type=placeholder_type_enum[ input_node_index])) else: placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue(type=placeholder_type_enum)) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[ "_output_shapes"]) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def
Example #2
Source File: strip_unused_lib.py From deep_image_model with Apache License 2.0 | 5 votes |
def strip_unused(input_graph_def, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a GraphDef. Args: input_graph_def: A graph with nodes we want to prune. input_node_names: A list of the nodes we use as inputs. output_node_names: A list of the output nodes. placeholder_type_enum: The AttrValue enum for the placeholder data type. Returns: A GraphDef with all unnecessary ops removed. """ # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). inputs_replaced_graph_def = tf.GraphDef() for node in input_graph_def.node: if node.name in input_node_names: placeholder_node = tf.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue( type=placeholder_type_enum)) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom( node.attr["_output_shapes"]) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def
Example #3
Source File: graph_util_test.py From deep_image_model with Apache License 2.0 | 5 votes |
def testExtractSubGraph(self): graph_def = tf.GraphDef() n1 = graph_def.node.add() n1.name = "n1" n1.input.extend(["n5"]) n2 = graph_def.node.add() n2.name = "n2" # Take the first output of the n1 node as the input. n2.input.extend(["n1:0"]) n3 = graph_def.node.add() n3.name = "n3" # Add a control input (which isn't really needed by the kernel, but # rather to enforce execution order between nodes). n3.input.extend(["^n2"]) n4 = graph_def.node.add() n4.name = "n4" # It is fine to have a loops in the graph as well. n5 = graph_def.node.add() n5.name = "n5" n5.input.extend(["n1"]) sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"]) self.assertEqual("n1", sub_graph.node[0].name) self.assertEqual("n2", sub_graph.node[1].name) self.assertEqual("n3", sub_graph.node[2].name) self.assertEqual("n5", sub_graph.node[3].name)
Example #4
Source File: quantize_graph.py From deep_image_model with Apache License 2.0 | 5 votes |
def remove_dead_nodes(self, output_names): """Removes nodes that are no longer needed for inference from the graph.""" old_output_graph = self.output_graph self.output_graph = graph_util.extract_sub_graph(old_output_graph, output_names)
Example #5
Source File: quantize_graph.py From tensorflow-for-poets-2 with Apache License 2.0 | 5 votes |
def remove_dead_nodes(self, output_names): """Removes nodes that are no longer needed for inference from the graph.""" old_output_graph = self.output_graph self.output_graph = graph_util.extract_sub_graph(old_output_graph, output_names)
Example #6
Source File: quantize_graph.py From MobileNet with Apache License 2.0 | 5 votes |
def remove_dead_nodes(self, output_names): """Removes nodes that are no longer needed for inference from the graph.""" old_output_graph = self.output_graph self.output_graph = graph_util.extract_sub_graph(old_output_graph, output_names)
Example #7
Source File: quantize_graph.py From sketch-to-react-native with MIT License | 5 votes |
def remove_dead_nodes(self, output_names): """Removes nodes that are no longer needed for inference from the graph.""" old_output_graph = self.output_graph self.output_graph = graph_util.extract_sub_graph(old_output_graph, output_names)
Example #8
Source File: graph_xform.py From tensorlang with Apache License 2.0 | 5 votes |
def strip_meta_graph(meta_graph_def, node_names, var_names): node_names = node_names[:] collections = meta_graph_def.collection_def # Look for matching variable names and initializers and keep them too. var_def = variable_pb2.VariableDef() for var_col_name in ["variables", "trainable_variables"]: var_def_bs = collections[var_col_name].bytes_list.value for var_def_b in var_def_bs: var_def.ParseFromString(var_def_b) if var_def.variable_name not in var_names: # TODO(adamb) Should remove variable from collection. continue node_names.append(var_def.initializer_name) wc_def = control_flow_pb2.WhileContextDef() wc_values = collections["while_context"].bytes_list.value for wc_ix in range(len(wc_values) - 1, -1, -1): wc_bytes = wc_values[wc_ix] wc_def.ParseFromString(wc_bytes) unused = True wc_pivot_name = wc_def.pivot_name for name in node_names: if name.startswith(wc_pivot_name): unused = False break if unused: del wc_values[wc_ix] graph_def = meta_graph_def.graph_def eprint("only keeping", node_names, "from", [n.name for n in graph_def.node]) graph_def = graph_util.extract_sub_graph(graph_def, node_names) meta_graph_def.graph_def.CopyFrom(graph_def)
Example #9
Source File: quantize_graph.py From pokemon-mini with Apache License 2.0 | 5 votes |
def remove_dead_nodes(self, output_names): """Removes nodes that are no longer needed for inference from the graph.""" old_output_graph = self.output_graph self.output_graph = graph_util.extract_sub_graph(old_output_graph, output_names)
Example #10
Source File: quantize_graph.py From AudioNet with MIT License | 5 votes |
def remove_dead_nodes(self, output_names): """Removes nodes that are no longer needed for inference from the graph.""" old_output_graph = self.output_graph self.output_graph = graph_util.extract_sub_graph(old_output_graph, output_names)
Example #11
Source File: strip_unused_lib.py From keras-lambda with MIT License | 5 votes |
def strip_unused(input_graph_def, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a GraphDef. Args: input_graph_def: A graph with nodes we want to prune. input_node_names: A list of the nodes we use as inputs. output_node_names: A list of the output nodes. placeholder_type_enum: The AttrValue enum for the placeholder data type, or a list that specifies one value per input node name. Returns: A GraphDef with all unnecessary ops removed. """ # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). inputs_replaced_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node.name in input_node_names: placeholder_node = node_def_pb2.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name if isinstance(placeholder_type_enum, list): input_node_index = input_node_names.index(node.name) placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue(type=placeholder_type_enum[ input_node_index])) else: placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue(type=placeholder_type_enum)) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[ "_output_shapes"]) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def
Example #12
Source File: strip_unused_lib.py From lambda-packs with MIT License | 4 votes |
def strip_unused(input_graph_def, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a GraphDef. Args: input_graph_def: A graph with nodes we want to prune. input_node_names: A list of the nodes we use as inputs. output_node_names: A list of the output nodes. placeholder_type_enum: The AttrValue enum for the placeholder data type, or a list that specifies one value per input node name. Returns: A `GraphDef` with all unnecessary ops removed. Raises: ValueError: If any element in `input_node_names` refers to a tensor instead of an operation. KeyError: If any element in `input_node_names` is not found in the graph. """ for name in input_node_names: if ":" in name: raise ValueError("Name '%s' appears to refer to a Tensor, " "not a Operation." % name) # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). not_found = {name for name in input_node_names} inputs_replaced_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node.name in input_node_names: not_found.remove(node.name) placeholder_node = node_def_pb2.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name if isinstance(placeholder_type_enum, list): input_node_index = input_node_names.index(node.name) placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue(type=placeholder_type_enum[ input_node_index])) else: placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue(type=placeholder_type_enum)) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[ "_output_shapes"]) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) if not_found: raise KeyError("The following input nodes were not found: %s\n" % not_found) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def
Example #13
Source File: strip_unused.py From TensorFlow_DCIGN with MIT License | 4 votes |
def strip_unused(input_graph, input_binary, output_graph, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a graph.""" if not tf.gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() mode = "rb" if input_binary else "r" with tf.gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). input_node_names_list = input_node_names.split(",") inputs_replaced_graph_def = tf.GraphDef() for node in input_graph_def.node: if node.name in input_node_names_list: placeholder_node = tf.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue( type=placeholder_type_enum)) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names.split(",")) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
Example #14
Source File: strip_unused_lib.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 4 votes |
def strip_unused(input_graph_def, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a GraphDef. Args: input_graph_def: A graph with nodes we want to prune. input_node_names: A list of the nodes we use as inputs. output_node_names: A list of the output nodes. placeholder_type_enum: The AttrValue enum for the placeholder data type, or a list that specifies one value per input node name. Returns: A `GraphDef` with all unnecessary ops removed. Raises: ValueError: If any element in `input_node_names` refers to a tensor instead of an operation. KeyError: If any element in `input_node_names` is not found in the graph. """ for name in input_node_names: if ":" in name: raise ValueError("Name '%s' appears to refer to a Tensor, " "not a Operation." % name) # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). not_found = {name for name in input_node_names} inputs_replaced_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node.name in input_node_names: not_found.remove(node.name) placeholder_node = node_def_pb2.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name if isinstance(placeholder_type_enum, list): input_node_index = input_node_names.index(node.name) placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue(type=placeholder_type_enum[ input_node_index])) else: placeholder_node.attr["dtype"].CopyFrom( attr_value_pb2.AttrValue(type=placeholder_type_enum)) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[ "_output_shapes"]) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) if not_found: raise KeyError("The following input nodes were not found: %s\n" % not_found) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def