Python tensorflow.scatter_add() Examples

The following are 23 code examples of tensorflow.scatter_add(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: optimizer.py    From Parser-v3 with Apache License 2.0 6 votes vote down vote up
def sparse_moving_average(self, variable, unique_indices, accumulant, name='Accumulator', decay=.9):
    """"""
    
    accumulant = tf.clip_by_value(accumulant, -self.clip, self.clip)
    first_dim = variable.get_shape().as_list()[0]
    accumulator = self.get_accumulator(name, variable)
    indexed_accumulator = tf.gather(accumulator, unique_indices)
    iteration = self.get_accumulator('{}/iteration'.format(name), variable, shape=[first_dim, 1])
    indexed_iteration = tf.gather(iteration, unique_indices)
    iteration = tf.scatter_add(iteration, unique_indices, tf.ones_like(indexed_iteration))
    indexed_iteration = tf.gather(iteration, unique_indices)
    
    if decay < 1:
      current_indexed_decay = decay * (1-decay**(indexed_iteration-1)) / (1-decay**indexed_iteration)
    else:
      current_indexed_decay = (indexed_iteration-1) / indexed_iteration
    
    accumulator = tf.scatter_update(accumulator, unique_indices, current_indexed_decay*indexed_accumulator)
    accumulator = tf.scatter_add(accumulator, unique_indices, (1-current_indexed_decay)*accumulant)
    return accumulator, iteration
  
  #============================================================= 
Example #2
Source File: optimize.py    From NJUNMT-tf with Apache License 2.0 6 votes vote down vote up
def _collect_gradients(gradients, variables):
    """ Collects gradients.

    Args:
        gradients: A list of gradients.
        variables: A list of variables for collecting the gradients.

    Returns: A tf op.
    """
    ops = []
    for grad, var in zip(gradients, variables):
        if isinstance(grad, tf.Tensor):
            ops.append(tf.assign_add(var, grad))
        else:
            ops.append(tf.scatter_add(var, grad.indices, grad.values))
    return tf.group(*ops, name="collect_gradients") 
Example #3
Source File: scatter_ops_test.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def testRepeatIndicesAdd(self):
    self._VariableRankTests(tf.scatter_add, True) 
Example #4
Source File: memory.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    base_update_op = super(LSHMemory, self).make_update_op(
        upd_idxs, upd_keys, upd_vals,
        batch_size, use_recent_idx, intended_output)

    # compute hash slots to be updated
    hash_slot_idxs = self.get_hash_slots(upd_keys)

    # make updates
    update_ops = []
    with tf.control_dependencies([base_update_op]):
      for i, slot_idxs in enumerate(hash_slot_idxs):
        # for each slot, choose which entry to replace
        entry_idx = tf.random_uniform([batch_size],
                                      maxval=self.num_per_hash_slot,
                                      dtype=tf.int32)
        entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
                                   dtype=tf.int32)
        entry_add = (tf.expand_dims(upd_idxs, 1) *
                     tf.one_hot(entry_idx, self.num_per_hash_slot,
                                dtype=tf.int32))

        mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
        with tf.control_dependencies([mul_op]):
          add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
          update_ops.append(add_op)

    return tf.group(*update_ops) 
Example #5
Source File: utils.py    From transformer-aan with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def collect_gradients(gradients, variables):
    ops = []

    for grad, var in zip(gradients, variables):
        if isinstance(grad, tf.Tensor):
            ops.append(tf.assign_add(var, grad))
        else:
            ops.append(tf.scatter_add(var, grad.indices, grad.values))

    return tf.group(*ops) 
Example #6
Source File: memory.py    From models with Apache License 2.0 5 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    base_update_op = super(LSHMemory, self).make_update_op(
        upd_idxs, upd_keys, upd_vals,
        batch_size, use_recent_idx, intended_output)

    # compute hash slots to be updated
    hash_slot_idxs = self.get_hash_slots(upd_keys)

    # make updates
    update_ops = []
    with tf.control_dependencies([base_update_op]):
      for i, slot_idxs in enumerate(hash_slot_idxs):
        # for each slot, choose which entry to replace
        entry_idx = tf.random_uniform([batch_size],
                                      maxval=self.num_per_hash_slot,
                                      dtype=tf.int32)
        entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
                                   dtype=tf.int32)
        entry_add = (tf.expand_dims(upd_idxs, 1) *
                     tf.one_hot(entry_idx, self.num_per_hash_slot,
                                dtype=tf.int32))

        mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
        with tf.control_dependencies([mul_op]):
          add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
          update_ops.append(add_op)

    return tf.group(*update_ops) 
Example #7
Source File: memory.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    base_update_op = super(LSHMemory, self).make_update_op(
        upd_idxs, upd_keys, upd_vals,
        batch_size, use_recent_idx, intended_output)

    # compute hash slots to be updated
    hash_slot_idxs = self.get_hash_slots(upd_keys)

    # make updates
    update_ops = []
    with tf.control_dependencies([base_update_op]):
      for i, slot_idxs in enumerate(hash_slot_idxs):
        # for each slot, choose which entry to replace
        entry_idx = tf.random_uniform([batch_size],
                                      maxval=self.num_per_hash_slot,
                                      dtype=tf.int32)
        entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
                                   dtype=tf.int32)
        entry_add = (tf.expand_dims(upd_idxs, 1) *
                     tf.one_hot(entry_idx, self.num_per_hash_slot,
                                dtype=tf.int32))

        mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
        with tf.control_dependencies([mul_op]):
          add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
          update_ops.append(add_op)

    return tf.group(*update_ops) 
Example #8
Source File: memory.py    From HumanRecognition with MIT License 5 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    base_update_op = super(LSHMemory, self).make_update_op(
        upd_idxs, upd_keys, upd_vals,
        batch_size, use_recent_idx, intended_output)

    # compute hash slots to be updated
    hash_slot_idxs = self.get_hash_slots(upd_keys)

    # make updates
    update_ops = []
    with tf.control_dependencies([base_update_op]):
      for i, slot_idxs in enumerate(hash_slot_idxs):
        # for each slot, choose which entry to replace
        entry_idx = tf.random_uniform([batch_size],
                                      maxval=self.num_per_hash_slot,
                                      dtype=tf.int32)
        entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
                                   dtype=tf.int32)
        entry_add = (tf.expand_dims(upd_idxs, 1) *
                     tf.one_hot(entry_idx, self.num_per_hash_slot,
                                dtype=tf.int32))

        mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
        with tf.control_dependencies([mul_op]):
          add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
          update_ops.append(add_op)

    return tf.group(*update_ops) 
Example #9
Source File: memory.py    From object_detection_with_tensorflow with MIT License 5 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    base_update_op = super(LSHMemory, self).make_update_op(
        upd_idxs, upd_keys, upd_vals,
        batch_size, use_recent_idx, intended_output)

    # compute hash slots to be updated
    hash_slot_idxs = self.get_hash_slots(upd_keys)

    # make updates
    update_ops = []
    with tf.control_dependencies([base_update_op]):
      for i, slot_idxs in enumerate(hash_slot_idxs):
        # for each slot, choose which entry to replace
        entry_idx = tf.random_uniform([batch_size],
                                      maxval=self.num_per_hash_slot,
                                      dtype=tf.int32)
        entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
                                   dtype=tf.int32)
        entry_add = (tf.expand_dims(upd_idxs, 1) *
                     tf.one_hot(entry_idx, self.num_per_hash_slot,
                                dtype=tf.int32))

        mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
        with tf.control_dependencies([mul_op]):
          add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
          update_ops.append(add_op)

    return tf.group(*update_ops) 
Example #10
Source File: cell_main.py    From amla with Apache License 2.0 5 votes vote down vote up
def calc_entropy(self, inputs, scope):
        with tf.variable_scope(scope, reuse=True):
            maxtensor = tf.to_float(tf.size(inputs))

            bincount = tf.get_variable("bincount", [self.numbins])
            featuremapsum = tf.get_variable("featuremapsum", [1])
            featuremapcount = tf.get_variable("featuremapcount", [1])
            inputs = tf.Print(inputs, [inputs, tf.shape(
                inputs)], message="Framemap:", summarize=100)
            binnum = tf.to_int32(
                tf.floor((tf.reduce_sum(inputs) / maxtensor) * (self.numbins - 1)))
            tbincount = tf.scatter_add(
                bincount, binnum, tf.to_float(
                    tf.constant(1)))
            bincount = bincount.assign(tbincount)
            bincount = tf.Print(bincount,
                                [tf.count_nonzero(bincount)],
                                message="Non zero bins count:")

            tfeaturemapsum = tf.add(featuremapsum, tf.reduce_sum(inputs))
            featuremapsum = featuremapsum.assign(tfeaturemapsum)

            tfeaturemapcount = tf.add(featuremapcount, tf.to_float(tf.constant(1)))
            featuremapcount = featuremapcount.assign(tfeaturemapcount)

            meanactivation = tf.divide(featuremapsum, featuremapcount)
            pbin = tf.divide(tf.to_float(bincount), tf.to_float(featuremapcount))
            entropy = tf.multiply(pbin, tf.log(pbin))
            entropy = tf.where(
                tf.is_nan(entropy),
                tf.zeros_like(entropy),
                entropy)
            entropy = tf.reduce_sum(entropy)
            entropy = tf.Print(entropy, [entropy], message=": raw entropy: ")
            entropy = tf.multiply(entropy, tf.multiply(
                meanactivation, tf.constant(-1.0)))
            entropy = tf.Print(
                entropy, [
                    scope, entropy], message=": scaled entropy: ")
            return entropy 
Example #11
Source File: memory.py    From object_detection_kitti with Apache License 2.0 5 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    base_update_op = super(LSHMemory, self).make_update_op(
        upd_idxs, upd_keys, upd_vals,
        batch_size, use_recent_idx, intended_output)

    # compute hash slots to be updated
    hash_slot_idxs = self.get_hash_slots(upd_keys)

    # make updates
    update_ops = []
    with tf.control_dependencies([base_update_op]):
      for i, slot_idxs in enumerate(hash_slot_idxs):
        # for each slot, choose which entry to replace
        entry_idx = tf.random_uniform([batch_size],
                                      maxval=self.num_per_hash_slot,
                                      dtype=tf.int32)
        entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
                                   dtype=tf.int32)
        entry_add = (tf.expand_dims(upd_idxs, 1) *
                     tf.one_hot(entry_idx, self.num_per_hash_slot,
                                dtype=tf.int32))

        mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
        with tf.control_dependencies([mul_op]):
          add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
          update_ops.append(add_op)

    return tf.group(*update_ops) 
Example #12
Source File: memory.py    From hands-detection with MIT License 5 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    base_update_op = super(LSHMemory, self).make_update_op(
        upd_idxs, upd_keys, upd_vals,
        batch_size, use_recent_idx, intended_output)

    # compute hash slots to be updated
    hash_slot_idxs = self.get_hash_slots(upd_keys)

    # make updates
    update_ops = []
    with tf.control_dependencies([base_update_op]):
      for i, slot_idxs in enumerate(hash_slot_idxs):
        # for each slot, choose which entry to replace
        entry_idx = tf.random_uniform([batch_size],
                                      maxval=self.num_per_hash_slot,
                                      dtype=tf.int32)
        entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
                                   dtype=tf.int32)
        entry_add = (tf.expand_dims(upd_idxs, 1) *
                     tf.one_hot(entry_idx, self.num_per_hash_slot,
                                dtype=tf.int32))

        mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
        with tf.control_dependencies([mul_op]):
          add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
          update_ops.append(add_op)

    return tf.group(*update_ops) 
Example #13
Source File: optimize.py    From Document-Transformer with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _collect_gradients(gradients, variables):
    ops = []

    for grad, var in zip(gradients, variables):
        if isinstance(grad, tf.Tensor):
            ops.append(tf.assign_add(var, grad))
        else:
            ops.append(tf.scatter_add(var, grad.indices, grad.values))

    return tf.group(*ops, name="collect_gradients") 
Example #14
Source File: memory.py    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    base_update_op = super(LSHMemory, self).make_update_op(
        upd_idxs, upd_keys, upd_vals,
        batch_size, use_recent_idx, intended_output)

    # compute hash slots to be updated
    hash_slot_idxs = self.get_hash_slots(upd_keys)

    # make updates
    update_ops = []
    with tf.control_dependencies([base_update_op]):
      for i, slot_idxs in enumerate(hash_slot_idxs):
        # for each slot, choose which entry to replace
        entry_idx = tf.random_uniform([batch_size],
                                      maxval=self.num_per_hash_slot,
                                      dtype=tf.int32)
        entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
                                   dtype=tf.int32)
        entry_add = (tf.expand_dims(upd_idxs, 1) *
                     tf.one_hot(entry_idx, self.num_per_hash_slot,
                                dtype=tf.int32))

        mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
        with tf.control_dependencies([mul_op]):
          add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
          update_ops.append(add_op)

    return tf.group(*update_ops) 
Example #15
Source File: scatter_ops_test.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def testVariableRankAdd(self):
    self._VariableRankTests(tf.scatter_add) 
Example #16
Source File: scatter_ops_test.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _NumpyAdd(ref, indices, updates):
  # Since numpy advanced assignment does not support repeated indices,
  # we run a simple loop to perform scatter_add.
  for i, indx in np.ndenumerate(indices):
    ref[indx] += updates[i] 
Example #17
Source File: embedding_ops_test.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def testWrongShape(self):
    # Indices and values mismatch.
    var = tf.Variable(tf.zeros(shape=[1024, 64, 64], dtype=tf.float32))
    indices = tf.placeholder(tf.int32, shape=[32])
    values = tf.placeholder(tf.float32, shape=[33, 64, 64])
    with self.assertRaises(ValueError):
      tf.scatter_add(var, indices, values)

    # Var and values mismatch.
    values = tf.placeholder(tf.float32, shape=[32, 64, 63])
    with self.assertRaises(ValueError):
      tf.scatter_add(var, indices, values) 
Example #18
Source File: embedding_ops_test.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _TestCase(self, shape, indices, scatter_op=tf.scatter_add):
    """Run a random test case with the given shape and indices.

    Args:
      shape: Shape of the parameters array.
      indices: One-dimensional array of ints, the indices of the last dimension
               of the parameters to update.
      scatter_op: ScatterAdd or ScatterSub.
    """
    super(ScatterAddSubTest, self).setUp()
    with self.test_session(use_gpu=False):
      # Create a random parameter array of given shape
      p_init = np.random.rand(*shape).astype("f")
      # Create the shape of the update array. All dimensions except the last
      # match the parameter array, the last dimension equals the # of indices.
      vals_shape = [len(indices)] + shape[1:]
      vals_init = np.random.rand(*vals_shape).astype("f")
      v_i = [float(x) for x in vals_init.ravel()]
      p = tf.Variable(p_init)
      vals = tf.constant(v_i, shape=vals_shape, name="vals")
      ind = tf.constant(indices, dtype=tf.int32)
      p2 = scatter_op(p, ind, vals, name="updated_p")
      # p = init
      tf.global_variables_initializer().run()
      # p += vals
      result = p2.eval()
    # Compute the expected 'p' using numpy operations.
    for i, ind in enumerate(indices):
      if scatter_op == tf.scatter_add:
        p_init.reshape(shape[0], -1)[ind, :] += (
            vals_init.reshape(vals_shape[0], -1)[i, :])
      else:
        p_init.reshape(shape[0], -1)[ind, :] -= (
            vals_init.reshape(vals_shape[0], -1)[i, :])
    self.assertTrue(all((p_init == result).ravel())) 
Example #19
Source File: memory.py    From Gun-Detector with Apache License 2.0 5 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    base_update_op = super(LSHMemory, self).make_update_op(
        upd_idxs, upd_keys, upd_vals,
        batch_size, use_recent_idx, intended_output)

    # compute hash slots to be updated
    hash_slot_idxs = self.get_hash_slots(upd_keys)

    # make updates
    update_ops = []
    with tf.control_dependencies([base_update_op]):
      for i, slot_idxs in enumerate(hash_slot_idxs):
        # for each slot, choose which entry to replace
        entry_idx = tf.random_uniform([batch_size],
                                      maxval=self.num_per_hash_slot,
                                      dtype=tf.int32)
        entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
                                   dtype=tf.int32)
        entry_add = (tf.expand_dims(upd_idxs, 1) *
                     tf.one_hot(entry_idx, self.num_per_hash_slot,
                                dtype=tf.int32))

        mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
        with tf.control_dependencies([mul_op]):
          add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
          update_ops.append(add_op)

    return tf.group(*update_ops) 
Example #20
Source File: memory.py    From yolo_v2 with Apache License 2.0 5 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    base_update_op = super(LSHMemory, self).make_update_op(
        upd_idxs, upd_keys, upd_vals,
        batch_size, use_recent_idx, intended_output)

    # compute hash slots to be updated
    hash_slot_idxs = self.get_hash_slots(upd_keys)

    # make updates
    update_ops = []
    with tf.control_dependencies([base_update_op]):
      for i, slot_idxs in enumerate(hash_slot_idxs):
        # for each slot, choose which entry to replace
        entry_idx = tf.random_uniform([batch_size],
                                      maxval=self.num_per_hash_slot,
                                      dtype=tf.int32)
        entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
                                   dtype=tf.int32)
        entry_add = (tf.expand_dims(upd_idxs, 1) *
                     tf.one_hot(entry_idx, self.num_per_hash_slot,
                                dtype=tf.int32))

        mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
        with tf.control_dependencies([mul_op]):
          add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
          update_ops.append(add_op)

    return tf.group(*update_ops) 
Example #21
Source File: cycle.py    From zero with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _collect_gradients(gradients, variables):
    ops = []

    for grad, var in zip(gradients, variables):
        if isinstance(grad, tf.Tensor):
            ops.append(tf.assign_add(var, grad))
        else:
            ops.append(tf.scatter_add(var, grad.indices, grad.values))

    return tf.group(*ops, name="collect_gradients") 
Example #22
Source File: loss_utils.py    From BERT with Apache License 2.0 5 votes vote down vote up
def center_loss_v1(config, embedding, labels, **kargs):
	'''
	embedding dim : (batch_size, num_features)
	'''
	num_features = embedding.get_shape()[-1]
	with tf.variable_scope(config.scope+"_center_loss"):
		centroids = tf.get_variable('center',
						shape=[config.num_classes, num_features],
						dtype=tf.float32,
						initializer=tf.contrib.layers.xavier_initializer(),
						trainable=False)

		centroids_delta = tf.get_variable('centroidsUpdateTempVariable',
						shape=[config.num_classes, num_features],
						dtype=tf.float32,
						initializer=tf.zeros_initializer(),
						trainable=False)

		centroids_batch = tf.gather(centroids, labels)
		# cLoss = tf.nn.l2_loss(embedding - centroids_batch) / (batch_size) # Eq. 2
		
		# cLoss = tf.reduce_mean(tf.reduce_sum((embedding - centroids_batch)**2, axis=-1))
		cLoss = tf.reduce_sum((embedding - centroids_batch)**2, axis=-1)

		diff = centroids_batch - embedding

		delta_c_nominator = tf.scatter_add(centroids_delta, labels, diff)
		indices = tf.expand_dims(labels, -1)
		updates = tf.cast(tf.ones_like(labels), tf.float32)
		shape = tf.constant([num_features])

		labels_sum = tf.expand_dims(tf.scatter_nd(indices, updates, shape),-1)
		centroids = centroids.assign_sub(config.alpha * delta_c_nominator / (1.0 + labels_sum))

		centroids_delta = centroids_delta.assign(tf.zeros([config.num_classes, num_features]))

		return cLoss, centroids 
Example #23
Source File: optimizers.py    From THUMT with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def compute_gradients(self, loss, var_list=None,
                          gate_gradients=tf.train.Optimizer.GATE_OP,
                          aggregation_method=None,
                          colocate_gradients_with_ops=False,
                          grad_loss=None):
        grads_and_vars = self._optimizer.compute_gradients(loss , var_list,
            gate_gradients, aggregation_method, colocate_gradients_with_ops,
            grad_loss)

        grads, var_list = list(zip(*grads_and_vars))

        # Do not create extra variables when step is 1
        if self._step == 1:
            grads = [self._all_reduce(t) for t in grads]
            return list(zip(grads, var_list))

        first_var = min(var_list, key=lambda x: x.name)
        iter_var = self._create_non_slot_variable(
            initial_value=0 if self._step == 1 else 1, name="iter",
            colocate_with=first_var)

        new_grads = []

        for grad, var in zip(grads, var_list):
            grad_acc = self._zeros_slot(var, "grad_acc", self._name)

            if isinstance(grad, tf.IndexedSlices):
                grad_acc = tf.scatter_add(grad_acc, grad.indices, grad.values,
                                          use_locking=self._use_locking)
            else:
                grad_acc = tf.assign_add(grad_acc, grad,
                                         use_locking=self._use_locking)

            def _acc_grad():
                return grad_acc

            def _avg_grad():
                return self._all_reduce(grad_acc / self._step)

            grad = tf.cond(tf.equal(iter_var, 0), _avg_grad, _acc_grad)
            new_grads.append(grad)

        return list(zip(new_grads, var_list))