Python tensorflow.compat.v1.where() Examples
The following are 30
code examples of tensorflow.compat.v1.where().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.compat.v1
, or try the search function
.
Example #1
Source File: learning_test.py From tf-slim with Apache License 2.0 | 6 votes |
def testTensorMultiplierOfGradient(self): gradient = tf.constant(self._grad_vec, dtype=tf.float32) variable = variables_lib.Variable(tf.zeros_like(gradient)) multiplier_flag = variables_lib.Variable(True) tensor_multiplier = tf.where(multiplier_flag, self._multiplier, 1.0) grad_to_var = (gradient, variable) gradient_multipliers = {variable: tensor_multiplier} [grad_to_var] = learning.multiply_gradients([grad_to_var], gradient_multipliers) with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) gradient_true_flag = sess.run(grad_to_var[0]) sess.run(multiplier_flag.assign(False)) gradient_false_flag = sess.run(grad_to_var[0]) np_testing.assert_almost_equal(gradient_true_flag, self._multiplied_grad_vec, 5) np_testing.assert_almost_equal(gradient_false_flag, self._grad_vec, 5)
Example #2
Source File: t2t_model.py From tensor2tensor with Apache License 2.0 | 6 votes |
def body(self, features): """Computes the targets' pre-logit activations given transformed inputs. Most `T2TModel` subclasses will override this method. Args: features: dict of str to Tensor, where each Tensor has shape [batch_size, ..., hidden_size]. It typically contains keys `inputs` and `targets`. Returns: output: Tensor of pre-logit activations with shape [batch_size, ..., hidden_size]. losses: Either single loss as a scalar, a list, a Tensor (to be averaged), or a dictionary of losses. If losses is a dictionary with the key "training", losses["training"] is considered the final training loss and output is considered logits; self.top and self.loss will be skipped. """ raise NotImplementedError("Abstract Method")
Example #3
Source File: expert_utils.py From tensor2tensor with Apache License 2.0 | 6 votes |
def __init__(self, pad_mask): """Compute and store the location of the padding. Args: pad_mask (tf.Tensor): Reference padding tensor of shape [batch_size,length] or [dim_origin] (dim_origin=batch_size*length) containing non-zeros positive values to indicate padding location. """ self.nonpad_ids = None self.dim_origin = None with tf.name_scope("pad_reduce/get_ids"): pad_mask = tf.reshape(pad_mask, [-1]) # Flatten the batch # nonpad_ids contains coordinates of zeros rows (as pad_mask is # float32, checking zero equality is done with |x| < epsilon, with # epsilon=1e-9 as standard, here pad_mask only contains positive values # so tf.abs would be redundant) self.nonpad_ids = tf.to_int32(tf.where(pad_mask < 1e-9)) self.dim_origin = tf.shape(pad_mask)[:1]
Example #4
Source File: scheduled_sampling.py From tensor2tensor with Apache License 2.0 | 6 votes |
def _mix_tokens(p_sample, gold_targets, sampled_targets): """Interleave sampled and gold tokens randomly. Args: p_sample: float in [0, 1]. Probability a token will come from 'sampled_targets'. 0 means all-gold, 1 means all-sampled. gold_targets: Tensor. Gold token IDs. sampled_targets: Tensor. Sampled token IDs. Same shape as 'gold_targets'. Returns: Tensor of same shape as 'gold_targets' containing a mix of tokens from 'gold_targets' and 'sampled_targets'. """ targets_shape = common_layers.shape_list(sampled_targets) return tf.where( tf.less(tf.random_uniform(targets_shape), p_sample), sampled_targets, gold_targets)
Example #5
Source File: sample.py From gpt2-estimator with MIT License | 6 votes |
def top_k_logits(logits, k): if k == 0: # no truncation return logits def _top_k(): values, _ = tf.nn.top_k(logits, k=k) min_values = values[:, -1, tf.newaxis] return tf.where( logits < min_values, tf.ones_like(logits, dtype=logits.dtype) * -1e10, logits, ) return tf.cond( tf.equal(k, 0), lambda: logits, lambda: _top_k(), )
Example #6
Source File: quantization.py From tensor2tensor with Apache License 2.0 | 6 votes |
def _randomized_roundoff_to_bfloat16(x, noise, cand1, cand2): """Round-off x to cand1 or to cand2 in an unbiased way. Cand1 and cand2 are the same shape as x. For every element of x, the corresponding elements of cand1 and cand2 should be the two closest bfloat16 values to x. Order does not matter. cand1 and cand2 must differ from each other. Args: x: A float32 Tensor. noise: A Tensor broadcastable to the shape of x containing random uniform values in [0.0, 1.0]. cand1: A bfloat16 Tensor the same shape as x. cand2: A bfloat16 Tensor the same shape as x. Returns: A bfloat16 Tensor. """ cand1_f = tf.to_float(cand1) cand2_f = tf.to_float(cand2) step_size = cand2_f - cand1_f fpart = (x - cand1_f) / step_size ret = tf.where(tf.greater(fpart, noise), cand2, cand1) return ret
Example #7
Source File: quantization.py From tensor2tensor with Apache License 2.0 | 6 votes |
def _to_bfloat16_unbiased(x, noise): """Convert a float32 to a bfloat16 using randomized roundoff. Args: x: A float32 Tensor. noise: a float32 Tensor with values in [0, 1), broadcastable to tf.shape(x) Returns: A float32 Tensor. """ x_sign = tf.sign(x) # Make sure x is positive. If it is zero, the two candidates are identical. x = x * x_sign + 1e-30 cand1 = tf.to_bfloat16(x) cand1_f = tf.to_float(cand1) # This relies on the fact that for a positive bfloat16 b, # b * 1.005 gives you the next higher bfloat16 and b*0.995 gives you the # next lower one. Both 1.005 and 0.995 are ballpark estimation. cand2 = tf.to_bfloat16( tf.where(tf.greater(x, cand1_f), cand1_f * 1.005, cand1_f * 0.995)) ret = _randomized_roundoff_to_bfloat16(x, noise, cand1, cand2) return ret * tf.to_bfloat16(x_sign)
Example #8
Source File: neural_gpu.py From tensor2tensor with Apache License 2.0 | 6 votes |
def neural_gpu_body(inputs, hparams, name=None): """The core Neural GPU.""" with tf.variable_scope(name, "neural_gpu"): def step(state, inp): # pylint: disable=missing-docstring x = tf.nn.dropout(state, 1.0 - hparams.dropout) for layer in range(hparams.num_hidden_layers): x = common_layers.conv_gru( x, (hparams.kernel_height, hparams.kernel_width), hparams.hidden_size, name="cgru_%d" % layer) # Padding input is zeroed-out in the modality, we check this by summing. padding_inp = tf.less(tf.reduce_sum(tf.abs(inp), axis=[1, 2]), 0.00001) new_state = tf.where(padding_inp, state, x) # No-op where inp is padding. return new_state return tf.foldl( step, tf.transpose(inputs, [1, 0, 2, 3]), initializer=inputs, parallel_iterations=1, swap_memory=True)
Example #9
Source File: ppo.py From tensor2tensor with Apache License 2.0 | 6 votes |
def _distributional_to_value(value_d, size, subscale, threshold): """Get a scalar value out of a value distribution in distributional RL.""" half = size // 2 value_range = (tf.to_float(tf.range(-half, half)) + 0.5) * subscale probs = tf.nn.softmax(value_d) if threshold == 0.0: return tf.reduce_sum(probs * value_range, axis=-1) # accumulated_probs[..., i] is the sum of probabilities in buckets upto i # so it is the probability that value <= i'th bucket value accumulated_probs = tf.cumsum(probs, axis=-1) # New probs are 0 on all lower buckets, until the threshold probs = tf.where(accumulated_probs < threshold, tf.zeros_like(probs), probs) probs /= tf.reduce_sum(probs, axis=-1, keepdims=True) # Re-normalize. return tf.reduce_sum(probs * value_range, axis=-1)
Example #10
Source File: expert_utils.py From tensor2tensor with Apache License 2.0 | 6 votes |
def __init__(self, num_experts, gates): """Create a SparseDispatcher. Args: num_experts: an integer. gates: a `Tensor` of shape `[batch_size, num_experts]`. Returns: a SparseDispatcher """ self._gates = gates self._num_experts = num_experts where = tf.to_int32(tf.where(tf.transpose(gates) > 0)) self._expert_index, self._batch_index = tf.unstack(where, num=2, axis=1) self._part_sizes_tensor = tf.reduce_sum(tf.to_int32(gates > 0), [0]) self._nonzero_gates = tf.gather( tf.reshape(self._gates, [-1]), self._batch_index * num_experts + self._expert_index)
Example #11
Source File: robust_model.py From interval-bound-propagation with Apache License 2.0 | 6 votes |
def filter_correct_class(verifiable_obj, num_classes, labels, margin): """Filters out the objective when the target class contains the true label. Args: verifiable_obj: 2D tensor of shape (num_classes, batch_size) containing verifiable objectives. num_classes: number of target classes. labels: 1D tensor of shape (batch_size) containing the labels for each example in the batch. margin: Verifiable objective values for correct class will be forced to `-margin`, thus disregarding large negative bounds when maximising. Returns: 2D tensor of shape (num_classes, batch_size) containing the corrected verifiable objective values for each (class, example). """ targets_to_filter = tf.expand_dims( tf.range(num_classes, dtype=labels.dtype), axis=1) neq = tf.not_equal(targets_to_filter, labels) verifiable_obj = tf.where(neq, verifiable_obj, -margin * tf.ones_like(verifiable_obj)) return verifiable_obj
Example #12
Source File: dataloader.py From Object_Detection_Tracking with Apache License 2.0 | 6 votes |
def resize_and_crop_boxes(self): """Resize boxes and crop it to the self._output dimension.""" boxlist = preprocessor.box_list.BoxList(self._boxes) boxes = preprocessor.box_list_scale( boxlist, self._scaled_height, self._scaled_width).get() # Adjust box coordinates based on the offset. box_offset = tf.stack([self._crop_offset_y, self._crop_offset_x, self._crop_offset_y, self._crop_offset_x,]) boxes -= tf.cast(tf.reshape(box_offset, [1, 4]), tf.float32) # Clip the boxes. boxes = self.clip_boxes(boxes) # Filter out ground truth boxes that are all zeros. indices = tf.where(tf.not_equal(tf.reduce_sum(boxes, axis=1), 0)) boxes = tf.gather_nd(boxes, indices) classes = tf.gather_nd(self._classes, indices) return boxes, classes
Example #13
Source File: region_similarity_calculator.py From Object_Detection_Tracking with Apache License 2.0 | 6 votes |
def iou(boxlist1, boxlist2, scope=None): """Computes pairwise intersection-over-union between box collections. Args: boxlist1: BoxList holding N boxes boxlist2: BoxList holding M boxes scope: name scope. Returns: a tensor with shape [N, M] representing pairwise iou scores. """ with tf.name_scope(scope, 'IOU'): intersections = intersection(boxlist1, boxlist2) areas1 = area(boxlist1) areas2 = area(boxlist2) unions = ( tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections) return tf.where( tf.equal(intersections, 0.0), tf.zeros_like(intersections), tf.truediv(intersections, unions))
Example #14
Source File: matcher.py From Object_Detection_Tracking with Apache License 2.0 | 6 votes |
def _match(self, similarity_matrix, **params): """Method to be overridden by implementations. Args: similarity_matrix: Float tensor of shape [N, M] with pairwise similarity where higher value means more similar. **params: Additional keyword arguments for specific implementations of the Matcher. Returns: match_results: Integer tensor of shape [M]: match_results[i]>=0 means that column i is matched to row match_results[i], match_results[i]=-1 means that the column is not matched. match_results[i]=-2 means that the column is ignored (usually this happens when there is a very weak match which one neither wants as positive nor negative example). """ pass
Example #15
Source File: dropout.py From lamb with Apache License 2.0 | 6 votes |
def _build(self, x, state): prev_keep_mask = state shape = tf.shape(x) noise = tf.random_uniform(shape, dtype=x.dtype) other_mask = tf.floor(self._keep_prob + noise) choice_noise = tf.random_uniform(shape, dtype=x.dtype) choice = tf.less(choice_noise, self._flip_prob) # KLUDGE(melisgl): The client has to pass the last keep_mask from # a batch to the next so the mask may end up next to some # recurrent cell state. This state is often zero at the beginning # and may be periodically zeroed (per example) during training. # While zeroing LSTM state is okay, zeroing the dropout mask is # not. So instead of forcing every client to deal with this common # (?) case, if an all zero mask is detected, then regenerate a # fresh mask. This is of course a major hack and won't help with # learnt initial states, for example. sum_ = tf.reduce_sum(prev_keep_mask, 1, keepdims=True) is_initializing = tf.equal(sum_, 0.0) self._keep_mask = tf.where(tf.logical_or(choice, is_initializing), other_mask, prev_keep_mask) self._time_step += 1 return x * self._keep_mask / self._keep_prob * self._scaler
Example #16
Source File: spectral_ops.py From magenta with Apache License 2.0 | 6 votes |
def unwrap(p, discont=np.pi, axis=-1): """Unwrap a cyclical phase tensor. Args: p: Phase tensor. discont: Float, size of the cyclic discontinuity. axis: Axis of which to unwrap. Returns: unwrapped: Unwrapped tensor of same size as input. """ dd = diff(p, axis=axis) ddmod = tf.mod(dd + np.pi, 2.0 * np.pi) - np.pi idx = tf.logical_and(tf.equal(ddmod, -np.pi), tf.greater(dd, 0)) ddmod = tf.where(idx, tf.ones_like(ddmod) * np.pi, ddmod) ph_correct = ddmod - dd idx = tf.less(tf.abs(dd), discont) ddmod = tf.where(idx, tf.zeros_like(ddmod), dd) ph_cumsum = tf.cumsum(ph_correct, axis=axis) shape = p.get_shape().as_list() shape[axis] = 1 ph_cumsum = tf.concat([tf.zeros(shape, dtype=p.dtype), ph_cumsum], axis=axis) unwrapped = p + ph_cumsum return unwrapped
Example #17
Source File: common_video.py From tensor2tensor with Apache License 2.0 | 6 votes |
def _encode_gif(images, fps): """Encodes numpy images into gif string. Args: images: A 4-D `uint8` `np.array` (or a list of 3-D images) of shape `[time, height, width, channels]` where `channels` is 1 or 3. fps: frames per second of the animation Returns: The encoded gif string. Raises: IOError: If the ffmpeg command returns an error. """ writer = WholeVideoWriter(fps) writer.write_multi(images) return writer.finish()
Example #18
Source File: common_video.py From tensor2tensor with Apache License 2.0 | 6 votes |
def scheduled_sample_prob(ground_truth_x, generated_x, batch_size, scheduled_sample_var): """Probability based scheduled sampling. Args: ground_truth_x: tensor of ground-truth data points. generated_x: tensor of generated data points. batch_size: batch size scheduled_sample_var: probability of choosing from ground_truth. Returns: New batch with randomly selected data points. """ probability_threshold = scheduled_sample_var probability_of_generated = tf.random_uniform([batch_size]) return tf.where(probability_of_generated > probability_threshold, generated_x, ground_truth_x)
Example #19
Source File: discretization.py From tensor2tensor with Apache License 2.0 | 6 votes |
def vae(x, z_size, name=None): """Simple variational autoencoder without discretization. Args: x: Input to the discretization bottleneck. z_size: Number of bits, where discrete codes range from 1 to 2**z_size. name: Name for the bottleneck scope. Returns: Embedding function, latent, loss, mu and log_simga. """ with tf.variable_scope(name, default_name="vae"): mu = tf.layers.dense(x, z_size, name="mu") log_sigma = tf.layers.dense(x, z_size, name="log_sigma") shape = common_layers.shape_list(x) epsilon = tf.random_normal([shape[0], shape[1], 1, z_size]) z = mu + tf.exp(log_sigma / 2) * epsilon kl = 0.5 * tf.reduce_mean( tf.expm1(log_sigma) + tf.square(mu) - log_sigma, axis=-1) free_bits = z_size // 4 kl_loss = tf.reduce_mean(tf.maximum(kl - free_bits, 0.0)) return z, kl_loss, mu, log_sigma
Example #20
Source File: utils.py From magenta with Apache License 2.0 | 6 votes |
def frequency_weighted_cost_mask(peak=10.0, hz_flat=1000, sr=16000, n_fft=512): """Calculates a mask to weight lower frequencies higher. Piecewise linear approximation. Assumes magnitude is in log scale. Args: peak: Cost increase at 0 Hz. hz_flat: Hz at which cost increase is 0. sr: Sample rate. n_fft: FFT size. Returns: Constant tensor [1, N_freq, 1] of cost weighting. """ n = int(n_fft / 2) cutoff = np.where( librosa.core.fft_frequencies(sr=sr, n_fft=n_fft) >= hz_flat)[0][0] mask = np.concatenate([np.linspace(peak, 1.0, cutoff), np.ones(n - cutoff)]) return tf.constant(mask[np.newaxis, :, np.newaxis], dtype=tf.float32) #--------------------------------------------------- # Neural Nets #---------------------------------------------------
Example #21
Source File: common_layers.py From tensor2tensor with Apache License 2.0 | 6 votes |
def _select_top_k(logits, top_k): """Replaces logits, expect the top k highest values, with small number (-1e6). If k is -1 don't replace anything. Args: logits: A `Tensor` of shape [batch_size, ..., vocab_size] top_k: vector of batch size. Returns: A `Tensor` with same shape as logits. """ vocab_size = logits.shape[-1] top_k = tf.where( tf.not_equal(top_k, -1), top_k, tf.ones_like(top_k) * vocab_size) return tf.where( tf.argsort(logits) < tf.reshape(top_k, [-1] + [1] * (len(logits.shape) - 1)), logits, tf.ones_like(logits) * -1e6)
Example #22
Source File: target_assigner.py From Object_Detection_Tracking with Apache License 2.0 | 6 votes |
def _create_classification_targets(self, groundtruth_labels, match): """Create classification targets for each anchor. Assign a classification target of for each anchor to the matching groundtruth label that is provided by match. Anchors that are not matched to anything are given the target self._unmatched_cls_target Args: groundtruth_labels: a tensor of shape [num_gt_boxes, d_1, ... d_k] with labels for each of the ground_truth boxes. The subshape [d_1, ... d_k] can be empty (corresponding to scalar labels). match: a matcher.Match object that provides a matching between anchors and groundtruth boxes. Returns: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], where the subshape [d_1, ..., d_k] is compatible with groundtruth_labels which has shape [num_gt_boxes, d_1, d_2, ... d_k]. """ return match.gather_based_on_match( groundtruth_labels, unmatched_value=self._unmatched_cls_target, ignored_value=self._unmatched_cls_target)
Example #23
Source File: common_layers.py From tensor2tensor with Apache License 2.0 | 6 votes |
def shape_list(x): """Return list of dims, statically where possible.""" x = tf.convert_to_tensor(x) # If unknown rank, return dynamic shape if x.get_shape().dims is None: return tf.shape(x) static = x.get_shape().as_list() shape = tf.shape(x) ret = [] for i, dim in enumerate(static): if dim is None: dim = shape[i] ret.append(dim) return ret
Example #24
Source File: matcher.py From Object_Detection_Tracking with Apache License 2.0 | 6 votes |
def match(self, similarity_matrix, scope=None, **params): """Computes matches among row and column indices and returns the result. Computes matches among the row and column indices based on the similarity matrix and optional arguments. Args: similarity_matrix: Float tensor of shape [N, M] with pairwise similarity where higher value means more similar. scope: Op scope name. Defaults to 'Match' if None. **params: Additional keyword arguments for specific implementations of the Matcher. Returns: A Match object with the results of matching. """ with tf.name_scope(scope, 'Match', [similarity_matrix, params]) as scope: return Match(self._match(similarity_matrix, **params))
Example #25
Source File: glow_ops.py From tensor2tensor with Apache License 2.0 | 6 votes |
def scale_gaussian_prior(name, z, logscale_factor=3.0, trainable=True): """Returns N(s^i * z^i, std^i) where s^i and std^i are pre-component. s^i is a learnable parameter with identity initialization. std^i is optionally learnable with identity initialization. Args: name: variable scope. z: input_tensor logscale_factor: equivalent to scaling up the learning_rate by a factor of logscale_factor. trainable: Whether or not std^i is learnt. """ with tf.variable_scope(name, reuse=tf.AUTO_REUSE): z_shape = common_layers.shape_list(z) latent_multiplier = tf.get_variable( "latent_multiplier", shape=z_shape, dtype=tf.float32, initializer=tf.ones_initializer()) log_scale = tf.get_variable( "log_scale_latent", shape=z_shape, dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=trainable) log_scale = log_scale * logscale_factor return tfp.distributions.Normal( loc=latent_multiplier * z, scale=tf.exp(log_scale))
Example #26
Source File: metrics.py From magenta with Apache License 2.0 | 5 votes |
def accuracy_without_true_negatives(true_positives, false_positives, false_negatives): """Creates an op for calculating accuracy without true negatives. Args: true_positives: A tensor representing true_positives. false_positives: A tensor representing false_positives. false_negatives: A tensor representing false_negatives. Returns: A tensor with the result of the calculation. """ return tf.where( tf.greater(true_positives + false_positives + false_negatives, 0), true_positives / (true_positives + false_positives + false_negatives), 0)
Example #27
Source File: sequence_example_lib.py From magenta with Apache License 2.0 | 5 votes |
def flatten_maybe_padded_sequences(maybe_padded_sequences, lengths=None): """Flattens the batch of sequences, removing padding (if applicable). Args: maybe_padded_sequences: A tensor of possibly padded sequences to flatten, sized `[N, M, ...]` where M = max(lengths). lengths: Optional length of each sequence, sized `[N]`. If None, assumes no padding. Returns: flatten_maybe_padded_sequences: The flattened sequence tensor, sized `[sum(lengths), ...]`. """ def flatten_unpadded_sequences(): # The sequences are equal length, so we should just flatten over the first # two dimensions. return tf.reshape(maybe_padded_sequences, [-1] + maybe_padded_sequences.shape.as_list()[2:]) if lengths is None: return flatten_unpadded_sequences() def flatten_padded_sequences(): indices = tf.where(tf.sequence_mask(lengths)) return tf.gather_nd(maybe_padded_sequences, indices) return tf.cond( tf.equal(tf.reduce_min(lengths), tf.shape(maybe_padded_sequences)[1]), flatten_unpadded_sequences, flatten_padded_sequences)
Example #28
Source File: utils.py From magenta with Apache License 2.0 | 5 votes |
def inv_mu_law_numpy(x, mu=255.0): """A numpy implementation of inverse Mu-Law. Args: x: The Mu-Law samples to decode. mu: The Mu we used to encode these samples. Returns: out: The decoded data. """ x = np.array(x).astype(np.float32) out = (x + 0.5) * 2. / (mu + 1) out = np.sign(out) / mu * ((1 + mu)**np.abs(out) - 1) out = np.where(np.equal(x, 0), x, out) return out
Example #29
Source File: benchmark_cnn.py From benchmarks with Apache License 2.0 | 5 votes |
def gradient_histogram_summary(self, avg_grads): """Create histogram of log values of all non-zero gradients.""" with tf.name_scope('log_gradients_summary'): all_grads = [] for grad, _ in avg_grads: all_grads.append(tf.reshape(grad, [-1])) grads = tf.abs(tf.concat(all_grads, 0)) # exclude grads with zero values. indices_for_non_zero_grads = tf.where(tf.not_equal(grads, 0)) log_grads = tf.reshape( tf.log(tf.gather(grads, indices_for_non_zero_grads)), [-1]) tf.summary.histogram('log_gradients', log_grads)
Example #30
Source File: visualization.py From tensor2robot with Apache License 2.0 | 5 votes |
def get_softmax_viz(image, softmax, nrows=None): """Arrange softmax maps in a grid and superimpose them on the image.""" softmax_shape = tf.shape(softmax) batch_size = softmax_shape[0] target_height = softmax_shape[1] * 2 target_width = softmax_shape[2] * 2 num_points = softmax_shape[3] if nrows is None: # Find a number of rows such that the arrangement is as square as possible. num_points_float = tf.cast(num_points, tf.float32) nfsqrt = tf.cast(tf.floor(tf.sqrt(num_points_float)), tf.int32) divs = tf.range(1, nfsqrt + 1) remainders = tf.mod(num_points_float, tf.cast(divs, tf.float32)) divs = tf.gather(divs, tf.where(tf.equal(remainders, 0))) nrows = tf.reduce_max(divs) ncols = tf.cast(num_points / nrows, tf.int32) nrows = tf.cast(nrows, tf.int32) # Normalize per channel img = softmax / tf.reduce_max(softmax, axis=[1, 2], keepdims=True) # Use softmax as hue and saturation and original image as value of HSV image. greyimg = tf.image.rgb_to_grayscale(image) greyimg = tf.image.resize_images(greyimg, [target_height, target_width]) greyimg = tf.tile(greyimg, [1, 1, 1, num_points]) greyimg = tf.reshape(greyimg, [batch_size, target_height, target_width, num_points, 1]) img = tf.image.resize_images(img, [target_height, target_width]) img = tf.reshape(img, [batch_size, target_height, target_width, num_points, 1]) img = tf.concat([img / 2.0 + 0.5, img, greyimg * 0.7 + 0.3], axis=4) # Rearrange channels into a ncols x nrows grid. img = tf.reshape(img, [batch_size, target_height, target_width, nrows, ncols, 3]) img = tf.transpose(img, [0, 3, 1, 4, 2, 5]) img = tf.reshape(img, [batch_size, target_height * nrows, target_width * ncols, 3]) img = tf.image.hsv_to_rgb(img) return img