Python tensorflow.tables_initializer() Examples
The following are 30
code examples of tensorflow.tables_initializer().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: model_test.py From yolo_v2 with Apache License 2.0 | 6 votes |
def test_create_summaries_is_runnable(self): ocr_model = self.create_model() data = data_provider.InputEndpoints( images=self.fake_images, images_orig=self.fake_images, labels=self.fake_labels, labels_one_hot=slim.one_hot_encoding(self.fake_labels, self.num_char_classes)) endpoints = ocr_model.create_base( images=self.fake_images, labels_one_hot=None) charset = create_fake_charset(self.num_char_classes) summaries = ocr_model.create_summaries( data, endpoints, charset, is_training=False) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) tf.tables_initializer().run() sess.run(summaries) # just check it is runnable
Example #2
Source File: estimator_test.py From training_results_v0.5 with Apache License 2.0 | 6 votes |
def testTrainInputFn(self): nmt_parser = argparse.ArgumentParser() nmt.add_arguments(nmt_parser) flags, _ = nmt_parser.parse_known_args() update_flags(flags, "input_fn_test") default_hparams = nmt.create_hparams(flags) hparams = nmt.extend_hparams(default_hparams) with self.test_session() as sess: input_fn = make_input_fn(hparams, tf.contrib.learn.ModeKeys.TRAIN) outputs = input_fn({}) sess.run(tf.tables_initializer()) iterator = outputs.make_initializable_iterator() sess.run(iterator.initializer) features = sess.run(iterator.get_next()) tf.logging.info("source: %s", features["source"]) tf.logging.info("target_input: %s", features["target_input"]) tf.logging.info("target_output: %s", features["target_output"]) tf.logging.info("source_sequence_length: %s", features["source_sequence_length"]) tf.logging.info("target_sequence_length: %s", features["target_sequence_length"])
Example #3
Source File: utils.py From TransE-Knowledge-Graph-Embedding with MIT License | 6 votes |
def load_model(sess, ckpt): with sess.as_default(): with sess.graph.as_default(): init_ops = [tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()] sess.run(init_ops) # load saved model ckpt_path = tf.train.latest_checkpoint(ckpt) if ckpt_path: print("Loading saved model: " + ckpt_path) else: raise ValueError("No checkpoint found in {}".format(ckpt)) # reader = tf.train.NewCheckpointReader(ckpt+'model.ckpt_0.876-580500') # variables = reader.get_variable_to_shape_map() # for v in variables: # print(v) saver = tf.train.Saver() saver.restore(sess, ckpt_path)
Example #4
Source File: test_case.py From Person-Detection-and-Tracking with MIT License | 6 votes |
def execute_cpu(self, graph_fn, inputs): """Constructs the graph, executes it on CPU and returns the result. Args: graph_fn: a callable that constructs the tensorflow graph to test. The arguments of this function should correspond to `inputs`. inputs: a list of numpy arrays to feed input to the computation graph. Returns: A list of numpy arrays or a scalar returned from executing the tensorflow graph. """ with self.test_session(graph=tf.Graph()) as sess: placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs] results = graph_fn(*placeholders) sess.run([tf.global_variables_initializer(), tf.tables_initializer(), tf.local_variables_initializer()]) materialized_results = sess.run(results, feed_dict=dict(zip(placeholders, inputs))) if (len(materialized_results) == 1 and (isinstance(materialized_results, list) or isinstance(materialized_results, tuple))): materialized_results = materialized_results[0] return materialized_results
Example #5
Source File: estimator_test.py From training_results_v0.5 with Apache License 2.0 | 6 votes |
def testTrainInputFn(self): nmt_parser = argparse.ArgumentParser() nmt.add_arguments(nmt_parser) flags, _ = nmt_parser.parse_known_args() update_flags(flags, "input_fn_test") default_hparams = nmt.create_hparams(flags) hparams = nmt.extend_hparams(default_hparams) with self.test_session() as sess: input_fn = make_input_fn(hparams, tf.contrib.learn.ModeKeys.TRAIN) outputs = input_fn({}) sess.run(tf.tables_initializer()) iterator = outputs.make_initializable_iterator() sess.run(iterator.initializer) features = sess.run(iterator.get_next()) tf.logging.info("source: %s", features["source"]) tf.logging.info("target_input: %s", features["target_input"]) tf.logging.info("target_output: %s", features["target_output"]) tf.logging.info("source_sequence_length: %s", features["source_sequence_length"]) tf.logging.info("target_sequence_length: %s", features["target_sequence_length"])
Example #6
Source File: model_test.py From DOTA_models with Apache License 2.0 | 6 votes |
def test_create_summaries_is_runnable(self): ocr_model = self.create_model() data = data_provider.InputEndpoints( images=self.fake_images, images_orig=self.fake_images, labels=self.fake_labels, labels_one_hot=slim.one_hot_encoding(self.fake_labels, self.num_char_classes)) endpoints = ocr_model.create_base( images=self.fake_images, labels_one_hot=None) charset = create_fake_charset(self.num_char_classes) summaries = ocr_model.create_summaries( data, endpoints, charset, is_training=False) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) tf.tables_initializer().run() sess.run(summaries) # just check it is runnable
Example #7
Source File: utils.py From realmix with Apache License 2.0 | 6 votes |
def make_set_filter_fn(elements): """Constructs a TensorFlow "set" data structure. Note that sets returned by this function are uninitialized. Initialize them by calling `sess.run(tf.tables_initializer())` Args: elements: A list of non-Tensor elements. Returns: A function that when called with a single tensor argument, returns a boolean tensor if the argument is in the set. """ table = tf.contrib.lookup.HashTable( tf.contrib.lookup.KeyValueTensorInitializer( elements, tf.tile([1], [len(elements)]) ), default_value=0, ) return lambda x: tf.equal(table.lookup(tf.dtypes.cast(x, tf.int32)), 1)
Example #8
Source File: test_case.py From ros_people_object_detection_tensorflow with Apache License 2.0 | 6 votes |
def execute_cpu(self, graph_fn, inputs): """Constructs the graph, executes it on CPU and returns the result. Args: graph_fn: a callable that constructs the tensorflow graph to test. The arguments of this function should correspond to `inputs`. inputs: a list of numpy arrays to feed input to the computation graph. Returns: A list of numpy arrays or a scalar returned from executing the tensorflow graph. """ with self.test_session(graph=tf.Graph()) as sess: placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs] results = graph_fn(*placeholders) sess.run([tf.global_variables_initializer(), tf.tables_initializer(), tf.local_variables_initializer()]) materialized_results = sess.run(results, feed_dict=dict(zip(placeholders, inputs))) if (len(materialized_results) == 1 and (isinstance(materialized_results, list) or isinstance(materialized_results, tuple))): materialized_results = materialized_results[0] return materialized_results
Example #9
Source File: test_case.py From Traffic-Rule-Violation-Detection-System with MIT License | 6 votes |
def execute_cpu(self, graph_fn, inputs): """Constructs the graph, executes it on CPU and returns the result. Args: graph_fn: a callable that constructs the tensorflow graph to test. The arguments of this function should correspond to `inputs`. inputs: a list of numpy arrays to feed input to the computation graph. Returns: A list of numpy arrays or a scalar returned from executing the tensorflow graph. """ with self.test_session(graph=tf.Graph()) as sess: placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs] results = graph_fn(*placeholders) sess.run([tf.global_variables_initializer(), tf.tables_initializer(), tf.local_variables_initializer()]) materialized_results = sess.run(results, feed_dict=dict(zip(placeholders, inputs))) if len(materialized_results) == 1: materialized_results = materialized_results[0] return materialized_results
Example #10
Source File: test_case.py From Traffic-Rule-Violation-Detection-System with MIT License | 6 votes |
def execute_tpu(self, graph_fn, inputs): """Constructs the graph, executes it on TPU and returns the result. Args: graph_fn: a callable that constructs the tensorflow graph to test. The arguments of this function should correspond to `inputs`. inputs: a list of numpy arrays to feed input to the computation graph. Returns: A list of numpy arrays or a scalar returned from executing the tensorflow graph. """ with self.test_session(graph=tf.Graph()) as sess: placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs] tpu_computation = tpu.rewrite(graph_fn, placeholders) sess.run(tpu.initialize_system()) sess.run([tf.global_variables_initializer(), tf.tables_initializer(), tf.local_variables_initializer()]) materialized_results = sess.run(tpu_computation, feed_dict=dict(zip(placeholders, inputs))) sess.run(tpu.shutdown_system()) if len(materialized_results) == 1: materialized_results = materialized_results[0] return materialized_results
Example #11
Source File: model_helper.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def create_or_load_model(model, model_dir, session, name): """Create translation model and initialize or load parameters in session.""" latest_ckpt = tf.train.latest_checkpoint(model_dir) if latest_ckpt: model = load_model(model, latest_ckpt, session, name) else: start_time = time.time() session.run(tf.global_variables_initializer()) session.run(tf.tables_initializer()) utils.print_out(" created %s model with fresh parameters, time %.2fs" % (name, time.time() - start_time)) global_step = model.global_step.eval(session=session) return model, global_step
Example #12
Source File: estimator.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def _get_tgt_sos_eos_id(hparams): with tf.Session() as sess: _, tgt_vocab_table = vocab_utils.create_vocab_tables( hparams.src_vocab_file, hparams.tgt_vocab_file, hparams.share_vocab) tgt_sos_id = tf.cast( tgt_vocab_table.lookup(tf.constant(hparams.sos)), tf.int32) tgt_eos_id = tf.cast( tgt_vocab_table.lookup(tf.constant(hparams.eos)), tf.int32) sess.run(tf.tables_initializer()) tgt_sos_id = sess.run(tgt_sos_id, {}) tgt_eos_id = sess.run(tgt_eos_id, {}) return tgt_sos_id, tgt_eos_id
Example #13
Source File: iterator_utils_test.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def testGetInferIterator(self): src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c a", "c a", "d", "f e a g"])) hparams = tf.contrib.training.HParams( random_seed=3, eos="eos", sos="sos") batch_size = 2 src_max_len = 3 dataset = iterator_utils.get_infer_iterator( src_dataset=src_dataset, src_vocab_table=src_vocab_table, batch_size=batch_size, eos=hparams.eos, src_max_len=src_max_len) table_initializer = tf.tables_initializer() iterator = dataset.make_initializable_iterator() get_next = iterator.get_next() with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) features = sess.run(get_next) self.assertAllEqual( [ [2, 2, 0], # c c a [2, 0, 3] ], # c a eos features["source"]) self.assertAllEqual([3, 2], features["source_sequence_length"])
Example #14
Source File: estimator.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def _get_tgt_sos_eos_id(hparams): with tf.Session() as sess: _, tgt_vocab_table = vocab_utils.create_vocab_tables( hparams.src_vocab_file, hparams.tgt_vocab_file, hparams.share_vocab) tgt_sos_id = tf.cast( tgt_vocab_table.lookup(tf.constant(hparams.sos)), tf.int32) tgt_eos_id = tf.cast( tgt_vocab_table.lookup(tf.constant(hparams.eos)), tf.int32) sess.run(tf.tables_initializer()) tgt_sos_id = sess.run(tgt_sos_id, {}) tgt_eos_id = sess.run(tgt_eos_id, {}) return tgt_sos_id, tgt_eos_id
Example #15
Source File: estimator.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def _convert_ids_to_strings(tgt_vocab_file, ids): """Convert prediction ids to words.""" with tf.Session() as sess: reverse_target_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) sess.run(tf.tables_initializer()) translations = sess.run( reverse_target_vocab_table.lookup( tf.to_int64(tf.convert_to_tensor(np.asarray(ids))))) return translations
Example #16
Source File: process.py From DeepRNN with MIT License | 5 votes |
def load_model(model, ckpt_path, session, name): """Load model from a checkpoint.""" try: model.saver.restore(session, ckpt_path) except tf.errors.NotFoundError as e: utils.print_out("Can't load checkpoint") utils.print_out("%s" % str(e)) # session.run(tf.tables_initializer()) ## why table still need to be initialized even model loaded?? utils.print_out(" loaded %s model parameters from %s" % (name, ckpt_path)) return model
Example #17
Source File: iterator_utils_test.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def testGetInferIterator(self): src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c a", "c a", "d", "f e a g"])) hparams = tf.contrib.training.HParams( random_seed=3, eos="eos", sos="sos") batch_size = 2 src_max_len = 3 dataset = iterator_utils.get_infer_iterator( src_dataset=src_dataset, src_vocab_table=src_vocab_table, batch_size=batch_size, eos=hparams.eos, src_max_len=src_max_len) table_initializer = tf.tables_initializer() iterator = dataset.make_initializable_iterator() get_next = iterator.get_next() with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) features = sess.run(get_next) self.assertAllEqual( [ [2, 2, 0], # c c a [2, 0, 3] ], # c a eos features["source"]) self.assertAllEqual([3, 2], features["source_sequence_length"])
Example #18
Source File: iterator_utils_test.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def testGetInferIterator(self): src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c a", "c a", "d", "f e a g"])) hparams = tf.contrib.training.HParams( random_seed=3, eos="eos", sos="sos") batch_size = 2 dataset = iterator_utils.get_infer_iterator( src_dataset=src_dataset, src_vocab_table=src_vocab_table, batch_size=batch_size, eos=hparams.eos) table_initializer = tf.tables_initializer() iterator = dataset.make_initializable_iterator() get_next = iterator.get_next() with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) features = sess.run(get_next) self.assertAllEqual( [ [2, 2, 0], # c c a [2, 0, 3] ], # c a eos features["source"]) self.assertAllEqual([3, 2], features["source_sequence_length"])
Example #19
Source File: vocab.py From DualRL with MIT License | 5 votes |
def test_vocab(): import tensorflow as tf import numpy as np import os from common_options import load_common_arguments os.environ["CUDA_VISIBLE_DEVICES"] = '0' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Load global vocab args = load_common_arguments() global_vocab, global_vocab_size = load_vocab(args.global_vocab_file) vocab, vocab_size = load_vocab_dict(args.global_vocab_file) assert global_vocab_size == vocab_size with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) i = 0 ks = vocab.keys() vs = vocab.values() v1 = sess.run(global_vocab.lookup(tf.convert_to_tensor(ks))) for i in range(len(vs)): assert vs[i] == v1[i]
Example #20
Source File: tf_example_decoder_test.py From Traffic-Rule-Violation-Detection-System with MIT License | 5 votes |
def testDecodeObjectLabelNoText(self): image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) bbox_classes = [1, 2] example = tf.train.Example(features=tf.train.Features(feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'), 'image/object/class/label': self._Int64Feature(bbox_classes), })).SerializeToString() label_map_string = """ item { id:1 name:'cat' } item { id:2 name:'dog' } """ label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') with tf.gfile.Open(label_map_path, 'wb') as f: f.write(label_map_string) example_decoder = tf_example_decoder.TfExampleDecoder( label_map_proto_file=label_map_path) tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) self.assertAllEqual((tensor_dict[ fields.InputDataFields.groundtruth_classes].get_shape().as_list()), [None]) init = tf.tables_initializer() with self.test_session() as sess: sess.run(init) tensor_dict = sess.run(tensor_dict) self.assertAllEqual(bbox_classes, tensor_dict[fields.InputDataFields.groundtruth_classes])
Example #21
Source File: model_test.py From yolo_v2 with Apache License 2.0 | 5 votes |
def test_predicted_text_has_correct_shape_w_charset(self): charset = create_fake_charset(self.num_char_classes) ocr_model = self.create_model(charset=charset) with self.test_session() as sess: endpoints_tf = ocr_model.create_base( images=self.fake_images, labels_one_hot=None) sess.run(tf.global_variables_initializer()) tf.tables_initializer().run() endpoints = sess.run(endpoints_tf) self.assertEqual(endpoints.predicted_text.shape, (self.batch_size,)) self.assertEqual(len(endpoints.predicted_text[0]), self.seq_length)
Example #22
Source File: process.py From DeepRNN with MIT License | 5 votes |
def create_or_load_model(model, model_dir, session, name): """Create translation model and initialize or load parameters in session.""" latest_ckpt = tf.train.latest_checkpoint(model_dir) if latest_ckpt: model._replace(model=load_model(model.model, latest_ckpt, session, name)) utils.print_out("checkpoint found, load checkpoint\n %s" % latest_ckpt) else: utils.print_out(" checkpoint not found in %s" % (model_dir)) utils.print_out(" created %s model with fresh parameters" % (name)) session.run(tf.global_variables_initializer()) # session.run(tf.tables_initializer()) global_step = model.model.global_step.eval(session=session) epoch_num = model.model.epoch_num.eval(session=session) return model, global_step, epoch_num
Example #23
Source File: train.py From TransE-Knowledge-Graph-Embedding with MIT License | 5 votes |
def train(): # Training with tf.Session() as sess: init_ops = [tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()] sess.run(init_ops) writer = tf.summary.FileWriter("summary", sess.graph) # graph for epoch in range(FLAGS.max_epoch): sess.run(iterator.initializer) model.train(sess) if not os.path.exists(FLAGS.model_dir): os.mkdir(FLAGS.model_dir) save_path = os.path.join(FLAGS.model_dir, "model.ckpt") model.save(sess, save_path) print('-----Start training-----') epoch_loss = 0.0 step = 0 while True: try: batch_loss, _, summary = model.train(sess) epoch_loss += batch_loss step += 1 writer.add_summary(summary) except tf.errors.OutOfRangeError: print('-----Finish training an epoch avg epoch loss={}-----'.format(epoch_loss / step)) break # show train batch metrics if step % FLAGS.stats_per_steps == 0: time_str = datetime.datetime.now().isoformat() print('{}\tepoch {:2d}\tstep {:3d}\ttrain loss={:.6f}'.format( time_str, epoch + 1, step, batch_loss)) if (epoch+1) % FLAGS.save_per_epochs == 0: if not os.path.exists(FLAGS.model_dir): os.mkdir(FLAGS.model_dir) save_path = os.path.join(FLAGS.model_dir, "model.ckpt") model.save(sess, save_path) print("Epoch {}, saved checkpoint to {}".format(epoch+1, save_path))
Example #24
Source File: tfr2wav.py From vqvae-speech with MIT License | 5 votes |
def main(_): tf.gfile.MkDir(args.output_dir) data = ByteWavWholeReader( speaker_list=txt2list(args.speaker_list), filenames=tf.gfile.Glob(args.file_pattern), num_epoch=1) XNOM = data.f[0] XWAV = tf.expand_dims(mu_law_decode(data.x[0, :]), -1) XBIN = tf.contrib.ffmpeg.encode_audio(XWAV, 'wav', 16000) sess_config = tf.ConfigProto( allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True)) with tf.Session(config=sess_config) as sess: sess.run(tf.tables_initializer()) sess.run(data.iterator.initializer) csv = open('vctk.csv', 'w') counter = 1 while True: try: fetch = {'xbin': XBIN, 'xwav': XWAV, 'wav_name': XNOM} result = sess.run(fetch) wav_name = result['wav_name'].decode('utf8') print('\rFile {:05d}: Processing {}'.format(counter, wav_name), end='') csv.write('{}, {:d}\n'.format(wav_name, len(result['xwav']))) filename = os.path.join(args.output_dir, wav_name) + '.wav' with open(filename, 'wb') as fp: fp.write(result['xbin']) counter += 1 except tf.errors.OutOfRangeError: print('\nEpoch complete') break print() csv.close()
Example #25
Source File: model_helper.py From nslt with Apache License 2.0 | 5 votes |
def create_or_load_model(model, model_dir, session, name): """Create translation model and initialize or load parameters in session.""" latest_ckpt = tf.train.latest_checkpoint(model_dir) if latest_ckpt: model = load_model(model, latest_ckpt, session, name) else: start_time = time.time() session.run(tf.global_variables_initializer()) session.run(tf.tables_initializer()) utils.print_out(" created %s model with fresh parameters, time %.2fs" % (name, time.time() - start_time)) global_step = model.global_step.eval(session=session) return model, global_step
Example #26
Source File: model_helper.py From nslt with Apache License 2.0 | 5 votes |
def load_model(model, ckpt, session, name): start_time = time.time() model.saver.restore(session, ckpt) session.run(tf.tables_initializer()) utils.print_out(" loaded %s model parameters from %s, time %.2fs" % (name, ckpt, time.time() - start_time)) return model
Example #27
Source File: nmt.py From DualRL with MIT License | 5 votes |
def create_model(sess, args, src_vocab_size, tgt_vocab_size, src_vocab_rev, tgt_vocab_rev, mode=constants.TRAIN, reuse=None, load_pretrained_model=False, direction="", model_save_dir=None): sess.run(tf.tables_initializer()) with tf.variable_scope(constants.NMT_VAR_SCOPE + direction, reuse=reuse): with tf.variable_scope("src"): src_emb = tf.get_variable("embedding", shape=[src_vocab_size, args.emb_dim]) with tf.variable_scope("dst"): tgt_emb = tf.get_variable("embedding", shape=[tgt_vocab_size, args.emb_dim]) model = NMT(mode, args.__dict__, src_vocab_size, tgt_vocab_size, src_emb, tgt_emb, src_vocab_rev, tgt_vocab_rev, direction) if load_pretrained_model: if model_save_dir is None: model_save_dir = args.nmt_model_save_dir if direction not in model_save_dir: if direction[::-1] in model_save_dir: model_save_dir = re.sub(direction[::-1], direction, model_save_dir) else: model_save_dir = os.path.join(model_save_dir, direction) print(model_save_dir) try: print("Loading nmt model from", model_save_dir) model.saver.restore(sess, model_save_dir) except Exception as e: print("Error! Loading nmt model from", model_save_dir) print("Again! Loading nmt model from", tf.train.latest_checkpoint(model_save_dir)) model.saver.restore(sess, tf.train.latest_checkpoint(model_save_dir)) else: if reuse is None: print("Creating model with new parameters.") sess.run(tf.global_variables_initializer()) else: print("Reuse parameters.") return model
Example #28
Source File: dataset_util_test.py From Person-Detection-and-Tracking with MIT License | 5 votes |
def test_make_initializable_iterator_with_hashTable(self): keys = [1, 0, -1] dataset = tf.data.Dataset.from_tensor_slices([[1, 2, -1, 5]]) table = tf.contrib.lookup.HashTable( initializer=tf.contrib.lookup.KeyValueTensorInitializer( keys=keys, values=list(reversed(keys))), default_value=100) dataset = dataset.map(table.lookup) data = dataset_util.make_initializable_iterator(dataset).get_next() init = tf.tables_initializer() with self.test_session() as sess: sess.run(init) self.assertAllEqual(sess.run(data), [-1, 100, 1, 100])
Example #29
Source File: tf_example_decoder_test.py From Person-Detection-and-Tracking with MIT License | 5 votes |
def testDecodeObjectLabelWithMapping(self): image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) bbox_classes_text = ['cat', 'dog'] example = tf.train.Example( features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'), 'image/object/class/text': self._BytesFeature(bbox_classes_text), })).SerializeToString() label_map_string = """ item { id:3 name:'cat' } item { id:1 name:'dog' } """ label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') with tf.gfile.Open(label_map_path, 'wb') as f: f.write(label_map_string) example_decoder = tf_example_decoder.TfExampleDecoder( label_map_proto_file=label_map_path) tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_classes] .get_shape().as_list()), [None]) with self.test_session() as sess: sess.run(tf.tables_initializer()) tensor_dict = sess.run(tensor_dict) self.assertAllEqual([3, 1], tensor_dict[fields.InputDataFields.groundtruth_classes])
Example #30
Source File: tf_example_decoder_test.py From Person-Detection-and-Tracking with MIT License | 5 votes |
def testDecodeObjectLabelUnrecognizedName(self): image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) encoded_jpeg = self._EncodeImage(image_tensor) bbox_classes_text = ['cat', 'cheetah'] example = tf.train.Example( features=tf.train.Features( feature={ 'image/encoded': self._BytesFeature(encoded_jpeg), 'image/format': self._BytesFeature('jpeg'), 'image/object/class/text': self._BytesFeature(bbox_classes_text), })).SerializeToString() label_map_string = """ item { id:2 name:'cat' } item { id:1 name:'dog' } """ label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') with tf.gfile.Open(label_map_path, 'wb') as f: f.write(label_map_string) example_decoder = tf_example_decoder.TfExampleDecoder( label_map_proto_file=label_map_path) tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_classes] .get_shape().as_list()), [None]) with self.test_session() as sess: sess.run(tf.tables_initializer()) tensor_dict = sess.run(tensor_dict) self.assertAllEqual([2, -1], tensor_dict[fields.InputDataFields.groundtruth_classes])