Python tensorflow.python.ops.math_ops.greater_equal() Examples

The following are 30 code examples of tensorflow.python.ops.math_ops.greater_equal(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.ops.math_ops , or try the search function .
Example #1
Source File: backend.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def greater_equal(x, y):
  """Element-wise truth value of (x >= y).

  Arguments:
      x: Tensor or variable.
      y: Tensor or variable.

  Returns:
      A bool tensor.
  """
  return math_ops.greater_equal(x, y) 
Example #2
Source File: math_grad.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _MaximumGrad(op, grad):
  """Returns grad*(x > y, x <= y) with type of grad."""
  return _MaximumMinimumGrad(op, grad, math_ops.greater_equal) 
Example #3
Source File: check_ops.py    From lambda-packs with MIT License 5 votes vote down vote up
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
                         name=None):
  """Assert the condition `x >= y` holds element-wise.

  Example of adding a dependency to an operation:

  ```python
  with tf.control_dependencies([tf.assert_greater_equal(x, y)]):
    output = tf.reduce_sum(x)
  ```

  This condition holds if for every pair of (possibly broadcast) elements
  `x[i]`, `y[i]`, we have `x[i] >= y[i]`.
  If both `x` and `y` are empty, this is trivially satisfied.

  Args:
    x:  Numeric `Tensor`.
    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`, `y`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).  Defaults to
      "assert_greater_equal"

  Returns:
    Op that raises `InvalidArgumentError` if `x >= y` is False.
  """
  message = message or ''
  with ops.name_scope(name, 'assert_greater_equal', [x, y, data]):
    x = ops.convert_to_tensor(x, name='x')
    y = ops.convert_to_tensor(y, name='y')
    if data is None:
      data = [
          message,
          'Condition x >= y did not hold element-wise:'
          'x (%s) = ' % x.name, x, 'y (%s) = ' % y.name, y
      ]
    condition = math_ops.reduce_all(math_ops.greater_equal(x, y))
    return control_flow_ops.Assert(condition, data, summarize=summarize) 
Example #4
Source File: embedding_ops.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _prune_invalid_ids(sparse_ids, sparse_weights):
  """Prune invalid IDs (< 0) from the input ids and weights."""
  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
  if sparse_weights is not None:
    is_id_valid = math_ops.logical_and(
        is_id_valid, math_ops.greater(sparse_weights.values, 0))
  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
  if sparse_weights is not None:
    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
  return sparse_ids, sparse_weights 
Example #5
Source File: target_column.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _accuracy_at_threshold(threshold):

  def _accuracy_metric(predictions, labels, weights=None):
    threshold_predictions = math_ops.to_float(
        math_ops.greater_equal(predictions, threshold))
    return metrics_lib.streaming_accuracy(predictions=threshold_predictions,
                                          labels=labels,
                                          weights=weights)

  return _accuracy_metric 
Example #6
Source File: metrics.py    From ULTRA with Apache License 2.0 5 votes vote down vote up
def mean_reciprocal_rank(labels, predictions, weights=None, name=None):
    """Computes mean reciprocal rank (MRR).

    Args:
      labels: A `Tensor` of the same shape as `predictions`. A value >= 1 means a
        relevant example.
      predictions: A `Tensor` with shape [batch_size, list_size]. Each value is
        the ranking score of the corresponding example.
      weights: A `Tensor` of the same shape of predictions or [batch_size, 1]. The
        former case is per-example and the latter case is per-list.
      name: A string used as the name for this metric.

    Returns:
      A metric for the weighted mean reciprocal rank of the batch.
    """
    with ops.name_scope(name, 'mean_reciprocal_rank',
                        (labels, predictions, weights)):
        _, list_size = array_ops.unstack(array_ops.shape(predictions))
        labels, predictions, weights, topn = _prepare_and_validate_params(
            labels, predictions, weights, list_size)
        sorted_labels, = utils.sort_by_scores(predictions, [labels], topn=topn)
        # Relevance = 1.0 when labels >= 1.0 to accommodate graded relevance.
        relevance = math_ops.to_float(
            math_ops.greater_equal(
                sorted_labels, 1.0))
        reciprocal_rank = 1.0 / math_ops.to_float(math_ops.range(1, topn + 1))
        # MRR has a shape of [batch_size, 1]
        mrr = math_ops.reduce_max(
            relevance * reciprocal_rank, axis=1, keepdims=True)
        return math_ops.reduce_mean(
            mrr * array_ops.ones_like(weights) * weights) 
Example #7
Source File: metrics.py    From ULTRA with Apache License 2.0 5 votes vote down vote up
def expected_reciprocal_rank(
        labels, predictions, weights=None, topn=None, name=None):
    """Computes expected reciprocal rank (ERR).

    Args:
      labels: A `Tensor` of the same shape as `predictions`. A value >= 1 means a
        relevant example.
      predictions: A `Tensor` with shape [batch_size, list_size]. Each value is
        the ranking score of the corresponding example.
      weights: A `Tensor` of the same shape of predictions or [batch_size, 1]. The
        former case is per-example and the latter case is per-list.
      topn: A cutoff for how many examples to consider for this metric.
      name: A string used as the name for this metric.

    Returns:
      A metric for the weighted expected reciprocal rank of the batch.
    """
    with ops.name_scope(name, 'expected_reciprocal_rank',
                        (labels, predictions, weights)):
        labels, predictions, weights, topn = _prepare_and_validate_params(
            labels, predictions, weights, topn)
        sorted_labels, sorted_weights = utils.sort_by_scores(
            predictions, [labels, weights], topn=topn)
        _, list_size = array_ops.unstack(array_ops.shape(sorted_labels))

        relevance = (math_ops.pow(2.0, sorted_labels) - 1) / \
            math_ops.pow(2.0, RankingMetricKey.MAX_LABEL)
        non_rel = tf.math.cumprod(1.0 - relevance, axis=1) / (1.0 - relevance)
        reciprocal_rank = 1.0 / \
            math_ops.to_float(math_ops.range(1, list_size + 1))
        mask = math_ops.to_float(math_ops.greater_equal(
            reciprocal_rank, 1.0 / (topn + 1)))
        reciprocal_rank = reciprocal_rank * mask
        # ERR has a shape of [batch_size, 1]
        err = math_ops.reduce_sum(
            relevance * non_rel * reciprocal_rank * sorted_weights, axis=1, keepdims=True)
        return math_ops.reduce_mean(err) 
Example #8
Source File: metrics.py    From ULTRA with Apache License 2.0 5 votes vote down vote up
def precision(labels, predictions, weights=None, topn=None, name=None):
    """Computes precision as weighted average of relevant examples.

    Args:
      labels: A `Tensor` of the same shape as `predictions`. A value >= 1 means a
        relevant example.
      predictions: A `Tensor` with shape [batch_size, list_size]. Each value is
        the ranking score of the corresponding example.
      weights: A `Tensor` of the same shape of predictions or [batch_size, 1]. The
        former case is per-example and the latter case is per-list.
      topn: A cutoff for how many examples to consider for this metric.
      name: A string used as the name for this metric.

    Returns:
      A metric for the weighted precision of the batch.
    """
    with ops.name_scope(name, 'precision', (labels, predictions, weights)):
        labels, predictions, weights, topn = _prepare_and_validate_params(
            labels, predictions, weights, topn)
        sorted_labels, sorted_weights = utils.sort_by_scores(
            predictions, [labels, weights], topn=topn)
        # Relevance = 1.0 when labels >= 1.0.
        relevance = math_ops.to_float(
            math_ops.greater_equal(
                sorted_labels, 1.0))
        per_list_precision = _safe_div(
            math_ops.reduce_sum(relevance * sorted_weights, 1, keepdims=True),
            math_ops.reduce_sum(
                array_ops.ones_like(relevance) * sorted_weights, 1, keepdims=True))
        # per_list_weights are computed from the whole list to avoid the problem of
        # 0 when there is no relevant example in topn.
        per_list_weights = _per_example_weights_to_per_list_weights(
            weights, math_ops.to_float(math_ops.greater_equal(labels, 1.0)))
        return math_ops.reduce_mean(per_list_precision * per_list_weights) 
Example #9
Source File: sparse_optimizers.py    From rigl with Apache License 2.0 5 votes vote down vote up
def is_mask_update_iter(self, global_step, last_update_step):
    """Function for checking if the current step is a mask update step.

    It also creates the drop_fraction op and assigns it to the self object.

    Args:
      global_step: tf.Variable(int), current training step.
      last_update_step: tf.Variable(int), holding the last iteration the mask
        is updated. Used to determine whether current iteration is a mask
        update step.


    Returns:
      bool, whether the current iteration is a mask_update step.
    """
    gs_dtype = global_step.dtype
    self._begin_step = math_ops.cast(self._begin_step, gs_dtype)
    self._end_step = math_ops.cast(self._end_step, gs_dtype)
    self._frequency = math_ops.cast(self._frequency, gs_dtype)
    is_step_within_update_range = math_ops.logical_and(
        math_ops.greater_equal(global_step, self._begin_step),
        math_ops.logical_or(
            math_ops.less_equal(global_step, self._end_step),
            # If _end_step is negative, we never stop updating the mask.
            # In other words we update the mask with given frequency until the
            # training ends.
            math_ops.less(self._end_step, 0)))
    is_update_step = math_ops.less_equal(
        math_ops.add(last_update_step, self._frequency), global_step)
    is_mask_update_iter_op = math_ops.logical_and(
        is_step_within_update_range, is_update_step)
    self.drop_fraction = self.get_drop_fraction(global_step,
                                                is_mask_update_iter_op)
    return is_mask_update_iter_op 
Example #10
Source File: test_forward.py    From incubator-tvm with Apache License 2.0 5 votes vote down vote up
def _test_greater_equal(data):
    """ One iteration of greater_equal """
    return _test_elemwise(math_ops.greater_equal, data)
#######################################################################
# Less
# ---- 
Example #11
Source File: test_forward.py    From incubator-tvm with Apache License 2.0 5 votes vote down vote up
def test_forward_rel_ops():
    t1 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    t2 = np.array([[9, 8, 7], [6, 5, 4], [3, 2, 1]])
    _test_forward_rel_op([t1, t2], math_ops.less)
    _test_forward_rel_op([t1, t2], math_ops.greater)
    _test_forward_rel_op([t1, t2], math_ops.less_equal)
    _test_forward_rel_op([t1, t2], math_ops.greater_equal)
    _test_forward_rel_op([t1, t2], math_ops.equal)
    _test_forward_rel_op([t1, t2], math_ops.not_equal)

#######################################################################
# ExpandDims
# ---------- 
Example #12
Source File: head.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _accuracy_at_threshold(threshold):

  def _accuracy_metric(predictions, labels, weights=None):
    threshold_predictions = math_ops.to_float(
        math_ops.greater_equal(predictions, threshold))
    return metrics_lib.streaming_accuracy(predictions=threshold_predictions,
                                          labels=labels,
                                          weights=weights)

  return _accuracy_metric 
Example #13
Source File: head.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def _accuracy_at_threshold(labels, predictions, weights, threshold, name=None):
  with ops.name_scope(
      name, 'accuracy_at_%s' % threshold,
      (predictions, labels, weights, threshold)) as scope:
    threshold_predictions = math_ops.to_float(
        math_ops.greater_equal(predictions, threshold))
    return metrics_lib.accuracy(
        labels=labels, predictions=threshold_predictions, weights=weights,
        name=scope) 
Example #14
Source File: check_ops.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
                         name=None):
  """Assert the condition `x >= y` holds element-wise.

  Example of adding a dependency to an operation:

  ```python
  with tf.control_dependencies([tf.assert_greater_equal(x, y)]):
    output = tf.reduce_sum(x)
  ```

  This condition holds if for every pair of (possibly broadcast) elements
  `x[i]`, `y[i]`, we have `x[i] >= y[i]`.
  If both `x` and `y` are empty, this is trivially satisfied.

  Args:
    x:  Numeric `Tensor`.
    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`, `y`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).  Defaults to
      "assert_greater_equal"

  Returns:
    Op that raises `InvalidArgumentError` if `x >= y` is False.
  """
  message = message or ''
  with ops.name_scope(name, 'assert_greater_equal', [x, y, data]):
    x = ops.convert_to_tensor(x, name='x')
    y = ops.convert_to_tensor(y, name='y')
    if data is None:
      data = [
          message,
          'Condition x >= y did not hold element-wise:'
          'x (%s) = ' % x.name, x, 'y (%s) = ' % y.name, y
      ]
    condition = math_ops.reduce_all(math_ops.greater_equal(x, y))
    return control_flow_ops.Assert(condition, data, summarize=summarize) 
Example #15
Source File: feature_column.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def _prune_invalid_ids(sparse_ids, sparse_weights):
  """Prune invalid IDs (< 0) from the input ids and weights."""
  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
  if sparse_weights is not None:
    is_id_valid = math_ops.logical_and(
        is_id_valid, math_ops.greater(sparse_weights.values, 0))
  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
  if sparse_weights is not None:
    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
  return sparse_ids, sparse_weights 
Example #16
Source File: check_ops.py    From keras-lambda with MIT License 5 votes vote down vote up
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
                         name=None):
  """Assert the condition `x >= y` holds element-wise.

  Example of adding a dependency to an operation:

  ```python
  with tf.control_dependencies([tf.assert_greater_equal(x, y)]):
    output = tf.reduce_sum(x)
  ```

  This condition holds if for every pair of (possibly broadcast) elements
  `x[i]`, `y[i]`, we have `x[i] >= y[i]`.
  If both `x` and `y` are empty, this is trivially satisfied.

  Args:
    x:  Numeric `Tensor`.
    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`, `y`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).  Defaults to
      "assert_greater_equal"

  Returns:
    Op that raises `InvalidArgumentError` if `x >= y` is False.
  """
  message = message or ''
  with ops.name_scope(name, 'assert_greater_equal', [x, y, data]):
    x = ops.convert_to_tensor(x, name='x')
    y = ops.convert_to_tensor(y, name='y')
    if data is None:
      data = [
          message,
          'Condition x >= y did not hold element-wise: x = ', x.name, x, 'y = ',
          y.name, y
      ]
    condition = math_ops.reduce_all(math_ops.greater_equal(x, y))
    return control_flow_ops.Assert(condition, data, summarize=summarize) 
Example #17
Source File: math_grad.py    From keras-lambda with MIT License 5 votes vote down vote up
def _MaximumGrad(op, grad):
  """Returns grad*(x > y, x <= y) with type of grad."""
  return _MaximumMinimumGrad(op, grad, math_ops.greater_equal) 
Example #18
Source File: head.py    From keras-lambda with MIT License 5 votes vote down vote up
def _accuracy_at_threshold(threshold):

  def _accuracy_metric(predictions, labels, weights=None):
    threshold_predictions = math_ops.to_float(
        math_ops.greater_equal(predictions, threshold))
    return metrics_lib.streaming_accuracy(
        predictions=threshold_predictions, labels=labels, weights=weights)

  return _accuracy_metric 
Example #19
Source File: embedding_ops.py    From keras-lambda with MIT License 5 votes vote down vote up
def _prune_invalid_ids(sparse_ids, sparse_weights):
  """Prune invalid IDs (< 0) from the input ids and weights."""
  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
  if sparse_weights is not None:
    is_id_valid = math_ops.logical_and(
        is_id_valid, math_ops.greater(sparse_weights.values, 0))
  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
  if sparse_weights is not None:
    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
  return sparse_ids, sparse_weights 
Example #20
Source File: target_column.py    From keras-lambda with MIT License 5 votes vote down vote up
def _accuracy_at_threshold(threshold):

  def _accuracy_metric(predictions, labels, weights=None):
    threshold_predictions = math_ops.to_float(
        math_ops.greater_equal(predictions, threshold))
    return metric_ops.streaming_accuracy(
        predictions=threshold_predictions, labels=labels, weights=weights)

  return _accuracy_metric 
Example #21
Source File: core.py    From keras-lambda with MIT License 5 votes vote down vote up
def __ge__(self, other):
    return greater_equal(self, other) 
Example #22
Source File: core_test.py    From keras-lambda with MIT License 5 votes vote down vote up
def setUp(self):
    super(CoreBinaryOpsTest, self).setUp()

    self.x_probs_broadcast_tensor = array_ops.reshape(
        self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size])

    self.channel_probs_broadcast_tensor = array_ops.reshape(
        self.channel_probs_lt.tensor, [1, self.channel_size, self.probs_size])

    # == and != are not element-wise for tf.Tensor, so they shouldn't be
    # elementwise for LabeledTensor, either.
    self.ops = [
        ('add', operator.add, math_ops.add, core.add),
        ('sub', operator.sub, math_ops.subtract, core.sub),
        ('mul', operator.mul, math_ops.multiply, core.mul),
        ('div', operator.truediv, math_ops.div, core.div),
        ('mod', operator.mod, math_ops.mod, core.mod),
        ('pow', operator.pow, math_ops.pow, core.pow_function),
        ('equal', None, math_ops.equal, core.equal),
        ('less', operator.lt, math_ops.less, core.less),
        ('less_equal', operator.le, math_ops.less_equal, core.less_equal),
        ('not_equal', None, math_ops.not_equal, core.not_equal),
        ('greater', operator.gt, math_ops.greater, core.greater),
        ('greater_equal', operator.ge, math_ops.greater_equal,
         core.greater_equal),
    ]
    self.test_lt_1 = self.x_probs_lt
    self.test_lt_2 = self.channel_probs_lt
    self.test_lt_1_broadcast = self.x_probs_broadcast_tensor
    self.test_lt_2_broadcast = self.channel_probs_broadcast_tensor
    self.broadcast_axes = [self.a0, self.a1, self.a3] 
Example #23
Source File: head.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def _accuracy_at_threshold(threshold):

  def _accuracy_metric(predictions, labels, weights=None):
    threshold_predictions = math_ops.to_float(
        math_ops.greater_equal(predictions, threshold))
    return metrics_lib.streaming_accuracy(
        predictions=threshold_predictions, labels=labels, weights=weights)

  return _accuracy_metric 
Example #24
Source File: embedding_ops.py    From lambda-packs with MIT License 5 votes vote down vote up
def _prune_invalid_ids(sparse_ids, sparse_weights):
  """Prune invalid IDs (< 0) from the input ids and weights."""
  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
  if sparse_weights is not None:
    is_id_valid = math_ops.logical_and(
        is_id_valid, math_ops.greater(sparse_weights.values, 0))
  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
  if sparse_weights is not None:
    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
  return sparse_ids, sparse_weights 
Example #25
Source File: target_column.py    From lambda-packs with MIT License 5 votes vote down vote up
def _accuracy_at_threshold(threshold):

  def _accuracy_metric(predictions, labels, weights=None):
    threshold_predictions = math_ops.to_float(
        math_ops.greater_equal(predictions, threshold))
    return metric_ops.streaming_accuracy(
        predictions=threshold_predictions, labels=labels, weights=weights)

  return _accuracy_metric 
Example #26
Source File: core.py    From lambda-packs with MIT License 5 votes vote down vote up
def __ge__(self, other):
    return greater_equal(self, other) 
Example #27
Source File: check_ops.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
                         name=None):
  """Assert the condition `x >= y` holds element-wise.

  Example of adding a dependency to an operation:

  ```python
  with tf.control_dependencies([tf.assert_greater_equal(x, y)]):
    output = tf.reduce_sum(x)
  ```

  This condition holds if for every pair of (possibly broadcast) elements
  `x[i]`, `y[i]`, we have `x[i] >= y[i]`.
  If both `x` and `y` are empty, this is trivially satisfied.

  Args:
    x:  Numeric `Tensor`.
    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`, `y`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).  Defaults to
      "assert_greater_equal"

  Returns:
    Op that raises `InvalidArgumentError` if `x >= y` is False.
  """
  message = message or ''
  with ops.name_scope(name, 'assert_greater_equal', [x, y, data]):
    x = ops.convert_to_tensor(x, name='x')
    y = ops.convert_to_tensor(y, name='y')
    if data is None:
      data = [
          message,
          'Condition x >= y did not hold element-wise: x = ', x.name, x, 'y = ',
          y.name, y
      ]
    condition = math_ops.reduce_all(math_ops.greater_equal(x, y))
    return control_flow_ops.Assert(condition, data, summarize=summarize) 
Example #28
Source File: head.py    From lambda-packs with MIT License 5 votes vote down vote up
def _streaming_accuracy_at_threshold(predictions, labels, weights, threshold):
  threshold_predictions = math_ops.to_float(
      math_ops.greater_equal(predictions, threshold))
  return metrics_lib.streaming_accuracy(
      predictions=threshold_predictions, labels=labels, weights=weights) 
Example #29
Source File: math_grad.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def _MaximumGrad(op, grad):
  """Returns grad*(x > y, x <= y) with type of grad."""
  return _MaximumMinimumGrad(op, grad, math_ops.greater_equal) 
Example #30
Source File: feature_column.py    From lambda-packs with MIT License 5 votes vote down vote up
def _prune_invalid_ids(sparse_ids, sparse_weights):
  """Prune invalid IDs (< 0) from the input ids and weights."""
  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
  if sparse_weights is not None:
    is_id_valid = math_ops.logical_and(
        is_id_valid, math_ops.greater(sparse_weights.values, 0))
  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
  if sparse_weights is not None:
    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
  return sparse_ids, sparse_weights