Python tensorflow.get_collection_ref() Examples
The following are 30
code examples of tensorflow.get_collection_ref().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: ptb_word_lm.py From g-tensorflow-models with Apache License 2.0 | 6 votes |
def import_ops(self): """Imports ops from collections.""" if self._is_training: self._train_op = tf.get_collection_ref("train_op")[0] self._lr = tf.get_collection_ref("lr")[0] self._new_lr = tf.get_collection_ref("new_lr")[0] self._lr_update = tf.get_collection_ref("lr_update")[0] rnn_params = tf.get_collection_ref("rnn_params") if self._cell and rnn_params: params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable( self._cell, self._cell.params_to_canonical, self._cell.canonical_to_params, rnn_params, base_variable_scope="Model/RNN") tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable) self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0] num_replicas = FLAGS.num_gpus if self._name == "Train" else 1 self._initial_state = util.import_state_tuples( self._initial_state, self._initial_state_name, num_replicas) self._final_state = util.import_state_tuples( self._final_state, self._final_state_name, num_replicas)
Example #2
Source File: graph_search_test.py From kfac with Apache License 2.0 | 6 votes |
def test_tied_weights_untied_bias_registered_bias(self): """Tests that ambiguity in graph raises value error. Graph search will find several possible registrations for tensors. In this registering b_1 as a linked variable will result in an error because there will remain an ambiguity on the other branch of the graph. """ with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) layer_collection.define_linked_parameters((tensor_dict['b_1'])) with self.assertRaises(gs.AmbiguousRegistrationError): gs.register_layers(layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))
Example #3
Source File: ptb_word_lm.py From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 | 6 votes |
def import_ops(self): """Imports ops from collections.""" if self._is_training: self._train_op = tf.get_collection_ref("train_op")[0] self._lr = tf.get_collection_ref("lr")[0] self._new_lr = tf.get_collection_ref("new_lr")[0] self._lr_update = tf.get_collection_ref("lr_update")[0] rnn_params = tf.get_collection_ref("rnn_params") if self._cell and rnn_params: params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable( self._cell, self._cell.params_to_canonical, self._cell.canonical_to_params, rnn_params, base_variable_scope="Model/RNN") tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable) self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0] num_replicas = FLAGS.num_gpus if self._name == "Train" else 1 self._initial_state = util.import_state_tuples( self._initial_state, self._initial_state_name, num_replicas) self._final_state = util.import_state_tuples( self._final_state, self._final_state_name, num_replicas)
Example #4
Source File: ptb_word_lm.py From object_detection_kitti with Apache License 2.0 | 6 votes |
def import_ops(self): """Imports ops from collections.""" if self._is_training: self._train_op = tf.get_collection_ref("train_op")[0] self._lr = tf.get_collection_ref("lr")[0] self._new_lr = tf.get_collection_ref("new_lr")[0] self._lr_update = tf.get_collection_ref("lr_update")[0] rnn_params = tf.get_collection_ref("rnn_params") if self._cell and rnn_params: params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable( self._cell, self._cell.params_to_canonical, self._cell.canonical_to_params, rnn_params, base_variable_scope="Model/RNN") tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable) self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0] num_replicas = FLAGS.num_gpus if self._name == "Train" else 1 self._initial_state = util.import_state_tuples( self._initial_state, self._initial_state_name, num_replicas) self._final_state = util.import_state_tuples( self._final_state, self._final_state_name, num_replicas)
Example #5
Source File: block_base.py From object_detection_kitti with Apache License 2.0 | 6 votes |
def MarkAsNonTrainable(self): """Mark all the variables of this block as non-trainable. All the variables owned directly or indirectly (through subblocks) are marked as non trainable. This function along with CheckpointInitOp can be used to load a pretrained model that consists in only one part of the whole graph. """ assert self._called all_variables = self.VariableList() collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) for v in all_variables: if v in collection: collection.remove(v)
Example #6
Source File: config.py From tensorflow-tbcnn with MIT License | 6 votes |
def initialize_tbcnn_weights(clz): clz.initialize_embedding_weights() # Don't train We tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES).remove(clz.get('We')) clz.create_variable('Wcomb1', (hyper.word_dim, hyper.word_dim), tf.constant_initializer(-.2, .2)) clz.create_variable('Wcomb2', (hyper.word_dim, hyper.word_dim), tf.random_uniform_initializer(-.2, .2)) clz.create_variable('Wconvt', (hyper.word_dim, hyper.conv_dim), tf.random_uniform_initializer(-.2, .2)) clz.create_variable('Wconvl', (hyper.word_dim, hyper.conv_dim), tf.random_uniform_initializer(-.2, .2)) clz.create_variable('Wconvr', (hyper.word_dim, hyper.conv_dim), tf.random_uniform_initializer(-.2, .2)) clz.create_variable('Bconv', (hyper.conv_dim,), tf.random_uniform_initializer(-.2, .2)) clz.create_variable('FC1/weight', (hyper.conv_dim, hyper.fc_dim), tf.random_uniform_initializer(-.2, .2)) clz.create_variable('FC1/bias', (hyper.fc_dim,), tf.random_uniform_initializer(-.2, .2)) clz.create_variable('FC2/weight', (hyper.fc_dim, hyper.output_dim), tf.random_uniform_initializer(-.2, .2)) clz.create_variable('FC2/bias', (hyper.output_dim, ), tf.random_uniform_initializer(-.2, .2))
Example #7
Source File: graph_search_test.py From kfac with Apache License 2.0 | 6 votes |
def mixed_usage_test(self): """Tests that graph search raises error on mixed types usage for tensors. Tensors can be reused in various locations in the tensorflow graph. This occurs regularly in the case of recurrent models or models with parallel graphs. However the tensors must be used for the same operation in each location or graph search should raise an error. """ with tf.Graph().as_default(): w = tf.get_variable('W', [10, 10]) x = tf.placeholder(tf.float32, shape=(32, 10)) y = tf.placeholder(tf.float32, shape=(32, 10, 10)) out_0 = tf.matmul(x, w) # pylint: disable=unused-variable out_1 = y + w # pylint: disable=unused-variable layer_collection = lc.LayerCollection() with self.assertRaises(ValueError) as cm: gs.register_layers(layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) self.assertIn('mixed record types', str(cm.exception))
Example #8
Source File: ptb_word_lm.py From yolo_v2 with Apache License 2.0 | 6 votes |
def import_ops(self): """Imports ops from collections.""" if self._is_training: self._train_op = tf.get_collection_ref("train_op")[0] self._lr = tf.get_collection_ref("lr")[0] self._new_lr = tf.get_collection_ref("new_lr")[0] self._lr_update = tf.get_collection_ref("lr_update")[0] rnn_params = tf.get_collection_ref("rnn_params") if self._cell and rnn_params: params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable( self._cell, self._cell.params_to_canonical, self._cell.canonical_to_params, rnn_params, base_variable_scope="Model/RNN") tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable) self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0] num_replicas = FLAGS.num_gpus if self._name == "Train" else 1 self._initial_state = util.import_state_tuples( self._initial_state, self._initial_state_name, num_replicas) self._final_state = util.import_state_tuples( self._final_state, self._final_state_name, num_replicas)
Example #9
Source File: bayesian_rnn.py From BayesianRecurrentNN with MIT License | 6 votes |
def import_ops(self): """Imports ops from collections.""" if self._is_training: self._train_op = tf.get_collection_ref("train_op")[0] self._lr = tf.get_collection_ref("lr")[0] self._new_lr = tf.get_collection_ref("new_lr")[0] self._lr_update = tf.get_collection_ref("lr_update")[0] rnn_params = tf.get_collection_ref("rnn_params") if self._cell and rnn_params: params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable( self._cell, self._cell.params_to_canonical, self._cell.canonical_to_params, rnn_params, base_variable_scope="Model/RNN") tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable) self._cost = tf.get_collection_ref(tf_util.with_prefix(self._name, "cost"))[0] self._kl_div = tf.get_collection_ref(tf_util.with_prefix(self._name, "kl_div"))[0] num_replicas = 1 self._initial_state = tf_util.import_state_tuples( self._initial_state, self._initial_state_name, num_replicas) self._final_state = tf_util.import_state_tuples( self._final_state, self._final_state_name, num_replicas)
Example #10
Source File: ptb_word_lm.py From object_detection_with_tensorflow with MIT License | 6 votes |
def import_ops(self): """Imports ops from collections.""" if self._is_training: self._train_op = tf.get_collection_ref("train_op")[0] self._lr = tf.get_collection_ref("lr")[0] self._new_lr = tf.get_collection_ref("new_lr")[0] self._lr_update = tf.get_collection_ref("lr_update")[0] rnn_params = tf.get_collection_ref("rnn_params") if self._cell and rnn_params: params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable( self._cell, self._cell.params_to_canonical, self._cell.canonical_to_params, rnn_params, base_variable_scope="Model/RNN") tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable) self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0] num_replicas = FLAGS.num_gpus if self._name == "Train" else 1 self._initial_state = util.import_state_tuples( self._initial_state, self._initial_state_name, num_replicas) self._final_state = util.import_state_tuples( self._final_state, self._final_state_name, num_replicas)
Example #11
Source File: ptb_word_lm.py From Gun-Detector with Apache License 2.0 | 6 votes |
def import_ops(self): """Imports ops from collections.""" if self._is_training: self._train_op = tf.get_collection_ref("train_op")[0] self._lr = tf.get_collection_ref("lr")[0] self._new_lr = tf.get_collection_ref("new_lr")[0] self._lr_update = tf.get_collection_ref("lr_update")[0] rnn_params = tf.get_collection_ref("rnn_params") if self._cell and rnn_params: params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable( self._cell, self._cell.params_to_canonical, self._cell.canonical_to_params, rnn_params, base_variable_scope="Model/RNN") tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable) self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0] num_replicas = FLAGS.num_gpus if self._name == "Train" else 1 self._initial_state = util.import_state_tuples( self._initial_state, self._initial_state_name, num_replicas) self._final_state = util.import_state_tuples( self._final_state, self._final_state_name, num_replicas)
Example #12
Source File: block_base.py From DOTA_models with Apache License 2.0 | 6 votes |
def MarkAsNonTrainable(self): """Mark all the variables of this block as non-trainable. All the variables owned directly or indirectly (through subblocks) are marked as non trainable. This function along with CheckpointInitOp can be used to load a pretrained model that consists in only one part of the whole graph. """ assert self._called all_variables = self.VariableList() collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) for v in all_variables: if v in collection: collection.remove(v)
Example #13
Source File: context.py From texar with Apache License 2.0 | 5 votes |
def global_mode(): """Returns the Tensor of global mode. This is a placeholder with default value of :tf_main:`tf.estimator.ModeKeys.TRAIN <estimator/ModeKeys>`. Example: .. code-block:: python mode = session.run(global_mode()) # mode == tf.estimator.ModeKeys.TRAIN mode = session.run( global_mode(), feed_dict={tf.global_mode(): tf.estimator.ModeKeys.PREDICT}) # mode == tf.estimator.ModeKeys.PREDICT """ mode = tf.get_collection_ref(_GLOBAL_MODE_KEY) if len(mode) < 1: # mode_tensor = tf.placeholder(tf.string, name="global_mode") mode_tensor = tf.placeholder_with_default( input=tf.estimator.ModeKeys.TRAIN, shape=(), name="global_mode") # mode_tensor = tf.constant( # value=tf.estimator.ModeKeys.TRAIN, # dtype=tf.string, # name="global_mode") mode.append(mode_tensor) return mode[0]
Example #14
Source File: common.py From ternarynet with Apache License 2.0 | 5 votes |
def restore_collection(backup): for k, v in six.iteritems(backup): del tf.get_collection_ref(k)[:] tf.get_collection_ref(k).extend(v)
Example #15
Source File: util.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def import_state_tuples(state_tuples, name, num_replicas): restored = [] for i in range(len(state_tuples) * num_replicas): c = tf.get_collection_ref(name)[2 * i + 0] h = tf.get_collection_ref(name)[2 * i + 1] restored.append(tf.contrib.rnn.LSTMStateTuple(c, h)) return tuple(restored)
Example #16
Source File: models.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def __init__(self, state_size, num_timesteps, sigma_min=1e-5, dtype=tf.float32, random_seed=None, graph_collection_name="R_TILDE_VARS"): self.dtype = dtype self.sigma_min = sigma_min initializers = {"w": tf.truncated_normal_initializer(seed=random_seed), "b": tf.zeros_initializer} self.graph_collection_name=graph_collection_name def custom_getter(getter, *args, **kwargs): out = getter(*args, **kwargs) ref = tf.get_collection_ref(self.graph_collection_name) if out not in ref: ref.append(out) return out self.fns = [ snt.Linear(output_size=2*state_size, initializers=initializers, name="r_tilde_%d" % t, custom_getter=custom_getter) for t in xrange(num_timesteps) ]
Example #17
Source File: util.py From object_detection_with_tensorflow with MIT License | 5 votes |
def import_state_tuples(state_tuples, name, num_replicas): restored = [] for i in range(len(state_tuples) * num_replicas): c = tf.get_collection_ref(name)[2 * i + 0] h = tf.get_collection_ref(name)[2 * i + 1] restored.append(tf.contrib.rnn.LSTMStateTuple(c, h)) return tuple(restored)
Example #18
Source File: tf_utils.py From rltf with MIT License | 5 votes |
def normalize(x, training, momentum=0.0): """Normalize a tensor along the batch dimension. Normalization is done using the statistics of the current batch (in training mode) or based on running mean and variance (in inference mode). Args: x: tf.Tensor, shape.ndims == 2. Input tensor training: tf.Tensor or bool. Whether to return the output in training mode (normalized with statistics of the current batch) or in inference mode (normalized with moving statistics) momentum: float. Momentum for the moving average. """ assert x.shape.ndims == 2 kwargs = dict(axis=-1, center=False, scale=False, trainable=True, training=training, momentum=momentum) ops = tf.get_collection_ref(tf.GraphKeys.UPDATE_OPS) i = len(ops) x = tf.layers.batch_normalization(x, **kwargs) # Get the batch norm update ops and remove them from the global list update_ops = ops[i:] del ops[i:] # Update the moving mean and variance before returning the output with tf.control_dependencies(update_ops): x = tf.identity(x) return x
Example #19
Source File: common.py From ternarynet with Apache License 2.0 | 5 votes |
def clear_collection(keys): for k in keys: del tf.get_collection_ref(k)[:]
Example #20
Source File: model_lib.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def filter_trainable_variables(trainable_scopes): """Keep only trainable variables which are prefixed with given scopes. Args: trainable_scopes: either list of trainable scopes or string with comma separated list of trainable scopes. This function removes all variables which are not prefixed with given trainable_scopes from collection of trainable variables. Useful during network fine tuning, when you only need to train subset of variables. """ if not trainable_scopes: return if isinstance(trainable_scopes, six.string_types): trainable_scopes = [scope.strip() for scope in trainable_scopes.split(',')] trainable_scopes = {scope for scope in trainable_scopes if scope} if not trainable_scopes: return trainable_collection = tf.get_collection_ref( tf.GraphKeys.TRAINABLE_VARIABLES) non_trainable_vars = [ v for v in trainable_collection if not any([v.op.name.startswith(s) for s in trainable_scopes]) ] for v in non_trainable_vars: trainable_collection.remove(v)
Example #21
Source File: common.py From VDAIC2017 with MIT License | 5 votes |
def restore_collection(backup): for k, v in six.iteritems(backup): del tf.get_collection_ref(k)[:] tf.get_collection_ref(k).extend(v)
Example #22
Source File: model_lib.py From adversarial-logit-pairing-analysis with Apache License 2.0 | 5 votes |
def filter_trainable_variables(trainable_scopes): """Keep only trainable variables which are prefixed with given scopes. Args: trainable_scopes: either list of trainable scopes or string with comma separated list of trainable scopes. This function removes all variables which are not prefixed with given trainable_scopes from collection of trainable variables. Useful during network fine tuning, when you only need to train subset of variables. """ if not trainable_scopes: return if isinstance(trainable_scopes, six.string_types): trainable_scopes = [scope.strip() for scope in trainable_scopes.split(',')] trainable_scopes = {scope for scope in trainable_scopes if scope} if not trainable_scopes: return trainable_collection = tf.get_collection_ref( tf.GraphKeys.TRAINABLE_VARIABLES) non_trainable_vars = [ v for v in trainable_collection if not any([v.op.name.startswith(s) for s in trainable_scopes]) ] for v in non_trainable_vars: trainable_collection.remove(v)
Example #23
Source File: common.py From VDAIC2017 with MIT License | 5 votes |
def clear_collection(keys): for k in keys: del tf.get_collection_ref(k)[:]
Example #24
Source File: tf_util.py From BayesianRecurrentNN with MIT License | 5 votes |
def import_state_tuples(state_tuples, name, num_replicas): restored = [] for i in range(len(state_tuples) * num_replicas): c = tf.get_collection_ref(name)[2 * i + 0] h = tf.get_collection_ref(name)[2 * i + 1] restored.append(tf.contrib.rnn.LSTMStateTuple(c, h)) return tuple(restored)
Example #25
Source File: model.py From aapm_thoracic_challenge with MIT License | 5 votes |
def __init__(self, sess, checkpoint_dir, log_dir, training_paths, testing_paths, roi, im_size, nclass, batch_size=1, layers=3, features_root=32, conv_size=3, dropout=0.5, testing_gt_available=True, loss_type='cross_entropy', class_weights=None): self.sess = sess self.checkpoint_dir = checkpoint_dir self.log_dir = log_dir self.training_paths = training_paths self.testing_paths = testing_paths self.testing_gt_available = testing_gt_available self.nclass = nclass self.im_size = im_size self.roi = roi # (roi_order, roi_name) self.batch_size = batch_size self.layers = layers self.features_root = features_root self.conv_size = conv_size self.dropout = dropout self.loss_type = loss_type self.class_weights = class_weights self.build_model() self.saver = tf.train.Saver(tf.trainable_variables() + tf.get_collection_ref('bn_collections'))
Example #26
Source File: common.py From Distributed-BA3C with Apache License 2.0 | 5 votes |
def clear_collection(keys): for k in keys: del tf.get_collection_ref(k)[:]
Example #27
Source File: common.py From Distributed-BA3C with Apache License 2.0 | 5 votes |
def restore_collection(backup): for k, v in six.iteritems(backup): del tf.get_collection_ref(k)[:] tf.get_collection_ref(k).extend(v)
Example #28
Source File: utils.py From dynamic-training-bench with Mozilla Public License 2.0 | 5 votes |
def variables_to_save(add_list=None): """Returns a list of variables to save. add_list variables are always added to the list Args: add_list: a list of variables Returns: list: list of tensors to save """ if add_list is None: add_list = [] return tf.trainable_variables() + tf.get_collection_ref( REQUIRED_NON_TRAINABLES) + add_list + training_process_variables()
Example #29
Source File: transfer_elmo_model.py From delta with Apache License 2.0 | 5 votes |
def transfer_elmo_model(vocab_file, options_file, weight_file, token_embedding_file, output_elmo_model): dump_token_embeddings( vocab_file, options_file, weight_file, token_embedding_file ) logging.info("finish dump_token_embeddings") tf.reset_default_graph() with tf.Session(graph=tf.Graph()) as sess: bilm = BidirectionalLanguageModel( options_file, weight_file, use_character_inputs=False, embedding_weight_file=token_embedding_file ) input_x = tf.placeholder(tf.int32, shape=[None, None], name='input_x') train_embeddings_op = bilm(input_x) input_x_elmo_op = weight_layers( 'output', train_embeddings_op, l2_coef=0.0 )['weighted_op'] input_x_elmo = tf.identity(input_x_elmo_op, name="input_x_elmo") logging.info("input_x_elmo shape: {}".format(input_x_elmo)) sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.save(sess, output_elmo_model) logging.info("finish saving!") all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) for v in all_variables: logging.info("variable name: {}".format(v.name))
Example #30
Source File: util.py From vae-seq with Apache License 2.0 | 5 votes |
def dynamic_hparam(key, value): """Returns a memoized, non-constant Tensor that allows feeding.""" collection = tf.get_collection_ref("HPARAMS_" + key) if len(collection) > 1: raise ValueError("Dynamic hparams ollection should contain one item.") if not collection: with tf.name_scope(""): default_value = tf.convert_to_tensor(value, name=key + "_default") tensor = tf.placeholder_with_default( default_value, default_value.get_shape(), name=key) collection.append(tensor) return collection[0]