Python tensorflow.python.pywrap_tensorflow.NewCheckpointReader() Examples
The following are 30
code examples of tensorflow.python.pywrap_tensorflow.NewCheckpointReader().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.pywrap_tensorflow
, or try the search function
.
Example #1
Source File: tf_policy_network.py From crossgap_il_rl with GNU General Public License v2.0 | 6 votes |
def resort_para_form_checkpoint(self, _ckpt_name_vec, graph, sess): # with tf.name_scope("restore"): if( isinstance(_ckpt_name_vec, list)): ckpt_name_vec = _ckpt_name_vec else: ckpt_name_vec = [_ckpt_name_vec] with tf.name_scope ("restore"): for ckpt_name in ckpt_name_vec: print("===== Restore data from %s =====" % ckpt_name) reader = pywrap_tensorflow.NewCheckpointReader(ckpt_name) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: # print("tensor_name: ", key) # print(reader.get_tensor(key)) # tensor = graph.get_tensor_by_name(key) try: tensor = graph.get_tensor_by_name(key + ":0") sess.run(tf.assign(tensor, reader.get_tensor(key))) # print(tensor) except: # print(key, " can not be restored") pass
Example #2
Source File: inspect_checkpoint.py From tf.fashionAI with Apache License 2.0 | 6 votes |
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) if all_tensors: var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) print(reader.get_tensor(key)) elif not tensor_name: print(reader.debug_string().decode("utf-8")) else: print("tensor_name: ", tensor_name) print(reader.get_tensor(tensor_name)) except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.")
Example #3
Source File: tf_policy_network.py From crossgap_il_rl with GNU General Public License v2.0 | 6 votes |
def resore_form_rl_net(self,ckpt_name, graph, sess): print("Restore form RL net") print("===== Prase data from %s =====" % ckpt_name) net_prefix = 'pi/pi' reader = pywrap_tensorflow.NewCheckpointReader(ckpt_name) var_to_shape_map = reader.get_variable_to_shape_map() for _key in var_to_shape_map: print(_key) # print("tensor_name: ", key) # print(reader.get_tensor(key)) # tensor = graph.get_tensor_by_name(key) if (str(_key).startswith('%s/net/'%net_prefix) or str(_key).startswith('%s/Trajectory_follower_mlp_net/'%net_prefix)): notaion_list = [m.start() for m in re.finditer('/', _key)] key = _key[int(notaion_list[1]+1):len(_key)]+ ":0" # print(key) try: tensor = graph.get_tensor_by_name(key) sess.run(tf.assign(tensor, reader.get_tensor(_key))) # print(tensor) except Exception as e: print(key, " can not be restored, e= ",str(e)) pass
Example #4
Source File: checkpint_inspect.py From SSD.TensorFlow with Apache License 2.0 | 6 votes |
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) if all_tensors: var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) print(reader.get_tensor(key)) elif not tensor_name: print(reader.debug_string().decode("utf-8")) else: print("tensor_name: ", tensor_name) print(reader.get_tensor(tensor_name)) except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.")
Example #5
Source File: checkpint_inspect.py From inference with Apache License 2.0 | 6 votes |
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) if all_tensors: var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) print(reader.get_tensor(key)) elif not tensor_name: print(reader.debug_string().decode("utf-8")) else: print("tensor_name: ", tensor_name) print(reader.get_tensor(tensor_name)) except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.")
Example #6
Source File: npz_file_to_checkpoint.py From will-people-like-your-image with GNU Lesser General Public License v3.0 | 6 votes |
def create_model_from_npz_file(npz, model, target): """Creates a tensorflow model from a given npz structure in which the variables for the desired model are stored. npz: Path to the npz structure containing files representing the variables in the model. model: Path in which the final model should be stored target: A target model which contains the desired names for the structure """ reader = pywrap_tensorflow.NewCheckpointReader(target) target_map = reader.get_variable_to_shape_map() variables = variables_dictionary_from_npz_file(npz) i = 0 for key in variables: if key_contained_in_map(key, target_map): name = 'var' + str(i) val = tf.Variable(variables[key], name=key) exec(name + " = val") i += 1 saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) save_path = saver.save(sess, model) print("Model saved in file: %s" % save_path)
Example #7
Source File: SSD512.py From Object-Detection-API-Tensorflow with MIT License | 5 votes |
def __init__(self, config, data_provider): assert config['mode'] in ['train', 'test'] assert config['data_format'] in ['channels_first', 'channels_last'] self.config = config self.data_provider = data_provider self.input_size = 512 if config['data_format'] == 'channels_last': self.data_shape = [512, 512, 3] else: self.data_shape = [3, 512, 512] self.num_classes = config['num_classes'] + 1 self.weight_decay = config['weight_decay'] self.prob = 1. - config['keep_prob'] self.data_format = config['data_format'] self.mode = config['mode'] self.batch_size = config['batch_size'] if config['mode'] == 'train' else 1 self.nms_score_threshold = config['nms_score_threshold'] self.nms_max_boxes = config['nms_max_boxes'] self.nms_iou_threshold = config['nms_iou_threshold'] self.reader = wrap.NewCheckpointReader(config['pretraining_weight']) if self.mode == 'train': self.num_train = data_provider['num_train'] self.num_val = data_provider['num_val'] self.train_generator = data_provider['train_generator'] self.train_initializer, self.train_iterator = self.train_generator if data_provider['val_generator'] is not None: self.val_generator = data_provider['val_generator'] self.val_initializer, self.val_iterator = self.val_generator self.global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False) self.is_training = True self._define_inputs() self._build_graph() self._create_saver() if self.mode == 'train': self._create_summary() self._init_session()
Example #8
Source File: npz_file_to_checkpoint.py From will-people-like-your-image with GNU Lesser General Public License v3.0 | 5 votes |
def print_variables_from_stored_model(graph_path): """Prints the names of the tensors stored in a tensorflow model. graph_path: path to the stored model. """ reader = pywrap_tensorflow.NewCheckpointReader(graph_path) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key)
Example #9
Source File: freeze_graph.py From pynlp with MIT License | 5 votes |
def freeze_graph(output_graph): ''' :param input_checkpoint: :param output_graph: PB模型保存路径 :return: ''' # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径 from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader("F:\python_work\siamese-lstm-network\deep-siamese-text-similarity\\atec_runs\\1553238291\checkpoints\model-170000") var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) # 指定输出的节点名称,该节点名称必须是原模型中存在的节点 output_node_names = "accuracy/temp_sim,output/distance" input_checkpoint = "F:\python_work\siamese-lstm-network\deep-siamese-text-similarity\\atec_runs\\1553238291\checkpoints\model-170000.meta" model_path = 'F:\python_work\siamese-lstm-network\deep-siamese-text-similarity\\atec_runs\\1553238291\checkpoints\model-170000' # 数据路径 saver = tf.train.import_meta_graph(input_checkpoint, clear_devices=False) graph = tf.get_default_graph() # 获得默认的图 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图 with tf.Session() as sess: saver.restore(sess, model_path) # 恢复图并得到数据 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定 sess=sess, input_graph_def=input_graph_def, # 等于:sess.graph_def output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开 with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型 f.write(output_graph_def.SerializeToString()) # 序列化输出 print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点
Example #10
Source File: train.py From GeetChinese_crack with MIT License | 5 votes |
def get_variables_in_checkpoint_file(self, file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.")
Example #11
Source File: RefineDet.py From Object-Detection-API-Tensorflow with MIT License | 5 votes |
def __init__(self, config, data_provider): assert config['mode'] in ['train', 'test'] assert config['data_format'] in ['channels_first', 'channels_last'] self.config = config self.data_provider = data_provider self.input_size = config['input_size'] if config['data_format'] == 'channels_last': self.data_shape = [self.input_size, self.input_size, 3] else: self.data_shape = [3, self.input_size, self.input_size] self.num_classes = config['num_classes'] + 1 self.weight_decay = config['weight_decay'] self.prob = 1. - config['keep_prob'] self.data_format = config['data_format'] self.mode = config['mode'] self.batch_size = config['batch_size'] if config['mode'] == 'train' else 1 self.anchor_ratios = [0.5, 1.0, 2.0] self.num_anchors = len(self.anchor_ratios) self.nms_score_threshold = config['nms_score_threshold'] self.nms_max_boxes = config['nms_max_boxes'] self.nms_iou_threshold = config['nms_iou_threshold'] self.reader = wrap.NewCheckpointReader(config['pretraining_weight']) if self.mode == 'train': self.num_train = data_provider['num_train'] self.num_val = data_provider['num_val'] self.train_generator = data_provider['train_generator'] self.train_initializer, self.train_iterator = self.train_generator if data_provider['val_generator'] is not None: self.val_generator = data_provider['val_generator'] self.val_initializer, self.val_iterator = self.val_generator self.global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False) self.is_training = True self._define_inputs() self._build_graph() self._create_saver() if self.mode == 'train': self._create_summary() self._init_session()
Example #12
Source File: saver.py From lighttrack with MIT License | 5 votes |
def get_variables_in_checkpoint_file(file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print( "It's likely that your checkpoint file has been compressed " "with SNAPPY.")
Example #13
Source File: SSD300.py From Object-Detection-API-Tensorflow with MIT License | 5 votes |
def __init__(self, config, data_provider): assert config['mode'] in ['train', 'test'] assert config['data_format'] in ['channels_first', 'channels_last'] self.config = config self.data_provider = data_provider self.input_size = 300 if config['data_format'] == 'channels_last': self.data_shape = [300, 300, 3] else: self.data_shape = [3, 300, 300] self.num_classes = config['num_classes'] + 1 self.weight_decay = config['weight_decay'] self.prob = 1. - config['keep_prob'] self.data_format = config['data_format'] self.mode = config['mode'] self.batch_size = config['batch_size'] if config['mode'] == 'train' else 1 self.nms_score_threshold = config['nms_score_threshold'] self.nms_max_boxes = config['nms_max_boxes'] self.nms_iou_threshold = config['nms_iou_threshold'] self.reader = wrap.NewCheckpointReader(config['pretraining_weight']) if self.mode == 'train': self.num_train = data_provider['num_train'] self.num_val = data_provider['num_val'] self.train_generator = data_provider['train_generator'] self.train_initializer, self.train_iterator = self.train_generator if data_provider['val_generator'] is not None: self.val_generator = data_provider['val_generator'] self.val_initializer, self.val_iterator = self.val_generator self.global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False) self.is_training = True self._define_inputs() self._build_graph() self._create_saver() if self.mode == 'train': self._create_summary() self._init_session()
Example #14
Source File: saver.py From PoseFix_RELEASE with MIT License | 5 votes |
def get_variables_in_checkpoint_file(file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return reader, var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print( "It's likely that your checkpoint file has been compressed " "with SNAPPY.")
Example #15
Source File: checkpint_inspect.py From SSD.TensorFlow with Apache License 2.0 | 5 votes |
def print_all_tensors_name(file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print(key) except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.")
Example #16
Source File: npz_file_to_checkpoint.py From will-people-like-your-image with GNU Lesser General Public License v3.0 | 5 votes |
def check_adam(model): """Checks whether a unwanted variable of the adam optimizer is still contained in a model. """ reader = pywrap_tensorflow.NewCheckpointReader(model) target_map = reader.get_variable_to_shape_map() for key in list(target_map): if 'Adam_1' in key: print(key + "contains Adam.") return
Example #17
Source File: inspect_cp.py From NJUNMT-tf with Apache License 2.0 | 5 votes |
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors, all_tensor_names): """Prints tensors in a checkpoint file. If no `tensor_name` is provided, prints the tensor names and shapes in the checkpoint file. If `tensor_name` is provided, prints the content of the tensor. Args: file_name: Name of the checkpoint file. tensor_name: Name of the tensor in the checkpoint file to print. all_tensors: Boolean indicating whether to print all tensors. all_tensor_names: Boolean indicating whether to print all tensor names. """ try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) if all_tensors or all_tensor_names: var_to_shape_map = reader.get_variable_to_shape_map() for key in sorted(var_to_shape_map): print("tensor_name: ", key) if all_tensors: print(reader.get_tensor(key)) elif not tensor_name: print(reader.debug_string().decode("utf-8")) else: print("tensor_name: ", tensor_name) print(reader.get_tensor(tensor_name)) except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.") if ("Data loss" in str(e) and (any([e in file_name for e in [".index", ".meta", ".data"]]))): proposed_file = ".".join(file_name.split(".")[0:-1]) v2_file_error_template = """ It's likely that this is a V2 checkpoint and you need to provide the filename *prefix*. Try removing the '.' and extension. Try: inspect checkpoint --file_name = {}""" print(v2_file_error_template.format(proposed_file))
Example #18
Source File: inspect_checkpoint.py From MobileNet with Apache License 2.0 | 5 votes |
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): """Prints tensors in a checkpoint file. If no `tensor_name` is provided, prints the tensor names and shapes in the checkpoint file. If `tensor_name` is provided, prints the content of the tensor. Args: file_name: Name of the checkpoint file. tensor_name: Name of the tensor in the checkpoint file to print. all_tensors: Boolean indicating whether to print all tensors. """ try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) if all_tensors: var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) print(reader.get_tensor(key)) elif not tensor_name: print(reader.debug_string().decode("utf-8")) else: print("tensor_name: ", tensor_name) print(reader.get_tensor(tensor_name)) except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.")
Example #19
Source File: base_network.py From VAE-GAN with MIT License | 5 votes |
def load_pretrained_model_weights(self, sess, cfg, network_name, only_bottom=True): config_file = get_config(cfg) asset_filepath = config_file['assets dir'] ckpt_path = os.path.join(asset_filepath, config_file["trainer params"].get("checkpoint dir", "checkpoint")) ckpt_name = '' with open(os.path.join(ckpt_path, 'checkpoint'), 'r') as infile: for line in infile: if line.startswith('model_checkpoint_path'): ckpt_name = line[len("model_checkpoint_path: \""):-2] checkpoint_path = os.path.join(ckpt_path, ckpt_name) print("Load checkpoint : ", checkpoint_path) reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() assign_list = [] var_list = self.all_vars var_dict = {var.name.split(':')[0] : var for var in var_list} for key in var_to_shape_map: if key.startswith(network_name): if only_bottom and 'fc' in key: continue var_name = self.name + '/' + key[len(network_name)+1:] assign_list.append(tf.assign(var_dict[var_name], reader.get_tensor(key))) assign_op = tf.group(assign_list) sess.run(assign_op) return True
Example #20
Source File: tf_policy_network.py From crossgap_il_rl with GNU General Public License v2.0 | 5 votes |
def prase_checkpoint_data(self, checkpoint_name): print("===== Prase data from %s =====" % checkpoint_name) reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_name) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) # print(reader.get_tensor(key))
Example #21
Source File: tf_policy_network.py From crossgap_il_rl with GNU General Public License v2.0 | 5 votes |
def resort_para_form_checkpoint( prefix, graph, sess): # ckpt_name_vec = ["./tf_net/planning_net/tf_saver_252750.ckpt", "./tf_net/control_mlp_net_train/tf_saver_2318100.ckpt"] # ckpt_name_vec = ["./tf_net/planning_net/tf_saver_252750.ckpt", "./tf_net/control_mlp_net_train/tf_saver_1300000.ckpt"] ckpt_name_vec = ["./tf_net/planning_net/tf_saver_107840000.ckpt", "./tf_net/pid_net/tf_saver_109330000.ckpt"] print("=========") file = open("full_structure.txt","w") file.writelines(str(graph.get_operations())) # for ops in tf.Graph.get_all_collection_keys(): # for ops in graph.get_operations(): # file.writelines(ops) # print(ops) file.close() print("=========") with tf.name_scope("restore"): for ckpt_name in ckpt_name_vec: print("===== Restore data from %s =====" % ckpt_name) reader = pywrap_tensorflow.NewCheckpointReader(ckpt_name) var_to_shape_map = reader.get_variable_to_shape_map() for _key in var_to_shape_map: # print("tensor_name: ", key) # print(reader.get_tensor(key)) # tensor = graph.get_tensor_by_name(key) key = prefix + _key+ ":0" try: tensor = graph.get_tensor_by_name(key) sess.run(tf.assign(tensor, reader.get_tensor(_key))) # print(tensor) except Exception as e: # print(key, " can not be restored, e= ",str(e)) pass
Example #22
Source File: tf_rapid_trajectory.py From crossgap_il_rl with GNU General Public License v2.0 | 5 votes |
def resort_para_form_checkpoint( prefix, graph, sess, ckpt_name ): from tensorflow.python import pywrap_tensorflow # ckpt_name_vec = ["./tf_net/planning_net/tf_saver_252750.ckpt", "./tf_net/control_mlp_net_train/tf_saver_2318100.ckpt"] # ckpt_name_vec = ["./tf_net/planning_net/tf_saver_252750.ckpt", "./tf_net/control_mlp_net_train/tf_saver_1300000.ckpt"] # ckpt_name_vec = ["./tf_net/planning_net/tf_saver_252750.ckpt", "./tf_net/control_mlp_net/save_net_mlp.ckpt"] ckpt_name_vec = [ckpt_name] print("=========") file = open("full_structure.txt","w") file.writelines(str(graph.get_operations())) # for ops in tf.Graph.get_all_collection_keys(): # for ops in graph.get_operations(): # file.writelines(ops) # print(ops) file.close() print("=========") with tf.name_scope("restore"): for ckpt_name in ckpt_name_vec: print("===== Restore data from %s =====" % ckpt_name) reader = pywrap_tensorflow.NewCheckpointReader(ckpt_name) var_to_shape_map = reader.get_variable_to_shape_map() for _key in var_to_shape_map: # print("tensor_name: ", key) # print(reader.get_tensor(key)) # tensor = graph.get_tensor_by_name(key) key = prefix + _key+ ":0" # key = prefix + _key try: tensor = graph.get_tensor_by_name(key) sess.run(tf.assign(tensor, reader.get_tensor(_key))) # print(tensor) except Exception as e: print(key, " can not be restored, e= ",str(e)) pass
Example #23
Source File: checkpint_inspect.py From inference with Apache License 2.0 | 5 votes |
def print_all_tensors_name(file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print(key) except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.")
Example #24
Source File: convert_from_depre.py From tf-faster-rcnn with MIT License | 5 votes |
def get_variables_in_checkpoint_file(file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.")
Example #25
Source File: train_val.py From tf-faster-rcnn with MIT License | 5 votes |
def get_variables_in_checkpoint_file(self, file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.")
Example #26
Source File: average_model.py From sequencing with MIT License | 5 votes |
def average_ckpt(checkpoint_from_paths, checkpoint_to_path): """Migrates the names of variables within a checkpoint. Args: checkpoint_from_path: Path to source checkpoint to be read in. checkpoint_to_path: Path to checkpoint to be written out. """ with ops.Graph().as_default(): new_variable_map = defaultdict(list) for checkpoint_from_path in checkpoint_from_paths: logging.info('Reading checkpoint_from_path %s' % checkpoint_from_path) reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_from_path) name_shape_map = reader.get_variable_to_shape_map() for var_name in name_shape_map: tensor = reader.get_tensor(var_name) new_variable_map[var_name].append(tensor) variable_map = {} for var_name in name_shape_map: tensor = reduce(lambda x, y: x + y, new_variable_map[var_name]) / len(new_variable_map[var_name]) var = variables.Variable(tensor, name=var_name) variable_map[var_name] = var print(variable_map) saver = saver_lib.Saver(variable_map) with session.Session() as sess: sess.run(variables.global_variables_initializer()) logging.info('Writing checkpoint_to_path %s' % checkpoint_to_path) saver.save(sess, checkpoint_to_path) logging.info('Summary:') logging.info(' Converted %d variable name(s).' % len(new_variable_map))
Example #27
Source File: saver.py From tf-cpn with MIT License | 5 votes |
def get_variables_in_checkpoint_file(file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print( "It's likely that your checkpoint file has been compressed " "with SNAPPY.")
Example #28
Source File: plugin.py From deep_image_model with Apache License 2.0 | 5 votes |
def _get_reader_for_run(self, run): if run in self.readers: return self.readers[run] config = self._configs[run] reader = None if config.model_checkpoint_path: try: reader = NewCheckpointReader(config.model_checkpoint_path) except Exception: # pylint: disable=broad-except logging.warning('Failed reading %s', config.model_checkpoint_path) self.readers[run] = reader return reader
Example #29
Source File: inspect_checkpoint.py From tf.fashionAI with Apache License 2.0 | 5 votes |
def print_all_tensors_name(file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print(key) except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.")
Example #30
Source File: RefineDet.py From RefineDet-tensorflow with MIT License | 5 votes |
def __init__(self, config, data_provider): assert config['mode'] in ['train', 'test'] assert config['data_format'] in ['channels_first', 'channels_last'] self.config = config self.data_provider = data_provider self.input_size = config['input_size'] if config['data_format'] == 'channels_last': self.data_shape = [self.input_size, self.input_size, 3] else: self.data_shape = [3, self.input_size, self.input_size] self.num_classes = config['num_classes'] + 1 self.weight_decay = config['weight_decay'] self.prob = 1. - config['keep_prob'] self.data_format = config['data_format'] self.mode = config['mode'] self.batch_size = config['batch_size'] if config['mode'] == 'train' else 1 self.anchor_ratios = [0.5, 1.0, 2.0] self.num_anchors = len(self.anchor_ratios) self.nms_score_threshold = config['nms_score_threshold'] self.nms_max_boxes = config['nms_max_boxes'] self.nms_iou_threshold = config['nms_iou_threshold'] self.reader = wrap.NewCheckpointReader(config['pretraining_weight']) if self.mode == 'train': self.num_train = data_provider['num_train'] self.num_val = data_provider['num_val'] self.train_generator = data_provider['train_generator'] self.train_initializer, self.train_iterator = self.train_generator if data_provider['val_generator'] is not None: self.val_generator = data_provider['val_generator'] self.val_initializer, self.val_iterator = self.val_generator self.global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False) self.is_training = True self._define_inputs() self._build_graph() self._create_saver() if self.mode == 'train': self._create_summary() self._init_session()