Python tensorflow.global_variables() Examples
The following are 30
code examples of tensorflow.global_variables().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: main.py From mac-network with Apache License 2.0 | 6 votes |
def setSavers(model): saver = tf.train.Saver(max_to_keep = config.weightsToKeep) subsetSaver = None if config.saveSubset: isRelevant = lambda var: any(s in var.name for s in config.varSubset) relevantVars = [var for var in tf.global_variables() if isRelevant(var)] subsetSaver = tf.train.Saver(relevantVars, max_to_keep = config.weightsToKeep, allow_empty = True) emaSaver = None if config.useEMA: emaSaver = tf.train.Saver(model.emaDict, max_to_keep = config.weightsToKeep) return { "saver": saver, "subsetSaver": subsetSaver, "emaSaver": emaSaver } ################################### restore / initialize weights ################################## # Restores weights of specified / last epoch if on restore mod. # Otherwise, initializes weights.
Example #2
Source File: utils_tf.py From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License | 6 votes |
def initialize_uninitialized_global_variables(sess): """ Only initializes the variables of a TensorFlow session that were not already initialized. :param sess: the TensorFlow session :return: """ # List all global variables global_vars = tf.global_variables() # Find initialized status for all variables is_var_init = [tf.is_variable_initialized(var) for var in global_vars] is_initialized = sess.run(is_var_init) # List all variables that were not initialized previously not_initialized_vars = [var for (var, init) in zip(global_vars, is_initialized) if not init] # Initialize all uninitialized variables found, if any if len(not_initialized_vars): sess.run(tf.variables_initializer(not_initialized_vars))
Example #3
Source File: help.py From Traffic_sign_detection_YOLO with MIT License | 6 votes |
def to_darknet(self): darknet_ckpt = self.darknet with self.graph.as_default() as g: for var in tf.global_variables(): name = var.name.split(':')[0] var_name = name.split('-') l_idx = int(var_name[0]) w_sig = var_name[1].split('/')[-1] l = darknet_ckpt.layers[l_idx] l.w[w_sig] = var.eval(self.sess) for layer in darknet_ckpt.layers: for ph in layer.h: layer.h[ph] = None return darknet_ckpt
Example #4
Source File: help.py From Traffic-Signs-and-Object-Detection with GNU General Public License v3.0 | 6 votes |
def to_darknet(self): darknet_ckpt = self.darknet with self.graph.as_default() as g: for var in tf.global_variables(): name = var.name.split(':')[0] var_name = name.split('-') l_idx = int(var_name[0]) w_sig = var_name[1].split('/')[-1] l = darknet_ckpt.layers[l_idx] l.w[w_sig] = var.eval(self.sess) for layer in darknet_ckpt.layers: for ph in layer.h: layer.h[ph] = None return darknet_ckpt
Example #5
Source File: tfutil.py From disentangling_conditional_gans with MIT License | 6 votes |
def init_uninited_vars(vars=None): if vars is None: vars = tf.global_variables() test_vars = []; test_ops = [] with tf.control_dependencies(None): # ignore surrounding control_dependencies for var in vars: assert is_tf_expression(var) try: tf.get_default_graph().get_tensor_by_name(var.name.replace(':0', '/IsVariableInitialized:0')) except KeyError: # Op does not exist => variable may be uninitialized. test_vars.append(var) with absolute_name_scope(var.name.split(':')[0]): test_ops.append(tf.is_variable_initialized(var)) init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] run([var.initializer for var in init_vars]) #---------------------------------------------------------------------------- # Set the values of given tf.Variables. # Equivalent to the following, but more efficient and does not bloat the tf graph: # tfutil.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
Example #6
Source File: faster_rcnn_meta_arch.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def restore_from_classification_checkpoint_fn( self, first_stage_feature_extractor_scope, second_stage_feature_extractor_scope): """Returns a map of variables to load from a foreign checkpoint. Args: first_stage_feature_extractor_scope: A scope name for the first stage feature extractor. second_stage_feature_extractor_scope: A scope name for the second stage feature extractor. Returns: A dict mapping variable names (to load from a checkpoint) to variables in the model graph. """ variables_to_restore = {} for variable in tf.global_variables(): for scope_name in [first_stage_feature_extractor_scope, second_stage_feature_extractor_scope]: if variable.op.name.startswith(scope_name): var_name = variable.op.name.replace(scope_name + '/', '') variables_to_restore[var_name] = variable return variables_to_restore
Example #7
Source File: ssd_meta_arch.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def restore_from_classification_checkpoint_fn(self, feature_extractor_scope): """Returns a map of variables to load from a foreign checkpoint. Args: feature_extractor_scope: A scope name for the feature extractor. Returns: A dict mapping variable names (to load from a checkpoint) to variables in the model graph. """ variables_to_restore = {} for variable in tf.global_variables(): var_name = variable.op.name if var_name.startswith(feature_extractor_scope + '/'): var_name = var_name.replace(feature_extractor_scope + '/', '') variables_to_restore[var_name] = variable return variables_to_restore
Example #8
Source File: utility.py From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def define_saver(exclude=None): """Create a saver for the variables we want to checkpoint. Args: exclude: List of regexes to match variable names to exclude. Returns: Saver object. """ variables = [] exclude = exclude or [] exclude = [re.compile(regex) for regex in exclude] for variable in tf.global_variables(): if any(regex.match(variable.name) for regex in exclude): continue variables.append(variable) saver = tf.train.Saver(variables, keep_checkpoint_every_n_hours=5) return saver
Example #9
Source File: ssd_pnasnet_feature_extractor.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def restore_from_classification_checkpoint_fn(self, feature_extractor_scope): """Returns a map of variables to load from a foreign checkpoint. Note that this overrides the default implementation in ssd_meta_arch.SSDFeatureExtractor which does not work for PNASNet checkpoints. Args: feature_extractor_scope: A scope name for the first stage feature extractor. Returns: A dict mapping variable names (to load from a checkpoint) to variables in the model graph. """ variables_to_restore = {} for variable in tf.global_variables(): if variable.op.name.startswith(feature_extractor_scope): var_name = variable.op.name.replace(feature_extractor_scope + '/', '') var_name += '/ExponentialMovingAverage' variables_to_restore[var_name] = variable return variables_to_restore
Example #10
Source File: utility.py From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def define_saver(exclude=None): """Create a saver for the variables we want to checkpoint. Args: exclude: List of regexes to match variable names to exclude. Returns: Saver object. """ variables = [] exclude = exclude or [] exclude = [re.compile(regex) for regex in exclude] for variable in tf.global_variables(): if any(regex.match(variable.name) for regex in exclude): continue variables.append(variable) saver = tf.train.Saver(variables, keep_checkpoint_every_n_hours=5) return saver
Example #11
Source File: sequenceNet.py From deep-summarization with MIT License | 6 votes |
def _start_session(self): """ Starts the Tensorflow Session :return: None """ self.sess.run(tf.global_variables_initializer()) # initialize the saver node # print tf.GraphKeys.GLOBAL_VARIABLES self.saver = tf.train.Saver(tf.global_variables()) # get the latest checkpoint last_checkpoint_path = self.checkpointer.get_last_checkpoint() if last_checkpoint_path is not None: print 'Previous saved tensorflow objects found... Extracting...' # restore the tensorflow variables self.saver.restore(self.sess, last_checkpoint_path) print 'Extraction Complete. Moving Forward....'
Example #12
Source File: base_model.py From hierarchical_loc with BSD 3-Clause "New" or "Revised" License | 6 votes |
def _checkpoint_var_search(self, checkpoint_path): reader = tf.train.NewCheckpointReader(checkpoint_path) saved_shapes = reader.get_variable_to_shape_map() model_names = tf.model_variables() # Used by tf.slim layers if not len(tf.model_variables()): model_names = tf.global_variables() # Fallback when slim is not used model_names = set([v.name.split(':')[0] for v in model_names]) checkpoint_names = set(saved_shapes.keys()) found_names = model_names & checkpoint_names missing_names = model_names - checkpoint_names shape_conflicts = set() restored = [] with tf.variable_scope('', reuse=True): for name in found_names: # print(tf.global_variables()) # print(name, name in model_names, name in checkpoint_names) var = tf.get_variable(name) var_shape = var.get_shape().as_list() if var_shape == saved_shapes[name]: restored.append(var) else: shape_conflicts.add(name) found_names -= shape_conflicts return (restored, sorted(found_names), sorted(missing_names), sorted(shape_conflicts))
Example #13
Source File: pos_tagger.py From deepnlp with MIT License | 6 votes |
def _init_pos_model(self, session): """Create POS Tagger model and initialize with random or load parameters in session.""" # initilize config config_dict = load_config(self.model_config_path) config = get_config(config_dict, self.name) config.batch_size = 1 config.num_steps = 1 # iterator one token per time model_var_scope = get_model_var_scope(self.var_scope, self.name) print ("NOTICE: Input POS Model Var Scope Name '%s'" % model_var_scope) # Check if self.model already exist if self.model is None: with tf.variable_scope(model_var_scope, tf.AUTO_REUSE): self.model = pos_model.POSTagger(is_training=False, config=config) # save object after is_training # Load Specific .data* ckpt file if len(glob.glob(self.ckpt_path + '.data*')) > 0: # file exist with pattern: 'pos.ckpt.data*' print("NOTICE: Loading model parameters from %s" % self.ckpt_path) all_vars = tf.global_variables() model_vars = [k for k in all_vars if model_var_scope in k.name.split("/")] tf.train.Saver(model_vars).restore(session, self.ckpt_path) else: print("NOTICE: Model not found, Try to run method: deepnlp.download(module='pos', name='%s')" % self.name) print("NOTICE: Created with fresh parameters.") session.run(tf.global_variables_initializer())
Example #14
Source File: run_summarization.py From TransferRL with MIT License | 6 votes |
def convert_to_reinforce_model(self): """Load non-reinforce checkpoint, add initialized extra variables for reinforce, and save as new checkpoint""" tf.logging.info("converting non-reinforce model to reinforce model..") # initialize an entire reinforce model from scratch sess = tf.Session(config=util.get_config()) print("initializing everything...") sess.run(tf.global_variables_initializer()) # load all non-reinforce weights from checkpoint saver = tf.train.Saver([v for v in tf.global_variables() if "reinforce" not in v.name and "Adagrad" not in v.name]) print("restoring non-reinforce variables...") curr_ckpt = util.load_ckpt(saver, sess) print("restored.") # save this model and quit new_fname = curr_ckpt + '_rl_init' print("saving model to %s..." % (new_fname)) new_saver = tf.train.Saver() # this one will save all variables that now exist new_saver.save(sess, new_fname) print("saved.") exit()
Example #15
Source File: transformer_test.py From fine-lm with MIT License | 6 votes |
def testVarNames(self): with tf.Graph().as_default(): model, features = get_model( mode=tf.estimator.ModeKeys.PREDICT, model_cls=transformer.TransformerScorer) _ = model.infer(features) scorer_vars = [v.name for v in tf.global_variables()] with tf.Graph().as_default(): model, features = get_model( mode=tf.estimator.ModeKeys.EVAL, model_cls=transformer.TransformerScorer) _ = model(features) scorer_eval_vars = [v.name for v in tf.global_variables()] with tf.Graph().as_default(): model, features = get_model( mode=tf.estimator.ModeKeys.EVAL, model_cls=transformer.Transformer) _ = model(features) transformer_vars = [v.name for v in tf.global_variables()] self.assertEqual(sorted(scorer_vars), sorted(transformer_vars)) self.assertEqual(sorted(scorer_eval_vars), sorted(transformer_vars))
Example #16
Source File: common_layers.py From fine-lm with MIT License | 6 votes |
def underlying_variable(t): """Find the underlying tf.Variable object. Args: t: a Tensor Returns: tf.Variable. """ t = underlying_variable_ref(t) assert t is not None # make sure that the graph has a variable index and that it is up-to-date if not hasattr(tf.get_default_graph(), "var_index"): tf.get_default_graph().var_index = {} var_index = tf.get_default_graph().var_index for v in tf.global_variables()[len(var_index):]: var_index[v.name] = v return var_index[t.name]
Example #17
Source File: run_summarization.py From TransferRL with MIT License | 6 votes |
def convert_to_coverage_model(self): """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint""" tf.logging.info("converting non-coverage model to coverage model..") # initialize an entire coverage model from scratch sess = tf.Session(config=util.get_config()) print("initializing everything...") sess.run(tf.global_variables_initializer()) # load all non-coverage weights from checkpoint saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name]) print("restoring non-coverage variables...") curr_ckpt = util.load_ckpt(saver, sess) print("restored.") # save this model and quit new_fname = curr_ckpt + '_cov_init' print("saving model to %s..." % (new_fname)) new_saver = tf.train.Saver() # this one will save all variables that now exist new_saver.save(sess, new_fname) print("saved.") exit()
Example #18
Source File: convert_weights.py From Tensorflow-YOLOv3 with MIT License | 6 votes |
def main(tiny): if tiny: model = YOLOv3_tiny(n_classes=80, iou_threshold=0.5, confidence_threshold=0.5) else: model = YOLOv3(n_classes=80, iou_threshold=0.5, confidence_threshold=0.5) inputs = tf.placeholder(tf.float32, [1, 416, 416, 3]) model(inputs) model_vars = tf.global_variables(scope=model.scope) if tiny: assign_ops = load_weights_tiny(model_vars, './weights/yolov3-tiny.weights') else: assign_ops = load_weights(model_vars, './weights/yolov3.weights') saver = tf.train.Saver(tf.global_variables(scope=model.scope)) with tf.Session() as sess: save_path = './weights/model-tiny.ckpt' if tiny else './weights/model.ckpt' sess.run(assign_ops) saver.save(sess, save_path) print("Model Saved at \"" + save_path + "\"")
Example #19
Source File: help.py From Automatic-Identification-and-Counting-of-Blood-Cells with GNU General Public License v3.0 | 6 votes |
def to_darknet(self): darknet_ckpt = self.darknet with self.graph.as_default() as g: for var in tf.global_variables(): name = var.name.split(':')[0] var_name = name.split('-') l_idx = int(var_name[0]) w_sig = var_name[1].split('/')[-1] l = darknet_ckpt.layers[l_idx] l.w[w_sig] = var.eval(self.sess) for layer in darknet_ckpt.layers: for ph in layer.h: layer.h[ph] = None return darknet_ckpt
Example #20
Source File: trainer_test.py From vehicle_counting_tensorflow with MIT License | 5 votes |
def restore_map(self, fine_tune_checkpoint_type='detection'): """Returns a map of variables to load from a foreign checkpoint. Args: fine_tune_checkpoint_type: whether to restore from a full detection checkpoint (with compatible variable names) or to restore from a classification checkpoint for initialization prior to training. Valid values: `detection`, `classification`. Default 'detection'. Returns: A dict mapping variable names to variables. """ return {var.op.name: var for var in tf.global_variables()}
Example #21
Source File: pos_model_bilstm_crf.py From deepnlp with MIT License | 5 votes |
def __init__(self, is_training, config): self.batch_size = batch_size = config.batch_size self.num_steps = num_steps = config.num_steps self.is_training = is_training self.crf_layer = config.crf_layer # if the model has the final CRF decoding layer size = config.hidden_size vocab_size = config.vocab_size # Define input and target tensors self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps]) self._targets = tf.placeholder(tf.int32, [batch_size, num_steps]) with tf.device("/cpu:0"): embedding = tf.get_variable("embedding", [vocab_size, size], dtype=data_type()) inputs = tf.nn.embedding_lookup(embedding, self._input_data) # BiLSTM CRF model self._cost, self._logits, self._transition_params = _bilstm_crf_model(inputs, self._targets, config) # Gradients and SGD update operation for training the model. self._lr = tf.Variable(0.0, trainable=False) tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(self._cost, tvars), config.max_grad_norm) optimizer = tf.train.GradientDescentOptimizer(self._lr) # vanila SGD self._train_op = optimizer.apply_gradients( zip(grads, tvars), global_step=tf.contrib.framework.get_or_create_global_step()) self._new_lr = tf.placeholder(data_type(), shape=[], name="new_learning_rate") self._lr_update = tf.assign(self._lr, self._new_lr) self.saver = tf.train.Saver(tf.global_variables())
Example #22
Source File: pos_model_bilstm.py From deepnlp with MIT License | 5 votes |
def __init__(self, is_training, config): self.batch_size = batch_size = config.batch_size self.num_steps = num_steps = config.num_steps self.is_training = is_training size = config.hidden_size vocab_size = config.vocab_size # Define input and target tensors self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps]) self._targets = tf.placeholder(tf.int32, [batch_size, num_steps]) with tf.device("/cpu:0"): embedding = tf.get_variable("embedding", [vocab_size, size], dtype=data_type()) inputs = tf.nn.embedding_lookup(embedding, self._input_data) if (config.bi_direction): # BiLSTM self._cost, self._logits = _bilstm_model(inputs, self._targets, config) else: # LSTM self._cost, self._logits, self._final_state, self._initial_state = _lstm_model(inputs, self._targets, config) # Gradients and SGD update operation for training the model. self._lr = tf.Variable(0.0, trainable=False) tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(self._cost, tvars), config.max_grad_norm) optimizer = tf.train.GradientDescentOptimizer(self._lr) # vanila SGD self._train_op = optimizer.apply_gradients( zip(grads, tvars), global_step=tf.contrib.framework.get_or_create_global_step()) self._new_lr = tf.placeholder(data_type(), shape=[], name="new_learning_rate") self._lr_update = tf.assign(self._lr, self._new_lr) self.saver = tf.train.Saver(tf.global_variables())
Example #23
Source File: ner_tagger.py From deepnlp with MIT License | 5 votes |
def _init_ner_model(self, session): """Create ner Tagger model and initialize or load parameters in session.""" # initilize config config_dict = load_config(self.model_config_path) config = get_config(config_dict, self.name) if config is None: print ("WARNING: Input model name %s has no configuration..." % self.name) config.batch_size = 1 config.num_steps = 1 # iterator one token per time model_var_scope = get_model_var_scope(self.var_scope, self.name) print ("NOTICE: Input NER Model Var Scope Name '%s'" % model_var_scope) # Check if self.model already exist if self.model is None: with tf.variable_scope(model_var_scope, reuse = tf.AUTO_REUSE): self.model = ner_model.NERTagger(is_training=True, config=config) # save object after is_training #else: # Model Graph Def already exist # print ("DEBUG: Model Def already exists") # update model parameters if len(glob.glob(self.ckpt_path + '.data*')) > 0: # file exist with pattern: 'ner.ckpt.data*' print("NOTICE: Loading model parameters from %s" % self.ckpt_path) all_vars = tf.global_variables() model_vars = [k for k in all_vars if model_var_scope in k.name.split("/")] # e.g. ner_var_scope_zh tf.train.Saver(model_vars).restore(session, self.ckpt_path) else: print("NOTICE: Model not found, Try to run method: deepnlp.download(module='ner', name='%s')" % self.name) print("NOTICE: Created with fresh parameters.") session.run(tf.global_variables_initializer())
Example #24
Source File: build.py From Traffic-Signs-and-Object-Detection with GNU General Public License v3.0 | 5 votes |
def setup_meta_ops(self): cfg = dict({ 'allow_soft_placement': False, 'log_device_placement': False }) utility = min(self.FLAGS.gpu, 1.) if utility > 0.0: self.say('GPU mode with {} usage'.format(utility)) cfg['gpu_options'] = tf.GPUOptions( per_process_gpu_memory_fraction = utility) cfg['allow_soft_placement'] = True else: self.say('Running entirely on CPU') cfg['device_count'] = {'GPU': 0} if self.FLAGS.train: self.build_train_op() if self.FLAGS.summary: self.summary_op = tf.summary.merge_all() self.writer = tf.summary.FileWriter(self.FLAGS.summary + 'train') self.sess = tf.Session(config = tf.ConfigProto(**cfg)) self.sess.run(tf.global_variables_initializer()) if not self.ntrain: return self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = self.FLAGS.keep) if self.FLAGS.load != 0: self.load_from_ckpt() if self.FLAGS.summary: self.writer.add_graph(self.sess.graph)
Example #25
Source File: checkpoint.py From vadnet with GNU Lesser General Public License v3.0 | 5 votes |
def print_graph(): vars = tf.global_variables() for var in vars: print('{}:{}'.format(var.name, var.eval().shape)) print(var.eval())
Example #26
Source File: faster_rcnn_meta_arch.py From object_detector_app with MIT License | 5 votes |
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True): """Returns callable for loading a checkpoint into the tensorflow graph. Args: checkpoint_path: path to checkpoint to restore. from_detection_checkpoint: whether to restore from a detection checkpoint (with compatible variable names) or to restore from a classification checkpoint for initialization prior to training. Note that when from_detection_checkpoint=True, the current implementation only supports restoration from an (exactly) identical model (with exception of the num_classes parameter). Returns: a callable which takes a tf.Session as input and loads a checkpoint when run. """ if not from_detection_checkpoint: return self._feature_extractor.restore_from_classification_checkpoint_fn( checkpoint_path, self.first_stage_feature_extractor_scope, self.second_stage_feature_extractor_scope) variables_to_restore = tf.global_variables() variables_to_restore.append(slim.get_or_create_global_step()) # Only load feature extractor variables to be consistent with loading from # a classification checkpoint. first_stage_variables = tf.contrib.framework.filter_variables( variables_to_restore, include_patterns=[self.first_stage_feature_extractor_scope, self.second_stage_feature_extractor_scope]) saver = tf.train.Saver(first_stage_variables) def restore(sess): saver.restore(sess, checkpoint_path) return restore
Example #27
Source File: faster_rcnn_meta_arch.py From object_detector_app with MIT License | 5 votes |
def restore_from_classification_checkpoint_fn( self, checkpoint_path, first_stage_feature_extractor_scope, second_stage_feature_extractor_scope): """Returns callable for loading a checkpoint into the tensorflow graph. Args: checkpoint_path: path to checkpoint to restore. first_stage_feature_extractor_scope: A scope name for the first stage feature extractor. second_stage_feature_extractor_scope: A scope name for the second stage feature extractor. Returns: a callable which takes a tf.Session as input and loads a checkpoint when run. """ variables_to_restore = {} for variable in tf.global_variables(): for scope_name in [first_stage_feature_extractor_scope, second_stage_feature_extractor_scope]: if variable.op.name.startswith(scope_name): var_name = variable.op.name.replace(scope_name + '/', '') variables_to_restore[var_name] = variable variables_to_restore = ( variables_helper.get_variables_available_in_checkpoint( variables_to_restore, checkpoint_path)) saver = tf.train.Saver(variables_to_restore) def restore(sess): saver.restore(sess, checkpoint_path) return restore
Example #28
Source File: build.py From Automatic-Identification-and-Counting-of-Blood-Cells with GNU General Public License v3.0 | 5 votes |
def setup_meta_ops(self): cfg = dict({ 'allow_soft_placement': False, 'log_device_placement': False }) utility = min(self.FLAGS.gpu, 1.) if utility > 0.0: self.say('GPU mode with {} usage'.format(utility)) cfg['gpu_options'] = tf.GPUOptions( per_process_gpu_memory_fraction=utility) cfg['allow_soft_placement'] = True else: self.say('Running entirely on CPU') cfg['device_count'] = {'GPU': 0} if self.FLAGS.train: self.build_train_op() if self.FLAGS.summary: self.summary_op = tf.summary.merge_all() self.writer = tf.summary.FileWriter(self.FLAGS.summary + 'train') self.sess = tf.Session(config=tf.ConfigProto(**cfg)) self.sess.run(tf.global_variables_initializer()) if not self.ntrain: return self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=self.FLAGS.keep) if self.FLAGS.load != 0: self.load_from_ckpt() if self.FLAGS.summary: self.writer.add_graph(self.sess.graph)
Example #29
Source File: checkpoint.py From vadnet with GNU Lesser General Public License v3.0 | 5 votes |
def get_var_from_graph(name:str): vars = tf.global_variables() for var in vars: if var.name == name: return var return None
Example #30
Source File: ner_model_bilstm_crf.py From deepnlp with MIT License | 5 votes |
def __init__(self, is_training, config): self.batch_size = batch_size = config.batch_size self.num_steps = num_steps = config.num_steps self.is_training = is_training self.crf_layer = config.crf_layer # if the model has the final CRF decoding layer size = config.hidden_size vocab_size = config.vocab_size # Define input and target tensors self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps]) self._targets = tf.placeholder(tf.int32, [batch_size, num_steps]) with tf.device("/cpu:0"): embedding = tf.get_variable("embedding", [vocab_size, size], dtype=data_type()) inputs = tf.nn.embedding_lookup(embedding, self._input_data) # BiLSTM CRF model self._cost, self._logits, self._transition_params = _bilstm_crf_model(inputs, self._targets, config) # Gradients and SGD update operation for training the model. self._lr = tf.Variable(0.0, trainable=False) tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(self._cost, tvars), config.max_grad_norm) optimizer = tf.train.GradientDescentOptimizer(self._lr) self._train_op = optimizer.apply_gradients( zip(grads, tvars), global_step=tf.contrib.framework.get_or_create_global_step()) self._new_lr = tf.placeholder(data_type(), shape=[], name="new_learning_rate") self._lr_update = tf.assign(self._lr, self._new_lr) self.saver = tf.train.Saver(tf.global_variables())