Python tensorflow.compat.v1.argmax() Examples
The following are 30
code examples of tensorflow.compat.v1.argmax().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.compat.v1
, or try the search function
.
Example #1
Source File: latent_layers.py From tensor2tensor with Apache License 2.0 | 6 votes |
def multinomial_sample(x, vocab_size=None, sampling_method="random", temperature=1.0): """Multinomial sampling from a n-dimensional tensor. Args: x: Tensor of shape [..., vocab_size]. Parameterizes logits of multinomial. vocab_size: Number of classes in multinomial distribution. sampling_method: String, "random" or otherwise deterministic. temperature: Positive float. Returns: Tensor of shape [...]. """ vocab_size = vocab_size or common_layers.shape_list(x)[-1] if sampling_method == "random" and temperature > 0.0: samples = tf.multinomial(tf.reshape(x, [-1, vocab_size]) / temperature, 1) else: samples = tf.argmax(x, axis=-1) reshaped_samples = tf.reshape(samples, common_layers.shape_list(x)[:-1]) return reshaped_samples
Example #2
Source File: autoencoders.py From tensor2tensor with Apache License 2.0 | 6 votes |
def gumbel_sample(self, reconstr_gan): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN vocab_size = self._problem_hparams.vocab_size["targets"] if hasattr(self._hparams, "vocab_divisor"): vocab_size += (-vocab_size) % self._hparams.vocab_divisor reconstr_gan = tf.nn.log_softmax(reconstr_gan) if is_training and hparams.gumbel_temperature > 0.0: gumbel_samples = discretization.gumbel_sample( common_layers.shape_list(reconstr_gan)) gumbel_samples *= hparams.gumbel_noise_factor reconstr_gan += gumbel_samples reconstr_sample = latent_layers.multinomial_sample( reconstr_gan, temperature=hparams.gumbel_temperature) reconstr_gan = tf.nn.softmax(reconstr_gan / hparams.gumbel_temperature) else: reconstr_sample = tf.argmax(reconstr_gan, axis=-1) reconstr_gan = tf.nn.softmax(reconstr_gan / 0.1) # Sharpen a bit. # Use 1-hot forward, softmax backward. reconstr_hot = tf.one_hot(reconstr_sample, vocab_size) reconstr_gan += reconstr_hot - tf.stop_gradient(reconstr_gan) return reconstr_gan
Example #3
Source File: autoencoders.py From tensor2tensor with Apache License 2.0 | 6 votes |
def infer(self, features, *args, **kwargs): # pylint: disable=arguments-differ """Produce predictions from the model by sampling.""" del args, kwargs # Inputs and features preparation needed to handle edge cases. if not features: features = {} inputs_old = None if "inputs" in features and len(features["inputs"].shape) < 4: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 2) # Sample and decode. num_channels = self.num_channels if "targets" not in features: features["targets"] = tf.zeros( [self.hparams.batch_size, 1, 1, num_channels], dtype=tf.int32) logits, _ = self(features) # pylint: disable=not-callable samples = tf.argmax(logits, axis=-1) # Restore inputs to not confuse Estimator in edge cases. if inputs_old is not None: features["inputs"] = inputs_old # Return samples. return samples
Example #4
Source File: modalities.py From tensor2tensor with Apache License 2.0 | 6 votes |
def image_top(body_output, targets, model_hparams, vocab_size): """Top transformation for images.""" del targets # unused arg # TODO(lukaszkaiser): is this a universal enough way to get channels? num_channels = model_hparams.problem.num_channels with tf.variable_scope("rgb_softmax"): body_output_shape = common_layers.shape_list(body_output) reshape_shape = body_output_shape[:3] reshape_shape.extend([num_channels, vocab_size]) res = tf.layers.dense(body_output, vocab_size * num_channels) res = tf.reshape(res, reshape_shape) if not tf.get_variable_scope().reuse: res_argmax = tf.argmax(res, axis=-1) tf.summary.image( "result", common_layers.tpu_safe_image_summary(res_argmax), max_outputs=1) return res
Example #5
Source File: loss.py From interval-bound-propagation with Apache License 2.0 | 6 votes |
def _build_attack_loss(self, labels): """Build adversarial loss using PGD attack.""" # PGD attack. if not self._attack: self._attack_accuracy = tf.constant(0.) self._attack_success = tf.constant(1.) self._attack_cross_entropy = tf.constant(0.) return if not isinstance(self._predictor.inputs, tf.Tensor): raise ValueError('Multiple inputs is not supported.') self._attack(self._predictor.inputs, labels) correct_examples = tf.equal(labels, tf.argmax(self._attack.logits, 1)) self._attack_accuracy = tf.reduce_mean( tf.cast(correct_examples, tf.float32)) self._attack_success = tf.reduce_mean( tf.cast(self._attack.success, tf.float32)) if self._label_smoothing > 0: attack_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2( labels=self._one_hot_labels, logits=self._attack.logits) else: attack_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=self._attack.logits) self._attack_cross_entropy = tf.reduce_mean(attack_cross_entropy)
Example #6
Source File: loss.py From interval-bound-propagation with Apache License 2.0 | 6 votes |
def _build_nominal_loss(self, labels): """Build natural cross-entropy loss on clean data.""" # Cross-entropy. nominal_logits = self._predictor.logits if self._label_smoothing > 0: num_classes = nominal_logits.shape[1].value one_hot_labels = tf.one_hot(labels, num_classes) smooth_positives = 1. - self._label_smoothing smooth_negatives = self._label_smoothing / num_classes one_hot_labels = one_hot_labels * smooth_positives + smooth_negatives nominal_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2( labels=one_hot_labels, logits=nominal_logits) self._one_hot_labels = one_hot_labels else: nominal_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=nominal_logits) self._cross_entropy = tf.reduce_mean(nominal_cross_entropy) # Accuracy. nominal_correct_examples = tf.equal(labels, tf.argmax(nominal_logits, 1)) self._nominal_accuracy = tf.reduce_mean( tf.cast(nominal_correct_examples, tf.float32))
Example #7
Source File: transformer_nat.py From tensor2tensor with Apache License 2.0 | 6 votes |
def vq_nearest_neighbor(x, hparams): """Find the nearest element in means to elements in x.""" bottleneck_size = 2**hparams.bottleneck_bits means = hparams.means x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True) means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keepdims=True) scalar_prod = tf.matmul(x, means, transpose_b=True) dist = x_norm_sq + tf.transpose(means_norm_sq) - 2 * scalar_prod if hparams.bottleneck_kind == "em": x_means_idx = tf.multinomial(-dist, num_samples=hparams.num_samples) x_means_hot = tf.one_hot( x_means_idx, depth=bottleneck_size) x_means_hot = tf.reduce_mean(x_means_hot, axis=1) else: x_means_idx = tf.argmax(-dist, axis=-1) x_means_hot = tf.one_hot(x_means_idx, depth=bottleneck_size) x_means = tf.matmul(x_means_hot, means) e_loss = tf.reduce_mean(tf.squared_difference(x, tf.stop_gradient(x_means))) return x_means_hot, e_loss
Example #8
Source File: common_layers.py From tensor2tensor with Apache License 2.0 | 6 votes |
def argmax_with_score(logits, axis=None): """Argmax along with the value.""" axis = axis or len(logits.get_shape()) - 1 predictions = tf.argmax(logits, axis=axis) logits_shape = shape_list(logits) prefix_shape, vocab_size = logits_shape[:-1], logits_shape[-1] prefix_size = 1 for d in prefix_shape: prefix_size *= d # Flatten to extract scores flat_logits = tf.reshape(logits, [prefix_size, vocab_size]) flat_predictions = tf.reshape(predictions, [prefix_size]) flat_indices = tf.stack( [tf.range(tf.to_int64(prefix_size)), tf.to_int64(flat_predictions)], axis=1) flat_scores = tf.gather_nd(flat_logits, flat_indices) # Unflatten scores = tf.reshape(flat_scores, prefix_shape) return predictions, scores
Example #9
Source File: slate_decomp_q_agent.py From recsim with Apache License 2.0 | 6 votes |
def _network_adapter(self, states, scope): self._validate_states(states) with tf.name_scope('network'): # Since we decompose the slate optimization into an item-level # optimization problem, the observation space is the user state # observation plus all documents' observations. In the Dopamine DQN agent # implementation, there is one head for each possible action value, which # is designed for computing the argmax operation in the action space. # In our implementation, we generate one output for each document. q_value_list = [] for i in range(self._num_candidates): user = tf.squeeze(states[:, 0, :, :], axis=2) doc = tf.squeeze(states[:, i + 1, :, :], axis=2) q_value_list.append(self.network(user, doc, scope)) q_values = tf.concat(q_value_list, axis=1) return dqn_agent.DQNNetworkType(q_values)
Example #10
Source File: common_layers.py From tensor2tensor with Apache License 2.0 | 6 votes |
def top_1_tpu(inputs): """find max and argmax over the last dimension. Works well on TPU Args: inputs: A tensor with shape [..., depth] Returns: values: a Tensor with shape [...] indices: a Tensor with shape [...] """ inputs_max = tf.reduce_max(inputs, axis=-1, keepdims=True) mask = tf.to_int32(tf.equal(inputs_max, inputs)) index = tf.range(tf.shape(inputs)[-1]) * mask return tf.squeeze(inputs_max, -1), tf.reduce_max(index, axis=-1)
Example #11
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def roc_auc(logits, labels, weights_fn=None): """Calculate ROC AUC. Requires binary classes. Args: logits: Tensor of size [batch_size, 1, 1, num_classes] labels: Tensor of size [batch_size, 1, 1, num_classes] weights_fn: Function that takes in labels and weighs examples (unused) Returns: ROC AUC (scalar), weights """ del weights_fn with tf.variable_scope("roc_auc", values=[logits, labels]): predictions = tf.argmax(logits, axis=-1) _, auc = tf.metrics.auc(labels, predictions, curve="ROC") return auc, tf.constant(1.0)
Example #12
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def sigmoid_precision_one_hot(logits, labels, weights_fn=None): """Calculate precision for a set, given one-hot labels and logits. Predictions are converted to one-hot, as predictions[example][arg-max(example)] = 1 Args: logits: Tensor of size [batch-size, o=1, p=1, num-classes] labels: Tensor of size [batch-size, o=1, p=1, num-classes] weights_fn: Function that takes in labels and weighs examples (unused) Returns: precision (scalar), weights """ with tf.variable_scope("sigmoid_precision_one_hot", values=[logits, labels]): del weights_fn num_classes = logits.shape[-1] predictions = tf.nn.sigmoid(logits) predictions = tf.argmax(predictions, -1) predictions = tf.one_hot(predictions, num_classes) _, precision = tf.metrics.precision(labels=labels, predictions=predictions) return precision, tf.constant(1.0)
Example #13
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def sigmoid_accuracy(logits, labels, weights_fn=None): """Calculate accuracy for a set, given integer labels and logits. Args: logits: Tensor of size [batch-size, o=1, p=1, num-classes] labels: Tensor of size [batch-size, o=1, p=1] weights_fn: Function that takes in labels and weighs examples (unused) Returns: accuracy (scalar), weights """ with tf.variable_scope("sigmoid_accuracy", values=[logits, labels]): del weights_fn predictions = tf.nn.sigmoid(logits) predictions = tf.argmax(predictions, -1) _, accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions) return accuracy, tf.constant(1.0)
Example #14
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def sigmoid_accuracy_one_hot(logits, labels, weights_fn=None): """Calculate accuracy for a set, given one-hot labels and logits. Args: logits: Tensor of size [batch-size, o=1, p=1, num-classes] labels: Tensor of size [batch-size, o=1, p=1, num-classes] weights_fn: Function that takes in labels and weighs examples (unused) Returns: accuracy (scalar), weights """ with tf.variable_scope("sigmoid_accuracy_one_hot", values=[logits, labels]): del weights_fn predictions = tf.nn.sigmoid(logits) labels = tf.argmax(labels, -1) predictions = tf.argmax(predictions, -1) _, accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions) return accuracy, tf.constant(1.0)
Example #15
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def image_summary(predictions, targets, hparams): """Reshapes predictions and passes it to tensorboard. Args: predictions : The predicted image (logits). targets : The ground truth. hparams: model hparams. Returns: summary_proto: containing the summary images. weights: A Tensor of zeros of the same shape as predictions. """ del hparams results = tf.cast(tf.argmax(predictions, axis=-1), tf.uint8) gold = tf.cast(targets, tf.uint8) summary1 = tf.summary.image("prediction", results, max_outputs=2) summary2 = tf.summary.image("data", gold, max_outputs=2) summary = tf.summary.merge([summary1, summary2]) return summary, tf.zeros_like(predictions)
Example #16
Source File: retrain.py From AudioNet with MIT License | 6 votes |
def add_evaluation_step(result_tensor, ground_truth_tensor): """Inserts the operations we need to evaluate the accuracy of our results. Args: result_tensor: The new final node that produces results. ground_truth_tensor: The node we feed ground truth data into. Returns: Tuple of (evaluation step, prediction). """ with tf.name_scope('accuracy'): with tf.name_scope('correct_prediction'): prediction = tf.argmax(result_tensor, 1) correct_prediction = tf.equal( prediction, tf.argmax(ground_truth_tensor, 1)) with tf.name_scope('accuracy'): evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', evaluation_step) return evaluation_step, prediction
Example #17
Source File: quantile_agent.py From batch_rl with Apache License 2.0 | 6 votes |
def _build_target_distribution(self): batch_size = tf.shape(self._replay.rewards)[0] # size of rewards: batch_size x 1 rewards = self._replay.rewards[:, None] # size of tiled_support: batch_size x num_atoms is_terminal_multiplier = 1. - tf.cast(self._replay.terminals, tf.float32) # Incorporate terminal state to discount factor. # size of gamma_with_terminal: batch_size x 1 gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier gamma_with_terminal = gamma_with_terminal[:, None] # size of next_qt_argmax: 1 x batch_size next_qt_argmax = tf.argmax( self._replay_next_target_net_outputs.q_values, axis=1)[:, None] batch_indices = tf.range(tf.to_int64(batch_size))[:, None] # size of next_qt_argmax: batch_size x 2 batch_indexed_next_qt_argmax = tf.concat( [batch_indices, next_qt_argmax], axis=1) # size of next_logits (next quantiles): batch_size x num_atoms next_logits = tf.gather_nd( self._replay_next_target_net_outputs.logits, batch_indexed_next_qt_argmax) return rewards + gamma_with_terminal * next_logits
Example #18
Source File: overfeat_test.py From models with Apache License 2.0 | 6 votes |
def testTrainEvalWithReuse(self): train_batch_size = 2 eval_batch_size = 1 train_height, train_width = 231, 231 eval_height, eval_width = 281, 281 num_classes = 1000 with self.test_session(): train_inputs = tf.random.uniform( (train_batch_size, train_height, train_width, 3)) logits, _ = overfeat.overfeat(train_inputs) self.assertListEqual(logits.get_shape().as_list(), [train_batch_size, num_classes]) tf.get_variable_scope().reuse_variables() eval_inputs = tf.random.uniform( (eval_batch_size, eval_height, eval_width, 3)) logits, _ = overfeat.overfeat(eval_inputs, is_training=False, spatial_squeeze=False) self.assertListEqual(logits.get_shape().as_list(), [eval_batch_size, 2, 2, num_classes]) logits = tf.reduce_mean(input_tensor=logits, axis=[1, 2]) predictions = tf.argmax(input=logits, axis=1) self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
Example #19
Source File: full_slate_q_agent.py From recsim with Apache License 2.0 | 5 votes |
def _build_networks(self): with tf.name_scope('networks'): self._replay_net_outputs = self._network_adapter(self._replay.states, 'Online') self._replay_next_target_net_outputs = self._network_adapter( self._replay.states, 'Target') self._net_outputs = self._network_adapter(self.state_ph, 'Online') self._q_argmax = tf.argmax(input=self._net_outputs.q_values, axis=1)[0]
Example #20
Source File: model.py From ocrd_anybaseocr with Apache License 2.0 | 5 votes |
def mrcnn_class_loss_graph(target_class_ids, pred_class_logits, active_class_ids): """Loss for the classifier head of Mask RCNN. target_class_ids: [batch, num_rois]. Integer class IDs. Uses zero padding to fill in the array. pred_class_logits: [batch, num_rois, num_classes] active_class_ids: [batch, num_classes]. Has a value of 1 for classes that are in the dataset of the image, and 0 for classes that are not in the dataset. """ # During model building, Keras calls this function with # target_class_ids of type float32. Unclear why. Cast it # to int to get around it. target_class_ids = tf.cast(target_class_ids, 'int64') # Find predictions of classes that are not in the dataset. pred_class_ids = tf.argmax(pred_class_logits, axis=2) # TODO: Update this line to work with batch > 1. Right now it assumes all # images in a batch have the same active_class_ids pred_active = tf.gather(active_class_ids[0], pred_class_ids) # Loss loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=target_class_ids, logits=pred_class_logits) # Erase losses of predictions of classes that are not in the active # classes of the image. loss = loss * pred_active # Computer loss mean. Use only predictions that contribute # to the loss to get a correct mean. loss = tf.reduce_sum(loss) / tf.reduce_sum(pred_active) return loss
Example #21
Source File: resnet_train_eval.py From rigl with Apache License 2.0 | 5 votes |
def create_eval_metrics(labels, logits): """Creates the evaluation metrics for the model.""" eval_metrics = {} label_keys = CLASSES predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32) eval_metrics['eval_accuracy'] = tf.metrics.accuracy( labels=labels, predictions=predictions) if FLAGS.per_class_metrics: with tf.name_scope('class_level_summaries') as scope: for i in range(len(label_keys)): labels = tf.cast(labels, tf.int64) name = scope + '/' + label_keys[i] eval_metrics[('class_level_summaries/precision/' + label_keys[i])] = tf.metrics.precision_at_k( labels=labels, predictions=logits, class_id=i, k=1, name=name) eval_metrics[('class_level_summaries/recall/' + label_keys[i])] = tf.metrics.recall_at_k( labels=labels, predictions=logits, class_id=i, k=1, name=name) return eval_metrics
Example #22
Source File: i3d_test.py From models with Apache License 2.0 | 5 votes |
def testEvaluation(self): batch_size = 2 num_frames = 64 height, width = 224, 224 num_classes = 1000 eval_inputs = tf.random.uniform((batch_size, num_frames, height, width, 3)) logits, _ = i3d.i3d(eval_inputs, num_classes, is_training=False) predictions = tf.argmax(input=logits, axis=1) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) output = sess.run(predictions) self.assertEquals(output.shape, (batch_size,))
Example #23
Source File: overfeat_test.py From models with Apache License 2.0 | 5 votes |
def testEvaluation(self): batch_size = 2 height, width = 231, 231 num_classes = 1000 with self.test_session(): eval_inputs = tf.random.uniform((batch_size, height, width, 3)) logits, _ = overfeat.overfeat(eval_inputs, is_training=False) self.assertListEqual(logits.get_shape().as_list(), [batch_size, num_classes]) predictions = tf.argmax(input=logits, axis=1) self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
Example #24
Source File: run_recurrent_model_boolq.py From language with Apache License 2.0 | 5 votes |
def evaluate(): """Evaluate a model on the dev set.""" sess = tf.Session() tf.logging.info("Building graph...") embeddings = load_embeddings() tf_data = load_batched_dataset(False, embeddings) it = tf_data.make_initializable_iterator() features, labels = it.get_next() logits = predict(False, embeddings, features["premise"], features["hypothesis"]) accuracy, update_ops = tf.metrics.accuracy( tf.argmax(logits, 1, output_type=tf.int32), tf.to_int32(labels)) tf.logging.info("Running initializers...") checkpoint_file = FLAGS.checkpoint_file if checkpoint_file is not None: saver = tf.train.Saver(tf.trainable_variables()) tf.logging.info("Restoring from checkpoint: " + checkpoint_file) saver.restore(sess, checkpoint_file) else: tf.logging.warning("No checkpoint given, evaling model with random weights") sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) sess.run(it.initializer) tf.logging.info("Starting loop....") while True: try: sess.run(update_ops) except tf.errors.OutOfRangeError: break tf.logging.info("Done") accuracy = sess.run(accuracy) print("Accuracy: %f" % accuracy)
Example #25
Source File: robust_model.py From interval-bound-propagation with Apache License 2.0 | 5 votes |
def run_classification(self, inputs, labels, length): prediction = self.run_prediction(inputs, length) correct = tf.cast(tf.equal(labels, tf.argmax(prediction, 1)), dtype=tf.float32) return correct
Example #26
Source File: attacks.py From interval-bound-propagation with Apache License 2.0 | 5 votes |
def find_worst_attack(self, objective_fn, adversarial_input, batch_size, input_shape): """Returns the attack that maximizes objective_fn.""" adversarial_objective = objective_fn(adversarial_input) adversarial_objective = tf.reshape(adversarial_objective, [-1, batch_size]) adversarial_input = tf.reshape(adversarial_input, [-1, batch_size] + input_shape) i = tf.argmax(adversarial_objective, axis=0) j = tf.cast(tf.range(tf.shape(adversarial_objective)[1]), i.dtype) ij = tf.stack([i, j], axis=1) return tf.gather_nd(adversarial_input, ij)
Example #27
Source File: multi_network_dqn_agent.py From batch_rl with Apache License 2.0 | 5 votes |
def _build_networks(self): super(MultiNetworkDQNAgent, self)._build_networks() # q_argmax is only used for picking an action self._q_argmax_eval = tf.argmax(self._net_outputs.q_values, axis=1)[0] if self.use_deep_exploration: if self.transform_strategy.endswith('STOCHASTIC'): q_transform = atari_helpers.random_stochastic_matrix( self.num_networks, num_cols=1) self._q_episode_transform = tf.get_variable( trainable=False, dtype=tf.float32, shape=q_transform.get_shape().as_list(), name='q_episode_transform') self._update_episode_q_function = self._q_episode_transform.assign( q_transform) episode_q_function = tf.tensordot( self._net_outputs.unordered_q_networks, self._q_episode_transform, axes=[[2], [0]]) self._q_argmax_train = tf.argmax(episode_q_function[:, :, 0], axis=1)[0] elif self.transform_strategy == 'IDENTITY': self._q_function_index = tf.Variable( initial_value=0, trainable=False, dtype=tf.int32, shape=(), name='q_head_episode') self._update_episode_q_function = self._q_function_index.assign( tf.random.uniform( shape=(), maxval=self.num_networks, dtype=tf.int32)) q_function = self._net_outputs.unordered_q_networks[ :, :, self._q_function_index] # This is only used for picking an action self._q_argmax_train = tf.argmax(q_function, axis=1)[0] else: self._q_argmax_train = self._q_argmax_eval
Example #28
Source File: lib_tfsampling.py From magenta with Apache License 2.0 | 5 votes |
def sample_with_temperature(logits, temperature): """Either argmax after softmax or random sample along the pitch axis. Args: logits: a Tensor of shape (batch, time, pitch, instrument). temperature: a float 0.0=argmax 1.0=random Returns: a Tensor of the same shape, with one_hots on the pitch dimension. """ logits = tf.transpose(logits, [0, 1, 3, 2]) pitch_range = tf.shape(logits)[-1] def sample_from_logits(logits): with tf.control_dependencies([tf.assert_greater(temperature, 0.0)]): logits = tf.identity(logits) reshaped_logits = ( tf.reshape(logits, [-1, tf.shape(logits)[-1]]) / temperature) choices = tf.multinomial(reshaped_logits, 1) choices = tf.reshape(choices, tf.shape(logits)[:logits.get_shape().ndims - 1]) return choices choices = tf.cond(tf.equal(temperature, 0.0), lambda: tf.argmax(tf.nn.softmax(logits), -1), lambda: sample_from_logits(logits)) samples_onehot = tf.one_hot(choices, pitch_range) return tf.transpose(samples_onehot, [0, 1, 3, 2])
Example #29
Source File: util.py From magenta with Apache License 2.0 | 5 votes |
def one_hot_to_embedding(one_hot, embedding_size=None): """Gets a dense embedding vector from a one-hot encoding.""" num_tokens = int(one_hot.shape[1]) label_id = tf.argmax(one_hot, axis=1) if embedding_size is None: embedding_size = get_default_embedding_size(num_tokens) embedding = tf.get_variable( 'one_hot_embedding', [num_tokens, embedding_size], dtype=tf.float32) return tf.nn.embedding_lookup(embedding, label_id, name='token_to_embedding')
Example #30
Source File: nasnet_test.py From benchmarks with Apache License 2.0 | 5 votes |
def testEvaluationMobileModel(self): batch_size = 2 height, width = 224, 224 num_classes = 1000 with self.test_session() as sess: eval_inputs = tf.random_uniform((batch_size, height, width, 3)) with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()): logits, _ = nasnet.build_nasnet_mobile(eval_inputs, num_classes, is_training=False) predictions = tf.argmax(logits, 1) sess.run(tf.global_variables_initializer()) output = sess.run(predictions) self.assertEqual(output.shape, (batch_size,))