Python tensorflow.compat.v2.stack() Examples

The following are 30 code examples of tensorflow.compat.v2.stack(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.compat.v2 , or try the search function .
Example #1
Source File: input_pipeline.py    From models with Apache License 2.0 6 votes vote down vote up
def process_multidoc_dataset(dataset, batch_size, params):
  """Parses, organizes and batches multi-doc dataset."""
  name_to_features, feature_list = multidoc_parse_spec(params)
  decode_fn = lambda record: decode_record(record, name_to_features)
  dataset = dataset.map(
      decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  def _select_data_from_record(record):
    """Filter out features to use for pretraining."""
    features = {"target_ids": record["input_ids_a"]}
    for feature in feature_list:
      tensors = [record["%s_%s" % (feature, i)] for i in params.passage_list]
      features[feature] = tf.stack(tensors)
    return features

  dataset = dataset.map(
      _select_data_from_record,
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  dataset = dataset.batch(batch_size, drop_remainder=True)
  return dataset 
Example #2
Source File: utils.py    From valan with Apache License 2.0 6 votes vote down vote up
def stack_nested_tensors(list_of_nests):
  """Stack a list of nested tensors.

  Args:
    list_of_nests: A list of nested tensors (or numpy arrays) of the same
      shape/structure.

  Returns:
    A nested array containing batched items, where each batched item is obtained
    by stacking corresponding items from the list of nested_arrays.
  """


  def stack_tensor(*tensors):
    result = [tf.convert_to_tensor(t) for t in tensors]
    return tf.stack(result)

  return tf.nest.map_structure(stack_tensor, *list_of_nests) 
Example #3
Source File: gradient_test.py    From tf-quant-finance with Apache License 2.0 6 votes vote down vote up
def test_forward_unconnected_gradient(self):
    t = tf.range(1, 3, dtype=tf.float32)  # Shape [2]
    zeros = tf.zeros([2], dtype=t.dtype)
    func = lambda t: tf.stack([zeros, zeros, zeros], axis=0)  # Shape [3, 2]
    expected_result = [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]
    with self.subTest("EagerExecution"):
      fwd_grad = self.evaluate(tff.math.fwd_gradient(
          func, t, unconnected_gradients=tf.UnconnectedGradients.ZERO))
      self.assertEqual(fwd_grad.shape, (3, 2))
      np.testing.assert_allclose(fwd_grad, expected_result)
    with self.subTest("GraphExecution"):
      @tf.function
      def grad_computation():
        y = func(t)
        return tff.math.fwd_gradient(
            y, t, unconnected_gradients=tf.UnconnectedGradients.ZERO)
      fwd_grad = self.evaluate(grad_computation())
      self.assertEqual(fwd_grad.shape, (3, 2))
      np.testing.assert_allclose(fwd_grad, expected_result) 
Example #4
Source File: util.py    From language with Apache License 2.0 6 votes vote down vote up
def labels_of_top_ranked_predictions_in_batch(labels, predictions):
  """Applying tf.metrics.mean to this gives precision at 1.

  Args:
    labels: minibatch of dense 0/1 labels, shape [batch_size rows, num_classes]
    predictions: minibatch of predictions of the same shape

  Returns:
    one-dimension tensor top_labels, where top_labels[i]=1.0 iff the
    top-scoring prediction for batch element i has label 1.0
  """
  indices_of_top_preds = tf.cast(tf.argmax(input=predictions, axis=1), tf.int32)
  batch_size = tf.reduce_sum(input_tensor=tf.ones_like(indices_of_top_preds))
  row_indices = tf.range(batch_size)
  thresholded_labels = tf.where(labels > 0.0, tf.ones_like(labels),
                                tf.zeros_like(labels))
  label_indices_to_gather = tf.transpose(
      a=tf.stack([row_indices, indices_of_top_preds]))
  return tf.gather_nd(thresholded_labels, label_indices_to_gather) 
Example #5
Source File: gradient_test.py    From tf-quant-finance with Apache License 2.0 6 votes vote down vote up
def test_backward_unconnected_gradient(self):
    t = tf.range(1, 3, dtype=tf.float32)  # Shape [2]
    zeros = tf.zeros([2], dtype=t.dtype)
    expected_result = [0.0, 0.0]
    func = lambda t: tf.stack([zeros, zeros, zeros], axis=0)  # Shape [3, 2]
    with self.subTest("EagerExecution"):
      backward_grad = self.evaluate(tff.math.gradients(
          func, t, unconnected_gradients=tf.UnconnectedGradients.ZERO))
      self.assertEqual(backward_grad.shape, (2,))
      np.testing.assert_allclose(backward_grad, expected_result)
    with self.subTest("GraphExecution"):
      @tf.function
      def grad_computation():
        y = func(t)
        return tff.math.gradients(
            y, t, unconnected_gradients=tf.UnconnectedGradients.ZERO)
      backward_grad = self.evaluate(grad_computation())
      self.assertEqual(backward_grad.shape, (2,))
      np.testing.assert_allclose(backward_grad, expected_result) 
Example #6
Source File: imagenet2012_corrupted.py    From datasets with Apache License 2.0 6 votes vote down vote up
def _decode_and_center_crop(image_bytes):
  """Crops to center of image with padding then scales image size."""
  shape = tf.image.extract_jpeg_shape(image_bytes)
  image_height = shape[0]
  image_width = shape[1]

  padded_center_crop_size = tf.cast(
      ((_IMAGE_SIZE / (_IMAGE_SIZE + _CROP_PADDING)) *
       tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)

  offset_height = ((image_height - padded_center_crop_size) + 1) // 2
  offset_width = ((image_width - padded_center_crop_size) + 1) // 2
  crop_window = tf.stack([
      offset_height, offset_width, padded_center_crop_size,
      padded_center_crop_size
  ])
  image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
  image = tf.image.resize([image], [_IMAGE_SIZE, _IMAGE_SIZE],
                          method=tf.image.ResizeMethod.BICUBIC)[0]
  image = tf.cast(image, tf.int32)

  return image 
Example #7
Source File: date_tensor.py    From tf-quant-finance with Apache License 2.0 6 votes vote down vote up
def to_tensor(self):
    """Packs the dates into a single Tensor.

    The Tensor has shape `date_tensor.shape() + (3,)`, where the last dimension
    represents years, months and days, in this order.

    This can be convenient when the dates are the final result of a computation
    in the graph mode: a `tf.function` can return `date_tensor.to_tensor()`, or,
    if one uses `tf.compat.v1.Session`, they can call
    `session.run(date_tensor.to_tensor())`.

    Returns:
      A Tensor of shape `date_tensor.shape() + (3,)`.

    #### Example

    ```python
    dates = tff.datetime.dates_from_tuples([(2019, 1, 25), (2020, 3, 2)])
    dates.to_tensor()  # tf.Tensor with contents [[2019, 1, 25], [2020, 3, 2]].
    ```
    """
    return tf.stack((self.year(), self.month(), self.day()), axis=-1) 
Example #8
Source File: extensions.py    From trax with Apache License 2.0 6 votes vote down vote up
def _key2seed(a):
  """Converts an RNG key to an RNG seed.

  Args:
    a: an RNG key, an ndarray of shape [] and dtype `np.int64`.

  Returns:
    an RNG seed, a tensor of shape [2] and dtype `tf.int32`.
  """

  def int64_to_int32s(a):
    """Converts an int64 tensor of shape [] to an int32 tensor of shape [2]."""
    a = tf.cast(a, tf.uint64)
    fst = tf.cast(a, tf.uint32)
    snd = tf.cast(
        tf.bitwise.right_shift(a, tf.constant(32, tf.uint64)), tf.uint32)
    a = [fst, snd]
    a = tf.nest.map_structure(lambda x: tf.cast(x, tf.int32), a)
    a = tf.stack(a)
    return a

  return int64_to_int32s(a.data) 
Example #9
Source File: hull_white_test.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def test_mean_variance_correlation_generic_2d(self):
    """Tests model with generic parameters in 2 dimensions."""
    for dtype in [tf.float32, tf.float64]:
      # Mean reversion without batch dimesnion
      mean_reversion = tff.math.piecewise.PiecewiseConstantFunc(
          [0.1, 2.0], values=3 * [self.mean_reversion], dtype=dtype)
      # Volatility with batch dimesnion
      volatility = tff.math.piecewise.PiecewiseConstantFunc(
          [[0.1, 0.2, 0.5], [0.1, 2.0, 3.0]],
          values=[4 * [self.volatility[0]],
                  4 * [self.volatility[1]]], dtype=dtype)
      def corr_matrix(t):
        one = tf.ones_like(t)
        row1 = tf.stack([one, 0.5 * t], axis=-1)
        row2 = tf.reverse(row1, [0])
        corr_matrix = tf.stack([row1, row2], axis=-1)
        return corr_matrix
      process = tff.models.hull_white.VectorHullWhiteModel(
          dim=2,
          mean_reversion=mean_reversion,
          volatility=volatility,
          corr_matrix=corr_matrix,
          initial_discount_rate_fn=self.instant_forward_rate_2d_fn,
          dtype=dtype)
      times = [0.1, 0.5]
      paths = process.sample_paths(
          times,
          num_samples=50000,
          random_type=tff.math.random.RandomType.SOBOL,
          skip=100000,
          time_step=0.01)
      self.assertEqual(paths.dtype, dtype)
      self.assertAllEqual(paths.shape, [50000, 2, 2])
      paths = self.evaluate(paths)
      paths = paths[:, -1, :]  # Extract paths values for the terminal time
      mean = np.mean(paths, axis=0)
      variance = np.var(paths, axis=0)
      self.assertAllClose(mean, self.true_mean(times[-1]), rtol=1e-3, atol=1e-3)
      self.assertAllClose(variance,
                          self.true_var(times[-1]), rtol=1e-3, atol=1e-3) 
Example #10
Source File: utils.py    From valan with Apache License 2.0 5 votes vote down vote up
def parallel_conv2d(inputs, filters, strides, padding):
  """Applies each filter in the batch of filters to each input.

  tf.nn.conv2d only supports applying the same filter on a batch of inputs.
  This function provides a similar interface, but allowing a batch of filters,
  a different one for each input.

  In the below definitions, B is the batch size, H and W are spatial input or
  output dimensions (overloaded between input and output), C1 is the input
  number of channels, C2 is output number of channels, KHxKW is the
  convolutional kernel spatial size.

  Args:
    inputs: BxHxWxC1 tensor - batch of input "images"
    filters: BxKHxKWxC1xC2 tensor - batch of convolutional kernels
    strides: See tf.nn.conv2d arg: strides
    padding: See tf.nn.conv2d arg: padding

  Returns:
    Tensor of shape BxHxWxC2
  """
  batch_size = inputs.shape[0]

  output_slices = [tf.nn.conv2d(inputs[i:i+1], filters[i], strides, padding)
                   for i in range(batch_size)]
  output = tf.stack(output_slices, axis=0)
  # Each output slice has a batch dimension of size 1. Get rid of it.
  assert output.shape[1] == 1, 'Each slice should have batch size of 1'
  output = output[:, 0, :, :, :]
  # Output should have same batch size and spatial dimensions as input, but
  # the number of channels is determined by the convolution filter
  assert_shape((batch_size, inputs.shape[1], inputs.shape[2], filters.shape[4]),
               output.shape)
  return output 
Example #11
Source File: utils.py    From valan with Apache License 2.0 5 votes vote down vote up
def get_first_true_column(x):
  """Transforms `x` into a tensor which has all elements set to False except the first True in the column.

  If x is [[True, False, False],
           [True, False, False],
           [False, True, False],
           [False, True, True]]
  the output should be
          [[True, False, False],
           [False, False, False],
           [False, True, False],
           [False, False, True]
          ]

  Args:
    x: A bool tensor with shape [num_steps, batch_size]

  Returns:
    A bool tensor with the same shape.
  """
  x = tf.transpose(x, perm=[1, 0])
  # Get indices
  y = tf.where(tf.equal(x, True))
  # Find first column in every row which is True
  first_true_cols = tf.cast(
      tf.math.segment_min(data=y[:, 1], segment_ids=y[:, 0]), tf.int32)
  # Convert back to indices
  first_true_indices = tf.stack(
      [tf.range(tf.size(first_true_cols)), first_true_cols], axis=1)
  # Now create the mask
  first_true_mask_sparse = tf.SparseTensor(
      indices=tf.cast(first_true_indices, tf.int64),
      values=tf.ones([tf.size(first_true_cols)], dtype=tf.bool),
      dense_shape=x.shape)
  first_true_mask = tf.sparse.to_dense(
      first_true_mask_sparse, default_value=False)
  return tf.transpose(first_true_mask, perm=[1, 0]) 
Example #12
Source File: policy_loader_test.py    From agents with Apache License 2.0 5 votes vote down vote up
def call(self, observation, step_type=None, network_state=(), training=False):
    del step_type, network_state, training
    # The action spec is BoundedTensorSpec(min=0, max=3) which means this
    # Network should emit a 4-logit.
    return tf.stack(4 * [observation + self.var], axis=-1), () 
Example #13
Source File: date_tensor.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def __repr__(self):
    output = "DateTensor: shape={}".format(self.shape)
    if tf.executing_eagerly():
      contents_np = np.stack(
          (self._years.numpy(), self._months.numpy(), self._days.numpy()),
          axis=-1)
      return output + ", contents={}".format(repr(contents_np))
    return output 
Example #14
Source File: tensor_wrapper.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def stack(cls, tensor_wrappers, axis=0):
    """See tf.stack."""
    cls._validate_tensor_types(tensor_wrappers, "stack")
    return cls._apply_sequence_to_tensor_op(
        lambda ts: tf.stack(ts, axis), tensor_wrappers) 
Example #15
Source File: date_utils_test.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def test_ordinal_to_year_month_day(self):
    date_tuples = test_data.test_dates
    ordinals = np.array(
        [datetime.date(y, m, d).toordinal() for y, m, d in date_tuples],
        dtype=np.int32)
    y, m, d = date_utils.ordinal_to_year_month_day(ordinals)
    result = tf.stack((y, m, d), axis=1)
    self.assertAllEqual(date_tuples, result) 
Example #16
Source File: swap_curve_fit.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _create_curve_building_tensors(float_leg_start_times,
                                   float_leg_end_times,
                                   fixed_leg_end_times,
                                   pv_settlement_times):
  """Helper function to create tensors needed for curve construction."""
  calc_groups_float = []
  calc_groups_fixed = []
  expiry_times = []
  settle_times_float = []
  settle_times_fixed = []
  num_instruments = len(float_leg_start_times)
  for i in range(num_instruments):
    expiry_times.append(
        tf.math.maximum(float_leg_end_times[i][-1], fixed_leg_end_times[i][-1]))

    calc_groups_float.append(
        tf.fill(tf.shape(float_leg_start_times[i]), i))
    calc_groups_fixed.append(tf.fill(tf.shape(fixed_leg_end_times[i]), i))
    settle_times_float.append(tf.fill(tf.shape(float_leg_start_times[i]),
                                      pv_settlement_times[i]))
    settle_times_fixed.append(tf.fill(tf.shape(fixed_leg_end_times[i]),
                                      pv_settlement_times[i]))

  expiry_times = tf.stack(expiry_times, axis=0)
  calc_groups_float = tf.concat(calc_groups_float, axis=0)
  calc_groups_fixed = tf.concat(calc_groups_fixed, axis=0)
  settle_times_float = tf.concat(settle_times_float, axis=0)
  settle_times_fixed = tf.concat(settle_times_fixed, axis=0)

  return CurveFittingVars(expiry_times=expiry_times,
                          calc_groups_float=calc_groups_float,
                          calc_groups_fixed=calc_groups_fixed,
                          settle_times_float=settle_times_float,
                          settle_times_fixed=settle_times_fixed) 
Example #17
Source File: swap_curve_fit.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _initialize_instrument_weights(float_times, fixed_times, dtype):
  """Function to compute default initial weights for optimization."""
  weights = tf.ones(len(float_times), dtype=dtype)
  one = tf.ones([], dtype=dtype)
  float_times_last = tf.stack([times[-1] for times in float_times])
  fixed_times_last = tf.stack([times[-1] for times in fixed_times])
  weights = tf.maximum(one / float_times_last, one / fixed_times_last)
  weights = tf.minimum(one, weights)
  return tf.unstack(weights, name='instrument_weights') 
Example #18
Source File: generic_ito_process.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _coord_grid_to_mesh_grid(coord_grid):
  if len(coord_grid) == 1:
    return tf.expand_dims(coord_grid[0], -1)
  return tf.stack(values=tf.meshgrid(*coord_grid, indexing='ij'), axis=-1) 
Example #19
Source File: univariate_geometric_brownian_motion.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _coord_grid_to_mesh_grid(coord_grid):
  if len(coord_grid) == 1:
    return tf.expand_dims(coord_grid[0], -1)
  return tf.stack(values=tf.meshgrid(*coord_grid, indexing="ij"), axis=-1) 
Example #20
Source File: multivariate_geometric_brownian_motion.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _coord_grid_to_mesh_grid(coord_grid):
  if len(coord_grid) == 1:
    return tf.expand_dims(coord_grid[0], -1)
  return tf.stack(values=tf.meshgrid(*coord_grid, indexing="ij"), axis=-1) 
Example #21
Source File: cubic_interpolation_test.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def test_spline_broadcast_batch(self, optimize_for_tpu):
    """Tests batch shape of spline and interpolation are broadcasted."""
    x_data1 = np.linspace(-5.0, 5.0, num=11)
    x_data2 = np.linspace(0.0, 10.0, num=11)
    x_data = np.array([x_data1, x_data2])
    y_data = 1.0 / (2.0 + x_data**2)
    x_data = tf.stack(x_data, axis=0)
    dtype = np.float64
    x_value_1 = tf.constant([[[-1.2, 0.0, 0.3]]], dtype=dtype)
    x_value_2 = tf.constant([-1.2, 0.0, 0.3], dtype=dtype)
    spline = tff.math.interpolation.cubic.build_spline(x_data,
                                                       y_data)

    result_1 = tff.math.interpolation.cubic.interpolate(
        x_value_1, spline,
        optimize_for_tpu=optimize_for_tpu, dtype=dtype)
    result_2 = tff.math.interpolation.cubic.interpolate(
        x_value_2, spline,
        optimize_for_tpu=optimize_for_tpu, dtype=dtype)
    expected_1 = np.array([[[0.29131469, 0.5, 0.4779499],
                            [0.5, 0.5, 0.45159077]]], dtype=dtype)
    expected_2 = np.array([[0.29131469, 0.5, 0.4779499],
                           [0.5, 0.5, 0.45159077]], dtype=dtype)
    with self.subTest("BroadcastData"):
      self.assertAllClose(result_1, expected_1)
    with self.subTest("BroadcastValues"):
      self.assertAllClose(result_2, expected_2) 
Example #22
Source File: conjugate_gradient_test.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def test_data_fitting(self):
    """Tests MLE estimation for a simple geometric GLM."""
    n, dim = 100, 3
    dtype = tf.float64
    np.random.seed(234095)
    x = np.random.choice([0, 1], size=[dim, n])
    s = 0.01 * np.sum(x, 0)
    p = 1. / (1 + np.exp(-s))
    y = np.random.geometric(p)
    x_data = tf.convert_to_tensor(value=x, dtype=dtype)
    y_data = tf.expand_dims(tf.convert_to_tensor(value=y, dtype=dtype), -1)

    def neg_log_likelihood(state):
      state_ext = tf.expand_dims(state, 0)
      linear_part = tf.matmul(state_ext, x_data)
      linear_part_ex = tf.stack([tf.zeros_like(linear_part), linear_part],
                                axis=0)
      term1 = tf.squeeze(
          tf.matmul(tf.reduce_logsumexp(linear_part_ex, axis=0), y_data), -1)
      term2 = (0.5 * tf.reduce_sum(state_ext * state_ext, axis=-1) -
               tf.reduce_sum(linear_part, axis=-1))
      return tf.squeeze(term1 + term2)

    self._check_algorithm(
        func=neg_log_likelihood,
        start_point=np.ones(shape=[dim]),
        expected_argmin=[-0.020460034354, 0.171708568111, 0.021200423717]) 
Example #23
Source File: diff_ops_test.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def test_diffs_differentiable(self):
    """Tests that the diffs op is differentiable."""
    x = tf.constant(2.0)
    xv = tf.stack([x, x * x, x * x * x], axis=0)

    # Produces [x, x^2 - x, x^3 - x^2]
    dxv = self.evaluate(math.diff(xv))
    np.testing.assert_array_equal(dxv, [2., 2., 4.])

    grad = self.evaluate(tf.gradients(math.diff(xv), x)[0])
    # Note that TF gradients adds up the components of the jacobian.
    # The sum of [1, 2x-1, 3x^2-2x] at x = 2 is 12.
    self.assertEqual(grad, 12.0) 
Example #24
Source File: gradient_test.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def test_backward_gradient(self):
    t = tf.range(1, 3, dtype=tf.float32)  # Shape [2]
    func = lambda t: tf.stack([t, t ** 2, t ** 3], axis=0)  # Shape [3, 2]
    with self.subTest("EagerExecution"):
      backward_grad = self.evaluate(tff.math.gradients(func, t))
      self.assertEqual(backward_grad.shape, (2,))
      np.testing.assert_allclose(backward_grad, [6., 17.])
    with self.subTest("GraphExecution"):
      @tf.function
      def grad_computation():
        y = func(t)
        return tff.math.gradients(y, t)
      backward_grad = self.evaluate(grad_computation())
      self.assertEqual(backward_grad.shape, (2,))
      np.testing.assert_allclose(backward_grad, [6., 17.]) 
Example #25
Source File: gradient_test.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def test_forward_gradient(self):
    t = tf.range(1, 3, dtype=tf.float32)  # Shape [2]
    func = lambda t: tf.stack([t, t ** 2, t ** 3], axis=0)  # Shape [3, 2]
    with self.subTest("EagerExecution"):
      fwd_grad = self.evaluate(tff.math.fwd_gradient(func, t))
      self.assertEqual(fwd_grad.shape, (3, 2))
      np.testing.assert_allclose(fwd_grad, [[1., 1.], [2., 4.], [3., 12.]])
    with self.subTest("GraphExecution"):
      @tf.function
      def grad_computation():
        y = func(t)
        return tff.math.fwd_gradient(y, t)
      fwd_grad = self.evaluate(grad_computation())
      self.assertEqual(fwd_grad.shape, (3, 2))
      np.testing.assert_allclose(fwd_grad, [[1., 1.], [2., 4.], [3., 12.]]) 
Example #26
Source File: arrays.py    From trax with Apache License 2.0 5 votes vote down vote up
def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs):
  del args, kwargs
  # TODO(nareshmodi): Consider a collective op to gather the tensors from the
  # various devices for performance reasons.
  return tf.stack(value.tensors) 
Example #27
Source File: array_ops.py    From trax with Apache License 2.0 5 votes vote down vote up
def stack(arrays, axis=0):
  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
  unwrapped_arrays = [
      a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays
  ]
  return asarray(tf.stack(unwrapped_arrays, axis)) 
Example #28
Source File: extensions.py    From trax with Apache License 2.0 5 votes vote down vote up
def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs):
  del args, kwargs
  # TODO(nareshmodi): Consider a collective op to gather the tensors from the
  # various devices for performance reasons.
  return tf.stack(value.tensors) 
Example #29
Source File: base_agent.py    From valan with Apache License 2.0 4 votes vote down vote up
def call(self, env_output, neck_state):
    """Runs the entire episode given time-major tensors.

    Args:
      env_output: An `EnvOutput` tuple with following expectations:
        reward - Unused
        done - A boolean tensor of shape  [num_timesteps, batch_size].
        observation - A nested structure with individual tensors that have first
          two dimensions equal to [num_timesteps, batch_size]
        info - Unused
      neck_state: A tensor or nested structure with individual tensors that have
        first dimension equal to batch_size and no time dimension.

    Returns:
      An `AgentOutput` tuple with individual tensors that have first two
        dimensions equal to [num_timesteps, batch_size]
    """
    unused_reward, done, observation, unused_info = env_output
    # Add current time_step and batch_size.
    self._current_num_timesteps = tf.shape(done)[0]
    self._current_batch_size = tf.shape(done)[1]

    torso_output = utils.batch_apply(self._torso, observation)
    # shape: [num_timesteps, batch_size, ...], where the trailing dimensions are
    # same as trailing dimensions of `neck_state`.
    reset_state = self._get_reset_state(observation, done, neck_state)
    neck_output_list = []
    for timestep, d in enumerate(tf.unstack(done)):
      neck_input = utils.get_row_nested_tensor(torso_output, timestep)
      # If the episode ended, the neck state should be reset before the next
      # step.
      curr_timestep_reset_state = utils.get_row_nested_tensor(
          reset_state, timestep)
      neck_state = tf.nest.map_structure(
          lambda reset_state, state: tf.compat.v1.where(d, reset_state, state),  
          curr_timestep_reset_state, neck_state)
      neck_output, neck_state = self._neck(neck_input, neck_state)
      neck_output_list.append(neck_output)

    head_input = tf.nest.map_structure(lambda *tensors: tf.stack(tensors),
                                       *neck_output_list)
    head_output = utils.batch_apply(self._head, head_input)
    assert isinstance(head_output, common.AgentOutput)
    return head_output, neck_state 
Example #30
Source File: array_ops.py    From trax with Apache License 2.0 4 votes vote down vote up
def moveaxis(a, source, destination):  # pylint: disable=missing-docstring
  """Raises ValueError if source, destination not in (-ndim(a), ndim(a))."""
  if not source and not destination:
    return a

  a = asarray(a).data

  if isinstance(source, int):
    source = (source,)
  if isinstance(destination, int):
    destination = (destination,)

  a_rank = utils._maybe_static(tf.rank(a))  # pylint: disable=protected-access

  def _correct_axis(axis, rank):
    if axis < 0:
      return axis + rank
    return axis

  source = tuple(_correct_axis(axis, a_rank) for axis in source)
  destination = tuple(_correct_axis(axis, a_rank) for axis in destination)

  if a.shape.rank is not None:
    perm = [i for i in range(a_rank) if i not in source]
    for dest, src in sorted(zip(destination, source)):
      assert dest <= len(perm)
      perm.insert(dest, src)
  else:
    r = tf.range(a_rank)

    def _remove_indices(a, b):
      """Remove indices (`b`) from `a`."""
      items = tf.unstack(tf.sort(tf.stack(b)), num=len(b))

      i = 0
      result = []

      for item in items:
        result.append(a[i:item])
        i = item + 1

      result.append(a[i:])

      return tf.concat(result, 0)

    minus_sources = _remove_indices(r, source)
    minus_dest = _remove_indices(r, destination)

    perm = tf.scatter_nd(tf.expand_dims(minus_dest, 1), minus_sources, [a_rank])
    perm = tf.tensor_scatter_nd_update(perm, tf.expand_dims(destination, 1),
                                       source)
  a = tf.transpose(a, perm)

  return utils.tensor_to_ndarray(a)