Python tensorflow.compat.v2.stack() Examples
The following are 30
code examples of tensorflow.compat.v2.stack().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.compat.v2
, or try the search function
.
Example #1
Source File: input_pipeline.py From models with Apache License 2.0 | 6 votes |
def process_multidoc_dataset(dataset, batch_size, params): """Parses, organizes and batches multi-doc dataset.""" name_to_features, feature_list = multidoc_parse_spec(params) decode_fn = lambda record: decode_record(record, name_to_features) dataset = dataset.map( decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) def _select_data_from_record(record): """Filter out features to use for pretraining.""" features = {"target_ids": record["input_ids_a"]} for feature in feature_list: tensors = [record["%s_%s" % (feature, i)] for i in params.passage_list] features[feature] = tf.stack(tensors) return features dataset = dataset.map( _select_data_from_record, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.batch(batch_size, drop_remainder=True) return dataset
Example #2
Source File: utils.py From valan with Apache License 2.0 | 6 votes |
def stack_nested_tensors(list_of_nests): """Stack a list of nested tensors. Args: list_of_nests: A list of nested tensors (or numpy arrays) of the same shape/structure. Returns: A nested array containing batched items, where each batched item is obtained by stacking corresponding items from the list of nested_arrays. """ def stack_tensor(*tensors): result = [tf.convert_to_tensor(t) for t in tensors] return tf.stack(result) return tf.nest.map_structure(stack_tensor, *list_of_nests)
Example #3
Source File: gradient_test.py From tf-quant-finance with Apache License 2.0 | 6 votes |
def test_forward_unconnected_gradient(self): t = tf.range(1, 3, dtype=tf.float32) # Shape [2] zeros = tf.zeros([2], dtype=t.dtype) func = lambda t: tf.stack([zeros, zeros, zeros], axis=0) # Shape [3, 2] expected_result = [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]] with self.subTest("EagerExecution"): fwd_grad = self.evaluate(tff.math.fwd_gradient( func, t, unconnected_gradients=tf.UnconnectedGradients.ZERO)) self.assertEqual(fwd_grad.shape, (3, 2)) np.testing.assert_allclose(fwd_grad, expected_result) with self.subTest("GraphExecution"): @tf.function def grad_computation(): y = func(t) return tff.math.fwd_gradient( y, t, unconnected_gradients=tf.UnconnectedGradients.ZERO) fwd_grad = self.evaluate(grad_computation()) self.assertEqual(fwd_grad.shape, (3, 2)) np.testing.assert_allclose(fwd_grad, expected_result)
Example #4
Source File: util.py From language with Apache License 2.0 | 6 votes |
def labels_of_top_ranked_predictions_in_batch(labels, predictions): """Applying tf.metrics.mean to this gives precision at 1. Args: labels: minibatch of dense 0/1 labels, shape [batch_size rows, num_classes] predictions: minibatch of predictions of the same shape Returns: one-dimension tensor top_labels, where top_labels[i]=1.0 iff the top-scoring prediction for batch element i has label 1.0 """ indices_of_top_preds = tf.cast(tf.argmax(input=predictions, axis=1), tf.int32) batch_size = tf.reduce_sum(input_tensor=tf.ones_like(indices_of_top_preds)) row_indices = tf.range(batch_size) thresholded_labels = tf.where(labels > 0.0, tf.ones_like(labels), tf.zeros_like(labels)) label_indices_to_gather = tf.transpose( a=tf.stack([row_indices, indices_of_top_preds])) return tf.gather_nd(thresholded_labels, label_indices_to_gather)
Example #5
Source File: gradient_test.py From tf-quant-finance with Apache License 2.0 | 6 votes |
def test_backward_unconnected_gradient(self): t = tf.range(1, 3, dtype=tf.float32) # Shape [2] zeros = tf.zeros([2], dtype=t.dtype) expected_result = [0.0, 0.0] func = lambda t: tf.stack([zeros, zeros, zeros], axis=0) # Shape [3, 2] with self.subTest("EagerExecution"): backward_grad = self.evaluate(tff.math.gradients( func, t, unconnected_gradients=tf.UnconnectedGradients.ZERO)) self.assertEqual(backward_grad.shape, (2,)) np.testing.assert_allclose(backward_grad, expected_result) with self.subTest("GraphExecution"): @tf.function def grad_computation(): y = func(t) return tff.math.gradients( y, t, unconnected_gradients=tf.UnconnectedGradients.ZERO) backward_grad = self.evaluate(grad_computation()) self.assertEqual(backward_grad.shape, (2,)) np.testing.assert_allclose(backward_grad, expected_result)
Example #6
Source File: imagenet2012_corrupted.py From datasets with Apache License 2.0 | 6 votes |
def _decode_and_center_crop(image_bytes): """Crops to center of image with padding then scales image size.""" shape = tf.image.extract_jpeg_shape(image_bytes) image_height = shape[0] image_width = shape[1] padded_center_crop_size = tf.cast( ((_IMAGE_SIZE / (_IMAGE_SIZE + _CROP_PADDING)) * tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32) offset_height = ((image_height - padded_center_crop_size) + 1) // 2 offset_width = ((image_width - padded_center_crop_size) + 1) // 2 crop_window = tf.stack([ offset_height, offset_width, padded_center_crop_size, padded_center_crop_size ]) image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) image = tf.image.resize([image], [_IMAGE_SIZE, _IMAGE_SIZE], method=tf.image.ResizeMethod.BICUBIC)[0] image = tf.cast(image, tf.int32) return image
Example #7
Source File: date_tensor.py From tf-quant-finance with Apache License 2.0 | 6 votes |
def to_tensor(self): """Packs the dates into a single Tensor. The Tensor has shape `date_tensor.shape() + (3,)`, where the last dimension represents years, months and days, in this order. This can be convenient when the dates are the final result of a computation in the graph mode: a `tf.function` can return `date_tensor.to_tensor()`, or, if one uses `tf.compat.v1.Session`, they can call `session.run(date_tensor.to_tensor())`. Returns: A Tensor of shape `date_tensor.shape() + (3,)`. #### Example ```python dates = tff.datetime.dates_from_tuples([(2019, 1, 25), (2020, 3, 2)]) dates.to_tensor() # tf.Tensor with contents [[2019, 1, 25], [2020, 3, 2]]. ``` """ return tf.stack((self.year(), self.month(), self.day()), axis=-1)
Example #8
Source File: extensions.py From trax with Apache License 2.0 | 6 votes |
def _key2seed(a): """Converts an RNG key to an RNG seed. Args: a: an RNG key, an ndarray of shape [] and dtype `np.int64`. Returns: an RNG seed, a tensor of shape [2] and dtype `tf.int32`. """ def int64_to_int32s(a): """Converts an int64 tensor of shape [] to an int32 tensor of shape [2].""" a = tf.cast(a, tf.uint64) fst = tf.cast(a, tf.uint32) snd = tf.cast( tf.bitwise.right_shift(a, tf.constant(32, tf.uint64)), tf.uint32) a = [fst, snd] a = tf.nest.map_structure(lambda x: tf.cast(x, tf.int32), a) a = tf.stack(a) return a return int64_to_int32s(a.data)
Example #9
Source File: hull_white_test.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def test_mean_variance_correlation_generic_2d(self): """Tests model with generic parameters in 2 dimensions.""" for dtype in [tf.float32, tf.float64]: # Mean reversion without batch dimesnion mean_reversion = tff.math.piecewise.PiecewiseConstantFunc( [0.1, 2.0], values=3 * [self.mean_reversion], dtype=dtype) # Volatility with batch dimesnion volatility = tff.math.piecewise.PiecewiseConstantFunc( [[0.1, 0.2, 0.5], [0.1, 2.0, 3.0]], values=[4 * [self.volatility[0]], 4 * [self.volatility[1]]], dtype=dtype) def corr_matrix(t): one = tf.ones_like(t) row1 = tf.stack([one, 0.5 * t], axis=-1) row2 = tf.reverse(row1, [0]) corr_matrix = tf.stack([row1, row2], axis=-1) return corr_matrix process = tff.models.hull_white.VectorHullWhiteModel( dim=2, mean_reversion=mean_reversion, volatility=volatility, corr_matrix=corr_matrix, initial_discount_rate_fn=self.instant_forward_rate_2d_fn, dtype=dtype) times = [0.1, 0.5] paths = process.sample_paths( times, num_samples=50000, random_type=tff.math.random.RandomType.SOBOL, skip=100000, time_step=0.01) self.assertEqual(paths.dtype, dtype) self.assertAllEqual(paths.shape, [50000, 2, 2]) paths = self.evaluate(paths) paths = paths[:, -1, :] # Extract paths values for the terminal time mean = np.mean(paths, axis=0) variance = np.var(paths, axis=0) self.assertAllClose(mean, self.true_mean(times[-1]), rtol=1e-3, atol=1e-3) self.assertAllClose(variance, self.true_var(times[-1]), rtol=1e-3, atol=1e-3)
Example #10
Source File: utils.py From valan with Apache License 2.0 | 5 votes |
def parallel_conv2d(inputs, filters, strides, padding): """Applies each filter in the batch of filters to each input. tf.nn.conv2d only supports applying the same filter on a batch of inputs. This function provides a similar interface, but allowing a batch of filters, a different one for each input. In the below definitions, B is the batch size, H and W are spatial input or output dimensions (overloaded between input and output), C1 is the input number of channels, C2 is output number of channels, KHxKW is the convolutional kernel spatial size. Args: inputs: BxHxWxC1 tensor - batch of input "images" filters: BxKHxKWxC1xC2 tensor - batch of convolutional kernels strides: See tf.nn.conv2d arg: strides padding: See tf.nn.conv2d arg: padding Returns: Tensor of shape BxHxWxC2 """ batch_size = inputs.shape[0] output_slices = [tf.nn.conv2d(inputs[i:i+1], filters[i], strides, padding) for i in range(batch_size)] output = tf.stack(output_slices, axis=0) # Each output slice has a batch dimension of size 1. Get rid of it. assert output.shape[1] == 1, 'Each slice should have batch size of 1' output = output[:, 0, :, :, :] # Output should have same batch size and spatial dimensions as input, but # the number of channels is determined by the convolution filter assert_shape((batch_size, inputs.shape[1], inputs.shape[2], filters.shape[4]), output.shape) return output
Example #11
Source File: utils.py From valan with Apache License 2.0 | 5 votes |
def get_first_true_column(x): """Transforms `x` into a tensor which has all elements set to False except the first True in the column. If x is [[True, False, False], [True, False, False], [False, True, False], [False, True, True]] the output should be [[True, False, False], [False, False, False], [False, True, False], [False, False, True] ] Args: x: A bool tensor with shape [num_steps, batch_size] Returns: A bool tensor with the same shape. """ x = tf.transpose(x, perm=[1, 0]) # Get indices y = tf.where(tf.equal(x, True)) # Find first column in every row which is True first_true_cols = tf.cast( tf.math.segment_min(data=y[:, 1], segment_ids=y[:, 0]), tf.int32) # Convert back to indices first_true_indices = tf.stack( [tf.range(tf.size(first_true_cols)), first_true_cols], axis=1) # Now create the mask first_true_mask_sparse = tf.SparseTensor( indices=tf.cast(first_true_indices, tf.int64), values=tf.ones([tf.size(first_true_cols)], dtype=tf.bool), dense_shape=x.shape) first_true_mask = tf.sparse.to_dense( first_true_mask_sparse, default_value=False) return tf.transpose(first_true_mask, perm=[1, 0])
Example #12
Source File: policy_loader_test.py From agents with Apache License 2.0 | 5 votes |
def call(self, observation, step_type=None, network_state=(), training=False): del step_type, network_state, training # The action spec is BoundedTensorSpec(min=0, max=3) which means this # Network should emit a 4-logit. return tf.stack(4 * [observation + self.var], axis=-1), ()
Example #13
Source File: date_tensor.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def __repr__(self): output = "DateTensor: shape={}".format(self.shape) if tf.executing_eagerly(): contents_np = np.stack( (self._years.numpy(), self._months.numpy(), self._days.numpy()), axis=-1) return output + ", contents={}".format(repr(contents_np)) return output
Example #14
Source File: tensor_wrapper.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def stack(cls, tensor_wrappers, axis=0): """See tf.stack.""" cls._validate_tensor_types(tensor_wrappers, "stack") return cls._apply_sequence_to_tensor_op( lambda ts: tf.stack(ts, axis), tensor_wrappers)
Example #15
Source File: date_utils_test.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def test_ordinal_to_year_month_day(self): date_tuples = test_data.test_dates ordinals = np.array( [datetime.date(y, m, d).toordinal() for y, m, d in date_tuples], dtype=np.int32) y, m, d = date_utils.ordinal_to_year_month_day(ordinals) result = tf.stack((y, m, d), axis=1) self.assertAllEqual(date_tuples, result)
Example #16
Source File: swap_curve_fit.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _create_curve_building_tensors(float_leg_start_times, float_leg_end_times, fixed_leg_end_times, pv_settlement_times): """Helper function to create tensors needed for curve construction.""" calc_groups_float = [] calc_groups_fixed = [] expiry_times = [] settle_times_float = [] settle_times_fixed = [] num_instruments = len(float_leg_start_times) for i in range(num_instruments): expiry_times.append( tf.math.maximum(float_leg_end_times[i][-1], fixed_leg_end_times[i][-1])) calc_groups_float.append( tf.fill(tf.shape(float_leg_start_times[i]), i)) calc_groups_fixed.append(tf.fill(tf.shape(fixed_leg_end_times[i]), i)) settle_times_float.append(tf.fill(tf.shape(float_leg_start_times[i]), pv_settlement_times[i])) settle_times_fixed.append(tf.fill(tf.shape(fixed_leg_end_times[i]), pv_settlement_times[i])) expiry_times = tf.stack(expiry_times, axis=0) calc_groups_float = tf.concat(calc_groups_float, axis=0) calc_groups_fixed = tf.concat(calc_groups_fixed, axis=0) settle_times_float = tf.concat(settle_times_float, axis=0) settle_times_fixed = tf.concat(settle_times_fixed, axis=0) return CurveFittingVars(expiry_times=expiry_times, calc_groups_float=calc_groups_float, calc_groups_fixed=calc_groups_fixed, settle_times_float=settle_times_float, settle_times_fixed=settle_times_fixed)
Example #17
Source File: swap_curve_fit.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _initialize_instrument_weights(float_times, fixed_times, dtype): """Function to compute default initial weights for optimization.""" weights = tf.ones(len(float_times), dtype=dtype) one = tf.ones([], dtype=dtype) float_times_last = tf.stack([times[-1] for times in float_times]) fixed_times_last = tf.stack([times[-1] for times in fixed_times]) weights = tf.maximum(one / float_times_last, one / fixed_times_last) weights = tf.minimum(one, weights) return tf.unstack(weights, name='instrument_weights')
Example #18
Source File: generic_ito_process.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _coord_grid_to_mesh_grid(coord_grid): if len(coord_grid) == 1: return tf.expand_dims(coord_grid[0], -1) return tf.stack(values=tf.meshgrid(*coord_grid, indexing='ij'), axis=-1)
Example #19
Source File: univariate_geometric_brownian_motion.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _coord_grid_to_mesh_grid(coord_grid): if len(coord_grid) == 1: return tf.expand_dims(coord_grid[0], -1) return tf.stack(values=tf.meshgrid(*coord_grid, indexing="ij"), axis=-1)
Example #20
Source File: multivariate_geometric_brownian_motion.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _coord_grid_to_mesh_grid(coord_grid): if len(coord_grid) == 1: return tf.expand_dims(coord_grid[0], -1) return tf.stack(values=tf.meshgrid(*coord_grid, indexing="ij"), axis=-1)
Example #21
Source File: cubic_interpolation_test.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def test_spline_broadcast_batch(self, optimize_for_tpu): """Tests batch shape of spline and interpolation are broadcasted.""" x_data1 = np.linspace(-5.0, 5.0, num=11) x_data2 = np.linspace(0.0, 10.0, num=11) x_data = np.array([x_data1, x_data2]) y_data = 1.0 / (2.0 + x_data**2) x_data = tf.stack(x_data, axis=0) dtype = np.float64 x_value_1 = tf.constant([[[-1.2, 0.0, 0.3]]], dtype=dtype) x_value_2 = tf.constant([-1.2, 0.0, 0.3], dtype=dtype) spline = tff.math.interpolation.cubic.build_spline(x_data, y_data) result_1 = tff.math.interpolation.cubic.interpolate( x_value_1, spline, optimize_for_tpu=optimize_for_tpu, dtype=dtype) result_2 = tff.math.interpolation.cubic.interpolate( x_value_2, spline, optimize_for_tpu=optimize_for_tpu, dtype=dtype) expected_1 = np.array([[[0.29131469, 0.5, 0.4779499], [0.5, 0.5, 0.45159077]]], dtype=dtype) expected_2 = np.array([[0.29131469, 0.5, 0.4779499], [0.5, 0.5, 0.45159077]], dtype=dtype) with self.subTest("BroadcastData"): self.assertAllClose(result_1, expected_1) with self.subTest("BroadcastValues"): self.assertAllClose(result_2, expected_2)
Example #22
Source File: conjugate_gradient_test.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def test_data_fitting(self): """Tests MLE estimation for a simple geometric GLM.""" n, dim = 100, 3 dtype = tf.float64 np.random.seed(234095) x = np.random.choice([0, 1], size=[dim, n]) s = 0.01 * np.sum(x, 0) p = 1. / (1 + np.exp(-s)) y = np.random.geometric(p) x_data = tf.convert_to_tensor(value=x, dtype=dtype) y_data = tf.expand_dims(tf.convert_to_tensor(value=y, dtype=dtype), -1) def neg_log_likelihood(state): state_ext = tf.expand_dims(state, 0) linear_part = tf.matmul(state_ext, x_data) linear_part_ex = tf.stack([tf.zeros_like(linear_part), linear_part], axis=0) term1 = tf.squeeze( tf.matmul(tf.reduce_logsumexp(linear_part_ex, axis=0), y_data), -1) term2 = (0.5 * tf.reduce_sum(state_ext * state_ext, axis=-1) - tf.reduce_sum(linear_part, axis=-1)) return tf.squeeze(term1 + term2) self._check_algorithm( func=neg_log_likelihood, start_point=np.ones(shape=[dim]), expected_argmin=[-0.020460034354, 0.171708568111, 0.021200423717])
Example #23
Source File: diff_ops_test.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def test_diffs_differentiable(self): """Tests that the diffs op is differentiable.""" x = tf.constant(2.0) xv = tf.stack([x, x * x, x * x * x], axis=0) # Produces [x, x^2 - x, x^3 - x^2] dxv = self.evaluate(math.diff(xv)) np.testing.assert_array_equal(dxv, [2., 2., 4.]) grad = self.evaluate(tf.gradients(math.diff(xv), x)[0]) # Note that TF gradients adds up the components of the jacobian. # The sum of [1, 2x-1, 3x^2-2x] at x = 2 is 12. self.assertEqual(grad, 12.0)
Example #24
Source File: gradient_test.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def test_backward_gradient(self): t = tf.range(1, 3, dtype=tf.float32) # Shape [2] func = lambda t: tf.stack([t, t ** 2, t ** 3], axis=0) # Shape [3, 2] with self.subTest("EagerExecution"): backward_grad = self.evaluate(tff.math.gradients(func, t)) self.assertEqual(backward_grad.shape, (2,)) np.testing.assert_allclose(backward_grad, [6., 17.]) with self.subTest("GraphExecution"): @tf.function def grad_computation(): y = func(t) return tff.math.gradients(y, t) backward_grad = self.evaluate(grad_computation()) self.assertEqual(backward_grad.shape, (2,)) np.testing.assert_allclose(backward_grad, [6., 17.])
Example #25
Source File: gradient_test.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def test_forward_gradient(self): t = tf.range(1, 3, dtype=tf.float32) # Shape [2] func = lambda t: tf.stack([t, t ** 2, t ** 3], axis=0) # Shape [3, 2] with self.subTest("EagerExecution"): fwd_grad = self.evaluate(tff.math.fwd_gradient(func, t)) self.assertEqual(fwd_grad.shape, (3, 2)) np.testing.assert_allclose(fwd_grad, [[1., 1.], [2., 4.], [3., 12.]]) with self.subTest("GraphExecution"): @tf.function def grad_computation(): y = func(t) return tff.math.fwd_gradient(y, t) fwd_grad = self.evaluate(grad_computation()) self.assertEqual(fwd_grad.shape, (3, 2)) np.testing.assert_allclose(fwd_grad, [[1., 1.], [2., 4.], [3., 12.]])
Example #26
Source File: arrays.py From trax with Apache License 2.0 | 5 votes |
def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs): del args, kwargs # TODO(nareshmodi): Consider a collective op to gather the tensors from the # various devices for performance reasons. return tf.stack(value.tensors)
Example #27
Source File: array_ops.py From trax with Apache License 2.0 | 5 votes |
def stack(arrays, axis=0): arrays = _promote_dtype(*arrays) # pylint: disable=protected-access unwrapped_arrays = [ a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays ] return asarray(tf.stack(unwrapped_arrays, axis))
Example #28
Source File: extensions.py From trax with Apache License 2.0 | 5 votes |
def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs): del args, kwargs # TODO(nareshmodi): Consider a collective op to gather the tensors from the # various devices for performance reasons. return tf.stack(value.tensors)
Example #29
Source File: base_agent.py From valan with Apache License 2.0 | 4 votes |
def call(self, env_output, neck_state): """Runs the entire episode given time-major tensors. Args: env_output: An `EnvOutput` tuple with following expectations: reward - Unused done - A boolean tensor of shape [num_timesteps, batch_size]. observation - A nested structure with individual tensors that have first two dimensions equal to [num_timesteps, batch_size] info - Unused neck_state: A tensor or nested structure with individual tensors that have first dimension equal to batch_size and no time dimension. Returns: An `AgentOutput` tuple with individual tensors that have first two dimensions equal to [num_timesteps, batch_size] """ unused_reward, done, observation, unused_info = env_output # Add current time_step and batch_size. self._current_num_timesteps = tf.shape(done)[0] self._current_batch_size = tf.shape(done)[1] torso_output = utils.batch_apply(self._torso, observation) # shape: [num_timesteps, batch_size, ...], where the trailing dimensions are # same as trailing dimensions of `neck_state`. reset_state = self._get_reset_state(observation, done, neck_state) neck_output_list = [] for timestep, d in enumerate(tf.unstack(done)): neck_input = utils.get_row_nested_tensor(torso_output, timestep) # If the episode ended, the neck state should be reset before the next # step. curr_timestep_reset_state = utils.get_row_nested_tensor( reset_state, timestep) neck_state = tf.nest.map_structure( lambda reset_state, state: tf.compat.v1.where(d, reset_state, state), curr_timestep_reset_state, neck_state) neck_output, neck_state = self._neck(neck_input, neck_state) neck_output_list.append(neck_output) head_input = tf.nest.map_structure(lambda *tensors: tf.stack(tensors), *neck_output_list) head_output = utils.batch_apply(self._head, head_input) assert isinstance(head_output, common.AgentOutput) return head_output, neck_state
Example #30
Source File: array_ops.py From trax with Apache License 2.0 | 4 votes |
def moveaxis(a, source, destination): # pylint: disable=missing-docstring """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" if not source and not destination: return a a = asarray(a).data if isinstance(source, int): source = (source,) if isinstance(destination, int): destination = (destination,) a_rank = utils._maybe_static(tf.rank(a)) # pylint: disable=protected-access def _correct_axis(axis, rank): if axis < 0: return axis + rank return axis source = tuple(_correct_axis(axis, a_rank) for axis in source) destination = tuple(_correct_axis(axis, a_rank) for axis in destination) if a.shape.rank is not None: perm = [i for i in range(a_rank) if i not in source] for dest, src in sorted(zip(destination, source)): assert dest <= len(perm) perm.insert(dest, src) else: r = tf.range(a_rank) def _remove_indices(a, b): """Remove indices (`b`) from `a`.""" items = tf.unstack(tf.sort(tf.stack(b)), num=len(b)) i = 0 result = [] for item in items: result.append(a[i:item]) i = item + 1 result.append(a[i:]) return tf.concat(result, 0) minus_sources = _remove_indices(r, source) minus_dest = _remove_indices(r, destination) perm = tf.scatter_nd(tf.expand_dims(minus_dest, 1), minus_sources, [a_rank]) perm = tf.tensor_scatter_nd_update(perm, tf.expand_dims(destination, 1), source) a = tf.transpose(a, perm) return utils.tensor_to_ndarray(a)