Python object_detection.protos.input_reader_pb2.InputReader() Examples

The following are 30 code examples of object_detection.protos.input_reader_pb2.InputReader(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module object_detection.protos.input_reader_pb2 , or try the search function .
Example #1
Source File: eval.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def get_configs_from_pipeline_file():
  """Reads evaluation configuration from a pipeline_pb2.TrainEvalPipelineConfig.

  Reads evaluation config from file specified by pipeline_config_path flag.

  Returns:
    model_config: a model_pb2.DetectionModel
    eval_config: a eval_pb2.EvalConfig
    input_config: a input_reader_pb2.InputReader
  """
  pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
    text_format.Merge(f.read(), pipeline_config)

  model_config = pipeline_config.model
  if FLAGS.eval_training_data:
    eval_config = pipeline_config.train_config
  else:
    eval_config = pipeline_config.eval_config
  input_config = pipeline_config.eval_input_reader

  return model_config, eval_config, input_config 
Example #2
Source File: dataset_builder_test.py    From Person-Detection-and-Tracking with MIT License 6 votes vote down vote up
def test_build_tf_record_input_reader_with_additional_channels(self):
    tf_record_path = self.create_tf_record(has_additional_channels=True)

    input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(input_reader_text_proto, input_reader_proto)
    tensor_dict = dataset_util.make_initializable_iterator(
        dataset_builder.build(
            input_reader_proto, batch_size=2,
            num_additional_channels=2)).get_next()

    sv = tf.train.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      output_dict = sess.run(tensor_dict)

    self.assertEquals((2, 4, 5, 5),
                      output_dict[fields.InputDataFields.image].shape) 
Example #3
Source File: eval.py    From garbage-object-detection-tensorflow with MIT License 6 votes vote down vote up
def get_configs_from_pipeline_file():
  """Reads evaluation configuration from a pipeline_pb2.TrainEvalPipelineConfig.

  Reads evaluation config from file specified by pipeline_config_path flag.

  Returns:
    model_config: a model_pb2.DetectionModel
    eval_config: a eval_pb2.EvalConfig
    input_config: a input_reader_pb2.InputReader
  """
  pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
    text_format.Merge(f.read(), pipeline_config)

  model_config = pipeline_config.model
  if FLAGS.eval_training_data:
    eval_config = pipeline_config.train_config
  else:
    eval_config = pipeline_config.eval_config
  input_config = pipeline_config.eval_input_reader

  return model_config, eval_config, input_config 
Example #4
Source File: dataset_builder_test.py    From vehicle_counting_tensorflow with MIT License 6 votes vote down vote up
def test_sample_one_of_n_shards(self):
    tf_record_path = self.create_tf_record(num_examples=4)

    input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      sample_1_of_n_examples: 2
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(input_reader_text_proto, input_reader_proto)
    tensor_dict = dataset_builder.make_initializable_iterator(
        dataset_builder.build(input_reader_proto, batch_size=1)).get_next()

    with tf.train.MonitoredSession() as sess:
      output_dict = sess.run(tensor_dict)
      self.assertAllEqual(['0'], output_dict[fields.InputDataFields.source_id])
      output_dict = sess.run(tensor_dict)
      self.assertEquals(['2'], output_dict[fields.InputDataFields.source_id]) 
Example #5
Source File: dataset_builder_test.py    From vehicle_counting_tensorflow with MIT License 6 votes vote down vote up
def test_sample_all_data(self):
    tf_record_path = self.create_tf_record(num_examples=2)

    input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      sample_1_of_n_examples: 1
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(input_reader_text_proto, input_reader_proto)
    tensor_dict = dataset_builder.make_initializable_iterator(
        dataset_builder.build(input_reader_proto, batch_size=1)).get_next()

    with tf.train.MonitoredSession() as sess:
      output_dict = sess.run(tensor_dict)
      self.assertAllEqual(['0'], output_dict[fields.InputDataFields.source_id])
      output_dict = sess.run(tensor_dict)
      self.assertEquals(['1'], output_dict[fields.InputDataFields.source_id]) 
Example #6
Source File: dataset_builder_test.py    From Person-Detection-and-Tracking with MIT License 6 votes vote down vote up
def test_build_tf_record_input_reader_and_load_instance_masks(self):
    tf_record_path = self.create_tf_record()

    input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      load_instance_masks: true
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(input_reader_text_proto, input_reader_proto)
    tensor_dict = dataset_util.make_initializable_iterator(
        dataset_builder.build(input_reader_proto, batch_size=1)).get_next()

    sv = tf.train.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      output_dict = sess.run(tensor_dict)
    self.assertAllEqual(
        (1, 1, 4, 5),
        output_dict[fields.InputDataFields.groundtruth_instance_masks].shape) 
Example #7
Source File: train.py    From garbage-object-detection-tensorflow with MIT License 6 votes vote down vote up
def get_configs_from_pipeline_file():
  """Reads training configuration from a pipeline_pb2.TrainEvalPipelineConfig.

  Reads training config from file specified by pipeline_config_path flag.

  Returns:
    model_config: model_pb2.DetectionModel
    train_config: train_pb2.TrainConfig
    input_config: input_reader_pb2.InputReader
  """
  pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
    text_format.Merge(f.read(), pipeline_config)

  model_config = pipeline_config.model
  train_config = pipeline_config.train_config
  input_config = pipeline_config.train_input_reader

  return model_config, train_config, input_config 
Example #8
Source File: eval.py    From HereIsWally with MIT License 6 votes vote down vote up
def get_configs_from_pipeline_file():
  """Reads evaluation configuration from a pipeline_pb2.TrainEvalPipelineConfig.

  Reads evaluation config from file specified by pipeline_config_path flag.

  Returns:
    model_config: a model_pb2.DetectionModel
    eval_config: a eval_pb2.EvalConfig
    input_config: a input_reader_pb2.InputReader
  """
  pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
    text_format.Merge(f.read(), pipeline_config)

  model_config = pipeline_config.model
  if FLAGS.eval_training_data:
    eval_config = pipeline_config.train_config
  else:
    eval_config = pipeline_config.eval_config
  input_config = pipeline_config.eval_input_reader

  return model_config, eval_config, input_config 
Example #9
Source File: config_util.py    From Person-Detection-and-Tracking with MIT License 6 votes vote down vote up
def _update_input_path(input_config, input_path):
  """Updates input configuration to reflect a new input path.

  The input_config object is updated in place, and hence not returned.

  Args:
    input_config: A input_reader_pb2.InputReader.
    input_path: A path to data or list of paths.

  Raises:
    TypeError: if input reader type is not `tf_record_input_reader`.
  """
  input_reader_type = input_config.WhichOneof("input_reader")
  if input_reader_type == "tf_record_input_reader":
    input_config.tf_record_input_reader.ClearField("input_path")
    if isinstance(input_path, list):
      input_config.tf_record_input_reader.input_path.extend(input_path)
    else:
      input_config.tf_record_input_reader.input_path.append(input_path)
  else:
    raise TypeError("Input reader type must be `tf_record_input_reader`.") 
Example #10
Source File: config_util.py    From ros_people_object_detection_tensorflow with Apache License 2.0 6 votes vote down vote up
def _update_input_path(input_config, input_path):
  """Updates input configuration to reflect a new input path.

  The input_config object is updated in place, and hence not returned.

  Args:
    input_config: A input_reader_pb2.InputReader.
    input_path: A path to data or list of paths.

  Raises:
    TypeError: if input reader type is not `tf_record_input_reader`.
  """
  input_reader_type = input_config.WhichOneof("input_reader")
  if input_reader_type == "tf_record_input_reader":
    input_config.tf_record_input_reader.ClearField("input_path")
    if isinstance(input_path, list):
      input_config.tf_record_input_reader.input_path.extend(input_path)
    else:
      input_config.tf_record_input_reader.input_path.append(input_path)
  else:
    raise TypeError("Input reader type must be `tf_record_input_reader`.") 
Example #11
Source File: train.py    From HereIsWally with MIT License 6 votes vote down vote up
def get_configs_from_pipeline_file():
  """Reads training configuration from a pipeline_pb2.TrainEvalPipelineConfig.

  Reads training config from file specified by pipeline_config_path flag.

  Returns:
    model_config: model_pb2.DetectionModel
    train_config: train_pb2.TrainConfig
    input_config: input_reader_pb2.InputReader
  """
  pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
    text_format.Merge(f.read(), pipeline_config)

  model_config = pipeline_config.model
  train_config = pipeline_config.train_config
  input_config = pipeline_config.train_input_reader

  return model_config, train_config, input_config 
Example #12
Source File: config_util.py    From vehicle_counting_tensorflow with MIT License 6 votes vote down vote up
def _update_tf_record_input_path(input_config, input_path):
  """Updates input configuration to reflect a new input path.

  The input_config object is updated in place, and hence not returned.

  Args:
    input_config: A input_reader_pb2.InputReader.
    input_path: A path to data or list of paths.

  Raises:
    TypeError: if input reader type is not `tf_record_input_reader`.
  """
  input_reader_type = input_config.WhichOneof("input_reader")
  if input_reader_type == "tf_record_input_reader":
    input_config.tf_record_input_reader.ClearField("input_path")
    if isinstance(input_path, list):
      input_config.tf_record_input_reader.input_path.extend(input_path)
    else:
      input_config.tf_record_input_reader.input_path.append(input_path)
  else:
    raise TypeError("Input reader type must be `tf_record_input_reader`.") 
Example #13
Source File: train.py    From object_detector_app with MIT License 6 votes vote down vote up
def get_configs_from_pipeline_file():
  """Reads training configuration from a pipeline_pb2.TrainEvalPipelineConfig.

  Reads training config from file specified by pipeline_config_path flag.

  Returns:
    model_config: model_pb2.DetectionModel
    train_config: train_pb2.TrainConfig
    input_config: input_reader_pb2.InputReader
  """
  pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
    text_format.Merge(f.read(), pipeline_config)

  model_config = pipeline_config.model
  train_config = pipeline_config.train_config
  input_config = pipeline_config.train_input_reader

  return model_config, train_config, input_config 
Example #14
Source File: config_util.py    From Traffic-Rule-Violation-Detection-System with MIT License 6 votes vote down vote up
def _update_input_path(input_config, input_path):
  """Updates input configuration to reflect a new input path.

  The input_config object is updated in place, and hence not returned.

  Args:
    input_config: A input_reader_pb2.InputReader.
    input_path: A path to data or list of paths.

  Raises:
    TypeError: if input reader type is not `tf_record_input_reader`.
  """
  input_reader_type = input_config.WhichOneof("input_reader")
  if input_reader_type == "tf_record_input_reader":
    input_config.tf_record_input_reader.ClearField("input_path")
    if isinstance(input_path, list):
      input_config.tf_record_input_reader.input_path.extend(input_path)
    else:
      input_config.tf_record_input_reader.input_path.append(input_path)
  else:
    raise TypeError("Input reader type must be `tf_record_input_reader`.") 
Example #15
Source File: dataset_builder_test.py    From ros_people_object_detection_tensorflow with Apache License 2.0 6 votes vote down vote up
def test_build_tf_record_input_reader_and_load_instance_masks(self):
    tf_record_path = self.create_tf_record()

    input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      load_instance_masks: true
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(input_reader_text_proto, input_reader_proto)
    tensor_dict = dataset_util.make_initializable_iterator(
        dataset_builder.build(input_reader_proto, batch_size=1)).get_next()

    sv = tf.train.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      output_dict = sess.run(tensor_dict)
    self.assertAllEqual(
        (1, 1, 4, 5),
        output_dict[fields.InputDataFields.groundtruth_instance_masks].shape) 
Example #16
Source File: config_util.py    From yolo_v2 with Apache License 2.0 6 votes vote down vote up
def _update_input_path(input_config, input_path):
  """Updates input configuration to reflect a new input path.

  The input_config object is updated in place, and hence not returned.

  Args:
    input_config: A input_reader_pb2.InputReader.
    input_path: A path to data or list of paths.

  Raises:
    TypeError: if input reader type is not `tf_record_input_reader`.
  """
  input_reader_type = input_config.WhichOneof("input_reader")
  if input_reader_type == "tf_record_input_reader":
    input_config.tf_record_input_reader.ClearField("input_path")
    if isinstance(input_path, list):
      input_config.tf_record_input_reader.input_path.extend(input_path)
    else:
      input_config.tf_record_input_reader.input_path.append(input_path)
  else:
    raise TypeError("Input reader type must be `tf_record_input_reader`.") 
Example #17
Source File: eval.py    From object_detector_app with MIT License 6 votes vote down vote up
def get_configs_from_pipeline_file():
  """Reads evaluation configuration from a pipeline_pb2.TrainEvalPipelineConfig.

  Reads evaluation config from file specified by pipeline_config_path flag.

  Returns:
    model_config: a model_pb2.DetectionModel
    eval_config: a eval_pb2.EvalConfig
    input_config: a input_reader_pb2.InputReader
  """
  pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
    text_format.Merge(f.read(), pipeline_config)

  model_config = pipeline_config.model
  if FLAGS.eval_training_data:
    eval_config = pipeline_config.train_config
  else:
    eval_config = pipeline_config.eval_config
  input_config = pipeline_config.eval_input_reader

  return model_config, eval_config, input_config 
Example #18
Source File: train.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def get_configs_from_pipeline_file():
  """Reads training configuration from a pipeline_pb2.TrainEvalPipelineConfig.

  Reads training config from file specified by pipeline_config_path flag.

  Returns:
    model_config: model_pb2.DetectionModel
    train_config: train_pb2.TrainConfig
    input_config: input_reader_pb2.InputReader
  """
  pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
    text_format.Merge(f.read(), pipeline_config)

  model_config = pipeline_config.model
  train_config = pipeline_config.train_config
  input_config = pipeline_config.train_input_reader

  return model_config, train_config, input_config 
Example #19
Source File: dataset_builder_test.py    From Person-Detection-and-Tracking with MIT License 5 votes vote down vote up
def test_build_tf_record_input_reader(self):
    tf_record_path = self.create_tf_record()

    input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(input_reader_text_proto, input_reader_proto)
    tensor_dict = dataset_util.make_initializable_iterator(
        dataset_builder.build(input_reader_proto, batch_size=1)).get_next()

    sv = tf.train.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      output_dict = sess.run(tensor_dict)

    self.assertTrue(
        fields.InputDataFields.groundtruth_instance_masks not in output_dict)
    self.assertEquals((1, 4, 5, 3),
                      output_dict[fields.InputDataFields.image].shape)
    self.assertAllEqual([[2]],
                        output_dict[fields.InputDataFields.groundtruth_classes])
    self.assertEquals(
        (1, 1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
    self.assertAllEqual(
        [0.0, 0.0, 1.0, 1.0],
        output_dict[fields.InputDataFields.groundtruth_boxes][0][0]) 
Example #20
Source File: input_reader_builder_test.py    From Traffic-Rule-Violation-Detection-System with MIT License 5 votes vote down vote up
def test_build_tf_record_input_reader(self):
    tf_record_path = self.create_tf_record()

    input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(input_reader_text_proto, input_reader_proto)
    tensor_dict = input_reader_builder.build(input_reader_proto)

    sv = tf.train.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      output_dict = sess.run(tensor_dict)

    self.assertTrue(fields.InputDataFields.groundtruth_instance_masks
                    not in output_dict)
    self.assertEquals(
        (4, 5, 3), output_dict[fields.InputDataFields.image].shape)
    self.assertEquals(
        [2], output_dict[fields.InputDataFields.groundtruth_classes])
    self.assertEquals(
        (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
    self.assertAllEqual(
        [0.0, 0.0, 1.0, 1.0],
        output_dict[fields.InputDataFields.groundtruth_boxes][0]) 
Example #21
Source File: dataset_util_test.py    From Traffic-Rule-Violation-Detection-System with MIT License 5 votes vote down vote up
def test_read_dataset_single_epoch(self):
    config = input_reader_pb2.InputReader()
    config.num_epochs = 1
    config.num_readers = 1
    config.shuffle = False

    data = self._get_dataset_next([self._path_template % '0'], config,
                                  batch_size=30)
    with self.test_session() as sess:
      # First batch will retrieve as much as it can, second batch will fail.
      self.assertAllEqual(sess.run(data), [[1, 10]])
      self.assertRaises(tf.errors.OutOfRangeError, sess.run, data) 
Example #22
Source File: dataset_util_test.py    From Traffic-Rule-Violation-Detection-System with MIT License 5 votes vote down vote up
def test_read_dataset(self):
    config = input_reader_pb2.InputReader()
    config.num_readers = 1
    config.shuffle = False

    data = self._get_dataset_next([self._path_template % '*'], config,
                                  batch_size=20)
    with self.test_session() as sess:
      self.assertAllEqual(sess.run(data),
                          [[1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3,
                            30, 4, 40, 5, 50]]) 
Example #23
Source File: input_reader_builder_test.py    From garbage-object-detection-tensorflow with MIT License 5 votes vote down vote up
def test_build_tf_record_input_reader(self):
    tf_record_path = self.create_tf_record()

    input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(input_reader_text_proto, input_reader_proto)
    tensor_dict = input_reader_builder.build(input_reader_proto)

    sv = tf.train.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      output_dict = sess.run(tensor_dict)

    self.assertEquals(
        (4, 5, 3), output_dict[fields.InputDataFields.image].shape)
    self.assertEquals(
        [2], output_dict[fields.InputDataFields.groundtruth_classes])
    self.assertEquals(
        (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
    self.assertAllEqual(
        [0.0, 0.0, 1.0, 1.0],
        output_dict[fields.InputDataFields.groundtruth_boxes][0]) 
Example #24
Source File: dataset_util_test.py    From Person-Detection-and-Tracking with MIT License 5 votes vote down vote up
def test_read_dataset(self):
    config = input_reader_pb2.InputReader()
    config.num_readers = 1
    config.shuffle = False

    data = self._get_dataset_next([self._path_template % '*'], config,
                                  batch_size=20)
    with self.test_session() as sess:
      self.assertAllEqual(sess.run(data),
                          [[1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3,
                            30, 4, 40, 5, 50]]) 
Example #25
Source File: input_reader_builder_test.py    From Person-Detection-and-Tracking with MIT License 5 votes vote down vote up
def test_build_tf_record_input_reader_and_load_instance_masks(self):
    tf_record_path = self.create_tf_record()

    input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      load_instance_masks: true
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(input_reader_text_proto, input_reader_proto)
    tensor_dict = input_reader_builder.build(input_reader_proto)

    sv = tf.train.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      output_dict = sess.run(tensor_dict)

    self.assertEquals(
        (4, 5, 3), output_dict[fields.InputDataFields.image].shape)
    self.assertEquals(
        [2], output_dict[fields.InputDataFields.groundtruth_classes])
    self.assertEquals(
        (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
    self.assertAllEqual(
        [0.0, 0.0, 1.0, 1.0],
        output_dict[fields.InputDataFields.groundtruth_boxes][0])
    self.assertAllEqual(
        (1, 4, 5),
        output_dict[fields.InputDataFields.groundtruth_instance_masks].shape) 
Example #26
Source File: dataset_util_test.py    From Person-Detection-and-Tracking with MIT License 5 votes vote down vote up
def test_reduce_num_reader(self):
    config = input_reader_pb2.InputReader()
    config.num_readers = 10
    config.shuffle = False

    data = self._get_dataset_next([self._path_template % '*'], config,
                                  batch_size=20)
    with self.test_session() as sess:
      self.assertAllEqual(sess.run(data),
                          [[1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3,
                            30, 4, 40, 5, 50]]) 
Example #27
Source File: input_reader_builder_test.py    From Person-Detection-and-Tracking with MIT License 5 votes vote down vote up
def test_build_tf_record_input_reader(self):
    tf_record_path = self.create_tf_record()

    input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(input_reader_text_proto, input_reader_proto)
    tensor_dict = input_reader_builder.build(input_reader_proto)

    sv = tf.train.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      output_dict = sess.run(tensor_dict)

    self.assertTrue(fields.InputDataFields.groundtruth_instance_masks
                    not in output_dict)
    self.assertEquals(
        (4, 5, 3), output_dict[fields.InputDataFields.image].shape)
    self.assertEquals(
        [2], output_dict[fields.InputDataFields.groundtruth_classes])
    self.assertEquals(
        (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
    self.assertAllEqual(
        [0.0, 0.0, 1.0, 1.0],
        output_dict[fields.InputDataFields.groundtruth_boxes][0]) 
Example #28
Source File: dataset_util_test.py    From Person-Detection-and-Tracking with MIT License 5 votes vote down vote up
def test_enable_shuffle(self):
    config = input_reader_pb2.InputReader()
    config.num_readers = 1
    config.shuffle = True

    data = self._get_dataset_next(
        [self._shuffle_path_template % '*'], config, batch_size=10)
    expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

    with self.test_session() as sess:
      self.assertTrue(
          np.any(np.not_equal(sess.run(data), expected_non_shuffle_output))) 
Example #29
Source File: dataset_util_test.py    From Person-Detection-and-Tracking with MIT License 5 votes vote down vote up
def test_disable_shuffle_(self):
    config = input_reader_pb2.InputReader()
    config.num_readers = 1
    config.shuffle = False

    data = self._get_dataset_next(
        [self._shuffle_path_template % '*'], config, batch_size=10)
    expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

    with self.test_session() as sess:
      self.assertAllEqual(sess.run(data), [expected_non_shuffle_output]) 
Example #30
Source File: dataset_util_test.py    From Person-Detection-and-Tracking with MIT License 5 votes vote down vote up
def test_read_dataset_single_epoch(self):
    config = input_reader_pb2.InputReader()
    config.num_epochs = 1
    config.num_readers = 1
    config.shuffle = False

    data = self._get_dataset_next([self._path_template % '0'], config,
                                  batch_size=30)
    with self.test_session() as sess:
      # First batch will retrieve as much as it can, second batch will fail.
      self.assertAllEqual(sess.run(data), [[1, 10]])
      self.assertRaises(tf.errors.OutOfRangeError, sess.run, data)