Python tensorflow.core.framework.graph_pb2.GraphDef() Examples
The following are 30
code examples of tensorflow.core.framework.graph_pb2.GraphDef().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.core.framework.graph_pb2
, or try the search function
.
Example #1
Source File: utils.py From speech_separation with MIT License | 7 votes |
def load_graph(graph_path,tensorboard=False,**kwargs): ''' :param graph_filename: the path of the pb file :return: tensorflow graph ''' with gfile.FastGFile(graph_path,'rb') as f: graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(f.read()) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def,name="") if tensorboard: writer = tf.summary.FileWriter("log/") writer.add_graph(graph) return graph
Example #2
Source File: tf.py From training_results_v0.6 with Apache License 2.0 | 6 votes |
def ProcessGraphDefParam(graph_def): """Type-checks and possibly canonicalizes `graph_def`. Parameters ---------- graph_def : Obj tensorflow graph definition. Returns ------- graph_def : Obj tensorflow graph devinition """ if not isinstance(graph_def, graph_pb2.GraphDef): # `graph_def` could be a dynamically-created message, so try a duck-typed # approach try: old_graph_def = graph_def graph_def = graph_pb2.GraphDef() graph_def.MergeFrom(old_graph_def) except TypeError: raise TypeError('graph_def must be a GraphDef proto.') return graph_def
Example #3
Source File: print_selective_registration_header_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testAll(self): default_ops = 'all' graphs = [ text_format.Parse(d, graph_pb2.GraphDef()) for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2] ] ops_and_kernels = print_selective_registration_header.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) header = print_selective_registration_header.get_header(ops_and_kernels, default_ops) self.assertListEqual( [ '#ifndef OPS_TO_REGISTER', # '#define OPS_TO_REGISTER', # '#define SHOULD_REGISTER_OP(op) true', # '#define SHOULD_REGISTER_OP_KERNEL(clz) true', # '#define SHOULD_REGISTER_OP_GRADIENT true', # '#endif' ], header.split('\n'))
Example #4
Source File: event_accumulator.py From lambda-packs with MIT License | 6 votes |
def Graph(self): """Return the graph definition, if there is one. If the graph is stored directly, return that. If no graph is stored directly but a metagraph is stored containing a graph, return that. Raises: ValueError: If there is no graph for this run. Returns: The `graph_def` proto. """ graph = graph_pb2.GraphDef() if self._graph is not None: graph.ParseFromString(self._graph) return graph raise ValueError('There is no graph in this EventAccumulator')
Example #5
Source File: event_accumulator.py From deep_image_model with Apache License 2.0 | 6 votes |
def Graph(self): """Return the graph definition, if there is one. If the graph is stored directly, return that. If no graph is stored directly but a metagraph is stored containing a graph, return that. Raises: ValueError: If there is no graph for this run. Returns: The `graph_def` proto. """ graph = graph_pb2.GraphDef() if self._graph is not None: graph.ParseFromString(self._graph) return graph raise ValueError('There is no graph in this EventAccumulator')
Example #6
Source File: meta_graph_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testStrippedOpListRecursiveFunctions(self): # The function module doesn't support recursive functions, so we build a # recursive function situation by ourselves: A calls B calls A and Const. graph = graph_pb2.GraphDef() a = graph.library.function.add() b = graph.library.function.add() a.signature.name = "A" b.signature.name = "B" a.node.add().op = "B" b.node.add().op = "Const" b.node.add().op = "A" # Use A in the graph graph.node.add().op = "A" # The stripped op list should contain just Const. op_list = tf.contrib.util.stripped_op_list_for_graph(graph) self.assertEqual(["Const"], [op.name for op in op_list.op])
Example #7
Source File: ops.py From lambda-packs with MIT License | 6 votes |
def as_graph_def(self, from_version=None, add_shapes=False): """Returns a serialized `GraphDef` representation of this graph. The serialized `GraphDef` can be imported into another `Graph` (using @{tf.import_graph_def}) or used with the [C++ Session API](../../api_docs/cc/index.md). This method is thread-safe. Args: from_version: Optional. If this is set, returns a `GraphDef` containing only the nodes that were added to this graph since its `version` property had the given value. add_shapes: If true, adds an "_output_shapes" list attr to each node with the inferred shapes of each of its outputs. Returns: A [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer. Raises: ValueError: If the `graph_def` would be too large. """ result, _ = self._as_graph_def(from_version, add_shapes) return result
Example #8
Source File: ops.py From deep_image_model with Apache License 2.0 | 6 votes |
def get_stats_for_node_def(graph, node, statistic_type): """Looks up the node's statistics function in the registry and calls it. This function takes a Graph object and a NodeDef from a GraphDef, and if there's an associated statistics method, calls it and returns a result. If no function has been registered for the particular node type, it returns an empty statistics object. Args: graph: A Graph object that's been set up with the node's graph. node: A NodeDef describing the operator. statistic_type: A string identifying the statistic we're interested in. Returns: An OpStats object containing information about resource usage. """ try: stats_func = _stats_registry.lookup(node.op + "," + statistic_type) result = stats_func(graph, node) except LookupError: result = OpStats(statistic_type) return result
Example #9
Source File: ops.py From lambda-packs with MIT License | 6 votes |
def get_stats_for_node_def(graph, node, statistic_type): """Looks up the node's statistics function in the registry and calls it. This function takes a Graph object and a NodeDef from a GraphDef, and if there's an associated statistics method, calls it and returns a result. If no function has been registered for the particular node type, it returns an empty statistics object. Args: graph: A Graph object that's been set up with the node's graph. node: A NodeDef describing the operator. statistic_type: A string identifying the statistic we're interested in. Returns: An OpStats object containing information about resource usage. """ try: stats_func = _stats_registry.lookup(node.op + "," + statistic_type) result = stats_func(graph, node) except LookupError: result = OpStats(statistic_type) return result
Example #10
Source File: all_models_to_tensorboard.py From realtime_object_detection with MIT License | 6 votes |
def create_tfevent_from_pb(model,optimized=False): print("> creating tfevent of model: {}".format(model)) if optimized: model_path=ROOT_DIR+'/models/{}/optimized_inference_graph.pb'.format(model) log_dir=ROOT_DIR+'/models/{}/log_opt/'.format(model) else: model_path=ROOT_DIR+'/models/{}/frozen_inference_graph.pb'.format(model) log_dir=ROOT_DIR+'/models/{}/log/'.format(model) with session.Session(graph=ops.Graph()) as sess: with gfile.FastGFile(model_path, "rb") as f: graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(f.read()) importer.import_graph_def(graph_def) pb_visual_writer = summary.FileWriter(log_dir) pb_visual_writer.add_graph(sess.graph) print("> Model {} Imported. \nVisualize by running: \ tensorboard --logdir={}".format(model_path, log_dir)) # Gather all Model Names in models/
Example #11
Source File: tensorflow.py From blueoil with Apache License 2.0 | 6 votes |
def read(self, pb_path: str) -> Graph: """Read TF file and load model. Args: pb_path (str): Path to TF file Returns: Model: Loaded model """ # load tensorflow model graph_def = graph_pb2.GraphDef() try: f = open(path.abspath(pb_path), "rb") graph_def.ParseFromString(f.read()) f.close() except IOError: print("Could not open file. Creating a new one.") # import graph graph = Importer.make_graph(graph_def) return graph
Example #12
Source File: print_selective_registration_header_test.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def testAll(self): default_ops = 'all' graphs = [ text_format.Parse(d, graph_pb2.GraphDef()) for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2] ] ops_and_kernels = print_selective_registration_header.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) header = print_selective_registration_header.get_header(ops_and_kernels, default_ops) self.assertListEqual( [ '#ifndef OPS_TO_REGISTER', # '#define OPS_TO_REGISTER', # '#define SHOULD_REGISTER_OP(op) true', # '#define SHOULD_REGISTER_OP_KERNEL(clz) true', # '#define SHOULD_REGISTER_OP_GRADIENT true', # '#endif' ], header.split('\n'))
Example #13
Source File: graph_to_dot.py From deep_image_model with Apache License 2.0 | 6 votes |
def main(unused_args): if not gfile.Exists(FLAGS.graph): print("Input graph file '" + FLAGS.graph + "' does not exist!") return -1 graph = graph_pb2.GraphDef() with open(FLAGS.graph, "r") as f: if FLAGS.input_binary: graph.ParseFromString(f.read()) else: text_format.Merge(f.read(), graph) with open(FLAGS.dot_output, "wb") as f: print("digraph graphname {", file=f) for node in graph.node: output_name = node.name print(" \"" + output_name + "\" [label=\"" + node.op + "\"];", file=f) for input_full_name in node.input: parts = input_full_name.split(":") input_name = re.sub(r"^\^", "", parts[0]) print(" \"" + input_name + "\" -> \"" + output_name + "\";", file=f) print("}", file=f) print("Created DOT file '" + FLAGS.dot_output + "'.")
Example #14
Source File: selective_registration_header_lib.py From lambda-packs with MIT License | 6 votes |
def get_header(graphs, proto_fileformat='rawproto', default_ops='NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'): """Computes a header for use with tensorflow SELECTIVE_REGISTRATION. Args: graphs: a list of paths to GraphDef files to include. proto_fileformat: optional format of proto file, either 'textproto' or 'rawproto' (default). default_ops: optional comma-separated string of operator:kernel pairs to always include implementation for. Pass 'all' to have all operators and kernels included. Default: 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'. Returns: the string of the header that should be written as ops_to_register.h. """ ops_and_kernels = get_ops_and_kernels(proto_fileformat, graphs, default_ops) if not ops_and_kernels: print('Error reading graph!') return 1 return get_header_from_ops_and_kernels(ops_and_kernels, default_ops == 'all')
Example #15
Source File: event_accumulator.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def Graph(self): """Return the graph definition, if there is one. If the graph is stored directly, return that. If no graph is stored directly but a metagraph is stored containing a graph, return that. Raises: ValueError: If there is no graph for this run. Returns: The `graph_def` proto. """ graph = graph_pb2.GraphDef() if self._graph is not None: graph.ParseFromString(self._graph) return graph raise ValueError('There is no graph in this EventAccumulator')
Example #16
Source File: ops.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def as_graph_def(self, from_version=None, add_shapes=False): """Returns a serialized `GraphDef` representation of this graph. The serialized `GraphDef` can be imported into another `Graph` (using [`import_graph_def()`](#import_graph_def)) or used with the [C++ Session API](../../api_docs/cc/index.md). This method is thread-safe. Args: from_version: Optional. If this is set, returns a `GraphDef` containing only the nodes that were added to this graph since its `version` property had the given value. add_shapes: If true, adds an "_output_shapes" list attr to each node with the inferred shapes of each of its outputs. Returns: A [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer. Raises: ValueError: If the `graph_def` would be too large. """ result, _ = self._as_graph_def(from_version, add_shapes) return result
Example #17
Source File: backend_tf.py From inference with Apache License 2.0 | 6 votes |
def load(self, model_path, inputs=None, outputs=None): # there is no input/output meta data i the graph so it need to come from config. if not inputs: raise ValueError("BackendTensorflow needs inputs") if not outputs: raise ValueError("BackendTensorflow needs outputs") self.outputs = outputs self.inputs = inputs # TODO: support checkpoint and saved_model formats? graph_def = graph_pb2.GraphDef() with open(model_path, "rb") as f: graph_def.ParseFromString(f.read()) g = tf.compat.v1.import_graph_def(graph_def, name='') self.sess = tf.compat.v1.Session(graph=g) return self
Example #18
Source File: test_util_test.py From deep_image_model with Apache License 2.0 | 5 votes |
def testAssertProtoEqualsStr(self): graph_str = "node { name: 'w1' op: 'params' }" graph_def = graph_pb2.GraphDef() text_format.Merge(graph_str, graph_def) # test string based comparison self.assertProtoEquals(graph_str, graph_def) # test original comparison self.assertProtoEquals(graph_def, graph_def)
Example #19
Source File: meta_graph.py From deep_image_model with Apache License 2.0 | 5 votes |
def stripped_op_list_for_graph(graph_def): """Collect the stripped OpDefs for ops used by a graph. This function computes the `stripped_op_list` field of `MetaGraphDef` and similar protos. The result can be communicated from the producer to the consumer, which can then use the C++ function `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. Args: graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. Returns: An `OpList` of ops used by the graph. Raises: ValueError: If an unregistered op is used. """ # This is the Python equivalent of StrippedOpListForGraph in C++. # Unfortunately, since the Python op registry can differ from that in C++, we # can't remove the duplication using swig (at least naively). # TODO(irving): Support taking graphs directly. used_ops = ops_used_by_graph_def(graph_def) # Verify that all used ops are registered. registered_ops = op_def_registry.get_registered_ops() # These internal ops used by functions are not registered, so we need to # whitelist them. # TODO(irving): Do something better here. op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList") for op in used_ops: if op not in registered_ops and op not in op_whitelist: raise ValueError("Op %s is used by the graph, but is not registered" % op) # Build the stripped op list in sorted order return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops) if op in registered_ops])
Example #20
Source File: quantize_graph.py From tensorflow-for-poets-2 with Apache License 2.0 | 5 votes |
def apply_final_node_renames(self): """Applies node renames in self.final_node_renames to self.output_graph.""" old_graph = self.output_graph self.output_graph = graph_pb2.GraphDef() for node in old_graph.node: node.name = self.final_node_renames.get(node.name, node.name) for index, input_name in enumerate(node.input): node_name = node_name_from_input(input_name) input_full_name = ensure_tensor_name_has_port(input_name) if node_name in self.final_node_renames: node.input[index] = "%s%s" % (self.final_node_renames[node_name], input_full_name[len(node_name):]) self.add_output_graph_node(node) return self.output_graph
Example #21
Source File: test_tfcoreml.py From MMdnn with MIT License | 5 votes |
def _test_coreml_model_image_input(self, tf_model_path, coreml_model, input_tensor_name, output_tensor_name, img_size, useCPUOnly = False): """Test single image input conversions. tf_model_path - the TF model coreml_model - converted CoreML model input_tensor_name - the input image tensor name output_tensor_name - the output tensor name img_size - size of the image """ img_np, img = _load_image(TEST_IMAGE ,resize_to=(img_size, img_size)) img_tf = np.expand_dims(img_np, axis = 0) img_tf[:,:,:,0] = self.image_scale * img_tf[:,:,:,0] + self.red_bias img_tf[:,:,:,1] = self.image_scale * img_tf[:,:,:,1] + self.green_bias img_tf[:,:,:,2] = self.image_scale * img_tf[:,:,:,2] + self.blue_bias #evaluate the TF model tf.reset_default_graph() graph_def = graph_pb2.GraphDef() with open(tf_model_path, "rb") as f: graph_def.ParseFromString(f.read()) g = tf.import_graph_def(graph_def) with tf.Session(graph=g) as sess: image_input_tensor = sess.graph.get_tensor_by_name('import/' + input_tensor_name) output = sess.graph.get_tensor_by_name('import/' + output_tensor_name) tf_out = sess.run(output,feed_dict={image_input_tensor: img_tf}) if len(tf_out.shape) == 4: tf_out = np.transpose(tf_out, (0,3,1,2)) tf_out_flatten = tf_out.flatten() #evaluate CoreML coreml_input_name = input_tensor_name.replace(':', '__').replace('/', '__') coreml_output_name = output_tensor_name.replace(':', '__').replace('/', '__') coreml_input = {coreml_input_name: img} #Test the default CoreML evaluation coreml_out = coreml_model.predict(coreml_input, useCPUOnly = useCPUOnly)[coreml_output_name] coreml_out_flatten = coreml_out.flatten() self._compare_tf_coreml_outputs(tf_out_flatten, coreml_out_flatten)
Example #22
Source File: test_tfcoreml.py From MMdnn with MIT License | 5 votes |
def _test_coreml_model_image_input(tf_model_path, coreml_model, input_tensor_name, output_tensor_name, img_size, useCPUOnly = False): """Test single image input conversions. tf_model_path - the TF model coreml_model - converted CoreML model input_tensor_name - the input image tensor name output_tensor_name - the output tensor name img_size - size of the image """ img_np, img = _load_image(TEST_IMAGE ,resize_to=(img_size, img_size)) img_tf = np.expand_dims(img_np, axis = 0) img_tf[:,:,:,0] = 2.0/255 * img_tf[:,:,:,0] - 1 img_tf[:,:,:,1] = 2.0/255 * img_tf[:,:,:,1] - 1 img_tf[:,:,:,2] = 2.0/255 * img_tf[:,:,:,2] - 1 #evaluate the TF model tf.reset_default_graph() graph_def = graph_pb2.GraphDef() with open(tf_model_path, "rb") as f: graph_def.ParseFromString(f.read()) g = tf.import_graph_def(graph_def) with tf.Session(graph=g) as sess: image_input_tensor = sess.graph.get_tensor_by_name('import/' + input_tensor_name) output = sess.graph.get_tensor_by_name('import/' + output_tensor_name) tf_out = sess.run(output,feed_dict={image_input_tensor: img_tf}) if len(tf_out.shape) == 4: tf_out = np.transpose(tf_out, (0,3,1,2)) tf_out_flatten = tf_out.flatten() #evaluate CoreML coreml_input_name = input_tensor_name.replace(':', '__').replace('/', '__') coreml_output_name = output_tensor_name.replace(':', '__').replace('/', '__') coreml_input = {coreml_input_name: img} #Test the default CoreML evaluation coreml_out = coreml_model.predict(coreml_input, useCPUOnly = useCPUOnly)[coreml_output_name] coreml_out_flatten = coreml_out.flatten() print (coreml_out_flatten) # compare_tf_coreml_outputs(tf_out_flatten, coreml_out_flatten)
Example #23
Source File: print_selective_registration_header.py From deep_image_model with Apache License 2.0 | 5 votes |
def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): """Gets the ops and kernels needed from the model files.""" ops = set() for proto_file in proto_files: tf.logging.info('Loading proto file %s', proto_file) # Load GraphDef. file_data = tf.gfile.GFile(proto_file).read() if proto_fileformat == 'rawproto': graph_def = graph_pb2.GraphDef.FromString(file_data) else: assert proto_fileformat == 'textproto' graph_def = text_format.Parse(file_data, graph_pb2.GraphDef()) # Find all ops and kernels used by the graph. for node_def in graph_def.node: if not node_def.device: node_def.device = '/cpu:0' kernel_class = pywrap_tensorflow.TryFindKernelClass( node_def.SerializeToString()) if kernel_class: op_and_kernel = (str(node_def.op), kernel_class.decode('utf-8')) if op_and_kernel not in ops: ops.add(op_and_kernel) else: print( 'Warning: no kernel found for op %s' % node_def.op, file=sys.stderr) # Add default ops. if default_ops_str != 'all': for s in default_ops_str.split(','): op, kernel = s.split(':') op_and_kernel = (op, kernel) if op_and_kernel not in ops: ops.add(op_and_kernel) return list(sorted(ops))
Example #24
Source File: ops.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def graph_def_versions(self): """The GraphDef version information of this graph. For details on the meaning of each version, see [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto). Returns: A `VersionDef`. """ return self._graph_def_versions
Example #25
Source File: control_flow_ops_test.py From deep_image_model with Apache License 2.0 | 5 votes |
def _StripGraph(self, gd): """Copy gd keeping only, node.name, node.op, node.input, and node.device.""" return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
Example #26
Source File: tf.py From training_results_v0.6 with Apache License 2.0 | 5 votes |
def get_workload(model_path): """ Import workload from frozen protobuf Parameters ---------- model_path: str model_path on remote repository to download from. Returns ------- graph_def: graphdef graph_def is the tensorflow workload for mobilenet. """ repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/' model_name = os.path.basename(model_path) model_url = os.path.join(repo_base, model_path) from mxnet.gluon.utils import download temp = util.tempdir() path_model = temp.relpath(model_name) download(model_url, path_model) # Creates graph from saved graph_def.pb. with tf.gfile.FastGFile(path_model, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name='') temp.remove() return graph_def ####################################################################### # PTB LSTMBlockCell Model # -----------------------
Example #27
Source File: load_graph_nms_v2.py From PythonPilot with Apache License 2.0 | 5 votes |
def load_frozen_graph_without_split(self): """ Load frozen_graph. """ model_path = self.cfg['model_path'] tf.reset_default_graph() graph_def = tf.GraphDef() with tf.gfile.GFile(model_path, 'rb') as fid: serialized_graph = fid.read() graph_def.ParseFromString(serialized_graph) # force CPU device placement for NMS ops for node in graph_def.node: if 'BatchMultiClassNonMaxSuppression' in node.name: node.device = '/device:CPU:0' else: node.device = '/device:GPU:0' tf.import_graph_def(graph_def, name='') #self.print_graph_operation_by_name(detection_graph, "Postprocessor/Slice") #self.print_graph_operation_by_name(detection_graph, "Postprocessor/ExpandDims_1") #self.print_graph_operation_by_name(detection_graph, "Postprocessor/stack_1") """ return """ return tf.get_default_graph()
Example #28
Source File: meta_graph.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def stripped_op_list_for_graph(graph_def): """Collect the stripped OpDefs for ops used by a graph. This function computes the `stripped_op_list` field of `MetaGraphDef` and similar protos. The result can be communicated from the producer to the consumer, which can then use the C++ function `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. Args: graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. Returns: An `OpList` of ops used by the graph. Raises: ValueError: If an unregistered op is used. """ # This is the Python equivalent of StrippedOpListForGraph in C++. # Unfortunately, since the Python op registry can differ from that in C++, we # can't remove the duplication using swig (at least naively). # TODO(irving): Support taking graphs directly. used_ops = ops_used_by_graph_def(graph_def) # Verify that all used ops are registered. registered_ops = op_def_registry.get_registered_ops() # These internal ops used by functions are not registered, so we need to # whitelist them. # TODO(irving): Do something better here. op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList") for op in used_ops: if op not in registered_ops and op not in op_whitelist: raise ValueError("Op %s is used by the graph, but is not registered" % op) # Build the stripped op list in sorted order return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops) if op in registered_ops])
Example #29
Source File: meta_graph.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def _read_file(filename): """Reads a file containing `GraphDef` and returns the protocol buffer. Args: filename: `graph_def` filename including the path. Returns: A `GraphDef` protocol buffer. Raises: IOError: If the file doesn't exist, or cannot be successfully parsed. """ graph_def = graph_pb2.GraphDef() if not file_io.file_exists(filename): raise IOError("File %s does not exist." % filename) # First try to read it as a binary file. file_content = file_io.read_file_to_string(filename) try: graph_def.ParseFromString(file_content) return graph_def except Exception: # pylint: disable=broad-except pass # Next try to read it as a text file. try: text_format.Merge(file_content.decode("utf-8"), graph_def) except text_format.ParseError as e: raise IOError("Cannot parse file %s: %s." % (filename, str(e))) return graph_def
Example #30
Source File: freeze_graph.py From deepnlp with MIT License | 5 votes |
def _parse_input_graph_proto(input_graph, input_binary): """Parser input tensorflow graph into GraphDef proto.""" if not gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 input_graph_def = graph_pb2.GraphDef() mode = "rb" if input_binary else "r" with gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) return input_graph_def