Python tensorflow.FIFOQueue() Examples
The following are 30
code examples of tensorflow.FIFOQueue().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: queue.py From AugmentedAutoencoder with MIT License | 6 votes |
def __init__(self, dataset, num_threads, queue_size, batch_size): self._dataset = dataset self._num_threads = num_threads self._queue_size = queue_size self._batch_size = batch_size datatypes = 2*['float32'] shapes = 2*[self._dataset.shape] batch_shape = [None]+list(self._dataset.shape) self._placeholders = 2*[ tf.placeholder(dtype=tf.float32, shape=batch_shape), tf.placeholder(dtype=tf.float32, shape=batch_shape) ] self._queue = tf.FIFOQueue(self._queue_size, datatypes, shapes=shapes) self.x, self.y = self._queue.dequeue_up_to(self._batch_size) self.enqueue_op = self._queue.enqueue_many(self._placeholders) self._coordinator = tf.train.Coordinator() self._threads = []
Example #2
Source File: reader_ops_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def _testOneEpoch(self, files): with self.test_session() as sess: reader = tf.TextLineReader(name="test_reader") queue = tf.FIFOQueue(99, [tf.string], shapes=()) key, value = reader.read(queue) queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_lines): k, v = sess.run([key, value]) self.assertAllEqual("%s:%d" % (files[i], j + 1), tf.compat.as_text(k)) self.assertAllEqual(self._LineText(i, j), v) with self.assertRaisesOpError("is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): k, v = sess.run([key, value])
Example #3
Source File: reader_ops_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testSkipHeaderLines(self): files = self._CreateFiles() with self.test_session() as sess: reader = tf.TextLineReader(skip_header_lines=1, name="test_reader") queue = tf.FIFOQueue(99, [tf.string], shapes=()) key, value = reader.read(queue) queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_lines - 1): k, v = sess.run([key, value]) self.assertAllEqual("%s:%d" % (files[i], j + 2), tf.compat.as_text(k)) self.assertAllEqual(self._LineText(i, j + 1), v) with self.assertRaisesOpError("is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): k, v = sess.run([key, value])
Example #4
Source File: reader_ops_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testOneEpoch(self): files = self._CreateFiles() with self.test_session() as sess: reader = tf.FixedLengthRecordReader( header_bytes=self._header_bytes, record_bytes=self._record_bytes, footer_bytes=self._footer_bytes, name="test_reader") queue = tf.FIFOQueue(99, [tf.string], shapes=()) key, value = reader.read(queue) queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_records): k, v = sess.run([key, value]) self.assertAllEqual("%s:%d" % (files[i], j), tf.compat.as_text(k)) self.assertAllEqual(self._Record(i, j), v) with self.assertRaisesOpError("is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): k, v = sess.run([key, value])
Example #5
Source File: random_model.py From ConMask with MIT License | 6 votes |
def manual_eval_ops(self, device='/cpu:0'): """ This is the baseline random model, this takes all the targets, randomly assign values to it and then report the result. :param device: :return: """ with tf.name_scope("namual_evaluation"): with tf.device('/cpu:0'): # head rel pair to evaluate ph_head_rel = tf.placeholder(tf.string, [1, 2], name='ph_head_rel') # tail targets to evaluate ph_eval_targets = tf.placeholder(tf.string, [1, None], name='ph_eval_targets') # indices of true tail targets in ph_eval_targets. Mask these when calculating filtered mean rank ph_true_target_idx = tf.placeholder(tf.int32, [None], name='ph_true_target_idx') # indices of true targets in the evaluation set, we will return the ranks of these targets ph_test_target_idx = tf.placeholder(tf.int32, [None], name='ph_test_target_idx') # We put random numbers into the pred_scores_queue pred_scores_queue = tf.FIFOQueue(1000000, dtypes=tf.float32, shapes=[[1]], name='pred_scorse_queue')
Example #6
Source File: reader_ops_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testOneEpoch(self): files = self._CreateFiles() with self.test_session() as sess: reader = tf.TFRecordReader(name="test_reader") queue = tf.FIFOQueue(99, [tf.string], shapes=()) key, value = reader.read(queue) queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_records): k, v = sess.run([key, value]) self.assertTrue(tf.compat.as_text(k).startswith("%s:" % files[i])) self.assertAllEqual(self._Record(i, j), v) with self.assertRaisesOpError("is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): k, v = sess.run([key, value])
Example #7
Source File: layers.py From PADME with MIT License | 6 votes |
def create_tensor(self, in_layers=None, **kwargs): # TODO(rbharath): Not sure if this layer can be called with __call__ # meaningfully, so not going to support that functionality for now. if in_layers is None: in_layers = self.in_layers in_layers = convert_to_layers(in_layers) self.dtypes = [x.out_tensor.dtype for x in in_layers] self.queue = tf.FIFOQueue(self.capacity, self.dtypes, names=self.names) feed_dict = {x.name: x.out_tensor for x in in_layers} self.out_tensor = self.queue.enqueue(feed_dict) self.close_op = self.queue.close() self.out_tensors = self.queue.dequeue() self._non_pickle_fields += ['queue', 'out_tensors', 'close_op'] # def none_tensors(self): # queue, out_tensors, out_tensor, close_op = self.queue, self.out_tensor, self.out_tensor, self.close_op # self.queue, self.out_tensor, self.out_tensors, self.close_op = None, None, None, None # return queue, out_tensors, out_tensor, close_op # def set_tensors(self, tensors): # self.queue, self.out_tensor, self.out_tensors, self.close_op = tensors
Example #8
Source File: video_avi_flow_saliency.py From self-supervision with BSD 3-Clause "New" or "Revised" License | 6 votes |
def __init__(self, path, batch_size=16, input_size=227, scale_factor=1.0, num_threads=10): self._path = path self._list_files = glob.glob(os.path.join(path, "**/*.avi")) self._batch_size = batch_size self._scale_factor = scale_factor self._image_size = input_size self._label_size = int(input_size * self._scale_factor) self._num_threads = num_threads self._coord = tf.train.Coordinator() self._image_shape = [batch_size, self._image_size, self._image_size, 3] self._label_shape = [batch_size, self._label_size, self._label_size, 1] p_x = tf.placeholder(tf.float32, self._image_shape, name='x') p_y = tf.placeholder(tf.float32, self._label_shape, name='y') inputs = [p_x, p_y] self._queue = tf.FIFOQueue(400, [i.dtype for i in inputs], [i.get_shape() for i in inputs]) self._inputs = inputs self._enqueue_op = self._queue.enqueue(inputs) self._queue_close_op = self._queue.close(cancel_pending_enqueues=True) self._threads = []
Example #9
Source File: cifar10_input_test.py From Gun-Detector with Apache License 2.0 | 6 votes |
def testSimple(self): labels = [9, 3, 0] records = [self._record(labels[0], 0, 128, 255), self._record(labels[1], 255, 0, 1), self._record(labels[2], 254, 255, 0)] contents = b"".join([record for record, _ in records]) expected = [expected for _, expected in records] filename = os.path.join(self.get_temp_dir(), "cifar") open(filename, "wb").write(contents) with self.test_session() as sess: q = tf.FIFOQueue(99, [tf.string], shapes=()) q.enqueue([filename]).run() q.close().run() result = cifar10_input.read_cifar10(q) for i in range(3): key, label, uint8image = sess.run([ result.key, result.label, result.uint8image]) self.assertEqual("%s:%d" % (filename, i), tf.compat.as_text(key)) self.assertEqual(labels[i], label) self.assertAllEqual(expected[i], uint8image) with self.assertRaises(tf.errors.OutOfRangeError): sess.run([result.key, result.uint8image])
Example #10
Source File: reader_ops_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testOneEpoch(self): files = self._CreateFiles() with self.test_session() as sess: options = tf.python_io.TFRecordOptions( compression_type=TFRecordCompressionType.ZLIB) reader = tf.TFRecordReader(name="test_reader", options=options) queue = tf.FIFOQueue(99, [tf.string], shapes=()) key, value = reader.read(queue) queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_records): k, v = sess.run([key, value]) self.assertTrue(tf.compat.as_text(k).startswith("%s:" % files[i])) self.assertAllEqual(self._Record(i, j), v) with self.assertRaisesOpError("is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): k, v = sess.run([key, value])
Example #11
Source File: cifar10_input_test.py From ml with Apache License 2.0 | 6 votes |
def testSimple(self): labels = [9, 3, 0] records = [self._record(labels[0], 0, 128, 255), self._record(labels[1], 255, 0, 1), self._record(labels[2], 254, 255, 0)] contents = b"".join([record for record, _ in records]) expected = [expected for _, expected in records] filename = os.path.join(self.get_temp_dir(), "cifar") open(filename, "wb").write(contents) with self.test_session() as sess: q = tf.FIFOQueue(99, [tf.string], shapes=()) q.enqueue([filename]).run() q.close().run() result = cifar10_input.read_cifar10(q) for i in range(3): key, label, uint8image = sess.run([ result.key, result.label, result.uint8image]) self.assertEqual("%s:%d" % (filename, i), tf.compat.as_text(key)) self.assertEqual(labels[i], label) self.assertAllEqual(expected[i], uint8image) with self.assertRaises(tf.errors.OutOfRangeError): sess.run([result.key, result.uint8image])
Example #12
Source File: video_avi_flow.py From self-supervision with BSD 3-Clause "New" or "Revised" License | 6 votes |
def __init__(self, files, batch_size=16, input_size=227, scale_factor=1.0, num_threads=10): self._list_files = files self._batch_size = batch_size self._scale_factor = scale_factor self._image_size = input_size self._label_size = int(input_size * self._scale_factor) self._num_threads = num_threads self._coord = tf.train.Coordinator() self._image_shape = [batch_size, self._image_size, self._image_size, 3] self._label_shape = [batch_size, self._label_size, self._label_size, 2] p_x = tf.placeholder(tf.float32, self._image_shape, name='x') p_y = tf.placeholder(tf.float32, self._label_shape, name='y') inputs = [p_x, p_y] self._queue = tf.FIFOQueue(400, [i.dtype for i in inputs], [i.get_shape() for i in inputs]) self._inputs = inputs self._enqueue_op = self._queue.enqueue(inputs) self._queue_close_op = self._queue.close(cancel_pending_enqueues=True) self._threads = []
Example #13
Source File: input_pipeline.py From contextualLSTM with Apache License 2.0 | 6 votes |
def enqueue(sess): """ Iterates over our data puts small junks into our queue.""" under = 0 max = len(raw_data) while True: print("starting to write into queue") upper = under + 20 print("try to enqueue ", under, " to ", upper) if upper <= max: curr_data = raw_data[under:upper] curr_target = raw_target[under:upper] under = upper else: rest = upper - max curr_data = np.concatenate((raw_data[under:max], raw_data[0:rest])) curr_target = np.concatenate((raw_target[under:max], raw_target[0:rest])) under = rest sess.run(enqueue_op, feed_dict={queue_input_data: curr_data, queue_input_target: curr_target}) print("added to the queue") print("finished enqueueing") # start the threads for our FIFOQueue and batch
Example #14
Source File: control_flow_ops_py_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testWhileQueue_1(self): with self.test_session(): q = tf.FIFOQueue(-1, tf.int32) i = tf.constant(0) def c(i): return tf.less(i, 10) def b(i): ni = tf.add(i, 1) ni = control_flow_ops.with_dependencies([q.enqueue((i,))], ni) return ni r = tf.while_loop(c, b, [i], parallel_iterations=1) self.assertEqual([10], r.eval()) for i in xrange(10): self.assertEqual([i], q.dequeue().eval())
Example #15
Source File: reader_ops_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testMultipleEpochs(self): with self.test_session() as sess: reader = tf.IdentityReader("test_reader") queue = tf.FIFOQueue(99, [tf.string], shapes=()) enqueue = queue.enqueue_many([["DD", "EE"]]) key, value = reader.read(queue) enqueue.run() self._ExpectRead(sess, key, value, b"DD") self._ExpectRead(sess, key, value, b"EE") enqueue.run() self._ExpectRead(sess, key, value, b"DD") self._ExpectRead(sess, key, value, b"EE") enqueue.run() self._ExpectRead(sess, key, value, b"DD") self._ExpectRead(sess, key, value, b"EE") queue.close().run() with self.assertRaisesOpError("is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): sess.run([key, value])
Example #16
Source File: fifo_queue_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testConstructorWithShapes(self): with tf.Graph().as_default(): q = tf.FIFOQueue(5, (tf.int32, tf.float32), shapes=(tf.TensorShape([1, 1, 2, 3]), tf.TensorShape([5, 8])), name="Q") self.assertTrue(isinstance(q.queue_ref, tf.Tensor)) self.assertEquals(tf.string_ref, q.queue_ref.dtype) self.assertProtoEquals(""" name:'Q' op:'FIFOQueue' attr { key: 'component_types' value { list { type: DT_INT32 type : DT_FLOAT } } } attr { key: 'shapes' value { list { shape { dim { size: 1 } dim { size: 1 } dim { size: 2 } dim { size: 3 } } shape { dim { size: 5 } dim { size: 8 } } } } } attr { key: 'capacity' value { i: 5 } } attr { key: 'container' value { s: '' } } attr { key: 'shared_name' value { s: '' } } """, q.queue_ref.op.node_def)
Example #17
Source File: speech_input.py From speechT with Apache License 2.0 | 6 votes |
def __init__(self, input_size, batch_size, data_generator_creator, max_steps=None): super().__init__(input_size) self.batch_size = batch_size self.data_generator_creator = data_generator_creator self.steps_left = max_steps with tf.device("/cpu:0"): # Define input and label placeholders # inputs is of dimension [batch_size, max_time, input_size] self.inputs = tf.placeholder(tf.float32, [batch_size, None, input_size], name='inputs') self.sequence_lengths = tf.placeholder(tf.int32, [batch_size], name='sequence_lengths') self.labels = tf.sparse_placeholder(tf.int32, name='labels') # Queue for inputs and labels self.queue = tf.FIFOQueue(dtypes=[tf.float32, tf.int32, tf.string], capacity=100) # queues do not support sparse tensors yet, we need to serialize... serialized_labels = tf.serialize_many_sparse(self.labels) self.enqueue_op = self.queue.enqueue([self.inputs, self.sequence_lengths, serialized_labels])
Example #18
Source File: video_jpeg_rolls_flow_saliency.py From self-supervision with BSD 3-Clause "New" or "Revised" License | 6 votes |
def __init__(self, path, root_path='', batch_size=16, input_size=227, num_threads=10): self._path = path self._root_path = root_path with open(path) as f: self._list_files = [x.rstrip('\n') for x in f.readlines()] print('list_files', len(self._list_files)) self._batch_size = batch_size self._input_size = input_size self._num_threads = num_threads self._coord = tf.train.Coordinator() self._base_shape = [batch_size, input_size, input_size] self._image_shape = self._base_shape + [3] self._label_shape = self._base_shape + [1] p_x = tf.placeholder(tf.float32, self._image_shape, name='x') p_y = tf.placeholder(tf.float32, self._label_shape, name='y') inputs = [p_x, p_y] self._queue = tf.FIFOQueue(400, [i.dtype for i in inputs], [i.get_shape() for i in inputs]) self._inputs = inputs self._enqueue_op = self._queue.enqueue(inputs) self._queue_close_op = self._queue.close(cancel_pending_enqueues=True) self._threads = []
Example #19
Source File: cifar10_input_test.py From yolo_v2 with Apache License 2.0 | 6 votes |
def testSimple(self): labels = [9, 3, 0] records = [self._record(labels[0], 0, 128, 255), self._record(labels[1], 255, 0, 1), self._record(labels[2], 254, 255, 0)] contents = b"".join([record for record, _ in records]) expected = [expected for _, expected in records] filename = os.path.join(self.get_temp_dir(), "cifar") open(filename, "wb").write(contents) with self.test_session() as sess: q = tf.FIFOQueue(99, [tf.string], shapes=()) q.enqueue([filename]).run() q.close().run() result = cifar10_input.read_cifar10(q) for i in range(3): key, label, uint8image = sess.run([ result.key, result.label, result.uint8image]) self.assertEqual("%s:%d" % (filename, i), tf.compat.as_text(key)) self.assertEqual(labels[i], label) self.assertAllEqual(expected[i], uint8image) with self.assertRaises(tf.errors.OutOfRangeError): sess.run([result.key, result.uint8image])
Example #20
Source File: fifo_queue_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testParallelEnqueue(self): with self.test_session() as sess: q = tf.FIFOQueue(10, tf.float32) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] enqueue_ops = [q.enqueue((x,)) for x in elems] dequeued_t = q.dequeue() # Run one producer thread for each element in elems. def enqueue(enqueue_op): sess.run(enqueue_op) threads = [self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops] for thread in threads: thread.start() for thread in threads: thread.join() # Dequeue every element using a single thread. results = [] for _ in xrange(len(elems)): results.append(dequeued_t.eval()) self.assertItemsEqual(elems, results)
Example #21
Source File: fifo_queue_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testParallelDequeue(self): with self.test_session() as sess: q = tf.FIFOQueue(10, tf.float32) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] enqueue_ops = [q.enqueue((x,)) for x in elems] dequeued_t = q.dequeue() # Enqueue every element using a single thread. for enqueue_op in enqueue_ops: enqueue_op.run() # Run one consumer thread for each element in elems. results = [] def dequeue(): results.append(sess.run(dequeued_t)) threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops] for thread in threads: thread.start() for thread in threads: thread.join() self.assertItemsEqual(elems, results)
Example #22
Source File: fifo_queue_test.py From deep_image_model with Apache License 2.0 | 5 votes |
def testEnqueueHalf(self): with self.test_session(): q = tf.FIFOQueue(10, tf.float16) enqueue_op = q.enqueue((10.0,)) enqueue_op.run()
Example #23
Source File: fifo_queue_test.py From deep_image_model with Apache License 2.0 | 5 votes |
def testEnqueue(self): with self.test_session(): q = tf.FIFOQueue(10, tf.float32) enqueue_op = q.enqueue((10.0,)) enqueue_op.run()
Example #24
Source File: collect.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def __init__(self, batch_env): super(_MemoryWrapper, self).__init__(batch_env) infinity = 10000000 meta_data = list(zip(*_rollout_metadata(batch_env))) # In memory wrapper we do not collect pdfs neither value_function # thus we only need the first 4 entries of meta_data shapes = meta_data[0][:4] dtypes = meta_data[1][:4] self.speculum = tf.FIFOQueue(infinity, shapes=shapes, dtypes=dtypes) observs_shape = batch_env.observ.shape # TODO(piotrmilos): possibly retrieve the observation type for batch_env self._observ = tf.Variable(tf.zeros(observs_shape, self.observ_dtype), trainable=False)
Example #25
Source File: train_test.py From Table-Detection-using-Deep-learning with BSD 3-Clause "New" or "Revised" License | 5 votes |
def get_dataset(self, dataset_type): """ Mocks luminoth.datasets.datasets.get_dataset """ def dataset_class(arg2): def build(): queue_dtypes = [tf.float32, tf.int32, tf.string] queue_names = ['image', 'bboxes', 'filename'] queue = tf.FIFOQueue( capacity=3, dtypes=queue_dtypes, names=queue_names, name='fifo_queue' ) filename = tf.cast('filename_test', tf.string) filename = tf.train.limit_epochs([filename], num_epochs=2) data = { 'image': tf.random_uniform([600, 800, 3], maxval=255), 'bboxes': tf.constant([[0, 0, 30, 30, 0]]), 'filename': filename } enqueue_ops = [queue.enqueue(data)] * 2 tf.train.add_queue_runner( tf.train.QueueRunner(queue, enqueue_ops)) return queue.dequeue() return build return dataset_class
Example #26
Source File: tf_queue.py From Re3 with GNU General Public License v3.0 | 5 votes |
def __init__(self, sess, placeholders, max_queue_size, max_queue_uses, use_random_order, batch_size): self.sess = sess self.placeholders = placeholders self.max_queue_size = max_queue_size self.max_queue_uses = max_queue_uses self.data_buffer = [] self.data_counts = np.zeros(max_queue_size) self.lock = threading.Lock() self.batch_size = batch_size self.enqueue_batch_size = self.placeholders[0].get_shape().as_list()[0] self.use_random_order = use_random_order self.num_samples = 0 # Set up queue and operations with tf.device('/cpu:0'): self.queue = tf.FIFOQueue(self.max_queue_size, [placeholder.dtype for placeholder in self.placeholders], shapes=[placeholder.get_shape().as_list()[1:] for placeholder in self.placeholders]) self.enqueue_op = self.queue.enqueue_many(self.placeholders) self.placeholder_outs = {self.placeholders[ii] : val for ii,val in enumerate(self.queue.dequeue_many(self.batch_size))} self.size = self.queue.size() # Start thread self.thread = threading.Thread(target=self.tf_enqueue_data) self.thread.daemon = True self.thread.start()
Example #27
Source File: benchmark_cnn.py From parallax with Apache License 2.0 | 5 votes |
def _benchmark_cnn(self): """Run cnn in benchmark mode. When forward_only on, it forwards CNN. Returns: Dictionary containing training statistics (num_workers, num_steps, average_wall_time, images_per_sec). """ (image_producer_ops, fetches) = self._build_model() fetches_list = nest.flatten(list(fetches.values())) main_fetch_group = tf.group(*fetches_list) global_step = tf.train.get_global_step() with tf.device('/cpu:0'): with tf.control_dependencies([main_fetch_group]): self.train_op = global_step.assign_add(1, use_locking=True) local_var_init_op = tf.local_variables_initializer() variable_mgr_init_ops = [local_var_init_op] local_var_init_op_group = tf.group(*variable_mgr_init_ops) if not self.use_synthetic_gpu_images: dummy_queue = tf.FIFOQueue(1, [tf.bool], shapes=[[]], name='dummy_queue', shared_name='dummy_queue') qr = tf.train.QueueRunner(dummy_queue, image_producer_ops) tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, qr)
Example #28
Source File: collect.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def __init__(self, batch_env): super(_MemoryWrapper, self).__init__(batch_env) infinity = 10000000 meta_data = list(zip(*_rollout_metadata(batch_env))) # In memory wrapper we do not collect pdfs neither value_function # thus we only need the first 4 entries of meta_data shapes = meta_data[0][:4] dtypes = meta_data[1][:4] self.speculum = tf.FIFOQueue(infinity, shapes=shapes, dtypes=dtypes) observs_shape = batch_env.observ.shape # TODO(piotrmilos): possibly retrieve the observation type for batch_env self._observ = tf.Variable(tf.zeros(observs_shape, self.observ_dtype), trainable=False)
Example #29
Source File: ppo_learner.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def __init__(self, batch_env): super(_MemoryWrapper, self).__init__(batch_env) infinity = 10000000 meta_data = list(zip(*_rollout_metadata(batch_env))) # In memory wrapper we do not collect pdfs neither value_function # thus we only need the first 4 entries of meta_data shapes = meta_data[0][:4] dtypes = meta_data[1][:4] self.speculum = tf.FIFOQueue(infinity, shapes=shapes, dtypes=dtypes) observs_shape = batch_env.observ.shape # TODO(piotrmilos): possibly retrieve the observation type for batch_env self._observ = tf.Variable( tf.zeros(observs_shape, self.observ_dtype), trainable=False)
Example #30
Source File: ppo_learner.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def __init__(self, batch_env): super(_MemoryWrapper, self).__init__(batch_env) infinity = 10000000 meta_data = list(zip(*_rollout_metadata(batch_env))) # In memory wrapper we do not collect pdfs neither value_function # thus we only need the first 4 entries of meta_data shapes = meta_data[0][:4] dtypes = meta_data[1][:4] self.speculum = tf.FIFOQueue(infinity, shapes=shapes, dtypes=dtypes) observs_shape = batch_env.observ.shape # TODO(piotrmilos): possibly retrieve the observation type for batch_env self._observ = tf.Variable( tf.zeros(observs_shape, self.observ_dtype), trainable=False)