Python tensorflow.scatter_update() Examples
The following are 30
code examples of tensorflow.scatter_update().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: yellowfin.py From MobileNet with Apache License 2.0 | 6 votes |
def curvature_range(self): # set up the curvature window self._curv_win = \ tf.Variable(np.zeros( [self._curv_win_width, ] ), dtype=tf.float32, name="curv_win", trainable=False) self._curv_win = tf.scatter_update(self._curv_win, self._global_step % self._curv_win_width, self._grad_norm_squared) # note here the iterations start from iteration 0 valid_window = tf.slice(self._curv_win, tf.constant( [0, ] ), tf.expand_dims(tf.minimum(tf.constant(self._curv_win_width), self._global_step + 1), dim=0) ) self._h_min_t = tf.reduce_min(valid_window) self._h_max_t = tf.reduce_max(valid_window) curv_range_ops = [] with tf.control_dependencies([self._h_min_t, self._h_max_t] ): avg_op = self._moving_averager.apply([self._h_min_t, self._h_max_t] ) with tf.control_dependencies([avg_op] ): self._h_min = tf.identity(self._moving_averager.average(self._h_min_t) ) self._h_max = tf.identity(self._moving_averager.average(self._h_max_t) ) curv_range_ops.append(avg_op) return curv_range_ops
Example #2
Source File: in_graph_batch_env.py From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def reset(self, indices=None): """Reset the batch of environments. Args: indices: The batch indices of the environments to reset; defaults to all. Returns: Batch tensor of the new observations. """ if indices is None: indices = tf.range(len(self._batch_env)) observ_dtype = self._parse_dtype(self._batch_env.observation_space) observ = tf.py_func( self._batch_env.reset, [indices], observ_dtype, name='reset') observ = tf.check_numerics(observ, 'observ') reward = tf.zeros_like(indices, tf.float32) done = tf.zeros_like(indices, tf.bool) with tf.control_dependencies([ tf.scatter_update(self._observ, indices, observ), tf.scatter_update(self._reward, indices, reward), tf.scatter_update(self._done, indices, done)]): return tf.identity(observ)
Example #3
Source File: optimizer.py From Parser-v3 with Apache License 2.0 | 6 votes |
def sparse_moving_average(self, variable, unique_indices, accumulant, name='Accumulator', decay=.9): """""" accumulant = tf.clip_by_value(accumulant, -self.clip, self.clip) first_dim = variable.get_shape().as_list()[0] accumulator = self.get_accumulator(name, variable) indexed_accumulator = tf.gather(accumulator, unique_indices) iteration = self.get_accumulator('{}/iteration'.format(name), variable, shape=[first_dim, 1]) indexed_iteration = tf.gather(iteration, unique_indices) iteration = tf.scatter_add(iteration, unique_indices, tf.ones_like(indexed_iteration)) indexed_iteration = tf.gather(iteration, unique_indices) if decay < 1: current_indexed_decay = decay * (1-decay**(indexed_iteration-1)) / (1-decay**indexed_iteration) else: current_indexed_decay = (indexed_iteration-1) / indexed_iteration accumulator = tf.scatter_update(accumulator, unique_indices, current_indexed_decay*indexed_accumulator) accumulator = tf.scatter_add(accumulator, unique_indices, (1-current_indexed_decay)*accumulant) return accumulator, iteration #=============================================================
Example #4
Source File: utility.py From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def reinit_nested_vars(variables, indices=None): """Reset all variables in a nested tuple to zeros. Args: variables: Nested tuple or list of variaables. indices: Batch indices to reset, defaults to all. Returns: Operation. """ if isinstance(variables, (tuple, list)): return tf.group(*[ reinit_nested_vars(variable, indices) for variable in variables]) if indices is None: return variables.assign(tf.zeros_like(variables)) else: zeros = tf.zeros([tf.shape(indices)[0]] + variables.shape[1:].as_list()) return tf.scatter_update(variables, indices, zeros)
Example #5
Source File: in_graph_batch_env.py From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def reset(self, indices=None): """Reset the batch of environments. Args: indices: The batch indices of the environments to reset; defaults to all. Returns: Batch tensor of the new observations. """ if indices is None: indices = tf.range(len(self._batch_env)) observ_dtype = self._parse_dtype(self._batch_env.observation_space) observ = tf.py_func( self._batch_env.reset, [indices], observ_dtype, name='reset') observ = tf.check_numerics(observ, 'observ') reward = tf.zeros_like(indices, tf.float32) done = tf.zeros_like(indices, tf.bool) with tf.control_dependencies([ tf.scatter_update(self._observ, indices, observ), tf.scatter_update(self._reward, indices, reward), tf.scatter_update(self._done, indices, done)]): return tf.identity(observ)
Example #6
Source File: utility.py From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def assign_nested_vars(variables, tensors, indices=None): """Assign tensors to matching nested tuple of variables. Args: variables: Nested tuple or list of variables to update. tensors: Nested tuple or list of tensors to assign. indices: Batch indices to assign to; default to all. Returns: Operation. """ if isinstance(variables, (tuple, list)): return tf.group(*[ assign_nested_vars(variable, tensor) for variable, tensor in zip(variables, tensors)]) if indices is None: return variables.assign(tensors) else: return tf.scatter_update(variables, indices, tensors)
Example #7
Source File: memory.py From yolo_v2 with Apache License 2.0 | 6 votes |
def make_update_op(self, upd_idxs, upd_keys, upd_vals, batch_size, use_recent_idx, intended_output): """Function that creates all the update ops.""" mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size], dtype=tf.float32)) with tf.control_dependencies([mem_age_incr]): mem_age_upd = tf.scatter_update( self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32)) mem_key_upd = tf.scatter_update( self.mem_keys, upd_idxs, upd_keys) mem_val_upd = tf.scatter_update( self.mem_vals, upd_idxs, upd_vals) if use_recent_idx: recent_idx_upd = tf.scatter_update( self.recent_idx, intended_output, upd_idxs) else: recent_idx_upd = tf.group() return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
Example #8
Source File: tf_atari_wrappers.py From training_results_v0.5 with Apache License 2.0 | 6 votes |
def _reset_non_empty(self, indices): # pylint: disable=protected-access new_values = self._batch_env._reset_non_empty(indices) # pylint: enable=protected-access initial_frames = getattr(self._batch_env, "history_observations", None) if initial_frames is not None: # Using history buffer frames for initialization, if they are available. with tf.control_dependencies([new_values]): # Transpose to [batch, height, width, history, channels] and merge # history and channels into one dimension. initial_frames = tf.transpose(initial_frames, [0, 2, 3, 1, 4]) initial_frames = tf.reshape(initial_frames, (len(self),) + self.observ_shape) else: inx = tf.concat( [ tf.ones(tf.size(tf.shape(new_values)), dtype=tf.int64)[:-1], [self.history] ], axis=0) initial_frames = tf.tile(new_values, inx) assign_op = tf.scatter_update(self._observ, indices, initial_frames) with tf.control_dependencies([assign_op]): return tf.gather(self.observ, indices)
Example #9
Source File: tf_atari_wrappers.py From training_results_v0.5 with Apache License 2.0 | 6 votes |
def _reset_non_empty(self, indices): # pylint: disable=protected-access new_values = self._batch_env._reset_non_empty(indices) # pylint: enable=protected-access initial_frames = getattr(self._batch_env, "history_observations", None) if initial_frames is not None: # Using history buffer frames for initialization, if they are available. with tf.control_dependencies([new_values]): # Transpose to [batch, height, width, history, channels] and merge # history and channels into one dimension. initial_frames = tf.transpose(initial_frames, [0, 2, 3, 1, 4]) initial_frames = tf.reshape(initial_frames, (len(self),) + self.observ_shape) else: inx = tf.concat( [ tf.ones(tf.size(tf.shape(new_values)), dtype=tf.int64)[:-1], [self.history] ], axis=0) initial_frames = tf.tile(new_values, inx) assign_op = tf.scatter_update(self._observ, indices, initial_frames) with tf.control_dependencies([assign_op]): return tf.gather(self.observ, indices)
Example #10
Source File: memory.py From Gun-Detector with Apache License 2.0 | 6 votes |
def make_update_op(self, upd_idxs, upd_keys, upd_vals, batch_size, use_recent_idx, intended_output): """Function that creates all the update ops.""" mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size], dtype=tf.float32)) with tf.control_dependencies([mem_age_incr]): mem_age_upd = tf.scatter_update( self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32)) mem_key_upd = tf.scatter_update( self.mem_keys, upd_idxs, upd_keys) mem_val_upd = tf.scatter_update( self.mem_vals, upd_idxs, upd_vals) if use_recent_idx: recent_idx_upd = tf.scatter_update( self.recent_idx, intended_output, upd_idxs) else: recent_idx_upd = tf.group() return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
Example #11
Source File: control_flow_ops_py_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testWhileUpdateVariable_1(self): with self.test_session(): select = tf.Variable([3.0, 4.0, 5.0]) n = tf.constant(0) def loop_iterator(j): return tf.less(j, 3) def loop_body(j): ns = tf.scatter_update(select, j, 10.0) nj = tf.add(j, 1) op = control_flow_ops.group(ns) nj = control_flow_ops.with_dependencies([op], nj) return [nj] r = tf.while_loop(loop_iterator, loop_body, [n], parallel_iterations=1) self.assertTrue(check_op_order(n.graph)) tf.global_variables_initializer().run() self.assertEqual(3, r.eval()) result = select.eval() self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
Example #12
Source File: control_flow_ops_py_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testWhileUpdateVariable_3(self): with self.test_session(): select = tf.Variable([3.0, 4.0, 5.0]) n = tf.constant(0) def loop_iterator(j, _): return tf.less(j, 3) def loop_body(j, _): ns = tf.scatter_update(select, j, 10.0) nj = tf.add(j, 1) return [nj, ns] r = tf.while_loop(loop_iterator, loop_body, [n, tf.identity(select)], parallel_iterations=1) tf.global_variables_initializer().run() result = r[1].eval() self.assertTrue(check_op_order(n.graph)) self.assertAllClose(np.array([10.0, 10.0, 10.0]), result) # b/24814703
Example #13
Source File: topn.py From deep_image_model with Apache License 2.0 | 6 votes |
def insert(self, ids, scores): """Insert the ids and scores into the TopN.""" with tf.control_dependencies(self.last_ops): scatter_op = tf.scatter_update(self.id_to_score, ids, scores) larger_scores = tf.greater(scores, self.sl_scores[0]) def shortlist_insert(): larger_ids = tf.boolean_mask(tf.to_int64(ids), larger_scores) larger_score_values = tf.boolean_mask(scores, larger_scores) shortlist_ids, new_ids, new_scores = self.ops.top_n_insert( self.sl_ids, self.sl_scores, larger_ids, larger_score_values) u1 = tf.scatter_update(self.sl_ids, shortlist_ids, new_ids) u2 = tf.scatter_update(self.sl_scores, shortlist_ids, new_scores) return tf.group(u1, u2) # We only need to insert into the shortlist if there are any # scores larger than the threshold. cond_op = tf.cond( tf.reduce_any(larger_scores), shortlist_insert, tf.no_op) with tf.control_dependencies([cond_op]): self.last_ops = [scatter_op, cond_op]
Example #14
Source File: topn.py From deep_image_model with Apache License 2.0 | 6 votes |
def remove(self, ids): """Remove the ids (and their associated scores) from the TopN.""" with tf.control_dependencies(self.last_ops): scatter_op = tf.scatter_update( self.id_to_score, ids, tf.ones_like( ids, dtype=tf.float32) * tf.float32.min) # We assume that removed ids are almost always in the shortlist, # so it makes no sense to hide the Op behind a tf.cond shortlist_ids_to_remove, new_length = self.ops.top_n_remove(self.sl_ids, ids) u1 = tf.scatter_update( self.sl_ids, tf.concat(0, [[0], shortlist_ids_to_remove]), tf.concat(0, [new_length, tf.ones_like(shortlist_ids_to_remove) * -1])) u2 = tf.scatter_update( self.sl_scores, shortlist_ids_to_remove, tf.float32.min * tf.ones_like( shortlist_ids_to_remove, dtype=tf.float32)) self.last_ops = [scatter_op, u1, u2]
Example #15
Source File: in_graph_batch_env.py From planet with Apache License 2.0 | 6 votes |
def reset(self, indices=None): """Reset the batch of environments. Args: indices: The batch indices of the environments to reset; defaults to all. Returns: Batch tensor of the new observations. """ if indices is None: indices = tf.range(len(self._batch_env)) observ_dtype = self._parse_dtype(self._batch_env.observation_space) observ = tf.py_func( self._batch_env.reset, [indices], observ_dtype, name='reset') reward = tf.zeros_like(indices, tf.float32) done = tf.zeros_like(indices, tf.int32) with tf.control_dependencies([ tf.scatter_update(self._observ, indices, observ), tf.scatter_update(self._reward, indices, reward), tf.scatter_update(self._done, indices, tf.to_int32(done))]): return tf.identity(observ)
Example #16
Source File: memory.py From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def replace(self, episodes, length, rows=None): """Replace full episodes. Args: episodes: Tuple of transition quantities with batch and time dimensions. length: Batch of sequence lengths. rows: Episodes to replace, defaults to all. Returns: Operation. """ rows = tf.range(self._capacity) if rows is None else rows assert rows.shape.ndims == 1 assert_capacity = tf.assert_less( rows, self._capacity, message='capacity exceeded') with tf.control_dependencies([assert_capacity]): assert_max_length = tf.assert_less_equal( length, self._max_length, message='max length exceeded') replace_ops = [] with tf.control_dependencies([assert_max_length]): for buffer_, elements in zip(self._buffers, episodes): replace_op = tf.scatter_update(buffer_, rows, elements) replace_ops.append(replace_op) with tf.control_dependencies(replace_ops): return tf.scatter_update(self._length, rows, length)
Example #17
Source File: yellowfin.py From MobileNet with Apache License 2.0 | 6 votes |
def get_mu_tensor(self): const_fact = self._dist_to_opt_avg**2 * self._h_min**2 / 2 / self._grad_var coef = tf.Variable([-1.0, 3.0, 0.0, 1.0], dtype=tf.float32, name="cubic_solver_coef") coef = tf.scatter_update(coef, tf.constant(2), -(3 + const_fact) ) roots = tf.py_func(np.roots, [coef], Tout=tf.complex64, stateful=False) # filter out the correct root root_idx = tf.logical_and(tf.logical_and(tf.greater(tf.real(roots), tf.constant(0.0) ), tf.less(tf.real(roots), tf.constant(1.0) ) ), tf.less(tf.abs(tf.imag(roots) ), 1e-5) ) # in case there are two duplicated roots satisfying the above condition root = tf.reshape(tf.gather(tf.gather(roots, tf.where(root_idx) ), tf.constant(0) ), shape=[] ) tf.assert_equal(tf.size(root), tf.constant(1) ) dr = self._h_max / self._h_min mu = tf.maximum(tf.real(root)**2, ( (tf.sqrt(dr) - 1)/(tf.sqrt(dr) + 1) )**2) return mu
Example #18
Source File: memory.py From batch-ppo with Apache License 2.0 | 6 votes |
def replace(self, episodes, length, rows=None): """Replace full episodes. Args: episodes: Tuple of transition quantities with batch and time dimensions. length: Batch of sequence lengths. rows: Episodes to replace, defaults to all. Returns: Operation. """ rows = tf.range(self._capacity) if rows is None else rows assert rows.shape.ndims == 1 assert_capacity = tf.assert_less( rows, self._capacity, message='capacity exceeded') with tf.control_dependencies([assert_capacity]): assert_max_length = tf.assert_less_equal( length, self._max_length, message='max length exceeded') with tf.control_dependencies([assert_max_length]): replace_ops = tools.nested.map( lambda var, val: tf.scatter_update(var, rows, val), self._buffers, episodes, flatten=True) with tf.control_dependencies(replace_ops): return tf.scatter_update(self._length, rows, length)
Example #19
Source File: in_graph_batch_env.py From batch-ppo with Apache License 2.0 | 6 votes |
def reset(self, indices=None): """Reset the batch of environments. Args: indices: The batch indices of the environments to reset; defaults to all. Returns: Batch tensor of the new observations. """ if indices is None: indices = tf.range(len(self._batch_env)) observ_dtype = self._parse_dtype(self._batch_env.observation_space) observ = tf.py_func( self._batch_env.reset, [indices], observ_dtype, name='reset') observ = tf.check_numerics(observ, 'observ') reward = tf.zeros_like(indices, tf.float32) done = tf.zeros_like(indices, tf.bool) with tf.control_dependencies([ tf.scatter_update(self._observ, indices, observ), tf.scatter_update(self._reward, indices, reward), tf.scatter_update(self._done, indices, done)]): return tf.identity(observ)
Example #20
Source File: utility.py From batch-ppo with Apache License 2.0 | 6 votes |
def reinit_nested_vars(variables, indices=None): """Reset all variables in a nested tuple to zeros. Args: variables: Nested tuple or list of variables. indices: Batch indices to reset, defaults to all. Returns: Operation. """ if isinstance(variables, (tuple, list)): return tf.group(*[ reinit_nested_vars(variable, indices) for variable in variables]) if indices is None: return variables.assign(tf.zeros_like(variables)) else: zeros = tf.zeros([tf.shape(indices)[0]] + variables.shape[1:].as_list()) return tf.scatter_update(variables, indices, zeros)
Example #21
Source File: latent_factor.py From openrec with Apache License 2.0 | 6 votes |
def censor_l2_norm_op(self, censor_id_list=None, max_norm=1): """Limit the norm of embeddings. Parameters ---------- censor_id_list: list or Tensorflow tensor list of embeddings to censor (indexed by ids). max_norm: float, optional Maximum norm. Returns ------- Tensorflow operator An operator for post-training execution. """ embedding_gather = tf.gather(self._embedding, indices=censor_id_list) norm = tf.sqrt(tf.reduce_sum(tf.square(embedding_gather), axis=1, keep_dims=True)) return tf.scatter_update(self._embedding, indices=censor_id_list, updates=embedding_gather / tf.maximum(norm, max_norm))
Example #22
Source File: model.py From professional-services with Apache License 2.0 | 6 votes |
def _update_embedding_matrix(row_indices, rows, embedding_size, tft_output, vocab_name): """Creates and maintains a lookup table of embeddings for inference. Args: row_indices: indices of rows of the lookup table to update. rows: the values to update the lookup table with. embedding_size: the size of the embedding. tft_output: a TFTransformOutput object. vocab_name: a tft vocabulary name. Returns: A num_items x embedding_size table of the latest embeddings with the given rows updated. """ embedding = _get_embedding_matrix(embedding_size, tft_output, vocab_name) return tf.scatter_update(embedding, row_indices, rows)
Example #23
Source File: association.py From Deep-Association-Learning with MIT License | 6 votes |
def update_cross_anchor(cross_anchors, intra_anchors_n, intra_anchors_batch_n, labels_cam, start_sign): # update cross-anchor for i in range(FLAGS.num_cams): # other_anchors: all the anchors under other cameras other_anchors_n = [] [other_anchors_n.append(intra_anchors_n[x]) for x in range(FLAGS.num_cams) if x is not i] other_anchors_n = tf.concat(other_anchors_n, 0) consistent, rank1_anchors = \ cyclic_ranking(intra_anchors_batch_n[i], intra_anchors_n[i], other_anchors_n, labels_cam[i], start_sign) # if the consistency fulfills, update by # merging with the best-matched rank1 anchors in another camera update = tf.where(consistent, (intra_anchors_batch_n[i] + rank1_anchors) / 2, intra_anchors_batch_n[i]) # update the associate centers under each camera cross_anchors[i] = tf.scatter_update(cross_anchors[i], labels_cam[i], update) return cross_anchors
Example #24
Source File: memory.py From hands-detection with MIT License | 6 votes |
def make_update_op(self, upd_idxs, upd_keys, upd_vals, batch_size, use_recent_idx, intended_output): """Function that creates all the update ops.""" mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size], dtype=tf.float32)) with tf.control_dependencies([mem_age_incr]): mem_age_upd = tf.scatter_update( self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32)) mem_key_upd = tf.scatter_update( self.mem_keys, upd_idxs, upd_keys) mem_val_upd = tf.scatter_update( self.mem_vals, upd_idxs, upd_vals) if use_recent_idx: recent_idx_upd = tf.scatter_update( self.recent_idx, intended_output, upd_idxs) else: recent_idx_upd = tf.group() return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
Example #25
Source File: tensorlfowapi.py From SSD_tensorflow_VOC with Apache License 2.0 | 6 votes |
def test_scatter_nd_2(): gt_bboxes = tf.constant([[0,0,1,2],[1,0,3,4],[100,100,105,102.5]]) gt_labels = tf.constant([1,2,6]) gt_anchors_labels = tf.Variable([100,100,100,100], trainable=False,collections=[ops.GraphKeys.LOCAL_VARIABLES]) gt_anchors_bboxes=tf.Variable([[100,100,105,105],[2,1,3,3.5],[0,0,10,10],[0.5,0.5,0.8,1.5]], trainable=False,collections=[ops.GraphKeys.LOCAL_VARIABLES],dtype=tf.float32) max_inds = [1,0,3] gt_anchors_labels = tf.scatter_update(gt_anchors_labels, max_inds,gt_labels) gt_anchors_bboxes = tf.scatter_update(gt_anchors_bboxes, max_inds,gt_bboxes) return gt_anchors_labels,gt_anchors_bboxes
Example #26
Source File: memory.py From object_detection_kitti with Apache License 2.0 | 6 votes |
def make_update_op(self, upd_idxs, upd_keys, upd_vals, batch_size, use_recent_idx, intended_output): """Function that creates all the update ops.""" mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size], dtype=tf.float32)) with tf.control_dependencies([mem_age_incr]): mem_age_upd = tf.scatter_update( self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32)) mem_key_upd = tf.scatter_update( self.mem_keys, upd_idxs, upd_keys) mem_val_upd = tf.scatter_update( self.mem_vals, upd_idxs, upd_vals) if use_recent_idx: recent_idx_upd = tf.scatter_update( self.recent_idx, intended_output, upd_idxs) else: recent_idx_upd = tf.group() return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
Example #27
Source File: memory.py From object_detection_with_tensorflow with MIT License | 6 votes |
def make_update_op(self, upd_idxs, upd_keys, upd_vals, batch_size, use_recent_idx, intended_output): """Function that creates all the update ops.""" mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size], dtype=tf.float32)) with tf.control_dependencies([mem_age_incr]): mem_age_upd = tf.scatter_update( self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32)) mem_key_upd = tf.scatter_update( self.mem_keys, upd_idxs, upd_keys) mem_val_upd = tf.scatter_update( self.mem_vals, upd_idxs, upd_vals) if use_recent_idx: recent_idx_upd = tf.scatter_update( self.recent_idx, intended_output, upd_idxs) else: recent_idx_upd = tf.group() return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
Example #28
Source File: memory.py From HumanRecognition with MIT License | 6 votes |
def make_update_op(self, upd_idxs, upd_keys, upd_vals, batch_size, use_recent_idx, intended_output): """Function that creates all the update ops.""" mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size], dtype=tf.float32)) with tf.control_dependencies([mem_age_incr]): mem_age_upd = tf.scatter_update( self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32)) mem_key_upd = tf.scatter_update( self.mem_keys, upd_idxs, upd_keys) mem_val_upd = tf.scatter_update( self.mem_vals, upd_idxs, upd_vals) if use_recent_idx: recent_idx_upd = tf.scatter_update( self.recent_idx, intended_output, upd_idxs) else: recent_idx_upd = tf.group() return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
Example #29
Source File: memory.py From g-tensorflow-models with Apache License 2.0 | 6 votes |
def make_update_op(self, upd_idxs, upd_keys, upd_vals, batch_size, use_recent_idx, intended_output): """Function that creates all the update ops.""" mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size], dtype=tf.float32)) with tf.control_dependencies([mem_age_incr]): mem_age_upd = tf.scatter_update( self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32)) mem_key_upd = tf.scatter_update( self.mem_keys, upd_idxs, upd_keys) mem_val_upd = tf.scatter_update( self.mem_vals, upd_idxs, upd_vals) if use_recent_idx: recent_idx_upd = tf.scatter_update( self.recent_idx, intended_output, upd_idxs) else: recent_idx_upd = tf.group() return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
Example #30
Source File: memory.py From models with Apache License 2.0 | 6 votes |
def make_update_op(self, upd_idxs, upd_keys, upd_vals, batch_size, use_recent_idx, intended_output): """Function that creates all the update ops.""" mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size], dtype=tf.float32)) with tf.control_dependencies([mem_age_incr]): mem_age_upd = tf.scatter_update( self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32)) mem_key_upd = tf.scatter_update( self.mem_keys, upd_idxs, upd_keys) mem_val_upd = tf.scatter_update( self.mem_vals, upd_idxs, upd_vals) if use_recent_idx: recent_idx_upd = tf.scatter_update( self.recent_idx, intended_output, upd_idxs) else: recent_idx_upd = tf.group() return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)