Python tensorflow.compat.v1.py_func() Examples
The following are 30
code examples of tensorflow.compat.v1.py_func().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.compat.v1
, or try the search function
.
Example #1
Source File: visualization.py From tensor2robot with Apache License 2.0 | 6 votes |
def tf_put_text(imgs, texts, text_size=1, text_pos=(0, 30), text_color=(0, 0, 1)): """Adds text to an image tensor.""" def _put_text(imgs, texts): """Python function that renders text onto a image.""" result = np.empty_like(imgs) for i in range(imgs.shape[0]): text = texts[i] if isinstance(text, bytes): text = six.ensure_text(text) # You may need to adjust text size and position and size. # If your images are in [0, 255] range replace (0, 0, 1) with (0, 0, 255) result[i, :, :, :] = cv2.putText( imgs[i, :, :, :], str(text), text_pos, cv2.FONT_HERSHEY_COMPLEX, text_size, text_color, 1) return result return tf.py_func(_put_text, [imgs, texts], Tout=imgs.dtype)
Example #2
Source File: metrics.py From models with Apache License 2.0 | 6 votes |
def rouge_l_fscore(predictions, labels): """ROUGE scores computation between labels and predictions. This is an approximate ROUGE scoring method since we do not glue word pieces or decode the ids and tokenize the output. Args: predictions: tensor, model predictions labels: tensor, gold output. Returns: rouge_l_fscore: approx rouge-l f1 score. """ outputs = tf.to_int32(tf.argmax(predictions, axis=-1)) rouge_l_f_score = tf.py_func(rouge_l_sentence_level, (outputs, labels), tf.float32) return rouge_l_f_score, tf.constant(1.0)
Example #3
Source File: metrics.py From models with Apache License 2.0 | 6 votes |
def rouge_2_fscore(logits, labels): """ROUGE-2 F1 score computation between labels and predictions. This is an approximate ROUGE scoring method since we do not glue word pieces or decode the ids and tokenize the output. Args: logits: tensor, model predictions labels: tensor, gold output. Returns: rouge2_fscore: approx rouge-2 f1 score. """ predictions = tf.to_int32(tf.argmax(logits, axis=-1)) # TODO: Look into removing use of py_func rouge_2_f_score = tf.py_func(rouge_n, (predictions, labels), tf.float32) return rouge_2_f_score, tf.constant(1.0)
Example #4
Source File: metrics.py From models with Apache License 2.0 | 6 votes |
def bleu_score(logits, labels): """Approximate BLEU score computation between labels and predictions. An approximate BLEU scoring method since we do not glue word pieces or decode the ids and tokenize the output. By default, we use ngram order of 4 and use brevity penalty. Also, this does not have beam search. Args: logits: Tensor of size [batch_size, length_logits, vocab_size] labels: Tensor of size [batch-size, length_labels] Returns: bleu: int, approx bleu score """ predictions = tf.to_int32(tf.argmax(logits, axis=-1)) # TODO: Look into removing use of py_func bleu = tf.py_func(compute_bleu, (labels, predictions), tf.float32) return bleu, tf.constant(1.0)
Example #5
Source File: tensor_utils.py From language with Apache License 2.0 | 6 votes |
def shaped_py_func(func, inputs, types, shapes, stateful=True, name=None): """Wrapper around tf.py_func that adds static shape information to the output. Args: func: Python function to call. inputs: List of input tensors. types: List of output tensor types. shapes: List of output tensor shapes. stateful: Whether or not the python function is stateful. name: Name of the op. Returns: output_tensors: List of output tensors. """ output_tensors = tf.py_func( func=func, inp=inputs, Tout=types, stateful=stateful, name=name) for t, s in zip(output_tensors, shapes): t.set_shape(s) return output_tensors
Example #6
Source File: common.py From language with Apache License 2.0 | 6 votes |
def print_text(tf_sequences, vocab, use_bpe=False, predict_mode=False): """Print text.""" def _print_separator(): if not predict_mode: tf.logging.info("=" * 80) print_ops = [tf.py_func(_print_separator, [], [])] for name, tf_sequence, tf_length, convert2txt in tf_sequences: def _do_print(n, sequence, lengths, to_txt): if to_txt: s = sequence[0][:lengths[0]] output = id2text(s, vocab, use_bpe=use_bpe) else: output = " ".join(sequence[0]) if not predict_mode: tf.logging.info("%s: %s", n, output) with tf.control_dependencies(print_ops): print_ops.append(tf.py_func( _do_print, [name, tf_sequence, tf_length, convert2txt], [])) with tf.control_dependencies(print_ops): return tf.py_func(_print_separator, [], [])
Example #7
Source File: robust_model.py From interval-bound-propagation with Apache License 2.0 | 6 votes |
def parse(data_dict): """Parse dataset from _data_gen into the same format as sst_binary.""" sentiment = data_dict['label'] sentence = data_dict['sentence'] dense_chars = tf.decode_raw(sentence, tf.uint8) dense_chars.set_shape((None,)) chars = tfp.math.dense_to_sparse(dense_chars) if six.PY3: safe_chr = lambda c: '?' if c >= 128 else chr(c) else: safe_chr = chr to_char = np.vectorize(safe_chr) chars = tf.SparseTensor(indices=chars.indices, values=tf.py_func(to_char, [chars.values], tf.string), dense_shape=chars.dense_shape) return {'sentiment': sentiment, 'sentence': chars}
Example #8
Source File: anchors.py From Object_Detection_Tracking with Apache License 2.0 | 6 votes |
def generate_detections(self, cls_outputs, box_outputs, indices, classes, image_id, image_scale, level_index, min_score_thresh, max_boxes_to_draw, use_tf=False): if use_tf: return _generate_detections_tf( cls_outputs, box_outputs, self._anchors.boxes, indices, classes, image_id, image_scale, level_index, min_score_thresh=min_score_thresh, max_boxes_to_draw=max_boxes_to_draw) else: return tf.py_func(_generate_detections, [ cls_outputs, box_outputs, self._anchors.boxes, indices, classes, image_id, image_scale, self._num_classes, level_index, #image_id, image_scale, self._target_classes, level_index, ], [tf.float32, tf.float32, tf.float32, tf.float32])
Example #9
Source File: data.py From magenta with Apache License 2.0 | 6 votes |
def transform_wav_data_op(wav_data_tensor, hparams, jitter_amount_sec): """Transforms with audio for data augmentation. Only for training.""" def transform_wav_data(wav_data): """Transforms with sox.""" if jitter_amount_sec: wav_data = audio_io.jitter_wav_data(wav_data, hparams.sample_rate, jitter_amount_sec) wav_data = audio_transform.transform_wav_audio(wav_data, hparams) return [wav_data] return tf.py_func( transform_wav_data, [wav_data_tensor], tf.string, name='transform_wav_data_op')
Example #10
Source File: rouge.py From tensor2tensor with Apache License 2.0 | 6 votes |
def rouge_2_fscore(predictions, labels, **unused_kwargs): """ROUGE-2 F1 score computation between labels and predictions. This is an approximate ROUGE scoring method since we do not glue word pieces or decode the ids and tokenize the output. Args: predictions: tensor, model predictions labels: tensor, gold output. Returns: rouge2_fscore: approx rouge-2 f1 score. """ outputs = tf.to_int32(tf.argmax(predictions, axis=-1)) # Convert the outputs and labels to a [batch_size, input_length] tensor. outputs = tf.squeeze(outputs, axis=[-1, -2]) labels = tf.squeeze(labels, axis=[-1, -2]) rouge_2_f_score = tf.py_func(rouge_n, (outputs, labels), tf.float32) return rouge_2_f_score, tf.constant(1.0)
Example #11
Source File: bleu_hook.py From tensor2tensor with Apache License 2.0 | 6 votes |
def bleu_score(predictions, labels, **unused_kwargs): """BLEU score computation between labels and predictions. An approximate BLEU scoring method since we do not glue word pieces or decode the ids and tokenize the output. By default, we use ngram order of 4 and use brevity penalty. Also, this does not have beam search. Args: predictions: tensor, model predictions labels: tensor, gold output. Returns: bleu: int, approx bleu score """ outputs = tf.to_int32(tf.argmax(predictions, axis=-1)) # Convert the outputs and labels to a [batch_size, input_length] tensor. outputs = tf.squeeze(outputs, axis=[-1, -2]) labels = tf.squeeze(labels, axis=[-1, -2]) bleu = tf.py_func(compute_bleu, (labels, outputs), tf.float32) return bleu, tf.constant(1.0)
Example #12
Source File: decoding.py From tensor2tensor with Apache License 2.0 | 6 votes |
def make_input_fn_from_generator(gen): """Use py_func to yield elements from the given generator.""" first_ex = six.next(gen) flattened = contrib.framework().nest.flatten(first_ex) types = [t.dtype for t in flattened] shapes = [[None] * len(t.shape) for t in flattened] first_ex_list = [first_ex] def py_func(): if first_ex_list: example = first_ex_list.pop() else: example = six.next(gen) return contrib.framework().nest.flatten(example) def input_fn(): flat_example = tf.py_func(py_func, [], types) _ = [t.set_shape(shape) for t, shape in zip(flat_example, shapes)] example = contrib.framework().nest.pack_sequence_as(first_ex, flat_example) return example return input_fn
Example #13
Source File: visualization_utils.py From models with Apache License 2.0 | 5 votes |
def add_cdf_image_summary(values, name): """Adds a tf.summary.image for a CDF plot of the values. Normalizes `values` such that they sum to 1, plots the cumulative distribution function and creates a tf image summary. Args: values: a 1-D float32 tensor containing the values. name: name for the image summary. """ def cdf_plot(values): """Numpy function to plot CDF.""" normalized_values = values / np.sum(values) sorted_values = np.sort(normalized_values) cumulative_values = np.cumsum(sorted_values) fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32) / cumulative_values.size) fig = plt.figure(frameon=False) ax = fig.add_subplot('111') ax.plot(fraction_of_examples, cumulative_values) ax.set_ylabel('cumulative normalized values') ax.set_xlabel('fraction of examples') fig.canvas.draw() width, height = fig.get_size_inches() * fig.get_dpi() image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape( 1, int(height), int(width), 3) return image cdf_plot = tf.py_func(cdf_plot, [values], tf.uint8) tf.summary.image(name, cdf_plot)
Example #14
Source File: search_utils.py From language with Apache License 2.0 | 5 votes |
def write_to_checkpoint(var_name, np_db, dtype, checkpoint_path): """Write np array to checkpoint.""" with tf.Graph().as_default(): init_value = tf.py_func(lambda: np_db, [], dtype, stateful=False) init_value.set_shape(np_db.shape) tf_db = tf.get_variable(var_name, initializer=init_value) saver = tf.train.Saver([tf_db]) with tf.Session() as session: session.run(tf.global_variables_initializer()) saver.save(session, checkpoint_path)
Example #15
Source File: search_utils.py From language with Apache License 2.0 | 5 votes |
def write_ragged_to_checkpoint(var_name, sp_mat, checkpoint_path): """Write scipy CSR matrix to checkpoint for loading to ragged tensor.""" data = sp_mat.data indices = sp_mat.indices rowsplits = sp_mat.indptr with tf.Graph().as_default(): init_data = tf.py_func( lambda: data.astype(np.float32), [], tf.float32, stateful=False) init_data.set_shape(data.shape) init_indices = tf.py_func( lambda: indices.astype(np.int64), [], tf.int64, stateful=False) init_indices.set_shape(indices.shape) init_rowsplits = tf.py_func( lambda: rowsplits.astype(np.int64), [], tf.int64, stateful=False) init_rowsplits.set_shape(rowsplits.shape) tf_data = tf.get_variable(var_name + "_data", initializer=init_data) tf_indices = tf.get_variable( var_name + "_indices", initializer=init_indices) tf_rowsplits = tf.get_variable( var_name + "_rowsplits", initializer=init_rowsplits) saver = tf.train.Saver([tf_data, tf_indices, tf_rowsplits]) with tf.Session() as session: session.run(tf.global_variables_initializer()) saver.save(session, checkpoint_path) with tf.gfile.Open(checkpoint_path + ".info", "w") as f: f.write(str(sp_mat.shape[0]) + " " + str(sp_mat.getnnz()))
Example #16
Source File: search_utils.py From language with Apache License 2.0 | 5 votes |
def write_sparse_to_checkpoint(var_name, sp_mat, checkpoint_path): """Write scipy sparse CSR matrix to checkpoint.""" sp_mat = sp_mat.tocoo() # Sort the indices lexicographically. sort_i = np.lexsort((sp_mat.col, sp_mat.row)) indices = np.mat([sp_mat.row[sort_i], sp_mat.col[sort_i]]).transpose() data = sp_mat.data[sort_i] with tf.Graph().as_default(): init_data = tf.py_func( lambda: data.astype(np.float32), [], tf.float32, stateful=False) init_data.set_shape(data.shape) init_indices = tf.py_func( lambda: indices.astype(np.int64), [], tf.int64, stateful=False) init_indices.set_shape(indices.shape) init_shape = tf.py_func( lambda: np.array(sp_mat.shape, dtype=np.int64), [], tf.int64, stateful=False) init_shape.set_shape([len(sp_mat.shape)]) tf_data = tf.get_variable(var_name + "_data", initializer=init_data) tf_indices = tf.get_variable( var_name + "_indices", initializer=init_indices) tf_shape = tf.get_variable(var_name + "_shape", initializer=init_shape) saver = tf.train.Saver([tf_data, tf_indices, tf_shape]) with tf.Session() as session: session.run(tf.global_variables_initializer()) saver.save(session, checkpoint_path) with tf.gfile.Open(checkpoint_path + ".info", "w") as f: f.write(str(sp_mat.getnnz()))
Example #17
Source File: rouge_utils.py From language with Apache License 2.0 | 5 votes |
def rouge_l_metric(predictions, prediction_len, labels, label_len, vocab, use_bpe=False): return tf.metrics.mean(tf.py_func( partial(rouge_l, vocab=vocab, use_bpe=use_bpe), [predictions, prediction_len, labels, label_len], tf.float32))
Example #18
Source File: simulated_batch_env.py From tensor2tensor with Apache License 2.0 | 5 votes |
def _reset_non_empty(self, indices): """Reset the batch of environments. Args: indices: The batch indices of the environments to reset; defaults to all. Returns: Batch tensor of the new observations. """ reset_video_op = tf.cond( self._video_condition, lambda: tf.py_func(self._video_reset_writer, [], []), tf.no_op) with tf.control_dependencies([reset_video_op]): inc_op = tf.assign_add(self._episode_counter, 1) with tf.control_dependencies([self.history_buffer.reset(indices), inc_op]): initial_frame_dump_op = tf.cond( self._video_condition, lambda: tf.py_func(self._video_dump_frames, # pylint: disable=g-long-lambda [self.history_buffer.get_all_elements()], []), tf.no_op) observ_assign_op = self._observ.assign( self.history_buffer.get_all_elements()[:, -1, ...]) with tf.control_dependencies([observ_assign_op, initial_frame_dump_op]): reset_model_op = tf.assign(self._reset_model, tf.constant(1.0)) with tf.control_dependencies([reset_model_op]): return tf.gather(self._observ.read_value(), indices)
Example #19
Source File: ops.py From language with Apache License 2.0 | 5 votes |
def lowercase_op(string_tensor): """Lowercase an arbitrarily sized string tensor.""" shape = tf.shape(string_tensor) lc = tf.py_func(_lowercase, [tf.reshape(string_tensor, [-1])], tf.string, False) return tf.reshape(lc, shape)
Example #20
Source File: common_video.py From tensor2tensor with Apache License 2.0 | 5 votes |
def gif_summary(name, tensor, max_outputs=3, fps=10, collections=None, family=None): """Outputs a `Summary` protocol buffer with gif animations. Args: name: Name of the summary. tensor: A 5-D `uint8` `Tensor` of shape `[batch_size, time, height, width, channels]` where `channels` is 1 or 3. max_outputs: Max number of batch elements to generate gifs for. fps: frames per second of the animation collections: Optional list of tf.GraphKeys. The collections to add the summary to. Defaults to [tf.GraphKeys.SUMMARIES] family: Optional; if provided, used as the prefix of the summary tag name, which controls the tab name used for display on Tensorboard. Returns: A scalar `Tensor` of type `string`. The serialized `Summary` protocol buffer. Raises: ValueError: if the given tensor has the wrong shape. """ tensor = tf.convert_to_tensor(tensor) if len(tensor.get_shape()) != 5: raise ValueError("Assuming videos given as tensors in the format " "[batch, time, height, width, channels] but got one " "of shape: %s" % str(tensor.get_shape())) tensor = tf.cast(tensor, tf.uint8) if distribute_summary_op_util.skip_summary(): return tf.constant("") with summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): val = tf.py_func( py_gif_summary, [tag, tensor, max_outputs, fps], tf.string, stateful=False, name=scope) summary_op_util.collect(val, collections, [tf.GraphKeys.SUMMARIES]) return val
Example #21
Source File: visualization_utils.py From models with Apache License 2.0 | 5 votes |
def add_hist_image_summary(values, bins, name): """Adds a tf.summary.image for a histogram plot of the values. Plots the histogram of values and creates a tf image summary. Args: values: a 1-D float32 tensor containing the values. bins: bin edges which will be directly passed to np.histogram. name: name for the image summary. """ def hist_plot(values, bins): """Numpy function to plot hist.""" fig = plt.figure(frameon=False) ax = fig.add_subplot('111') y, x = np.histogram(values, bins=bins) ax.plot(x[:-1], y) ax.set_ylabel('count') ax.set_xlabel('value') fig.canvas.draw() width, height = fig.get_size_inches() * fig.get_dpi() image = np.fromstring( fig.canvas.tostring_rgb(), dtype='uint8').reshape( 1, int(height), int(width), 3) return image hist_plot = tf.py_func(hist_plot, [values, bins], tf.uint8) tf.summary.image(name, hist_plot)
Example #22
Source File: object_detection_evaluation.py From models with Apache License 2.0 | 5 votes |
def get_estimator_eval_metric_ops(self, eval_dict): """Returns dict of metrics to use with `tf.estimator.EstimatorSpec`. Note that this must only be implemented if performing evaluation with a `tf.estimator.Estimator`. Args: eval_dict: A dictionary that holds tensors for evaluating an object detection model, returned from eval_util.result_dict_for_single_example(). It must contain standard_fields.InputDataFields.key. Returns: A dictionary of metric names to tuple of value_op and update_op that can be used as eval metric ops in `tf.estimator.EstimatorSpec`. """ update_op = self.add_eval_dict(eval_dict) def first_value_func(): self._metrics = self.evaluate() self.clear() return np.float32(self._metrics[self._metric_names[0]]) def value_func_factory(metric_name): def value_func(): return np.float32(self._metrics[metric_name]) return value_func # Ensure that the metrics are only evaluated once. first_value_op = tf.py_func(first_value_func, [], tf.float32) eval_metric_ops = {self._metric_names[0]: (first_value_op, update_op)} with tf.control_dependencies([first_value_op]): for metric_name in self._metric_names[1:]: eval_metric_ops[metric_name] = (tf.py_func( value_func_factory(metric_name), [], np.float32), update_op) return eval_metric_ops
Example #23
Source File: metrics.py From models with Apache License 2.0 | 5 votes |
def get_eval_metrics(logits, labels, params): """Return dictionary of model evaluation metrics.""" metrics = { "accuracy": _convert_to_eval_metric(padded_accuracy)(logits, labels), "accuracy_top5": _convert_to_eval_metric(padded_accuracy_top5)( logits, labels), "accuracy_per_sequence": _convert_to_eval_metric( padded_sequence_accuracy)(logits, labels), "neg_log_perplexity": _convert_to_eval_metric(padded_neg_log_perplexity)( logits, labels, params["vocab_size"]), } if not params["use_tpu"]: # TPU does not support tf.py_func metrics.update({ "approx_bleu_score": _convert_to_eval_metric( bleu_score)(logits, labels), "rouge_2_fscore": _convert_to_eval_metric( rouge_2_fscore)(logits, labels), "rouge_L_fscore": _convert_to_eval_metric( rouge_l_fscore)(logits, labels), }) # Prefix each of the metric names with "metrics/". This allows the metric # graphs to display under the "metrics" category in TensorBoard. metrics = {"metrics/%s" % k: v for k, v in six.iteritems(metrics)} return metrics
Example #24
Source File: simulated_batch_env.py From tensor2tensor with Apache License 2.0 | 5 votes |
def __init__(self, initial_frame_chooser, observ_shape, observ_dtype, num_initial_frames, batch_size): self.batch_size = batch_size self._observ_dtype = observ_dtype initial_shape = (batch_size, num_initial_frames) + observ_shape self._initial_frames = tf.py_func( initial_frame_chooser, [tf.constant(batch_size)], observ_dtype ) self._initial_frames.set_shape(initial_shape) self._history_buff = tf.Variable(tf.zeros(initial_shape, observ_dtype), trainable=False)
Example #25
Source File: py_func_batch_env.py From tensor2tensor with Apache License 2.0 | 5 votes |
def _reset_non_empty(self, indices): """Reset the batch of environments. Args: indices: The batch indices of the environments to reset; defaults to all. Returns: Batch tensor of the new observations. """ observ = tf.py_func( self._batch_env.reset, [indices], self.observ_dtype, name="reset") observ.set_shape(indices.get_shape().concatenate(self.observ_shape)) with tf.control_dependencies([ tf.scatter_update(self._observ, indices, observ)]): return tf.identity(observ)
Example #26
Source File: py_func_batch_env.py From tensor2tensor with Apache License 2.0 | 5 votes |
def simulate(self, action): """Step the batch of environments. The results of the step can be accessed from the variables defined below. Args: action: Tensor holding the batch of actions to apply. Returns: Operation. """ with tf.name_scope("environment/simulate"): if action.dtype in (tf.float16, tf.float32, tf.float64): action = tf.check_numerics(action, "action") def step(action): step_response = self._batch_env.step(action) # Current env doesn't return `info`, but EnvProblem does. # TODO(afrozm): The proper way to do this is to make T2TGymEnv return # an empty info return value. if len(step_response) == 3: (observ, reward, done) = step_response else: (observ, reward, done, _) = step_response return (observ, reward.astype(np.float32), done) observ, reward, done = tf.py_func( step, [action], [self.observ_dtype, tf.float32, tf.bool], name="step") reward = tf.check_numerics(reward, "reward") reward.set_shape((len(self),)) done.set_shape((len(self),)) with tf.control_dependencies([self._observ.assign(observ)]): return tf.identity(reward), tf.identity(done)
Example #27
Source File: utils.py From magenta with Apache License 2.0 | 5 votes |
def tf_specgram(audio, n_fft=512, hop_length=None, mask=True, log_mag=True, re_im=False, dphase=True, mag_only=False): """Specgram tensorflow op (uses pyfunc).""" return tf.py_func(batch_specgram, [ audio, n_fft, hop_length, mask, log_mag, re_im, dphase, mag_only ], tf.float32)
Example #28
Source File: sari_hook.py From tensor2tensor with Apache License 2.0 | 5 votes |
def get_sari(source_ids, prediction_ids, target_ids, max_gram_size=4): """Computes the SARI scores from the given source, prediction and targets. Args: source_ids: A 2D tf.Tensor of size (batch_size , sequence_length) prediction_ids: A 2D tf.Tensor of size (batch_size, sequence_length) target_ids: A 3D tf.Tensor of size (batch_size, number_of_targets, sequence_length) max_gram_size: int. largest n-gram size we care about (e.g. 3 for unigrams, bigrams, and trigrams) Returns: A 4-tuple of 1D float Tensors of size (batch_size) for the SARI score and the keep, addition and deletion scores. """ def get_sari_numpy(source_ids, prediction_ids, target_ids): """Iterate over elements in the batch and call the SARI function.""" sari_scores = [] keep_scores = [] add_scores = [] deletion_scores = [] # Iterate over elements in the batch. for source_ids_i, prediction_ids_i, target_ids_i in zip( source_ids, prediction_ids, target_ids): sari, keep, add, deletion = get_sari_score( source_ids_i, prediction_ids_i, target_ids_i, max_gram_size, BETA_FOR_SARI_DELETION_F_MEASURE) sari_scores.append(sari) keep_scores.append(keep) add_scores.append(add) deletion_scores.append(deletion) return (np.asarray(sari_scores), np.asarray(keep_scores), np.asarray(add_scores), np.asarray(deletion_scores)) sari, keep, add, deletion = tf.py_func( get_sari_numpy, [source_ids, prediction_ids, target_ids], [tf.float64, tf.float64, tf.float64, tf.float64]) return sari, keep, add, deletion
Example #29
Source File: data.py From magenta with Apache License 2.0 | 5 votes |
def wav_to_spec_op(wav_audio, hparams): """Return an op for converting wav audio to a spectrogram.""" if hparams.spec_type == 'tflite_compat_mel': assert hparams.spec_log_amplitude spec = tflite_compat_mel(wav_audio, hparams=hparams) else: spec = tf.py_func( functools.partial(wav_to_spec, hparams=hparams), [wav_audio], tf.float32, name='wav_to_spec') spec.set_shape([None, hparams_frame_size(hparams)]) return spec
Example #30
Source File: data.py From magenta with Apache License 2.0 | 5 votes |
def get_spectrogram_hash_op(spectrogram): """Calculate hash of the spectrogram.""" def get_spectrogram_hash(spectrogram): # Compute a hash of the spectrogram, save it as an int64. # Uses adler because it's fast and will fit into an int (md5 is too large). spectrogram_serialized = io.BytesIO() np.save(spectrogram_serialized, spectrogram) spectrogram_hash = np.int64(zlib.adler32(spectrogram_serialized.getvalue())) spectrogram_serialized.close() return spectrogram_hash spectrogram_hash = tf.py_func(get_spectrogram_hash, [spectrogram], tf.int64, name='get_spectrogram_hash') spectrogram_hash.set_shape([]) return spectrogram_hash