Python tensorflow.compat.v2.gather_nd() Examples

The following are 9 code examples of tensorflow.compat.v2.gather_nd(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.compat.v2 , or try the search function .
Example #1
Source File: util.py    From language with Apache License 2.0 6 votes vote down vote up
def labels_of_top_ranked_predictions_in_batch(labels, predictions):
  """Applying tf.metrics.mean to this gives precision at 1.

  Args:
    labels: minibatch of dense 0/1 labels, shape [batch_size rows, num_classes]
    predictions: minibatch of predictions of the same shape

  Returns:
    one-dimension tensor top_labels, where top_labels[i]=1.0 iff the
    top-scoring prediction for batch element i has label 1.0
  """
  indices_of_top_preds = tf.cast(tf.argmax(input=predictions, axis=1), tf.int32)
  batch_size = tf.reduce_sum(input_tensor=tf.ones_like(indices_of_top_preds))
  row_indices = tf.range(batch_size)
  thresholded_labels = tf.where(labels > 0.0, tf.ones_like(labels),
                                tf.zeros_like(labels))
  label_indices_to_gather = tf.transpose(
      a=tf.stack([row_indices, indices_of_top_preds]))
  return tf.gather_nd(thresholded_labels, label_indices_to_gather) 
Example #2
Source File: piecewise.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _piecewise_constant_function(x, jump_locations, values,
                                 batch_rank, side='left'):
  """Computes value of the piecewise constant function."""
  # Initializer already verified that `jump_locations` and `values` have the
  # same shape
  batch_shape = jump_locations.shape.as_list()[:-1]
  # Check that the batch shape of `x` is the same as of `jump_locations` and
  # `values`
  batch_shape_x = x.shape.as_list()[:batch_rank]
  if batch_shape_x != batch_shape:
    raise ValueError('Batch shape of `x` is {1} but should be {0}'.format(
        batch_shape, batch_shape_x))
  if x.shape.as_list()[:batch_rank]:
    no_batch_shape = False
  else:
    no_batch_shape = True
    x = tf.expand_dims(x, 0)
  # Expand batch size to one if there is no batch shape
  if not batch_shape:
    jump_locations = tf.expand_dims(jump_locations, 0)
    values = tf.expand_dims(values, 0)
  indices = tf.searchsorted(jump_locations, x, side=side)
  index_matrix = _prepare_index_matrix(
      indices.shape.as_list()[:-1], indices.shape.as_list()[-1], indices.dtype)
  indices_nd = tf.concat(
      [index_matrix, tf.expand_dims(indices, -1)], -1)
  res = tf.gather_nd(values, indices_nd)
  if no_batch_shape:
    return tf.squeeze(res, 0)
  else:
    return res 
Example #3
Source File: piecewise.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _prepare_index_matrix(batch_shape, num_points, dtype):
  """Prepares index matrix for index argument of `tf.gather_nd`."""
  batch_shape_reverse = batch_shape.copy()
  batch_shape_reverse.reverse()
  index_matrix = tf.constant(
      np.flip(np.transpose(np.indices(batch_shape_reverse)), -1),
      dtype=dtype)
  batch_rank = len(batch_shape)
  # Broadcast index matrix to the shape of
  # `batch_shape + [num_points] + [batch_rank]`.
  broadcasted_shape = batch_shape + [num_points] + [batch_rank]
  index_matrix = tf.expand_dims(index_matrix, -2) + tf.zeros(
      tf.TensorShape(broadcasted_shape), dtype=dtype)
  return index_matrix 
Example #4
Source File: utils.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def prepare_indices(indices):
  """Prepares `tf.searchsorted` output for index argument of `tf.gather_nd`.

  Creates an index matrix that can be used along with `tf.gather_nd`.

  #### Example
  indices = tf.constant([[[1, 2], [2, 3]]])
  index_matrix = utils.prepare_indices(indices)
  # Outputs a tensor of shape [1, 2, 3, 2]
  # [[[[0, 0], [0, 0], [0, 0]], [[0, 1], [0, 1], [0, 1]]]]
  # The index matrix can be concatenated with the indices in order to obtain
  # gather_nd selection matrix
  tf.concat([index_matrix, tf.expand_dims(indices, axis=-1)], axis=-1)
  # Outputs
  # [[[[0, 0, 1], [0, 0, 2], [0, 0, 3]],
  #   [[0, 1, 2], [0, 1, 3], [0, 1, 4]]]]

  Args:
    indices: A `Tensor` of any shape and dtype.

  Returns:
    A `Tensor` of the same dtype as `indices` and shape
    `indices.shape + [indices.shape.rank - 1]`.
  """
  batch_shape = indices.shape.as_list()[:-1]
  num_points = indices.shape.as_list()[-1]
  batch_shape_reverse = indices.shape.as_list()[:-1]
  batch_shape_reverse.reverse()
  index_matrix = tf.constant(
      np.flip(np.transpose(np.indices(batch_shape_reverse)), -1),
      dtype=indices.dtype)
  batch_rank = len(batch_shape)
  # Broadcast index matrix to the shape of
  # `batch_shape + [num_points] + [batch_rank]`
  broadcasted_shape = batch_shape + [num_points] + [batch_rank]
  index_matrix = tf.expand_dims(index_matrix, -2) + tf.zeros(
      broadcasted_shape, dtype=indices.dtype)
  return index_matrix 
Example #5
Source File: extensions.py    From trax with Apache License 2.0 4 votes vote down vote up
def sort_key_val(keys, values, dimension=-1):
  """Sorts keys along a dimension and applies same permutation to values.

  Args:
    keys: an array. The dtype must be comparable numbers (integers and reals).
    values: an array, with the same shape of `keys`.
    dimension: an `int`. The dimension along which to sort.

  Returns:
    Permuted keys and values.
  """
  keys = tf_np.asarray(keys)
  values = tf_np.asarray(values)
  rank = keys.data.shape.ndims
  if rank is None:
    rank = values.data.shape.ndims
  if rank is None:
    # We need to know the rank because tf.gather requires batch_dims to be `int`
    raise ValueError("The rank of either keys or values must be known, but "
                     "both are unknown (i.e. their shapes are both None).")
  if dimension in (-1, rank - 1):

    def maybe_swapaxes(a):
      return a
  else:

    def maybe_swapaxes(a):
      return tf_np.swapaxes(a, dimension, -1)

  # We need to swap axes because tf.gather (and tf.gather_nd) supports
  # batch_dims on the left but not on the right.
  # TODO(wangpeng): Investigate whether we should do swapaxes or moveaxis.
  keys = maybe_swapaxes(keys)
  values = maybe_swapaxes(values)
  idxs = tf_np.argsort(keys)
  idxs = idxs.data

  # Using tf.gather rather than np.take because the former supports batch_dims
  def gather(a):
    return tf_np.asarray(tf.gather(a.data, idxs, batch_dims=rank - 1))

  keys = gather(keys)
  values = gather(values)
  keys = maybe_swapaxes(keys)
  values = maybe_swapaxes(values)
  return keys, values


# Use int64 instead of int32 to avoid TF's "int32 problem" 
Example #6
Source File: network.py    From ranking with Apache License 2.0 4 votes vote down vote up
def compute_logits(self,
                     context_features=None,
                     example_features=None,
                     training=None,
                     mask=None):
    """Scores context and examples to return a score per document.

    Args:
      context_features: (dict) context feature names to 2D tensors of shape
        [batch_size, feature_dims].
      example_features: (dict) example feature names to 3D tensors of shape
        [batch_size, list_size, feature_dims].
      training: (bool) whether in train or inference mode.
      mask: (tf.Tensor) Mask is a tensor of shape [batch_size, list_size], which
        is True for a valid example and False for invalid one. If mask is None,
        all entries are valid.

    Returns:
      (tf.Tensor) A score tensor of shape [batch_size, list_size].
    """
    tensor = next(six.itervalues(example_features))
    batch_size = tf.shape(tensor)[0]
    list_size = tf.shape(tensor)[1]
    if mask is None:
      mask = tf.ones(shape=[batch_size, list_size], dtype=tf.bool)
    nd_indices, nd_mask = utils.padded_nd_indices(is_valid=mask)

    # Expand query features to be of [batch_size, list_size, ...].
    large_batch_context_features = {}
    for name, tensor in six.iteritems(context_features):
      x = tf.expand_dims(input=tensor, axis=1)
      x = tf.gather(x, tf.zeros([list_size], tf.int32), axis=1)
      large_batch_context_features[name] = utils.reshape_first_ndims(
          x, 2, [batch_size * list_size])

    large_batch_example_features = {}
    for name, tensor in six.iteritems(example_features):
      # Replace invalid example features with valid ones.
      padded_tensor = tf.gather_nd(tensor, nd_indices)
      large_batch_example_features[name] = utils.reshape_first_ndims(
          padded_tensor, 2, [batch_size * list_size])

    # Get scores for large batch.
    scores = self.score(
        context_features=large_batch_context_features,
        example_features=large_batch_example_features,
        training=training)
    logits = tf.reshape(
        scores, shape=[batch_size, list_size])

    # Apply nd_mask to zero out invalid entries.
    logits = tf.where(nd_mask, logits, tf.zeros_like(logits))
    return logits 
Example #7
Source File: piecewise.py    From tf-quant-finance with Apache License 2.0 4 votes vote down vote up
def _piecewise_constant_integrate(x1, x2, jump_locations, values, batch_rank):
  """Integrates piecewise constant function between `x1` and `x2`."""
  # Initializer already verified that `jump_locations` and `values` have the
  # same shape.
  # Expand batch size to one if there is no batch shape.
  if x1.shape.as_list()[:batch_rank]:
    no_batch_shape = False
  else:
    no_batch_shape = True
    x1 = tf.expand_dims(x1, 0)
    x2 = tf.expand_dims(x2, 0)
  if not jump_locations.shape.as_list()[:-1]:
    jump_locations = tf.expand_dims(jump_locations, 0)
    values = tf.expand_dims(values, 0)
    batch_rank += 1

  # Compute the index matrix that is later used for `tf.gather_nd`.
  index_matrix = _prepare_index_matrix(
      x1.shape.as_list()[:-1], x1.shape.as_list()[-1], tf.int32)
  # Compute integral values at the jump locations starting from the first jump
  # location.
  event_shape = values.shape[(batch_rank+1):]
  num_data_points = values.shape.as_list()[batch_rank]
  diff = jump_locations[..., 1:] - jump_locations[..., :-1]
  # Broadcast `diff` to the shape of
  # `batch_shape + [num_data_points - 2] + [1] * sample_rank`.
  for _ in event_shape:
    diff = tf.expand_dims(diff, -1)
  slice_indices = batch_rank * [slice(None)]
  slice_indices += [slice(1, num_data_points - 1)]
  integrals = tf.cumsum(values[slice_indices] * diff, batch_rank)
  # Pad integrals with zero values on left and right.
  batch_shape = integrals.shape.as_list()[:batch_rank]
  zeros = tf.zeros(batch_shape + [1] + event_shape, dtype=integrals.dtype)
  integrals = tf.concat([zeros, integrals, zeros], axis=batch_rank)
  # Get jump locations and values and the integration end points
  value1, jump_location1, indices_nd1 = _get_indices_and_values(
      x1, index_matrix, jump_locations, values, 'left', batch_rank)
  value2, jump_location2, indices_nd2 = _get_indices_and_values(
      x2, index_matrix, jump_locations, values, 'right', batch_rank)
  integrals1 = tf.gather_nd(integrals, indices_nd1)
  integrals2 = tf.gather_nd(integrals, indices_nd2)
  # Broadcast `x1`, `x2`, `jump_location1`, `jump_location2` to the shape
  # `batch_shape + [num_points] + [1] * sample_rank`.
  for _ in event_shape:
    x1 = tf.expand_dims(x1, -1)
    x2 = tf.expand_dims(x2, -1)
    jump_location1 = tf.expand_dims(jump_location1, -1)
    jump_location2 = tf.expand_dims(jump_location2, -1)
  # Compute the value of the integral.
  res = ((jump_location1 - x1) * value1
         + (x2 - jump_location2) * value2
         + integrals2 - integrals1)
  if no_batch_shape:
    return tf.squeeze(res, 0)
  else:
    return res 
Example #8
Source File: piecewise.py    From tf-quant-finance with Apache License 2.0 4 votes vote down vote up
def _get_indices_and_values(x, index_matrix, jump_locations, values, side,
                            batch_rank):
  """Computes values and jump locations of the piecewise constant function.

  Given `jump_locations` and the `values` on the corresponding segments of the
  piecewise constant function, the function identifies the nearest jump to `x`
  from the right or left (which is determined by the `side` argument) and the
  corresponding value of the piecewise constant function at `x`

  Args:
    x: A real `Tensor` of shape `batch_shape + [num_points]`. Points at which
      the function has to be evaluated.
    index_matrix: An `int32` `Tensor` of shape
      `batch_shape + [num_points] + [len(batch_shape)]` such that if
      `batch_shape = [i1, .., in]`, then for all `j1, ..., jn, l`,
      `index_matrix[j1,..,jn, l] = [j1, ..., jn]`.
    jump_locations: A `Tensor` of the same `dtype` as `x` and shape
      `batch_shape + [num_jump_points]`. The locations where the function
      changes its values. Note that the values are expected to be ordered
      along the last dimension.
    values: A `Tensor` of the same `dtype` as `x` and shape
      `batch_shape + [num_jump_points + 1]`. Defines `values[..., i]` on
      `jump_locations[..., i - 1], jump_locations[..., i]`.
    side: A Python string. Whether the function is left- or right- continuous.
      The corresponding values for side should be `left` and `right`.
    batch_rank: A Python scalar stating the batch rank of `x`.

  Returns:
    A tuple of three `Tensor` of the same `dtype` as `x` and shapes
    `batch_shape + [num_points] + event_shape`, `batch_shape + [num_points]`,
    and `batch_shape + [num_points] + [2 * len(batch_shape)]`. The `Tensor`s
    correspond to the values, jump locations at `x`, and the corresponding
    indices used to obtain jump locations via `tf.gather_nd`.
  """
  indices = tf.searchsorted(jump_locations, x, side=side)
  num_data_points = tf.shape(values)[batch_rank] - 2
  if side == 'right':
    indices_jump = indices - 1
    indices_jump = tf.maximum(indices_jump, 0)
  else:
    indices_jump = tf.minimum(indices, num_data_points)
  indices_nd = tf.concat(
      [index_matrix, tf.expand_dims(indices, -1)], -1)
  indices_jump_nd = tf.concat(
      [index_matrix, tf.expand_dims(indices_jump, -1)], -1)
  value = tf.gather_nd(values, indices_nd)
  jump_location = tf.gather_nd(jump_locations, indices_jump_nd)
  return value, jump_location, indices_jump_nd 
Example #9
Source File: base_agent.py    From valan with Apache License 2.0 4 votes vote down vote up
def _get_reset_state(self, observation, done, default_state):
    """Resets the state wherever marked in `done` tensor.

    Consider the following example with num_timesteps=2, batch_size=3,
    state_size=1:
      default_state (batch_size, state_size) = [[5.], [5.], [5.]]
      done (num_timesteps, batch_size) = [[True, True, False],
                                          [False, True, False]]
      observation (num_timesteps, batch_size, 1) = [[[1.], [2.], [3.]],
                                                    [[4.], [5.], [6.]]]
      self.get_initial_state implements `observation + 10`.
    then returned tensor will be of shape (num_timesteps, batch_size,
    state_size) and its value will be:
      [[[11.], [12.], [0.]],
       [[0.],  [15.], [0.]]]
    where state values are replaced by call to `self.get_initial_state` wherever
    done=True. Note that the state values where done=False are set to zeros and
    are expected not to be used by the caller.

    Args:
      observation: A nested structure with individual tensors that have first
        two dimensions equal to [num_timesteps, batch_size].
      done: A boolean tensor of shape  [num_timesteps, batch_size].
      default_state: A tensor or nested structure with individual tensors that
        have first dimension equal to batch_size and no time dimension.

    Returns:
      A structure similar to `default_state` except that all tensors in the
      returned structure have an additional leading dimension equal to
      num_timesteps.
    """
    reset_indices = tf.compat.v1.where(tf.equal(done, True))

    def _get_reset_state_indices():
      reset_indices_obs = tf.nest.map_structure(
          lambda t: tf.gather_nd(t, reset_indices), observation)
      # shape: [num_indices_to_reset, ...]
      reset_indices_state = self.get_initial_state(
          reset_indices_obs, batch_size=tf.shape(reset_indices)[0])
      # Scatter tensors in `reset_indices_state` to shape: [num_timesteps,
      # batch_size, ...]
      return tf.nest.map_structure(
          lambda reset_tensor: tf.scatter_nd(  
              indices=reset_indices,
              updates=reset_tensor,
              shape=done.shape.as_list() + reset_tensor.shape.as_list()[1:]),
          reset_indices_state)

    # A minor optimization wherein if all elements in `done` are False, we
    # simply return a structure with zeros tensors of correct shape.
    return tf.cond(
        tf.greater(tf.size(reset_indices), 0),
        _get_reset_state_indices,
        lambda: tf.nest.map_structure(  
            lambda t: tf.zeros(         
                shape=done.shape.as_list() + t.shape.as_list()[1:],
                dtype=t.dtype),
            default_state))