Python tensorflow.python.framework.graph_util.convert_variables_to_constants() Examples
The following are 30
code examples of tensorflow.python.framework.graph_util.convert_variables_to_constants().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.framework.graph_util
, or try the search function
.
Example #1
Source File: freeze_graph.py From MobileFaceNet_Tensorflow with Apache License 2.0 | 6 votes |
def freeze_graph_def(sess, input_graph_def, output_node_names): for node in input_graph_def.node: if node.op == 'RefSwitch': node.op = 'Switch' for index in xrange(len(node.input)): if 'moving_' in node.input[index]: node.input[index] = node.input[index] + '/read' elif node.op == 'AssignSub': node.op = 'Sub' if 'use_locking' in node.attr: del node.attr['use_locking'] elif node.op == 'AssignAdd': node.op = 'Add' if 'use_locking' in node.attr: del node.attr['use_locking'] # Get the list of important nodes whitelist_names = [] for node in input_graph_def.node: if (node.name.startswith('MobileFaceNet') or node.name.startswith('embeddings')): whitelist_names.append(node.name) # Replace all the variables in the graph with constants of the same values output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_whitelist=whitelist_names) return output_graph_def
Example #2
Source File: model.py From delta with Apache License 2.0 | 6 votes |
def frozen_graph_to_pb(outputs, frozen_graph_pb_path, sess, graph=None): """Freeze graph to a pb file.""" if not isinstance(outputs, (list)): raise ValueError("Frozen graph: outputs must be list of output node name") if graph is None: graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() logging.info("Frozen graph: len of input graph nodes: {}".format( len(input_graph_def.node))) # We use a built-in TF helper to export variables to constant output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, outputs, ) logging.info("Frozen graph: len of output graph nodes: {}".format( len(output_graph_def.node))) # pylint: disable=no-member with tf.gfile.GFile(frozen_graph_pb_path, "wb") as in_f: in_f.write(output_graph_def.SerializeToString())
Example #3
Source File: savevariables.py From nextitnet with MIT License | 6 votes |
def save_mode_pb(pb_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(2, name='b') xy = tf.multiply(x, y) # 这里的输出需要加上name属性 op = tf.add(xy, b, name='op_to_store') sess = tf.Session() sess.run(tf.global_variables_initializer()) path = os.path.dirname(os.path.abspath(pb_file_path)) if os.path.isdir(path) is False: os.makedirs(path) # convert_variables_to_constants 需要指定output_node_names,list(),可以多个 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store']) with tf.gfile.FastGFile(pb_file_path, mode='wb') as f: f.write(constant_graph.SerializeToString()) # test feed_dict = {x: 2, y: 4} print(sess.run(op, feed_dict))
Example #4
Source File: session_bundle_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def setUp(self): self.base_path = os.path.join(tf.test.get_temp_dir(), "no_vars") if not os.path.exists(self.base_path): os.mkdir(self.base_path) # Create a simple graph with a variable, then convert variables to # constants and export the graph. with tf.Graph().as_default() as g: x = tf.placeholder(tf.float32, name="x") w = tf.Variable(3.0) y = tf.sub(w * x, 7.0, name="y") # pylint: disable=unused-variable tf.add_to_collection("meta", "this is meta") with self.test_session(graph=g) as session: tf.global_variables_initializer().run() new_graph_def = graph_util.convert_variables_to_constants( session, g.as_graph_def(), ["y"]) filename = os.path.join(self.base_path, constants.META_GRAPH_DEF_FILENAME) tf.train.export_meta_graph( filename, graph_def=new_graph_def, collection_list=["meta"])
Example #5
Source File: freeze_graph.py From MobileFaceNet_TF with Apache License 2.0 | 6 votes |
def freeze_graph_def(sess, input_graph_def, output_node_names): for node in input_graph_def.node: if node.op == 'RefSwitch': node.op = 'Switch' for index in xrange(len(node.input)): if 'moving_' in node.input[index]: node.input[index] = node.input[index] + '/read' elif node.op == 'AssignSub': node.op = 'Sub' if 'use_locking' in node.attr: del node.attr['use_locking'] elif node.op == 'AssignAdd': node.op = 'Add' if 'use_locking' in node.attr: del node.attr['use_locking'] # Get the list of important nodes whitelist_names = [] for node in input_graph_def.node: if (node.name.startswith('MobileFaceNet') or node.name.startswith('embeddings')): whitelist_names.append(node.name) # Replace all the variables in the graph with constants of the same values output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_whitelist=whitelist_names) return output_graph_def
Example #6
Source File: keras_to_tf.py From FSA-Net with Apache License 2.0 | 6 votes |
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): graph = session.graph with graph.as_default(): freeze_var_names = list( set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) output_names = output_names or [] output_names += [v.op.name for v in tf.global_variables()] # Graph -> GraphDef ProtoBuf input_graph_def = graph.as_graph_def() if clear_devices: for node in input_graph_def.node: node.device = "" frozen_graph = convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_names) return frozen_graph
Example #7
Source File: tf_utils.py From nlp_research with MIT License | 6 votes |
def write_pb(checkpoint_path, pb_path, output_nodes): checkpoint_file = tf.train.latest_checkpoint(checkpoint_path) sess = tf.Session() saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) graph = tf.get_default_graph() # 获得默认的图 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图 # convert_variables_to_constants 需要指定output_node_names,list(),可以多个 constant_graph = graph_util.convert_variables_to_constants(sess, input_graph_def,# 等于:sess.graph_def output_nodes) # 写入序列化的 PB 文件 with tf.gfile.FastGFile(pb_path, mode='wb') as f: f.write(constant_graph.SerializeToString())
Example #8
Source File: freeze.py From honk with MIT License | 6 votes |
def main(_): # Create the model and load its weights. sess = tf.InteractiveSession() create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count, FLAGS.model_architecture) models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) # Turn all the variables into inline constants inside the graph and save it. frozen_graph_def = graph_util.convert_variables_to_constants( sess, sess.graph_def, ['labels_softmax']) tf.train.write_graph( frozen_graph_def, os.path.dirname(FLAGS.output_file), os.path.basename(FLAGS.output_file), as_text=False) tf.logging.info('Saved frozen graph to %s', FLAGS.output_file)
Example #9
Source File: convert.py From tf-encrypted with Apache License 2.0 | 6 votes |
def export_cnn() -> None: input = tf.placeholder(tf.float32, shape=(1, 1, 3, 3)) filter = tf.constant(np.ones((3, 3, 1, 1)), dtype=tf.float32) x = tf.nn.conv2d(input, filter, (1, 1, 1, 1), "SAME", data_format="NCHW") x = tf.nn.sigmoid(x) x = tf.nn.relu(x) pred_node_names = ["output"] tf.identity(x, name=pred_node_names[0]) with tf.Session() as sess: constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(), pred_node_names ) frozen = graph_util.remove_training_nodes(constant_graph) output = "cnn.pb" graph_io.write_graph(frozen, ".", output, as_text=False)
Example #10
Source File: convert_test.py From tf-encrypted with Apache License 2.0 | 6 votes |
def export(x: tf.Tensor, filename: str, sess=None): should_close = False if sess is None: should_close = True sess = tf.Session() pred_node_names = ["output"] tf.identity(x, name=pred_node_names[0]) graph = graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(), pred_node_names ) graph = graph_util.remove_training_nodes(graph) path = graph_io.write_graph(graph, ".", filename, as_text=False) if should_close: sess.close() return path
Example #11
Source File: session_bundle_test.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def setUp(self): self.base_path = os.path.join(test.get_temp_dir(), "no_vars") if not os.path.exists(self.base_path): os.mkdir(self.base_path) # Create a simple graph with a variable, then convert variables to # constants and export the graph. with ops.Graph().as_default() as g: x = array_ops.placeholder(dtypes.float32, name="x") w = variables.Variable(3.0) y = math_ops.subtract(w * x, 7.0, name="y") # pylint: disable=unused-variable ops.add_to_collection("meta", "this is meta") with self.test_session(graph=g) as session: variables.global_variables_initializer().run() new_graph_def = graph_util.convert_variables_to_constants( session, g.as_graph_def(), ["y"]) filename = os.path.join(self.base_path, constants.META_GRAPH_DEF_FILENAME) saver.export_meta_graph( filename, graph_def=new_graph_def, collection_list=["meta"])
Example #12
Source File: private_model.py From tf-encrypted with Apache License 2.0 | 6 votes |
def secure_model(model, **kwargs): """Secure a plaintext model from the current session.""" session = K.get_session() min_graph = graph_util.convert_variables_to_constants( session, session.graph_def, [node.op.name for node in model.outputs] ) graph_fname = "model.pb" tf.train.write_graph(min_graph, _TMPDIR, graph_fname, as_text=False) if "batch_size" in kwargs: batch_size = kwargs.pop("batch_size") else: batch_size = 1 graph_def, inputs = load_graph( os.path.join(_TMPDIR, graph_fname), batch_size=batch_size ) c = tfe.convert.convert.Converter(tfe.convert.registry(), **kwargs) y = c.convert(remove_training_nodes(graph_def), "input-provider", inputs) return PrivateModel(y)
Example #13
Source File: image_classifier_tf.py From aiexamples with Apache License 2.0 | 6 votes |
def keras_to_tensorflow(keras_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True): if os.path.exists(output_dir) == False: os.mkdir(output_dir) out_nodes = [] for i in range(len(keras_model.outputs)): out_nodes.append(out_prefix + str(i + 1)) tf.identity(keras_model.output[i], out_prefix + str(i + 1)) sess = K.get_session() init_graph = sess.graph.as_graph_def() main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes) graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False) if log_tensorboard: import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)
Example #14
Source File: freeze.py From adversarial_audio with MIT License | 6 votes |
def main(_): # Create the model and load its weights. sess = tf.InteractiveSession() create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count, FLAGS.model_architecture) models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) # Turn all the variables into inline constants inside the graph and save it. frozen_graph_def = graph_util.convert_variables_to_constants( sess, sess.graph_def, ['labels_softmax']) tf.train.write_graph( frozen_graph_def, os.path.dirname(FLAGS.output_file), os.path.basename(FLAGS.output_file), as_text=False) tf.logging.info('Saved frozen graph to %s', FLAGS.output_file)
Example #15
Source File: freeze.py From TF_SpeechRecoChallenge with Apache License 2.0 | 6 votes |
def main(_): # Create the model and load its weights. sess = tf.InteractiveSession() create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count, FLAGS.resnet_size, FLAGS.model_architecture) models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) # Turn all the variables into inline constants inside the graph and save it. frozen_graph_def = graph_util.convert_variables_to_constants( sess, sess.graph_def, ['labels_softmax']) tf.train.write_graph( frozen_graph_def, os.path.dirname(FLAGS.output_file), os.path.basename(FLAGS.output_file), as_text=False) tf.logging.info('Saved frozen graph to %s', FLAGS.output_file)
Example #16
Source File: train.py From photonix with GNU Affero General Public License v3.0 | 5 votes |
def save_graph_to_file(sess, graph, graph_file_name): output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) with gfile.FastGFile(graph_file_name, 'wb') as f: f.write(output_graph_def.SerializeToString()) return
Example #17
Source File: retrain.py From Elphas with Apache License 2.0 | 5 votes |
def save_graph_to_file(graph, graph_file_name, model_info, class_count): """Saves an graph to file, creating a valid quantized one if necessary.""" sess, _, _, _, _ = build_eval_session(model_info, class_count) graph = sess.graph output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) with gfile.FastGFile(graph_file_name, 'wb') as f: f.write(output_graph_def.SerializeToString())
Example #18
Source File: test_bn_dynamic.py From incubator-tvm with Apache License 2.0 | 5 votes |
def verify_fused_batch_norm(shape): g = tf.Graph() with g.as_default(): input_tensor = tf.placeholder(tf.float32, shape=shape, name='input') alpha = tf.constant(np.random.rand(shape[-1],), dtype=tf.float32, name='alpha') beta = tf.constant(np.random.rand(shape[-1],), dtype=tf.float32, name='beta') bn = tf.nn.fused_batch_norm(x=input_tensor, offset=beta, scale=alpha, name='bn') out = tf.identity(bn[0], name='output') data = np.random.rand(*shape) with tf.Session(graph=out.graph) as sess: sess.run([tf.global_variables_initializer()]) tf_out = sess.run(out, feed_dict={input_tensor:data}) constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output']) for device in ["llvm"]: ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) continue mod, params = relay.frontend.from_tensorflow(constant_graph, outputs=['output']) with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target=device, params=params) from tvm.contrib import graph_runtime m = graph_runtime.create(graph, lib, ctx) m.set_input(**params) m.set_input('input', data) m.run() tvm_out = m.get_output(0) tvm.testing.assert_allclose(tvm_out.asnumpy(), tf_out.astype(tvm_out.dtype), atol=1e-3, rtol=1e-3)
Example #19
Source File: retrain.py From Image-classification-transfer-learning with Apache License 2.0 | 5 votes |
def save_graph_to_file(sess, graph, graph_file_name): output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) with gfile.FastGFile(graph_file_name, 'wb') as f: f.write(output_graph_def.SerializeToString()) return
Example #20
Source File: freeze_graph.py From pynlp with MIT License | 5 votes |
def freeze_graph(output_graph): ''' :param input_checkpoint: :param output_graph: PB模型保存路径 :return: ''' # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径 from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader("F:\python_work\siamese-lstm-network\deep-siamese-text-similarity\\atec_runs\\1553238291\checkpoints\model-170000") var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) # 指定输出的节点名称,该节点名称必须是原模型中存在的节点 output_node_names = "accuracy/temp_sim,output/distance" input_checkpoint = "F:\python_work\siamese-lstm-network\deep-siamese-text-similarity\\atec_runs\\1553238291\checkpoints\model-170000.meta" model_path = 'F:\python_work\siamese-lstm-network\deep-siamese-text-similarity\\atec_runs\\1553238291\checkpoints\model-170000' # 数据路径 saver = tf.train.import_meta_graph(input_checkpoint, clear_devices=False) graph = tf.get_default_graph() # 获得默认的图 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图 with tf.Session() as sess: saver.restore(sess, model_path) # 恢复图并得到数据 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定 sess=sess, input_graph_def=input_graph_def, # 等于:sess.graph_def output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开 with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型 f.write(output_graph_def.SerializeToString()) # 序列化输出 print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点
Example #21
Source File: util.py From pynlp with MIT License | 5 votes |
def freeze_graph(output_graph): ''' :param input_checkpoint: :param output_graph: PB模型保存路径 :return: ''' # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点 output_node_names = "output/distance,accuracy/accuracy,accuracy/temp_sim" input_checkpoint = "F:\python_work\siamese-lstm-network\deep-siamese-text-similarity\\runs\\1546075513\checkpoints\model-485000.meta" model_path = 'F:\python_work\siamese-lstm-network\deep-siamese-text-similarity\\runs\\1546075513\checkpoints\model-485000' # 数据路径 saver = tf.train.import_meta_graph(input_checkpoint, clear_devices=False) graph = tf.get_default_graph() # 获得默认的图 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图 with tf.Session() as sess: saver.restore(sess, model_path) # 恢复图并得到数据 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定 sess=sess, input_graph_def=input_graph_def, # 等于:sess.graph_def output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开 with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型 f.write(output_graph_def.SerializeToString()) # 序列化输出 print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点
Example #22
Source File: retrain.py From Face_ID with MIT License | 5 votes |
def save_graph_to_file(sess, graph, graph_file_name): output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) with gfile.FastGFile(graph_file_name, 'wb') as f: f.write(output_graph_def.SerializeToString()) return
Example #23
Source File: freeze_graph.py From pynlp with MIT License | 5 votes |
def freeze_graph(output_graph): ''' :param input_checkpoint: :param output_graph: PB模型保存路径 :return: ''' # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径 from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader("F:\python_work\siamese-lstm-network\ESIM\copy\checkpoint_dev_loss_0.3141.ckpt") var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) # 指定输出的节点名称,该节点名称必须是原模型中存在的节点 output_node_names = "composition/feed_forward/feed_foward_layer2/dense/Tanh" input_checkpoint = "F:\python_work\siamese-lstm-network\ESIM\copy\checkpoint_dev_loss_0.3141.ckpt.meta" model_path = 'F:\python_work\siamese-lstm-network\ESIM\copy\checkpoint_dev_loss_0.3141.ckpt' # 数据路径 saver = tf.train.import_meta_graph(input_checkpoint, clear_devices=False) graph = tf.get_default_graph() # 获得默认的图 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图 with tf.Session() as sess: saver.restore(sess, model_path) # 恢复图并得到数据 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定 sess=sess, input_graph_def=input_graph_def, # 等于:sess.graph_def output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开 with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型 f.write(output_graph_def.SerializeToString()) # 序列化输出 print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点
Example #24
Source File: util.py From VNet3D with MIT License | 5 votes |
def convertMetaModelToPbModel(meta_model, pb_model): # Step 1 # import the model metagraph saver = tf.train.import_meta_graph(meta_model + '.meta', clear_devices=True) # make that as the default graph graph = tf.get_default_graph() sess = tf.Session() # now restore the variables saver.restore(sess, meta_model) # Step 2 # Find the output name for op in graph.get_operations(): print(op.name) # Step 3 output_graph_def = graph_util.convert_variables_to_constants( sess, # The session sess.graph_def, # input_graph_def is useful for retrieving the nodes ["Placeholder", "output/Sigmoid"]) # Step 4 # output folder output_fld = './' # output pb file name output_model_file = 'model.pb' # write the graph graph_io.write_graph(output_graph_def, pb_model + output_fld, output_model_file, as_text=False)
Example #25
Source File: freeze_graph.py From facenet-demo with MIT License | 5 votes |
def freeze_graph_def(sess, input_graph_def, output_node_names): for node in input_graph_def.node: if node.op == 'RefSwitch': node.op = 'Switch' for index in xrange(len(node.input)): if 'moving_' in node.input[index]: node.input[index] = node.input[index] + '/read' elif node.op == 'AssignSub': node.op = 'Sub' if 'use_locking' in node.attr: del node.attr['use_locking'] elif node.op == 'AssignAdd': node.op = 'Add' if 'use_locking' in node.attr: del node.attr['use_locking'] # Get the list of important nodes whitelist_names = [] for node in input_graph_def.node: if (node.name.startswith('InceptionResnet') or node.name.startswith('embeddings') or node.name.startswith('image_batch') or node.name.startswith('label_batch') or node.name.startswith('phase_train') or node.name.startswith('Logits')): whitelist_names.append(node.name) # Replace all the variables in the graph with constants of the same values output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_whitelist=whitelist_names) return output_graph_def
Example #26
Source File: train.py From delta with Apache License 2.0 | 5 votes |
def to_graph_def(graph_path): with tf.Session() as sess: ret = create_model() #sess.run(tf.compat.v1.global_variables_initializer()) tf.global_variables_initializer().run() graph_summary = GraphSummary(graph_def=sess.graph_def) graph_summary.Summary() graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, graph_summary["outputs"]) with open(graph_path, "wb") as f: f.write(graph_def.SerializeToString())
Example #27
Source File: retrain.py From Attendace_management_system with MIT License | 5 votes |
def save_graph_to_file(sess, graph, graph_file_name): output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) with gfile.FastGFile(graph_file_name, 'wb') as f: f.write(output_graph_def.SerializeToString()) return
Example #28
Source File: retrain.py From aiexamples with Apache License 2.0 | 5 votes |
def save_graph_to_file(sess, graph, graph_file_name): output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) with gfile.FastGFile(graph_file_name, 'wb') as f: f.write(output_graph_def.SerializeToString()) return
Example #29
Source File: freeze_graph.py From uai-sdk with Apache License 2.0 | 5 votes |
def freeze_graph_def(sess, input_graph_def, output_node_names): for node in input_graph_def.node: if node.op == 'RefSwitch': node.op = 'Switch' for index in xrange(len(node.input)): if 'moving_' in node.input[index]: node.input[index] = node.input[index] + '/read' elif node.op == 'AssignSub': node.op = 'Sub' if 'use_locking' in node.attr: del node.attr['use_locking'] elif node.op == 'AssignAdd': node.op = 'Add' if 'use_locking' in node.attr: del node.attr['use_locking'] # Get the list of important nodes whitelist_names = [] for node in input_graph_def.node: if (node.name.startswith('InceptionResnet') or node.name.startswith('embeddings') or node.name.startswith('image_batch') or node.name.startswith('label_batch') or node.name.startswith('phase_train') or node.name.startswith('Logits')): whitelist_names.append(node.name) # Replace all the variables in the graph with constants of the same values output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_whitelist=whitelist_names) return output_graph_def
Example #30
Source File: tf_retrain.py From image_recognition with MIT License | 5 votes |
def save_graph_to_file(sess, graph, graph_file_name): output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) with gfile.FastGFile(graph_file_name, 'wb') as f: f.write(output_graph_def.SerializeToString()) return