Python tensorflow.FixedLenFeature() Examples
The following are 30
code examples of tensorflow.FixedLenFeature().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: dataset.py From DNA-GAN with MIT License | 7 votes |
def parse_fn(self, serialized_example): features={ 'image/id_name': tf.FixedLenFeature([], tf.string), 'image/height' : tf.FixedLenFeature([], tf.int64), 'image/width' : tf.FixedLenFeature([], tf.int64), 'image/encoded': tf.FixedLenFeature([], tf.string), } for name in self.feature_list: features[name] = tf.FixedLenFeature([], tf.int64) example = tf.parse_single_example(serialized_example, features=features) image = tf.decode_raw(example['image/encoded'], tf.uint8) raw_height = tf.cast(example['image/height'], tf.int32) raw_width = tf.cast(example['image/width'], tf.int32) image = tf.reshape(image, [raw_height, raw_width, 3]) image = tf.image.resize_images(image, size=[self.height, self.width]) # from IPython import embed; embed(); exit() feature_val_list = [tf.cast(example[name], tf.float32) for name in self.feature_list] return image, feature_val_list
Example #2
Source File: readers.py From Youtube-8M-WILLOW with Apache License 2.0 | 6 votes |
def prepare_serialized_examples(self, serialized_examples): # set the mapping from the fields to data types in the proto num_features = len(self.feature_names) assert num_features > 0, "self.feature_names is empty!" assert len(self.feature_names) == len(self.feature_sizes), \ "length of feature_names (={}) != length of feature_sizes (={})".format( \ len(self.feature_names), len(self.feature_sizes)) feature_map = {"video_id": tf.FixedLenFeature([], tf.string), "labels": tf.VarLenFeature(tf.int64)} for feature_index in range(num_features): feature_map[self.feature_names[feature_index]] = tf.FixedLenFeature( [self.feature_sizes[feature_index]], tf.float32) features = tf.parse_example(serialized_examples, features=feature_map) labels = tf.sparse_to_indicator(features["labels"], self.num_classes) labels.set_shape([None, self.num_classes]) concatenated_features = tf.concat([ features[feature_name] for feature_name in self.feature_names], 1) return features["video_id"], concatenated_features, labels, tf.ones([tf.shape(serialized_examples)[0]])
Example #3
Source File: reader.py From CapsLayer with Apache License 2.0 | 6 votes |
def parse_fun(serialized_example): """ Data parsing function. """ features = tf.parse_single_example(serialized_example, features={'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64), 'height': tf.FixedLenFeature([], tf.int64), 'width': tf.FixedLenFeature([], tf.int64), 'depth': tf.FixedLenFeature([], tf.int64)}) height = tf.cast(features['height'], tf.int32) width = tf.cast(features['width'], tf.int32) depth = tf.cast(features['depth'], tf.int32) image = tf.decode_raw(features['image'], tf.float32) image = tf.reshape(image, shape=[height * width * depth]) image.set_shape([28 * 28 * 1]) image = tf.cast(image, tf.float32) * (1. / 255) label = tf.cast(features['label'], tf.int32) features = {'images': image, 'labels': label} return(features)
Example #4
Source File: gym_problems.py From fine-lm with MIT License | 6 votes |
def extra_reading_spec(self): """Additional data fields to store on disk and their decoders.""" # TODO(piotrmilos): shouldn't done be included here? data_fields = { "frame_number": tf.FixedLenFeature([1], tf.int64), "action": tf.FixedLenFeature([1], tf.int64), "reward": tf.FixedLenFeature([1], tf.int64) } decoders = { "frame_number": tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="frame_number"), "action": tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="action"), "reward": tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="reward"), } return data_fields, decoders
Example #5
Source File: read_tfrecord.py From 2019-CCF-BDCI-OCR-MCZJ-OCR-IdentificationIDElement with MIT License | 6 votes |
def _extract_features_batch(self, serialized_batch): features = tf.parse_example( serialized_batch, features={'images': tf.FixedLenFeature([], tf.string), 'imagepaths': tf.FixedLenFeature([], tf.string), 'labels': tf.VarLenFeature(tf.int64), }) bs = features['images'].shape[0] images = tf.decode_raw(features['images'], tf.uint8) w, h = tuple(CFG.ARCH.INPUT_SIZE) images = tf.cast(x=images, dtype=tf.float32) #images = tf.subtract(tf.divide(images, 128.0), 1.0) images = tf.reshape(images, [bs, h, -1, CFG.ARCH.INPUT_CHANNELS]) labels = features['labels'] labels = tf.cast(labels, tf.int32) imagepaths = features['imagepaths'] return images, labels, imagepaths
Example #6
Source File: reader.py From CapsLayer with Apache License 2.0 | 6 votes |
def parse_fun(serialized_example): """ Data parsing function. """ features = tf.parse_single_example(serialized_example, features={'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64), 'height': tf.FixedLenFeature([], tf.int64), 'width': tf.FixedLenFeature([], tf.int64), 'depth': tf.FixedLenFeature([], tf.int64)}) height = tf.cast(features['height'], tf.int32) width = tf.cast(features['width'], tf.int32) depth = tf.cast(features['depth'], tf.int32) image = tf.decode_raw(features['image'], tf.float32) image = tf.reshape(image, shape=[height * width * depth]) image.set_shape([28 * 28 * 1]) image = tf.cast(image, tf.float32) * (1. / 255) label = tf.cast(features['label'], tf.int32) features = {'images': image, 'labels': label} return(features)
Example #7
Source File: train.py From centernet_tensorflow_wilderface_voc with MIT License | 6 votes |
def parse_color_data(example_proto): features = {"img_raw": tf.FixedLenFeature([], tf.string), "label": tf.FixedLenFeature([], tf.string), "width": tf.FixedLenFeature([], tf.int64), "height": tf.FixedLenFeature([], tf.int64)} parsed_features = tf.parse_single_example(example_proto, features) img = parsed_features["img_raw"] img = tf.decode_raw(img, tf.uint8) width = parsed_features["width"] height = parsed_features["height"] img = tf.reshape(img, [height, width, 3]) img = tf.cast(img, tf.float32) * (1. / 255.) - 0.5 label = parsed_features["label"] label = tf.decode_raw(label, tf.float32) return img, label
Example #8
Source File: loader.py From SketchCNN with MIT License | 6 votes |
def _read_raw(self): """Read raw data from TFRecord. Returns: :return: data list [input_raw, label_raw]. """ self._reader = tf.TFRecordReader() _, serialized_example = self._reader.read(self._queue) features = tf.parse_single_example(serialized_example, features={ 'name': tf.FixedLenFeature([], tf.string), 'block': tf.FixedLenFeature([], tf.string) }) input_raw, label_raw = decode_block(features['block'], tensor_size=self._raw_size) if self._with_key: return input_raw, label_raw, features['name'] return input_raw, label_raw
Example #9
Source File: video_utils.py From fine-lm with MIT License | 6 votes |
def example_reading_spec(self): extra_data_fields, extra_data_items_to_decoders = self.extra_reading_spec data_fields = { "image/encoded": tf.FixedLenFeature((), tf.string), "image/format": tf.FixedLenFeature((), tf.string), } data_fields.update(extra_data_fields) data_items_to_decoders = { "frame": tf.contrib.slim.tfexample_decoder.Image( image_key="image/encoded", format_key="image/format", shape=[self.frame_height, self.frame_width, self.num_channels], channels=self.num_channels), } data_items_to_decoders.update(extra_data_items_to_decoders) return data_fields, data_items_to_decoders
Example #10
Source File: vfn_train.py From view-finding-network with GNU General Public License v3.0 | 6 votes |
def read_and_decode_aug(filename_queue): reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, # Defaults are not specified since both keys are required. features={ 'image_raw': tf.FixedLenFeature([], tf.string), }) image = tf.decode_raw(features['image_raw'], tf.uint8) image = tf.image.random_flip_left_right(tf.reshape(image, [227, 227, 6])) # Convert from [0, 255] -> [-0.5, 0.5] floats. image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 image = tf.image.random_brightness(image, 0.01) image = tf.image.random_contrast(image, 0.95, 1.05) return tf.split(image, 2, 2) # 3rd dimension two parts
Example #11
Source File: vfn_train.py From view-finding-network with GNU General Public License v3.0 | 6 votes |
def read_and_decode(filename_queue): reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, # Defaults are not specified since both keys are required. features={ 'image_raw': tf.FixedLenFeature([], tf.string), }) image = tf.decode_raw(features['image_raw'], tf.uint8) image = tf.reshape(image, [227, 227, 6]) # Convert from [0, 255] -> [-0.5, 0.5] floats. image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 return tf.split(image, 2, 2) # 3rd dimension two parts
Example #12
Source File: inputs.py From ffn with Apache License 2.0 | 6 votes |
def load_patch_coordinates_from_filename_queue(filename_queue): """Loads coordinates and volume names from filename queue. Args: filename_queue: Tensorflow queue created from create_filename_queue() Returns: Tuple of coordinates (shape `[1, 3]`) and volume name (shape `[1]`) tensors. """ record_options = tf.python_io.TFRecordOptions( tf.python_io.TFRecordCompressionType.GZIP) keys, protos = tf.TFRecordReader(options=record_options).read(filename_queue) examples = tf.parse_single_example(protos, features=dict( center=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64), label_volume_name=tf.FixedLenFeature(shape=[1], dtype=tf.string), )) coord = examples['center'] volname = examples['label_volume_name'] return coord, volname
Example #13
Source File: model.py From cloudml-dist-mnist-example with Apache License 2.0 | 6 votes |
def read_and_decode(filename_queue): reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64), }) image = tf.decode_raw(features['image_raw'], tf.uint8) image.set_shape([784]) image = tf.cast(image, tf.float32) * (1. / 255) label = tf.cast(features['label'], tf.int32) return image, label
Example #14
Source File: tf_schema_utils_test.py From spotify-tensorflow with Apache License 2.0 | 5 votes |
def test_round_trip(self): feature_spec = { "scalar_feature_1": tf.FixedLenFeature(shape=[], dtype=tf.int64), "scalar_feature_2": tf.FixedLenFeature(shape=[], dtype=tf.int64), "scalar_feature_3": tf.FixedLenFeature(shape=[], dtype=tf.float32), "varlen_feature_1": tf.VarLenFeature(dtype=tf.float32), "varlen_feature_2": tf.VarLenFeature(dtype=tf.string), "1d_vector_feature": tf.FixedLenFeature(shape=[1], dtype=tf.string), "2d_vector_feature": tf.FixedLenFeature(shape=[2, 2], dtype=tf.float32), "sparse_feature": tf.SparseFeature("idx", "value", tf.float32, 10), } inferred_schema = feature_spec_to_schema(feature_spec) inferred_feature_spec = schema_to_feature_spec(inferred_schema) self.assertEqual(inferred_feature_spec, feature_spec)
Example #15
Source File: test_data2.py From basenji with Apache License 2.0 | 5 votes |
def parse_proto(example_protos): features = { 'genome': tf.FixedLenFeature([1], tf.int64), 'sequence': tf.FixedLenFeature([], tf.string), 'target': tf.FixedLenFeature([], tf.string) } parsed_features = tf.parse_example(example_protos, features=features) genome = parsed_features['genome'] seq = tf.decode_raw(parsed_features['sequence'], tf.uint8) targets = tf.decode_raw(parsed_features['target'], tf.float16) return {'genome': genome, 'sequence': seq, 'target': targets}
Example #16
Source File: reader.py From CycleGAN-TensorFlow with MIT License | 5 votes |
def feed(self): """ Returns: images: 4D tensor [batch_size, image_width, image_height, image_depth] """ with tf.name_scope(self.name): filename_queue = tf.train.string_input_producer([self.tfrecords_file]) reader = tf.TFRecordReader() _, serialized_example = self.reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'image/file_name': tf.FixedLenFeature([], tf.string), 'image/encoded_image': tf.FixedLenFeature([], tf.string), }) image_buffer = features['image/encoded_image'] image = tf.image.decode_jpeg(image_buffer, channels=3) image = self._preprocess(image) images = tf.train.shuffle_batch( [image], batch_size=self.batch_size, num_threads=self.num_threads, capacity=self.min_queue_examples + 3*self.batch_size, min_after_dequeue=self.min_queue_examples ) tf.summary.image('_input', images) return images
Example #17
Source File: test_data.py From basenji with Apache License 2.0 | 5 votes |
def parse_proto(example_protos): features = { tfrecord_util.TFR_INPUT: tf.FixedLenFeature([], tf.string), tfrecord_util.TFR_OUTPUT: tf.FixedLenFeature([], tf.string) } parsed_features = tf.parse_example(example_protos, features=features) seq = tf.decode_raw(parsed_features[tfrecord_util.TFR_INPUT], tf.uint8) targets = tf.decode_raw(parsed_features[tfrecord_util.TFR_OUTPUT], tf.float16) return {tfrecord_util.TFR_INPUT: seq, tfrecord_util.TFR_OUTPUT: targets}
Example #18
Source File: tfr_qc.py From basenji with Apache License 2.0 | 5 votes |
def parse_proto(example_protos): features = { 'genome': tf.FixedLenFeature([1], tf.int64), 'sequence': tf.FixedLenFeature([], tf.string), 'target': tf.FixedLenFeature([], tf.string) } parsed_features = tf.parse_example(example_protos, features=features) seq = tf.decode_raw(parsed_features['sequence'], tf.uint8) targets = tf.decode_raw(parsed_features['target'], tf.float16) return {'sequence': seq, 'targets': targets} ################################################################################ # __main__ ################################################################################
Example #19
Source File: tfr_bw.py From basenji with Apache License 2.0 | 5 votes |
def parse_proto(example_protos): features = { 'genome': tf.FixedLenFeature([1], tf.int64), 'sequence': tf.FixedLenFeature([], tf.string), 'target': tf.FixedLenFeature([], tf.string) } parsed_features = tf.parse_example(example_protos, features=features) genome = parsed_features['genome'] seq = tf.decode_raw(parsed_features['sequence'], tf.uint8) targets = tf.decode_raw(parsed_features['target'], tf.float16) return {'genome': genome, 'sequence': seq, 'target': targets} ################################################################################ # __main__ ################################################################################
Example #20
Source File: test_augment.py From basenji with Apache License 2.0 | 5 votes |
def make_data_op(tfr_pattern, seq_length, target_length): dataset = tf.data.Dataset.list_files(tfr_pattern) def file_to_records(filename): return tf.data.TFRecordDataset(filename, compression_type='ZLIB') dataset = dataset.flat_map(file_to_records) def parse_proto(example_protos): features = { tfrecord_util.TFR_INPUT: tf.FixedLenFeature([], tf.string), tfrecord_util.TFR_OUTPUT: tf.FixedLenFeature([], tf.string) } parsed_features = tf.parse_example(example_protos, features=features) seq = tf.decode_raw(parsed_features[tfrecord_util.TFR_INPUT], tf.uint8) seq = tf.reshape(seq, [1, seq_length, -1]) seq = tf.cast(seq, tf.float32) targets = tf.decode_raw(parsed_features[tfrecord_util.TFR_OUTPUT], tf.float16) targets = tf.reshape(targets, (1, target_length, -1)) targets = tf.cast(targets, tf.float32) na = tf.zeros(targets.shape[:-1], dtype=tf.bool) return {'sequence': seq, 'label': targets, 'na':na} dataset = dataset.batch(1) dataset = dataset.map(parse_proto) iterator = dataset.make_one_shot_iterator() try: next_op = iterator.get_next() except tf.errors.OutOfRangeError: print('TFRecord pattern %s is empty' % self.tfr_pattern, file=sys.stderr) exit(1) return next_op ################################################################################ # __main__ ################################################################################
Example #21
Source File: dataset.py From rcan-tensorflow with MIT License | 5 votes |
def parse_tfr_tf(record): features = tf.parse_single_example(record, features={ 'shape': tf.FixedLenFeature([3], tf.int64), 'data': tf.FixedLenFeature([], tf.string)}) data = tf.decode_raw(features['data'], tf.uint8) return tf.reshape(data, features['shape'])
Example #22
Source File: reader.py From CapsLayer with Apache License 2.0 | 5 votes |
def parse_fun(serialized_example): """ Data parsing function. """ features = tf.parse_single_example(serialized_example, features={'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)}) image = tf.decode_raw(features['image'], tf.float32) image = tf.reshape(image, shape=[32 * 32 * 3]) image.set_shape([32 * 32 * 3]) image = tf.cast(image, tf.float32) * (1. / 255) label = tf.cast(features['label'], tf.int32) features = {'images': image, 'labels': label} return(features)
Example #23
Source File: example_decoders_test.py From spotify-tensorflow with Apache License 2.0 | 5 votes |
def test_example_with_feature_spec_decoder(self): feature_spec = { "scalar_feature_1": tf.FixedLenFeature(shape=[], dtype=tf.int64), "scalar_feature_2": tf.FixedLenFeature(shape=[], dtype=tf.int64), "scalar_feature_3": tf.FixedLenFeature(shape=[], dtype=tf.float32), "varlen_feature_1": tf.VarLenFeature(dtype=tf.float32), "varlen_feature_2": tf.VarLenFeature(dtype=tf.string), "1d_vector_feature": tf.FixedLenFeature(shape=[1], dtype=tf.string), "2d_vector_feature": tf.FixedLenFeature(shape=[2, 2], dtype=tf.float32), "sparse_feature": tf.SparseFeature("sparse_feature_idx", "sparse_feature_value", tf.float32, 10), } dec = ExampleWithFeatureSpecDecoder(feature_spec) actual_json = json.loads(dec.to_json(self.example_str)) expected_decoded = { "scalar_feature_1": 12, "scalar_feature_2": 12, "scalar_feature_3": 1.0, "varlen_feature_1": [89.0], "1d_vector_feature": ["this is a ,text"], "2d_vector_feature": [[1.0, 2.0], [3.0, 4.0]], "varlen_feature_2": ["female"], "sparse_feature_idx": [1, 4], "sparse_feature_value": [12.0, 20.0], } self.assertEqual(actual_json, expected_decoded)
Example #24
Source File: dataset_test.py From spotify-tensorflow with Apache License 2.0 | 5 votes |
def _write_test_data(): schema = feature_spec_to_schema({"f1": tf.FixedLenFeature((), tf.int64), "f2": tf.FixedLenFeature((), tf.int64)}) values = [{"f1": 1, "f2": 2}] example_proto = [example_pb2.Example(features=feature_pb2.Features(feature={ k: feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=[v])) for k, v in d.items() })) for d in values] return DataUtil.write_test_data(example_proto, schema)
Example #25
Source File: tf_io_pipline_tools.py From lanenet-lane-detection with Apache License 2.0 | 5 votes |
def decode(serialized_example): """ Parses an image and label from the given `serialized_example` :param serialized_example: :return: """ features = tf.parse_single_example( serialized_example, # Defaults are not specified since both keys are required. features={ 'gt_image_raw': tf.FixedLenFeature([], tf.string), 'gt_binary_image_raw': tf.FixedLenFeature([], tf.string), 'gt_instance_image_raw': tf.FixedLenFeature([], tf.string) }) # decode gt image gt_image_shape = tf.stack([RESIZE_IMAGE_HEIGHT, RESIZE_IMAGE_WIDTH, 3]) gt_image = tf.decode_raw(features['gt_image_raw'], tf.uint8) gt_image = tf.reshape(gt_image, gt_image_shape) # decode gt binary image gt_binary_image_shape = tf.stack([RESIZE_IMAGE_HEIGHT, RESIZE_IMAGE_WIDTH, 1]) gt_binary_image = tf.decode_raw(features['gt_binary_image_raw'], tf.uint8) gt_binary_image = tf.reshape(gt_binary_image, gt_binary_image_shape) # decode gt instance image gt_instance_image_shape = tf.stack([RESIZE_IMAGE_HEIGHT, RESIZE_IMAGE_WIDTH, 1]) gt_instance_image = tf.decode_raw(features['gt_instance_image_raw'], tf.uint8) gt_instance_image = tf.reshape(gt_instance_image, gt_instance_image_shape) return gt_image, gt_binary_image, gt_instance_image
Example #26
Source File: dataset.py From spotify-tensorflow with Apache License 2.0 | 5 votes |
def parse_schema(cls, schema_path): # type: (str) -> Tuple[Dict[str, Union[tf.FixedLenFeature, tf.VarLenFeature, tf.SparseFeature]], Schema] # noqa: E501 """ Returns TensorFlow Feature Spec and parsed tf.metadata Schema for given tf.metadata Schema. :param schema_path: tf.metadata Schema path """ schema = parse_schema_file(schema_path) return schema_to_feature_spec(schema), schema
Example #27
Source File: cifar10_tf.py From deep_architect with MIT License | 5 votes |
def parser(self, serialized_example): """Parses a single tf.Example into image and label tensors.""" # Dimensions of the images in the CIFAR-10 dataset. # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the # input format. features = tf.parse_single_example(serialized_example, features={ 'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64), }) image = tf.decode_raw(features['image'], tf.uint8) image.set_shape([DEPTH * HEIGHT * WIDTH]) # Reshape from [depth * height * width] to [depth, height, width]. image = tf.cast( tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]), tf.float32) label = tf.cast(features['label'], tf.int32) # # Custom preprocessing. # image = self.preprocess(image) return image, label
Example #28
Source File: dataset.py From spotify-tensorflow with Apache License 2.0 | 5 votes |
def parse_schema_from_stats(cls, stats_path): # type: (str) -> Tuple[Dict[str, Union[tf.FixedLenFeature, tf.VarLenFeature, tf.SparseFeature]], Schema] # noqa: E501 """ Returns TensorFlow Feature Spec and parsed tf.metadata Schema for given tf.metadata DatasetFeatureStatisticsList. :param stats_path: tf.metadata DatasetFeatureStatisticsList path """ import tensorflow_data_validation as tfdv stats = tfdv.load_statistics(stats_path) schema = tfdv.infer_schema(stats) return schema_to_feature_spec(schema), schema
Example #29
Source File: dataset.py From spotify-tensorflow with Apache License 2.0 | 5 votes |
def _examples(cls, file_pattern, # type: str schema_path=None, # type: str feature_spec=None, # type: Dict[str, Union[tf.FixedLenFeature, tf.VarLenFeature, tf.SparseFeature]] # noqa: E501 default_value=0, # type: float compression_type=None, # type: str batch_size=128, # type: int shuffle=True, # type: bool num_epochs=1, # type: int shuffle_buffer_size=10000, # type: int shuffle_seed=None, # type: int prefetch_buffer_size=1, # type: int reader_num_threads=1, # type: int parser_num_threads=2, # type: int sloppy_ordering=False, # type: bool drop_final_batch=False # type: bool ): # type: (...) -> Iterator[pd.DataFrame] Datasets._assert_eager("DataFrame") dataset = Datasets.dict._examples(file_pattern=file_pattern, schema_path=schema_path, default_value=default_value, feature_spec=feature_spec, compression_type=compression_type, batch_size=batch_size, shuffle=shuffle, num_epochs=num_epochs, shuffle_buffer_size=shuffle_buffer_size, shuffle_seed=shuffle_seed, prefetch_buffer_size=prefetch_buffer_size, reader_num_threads=reader_num_threads, parser_num_threads=parser_num_threads, sloppy_ordering=sloppy_ordering, drop_final_batch=drop_final_batch) for d in dataset: yield pd.DataFrame(data=d)
Example #30
Source File: dataset.py From spotify-tensorflow with Apache License 2.0 | 5 votes |
def _examples(cls, file_pattern, # type: str schema_path=None, # type: str feature_spec=None, # type: Dict[str, Union[tf.FixedLenFeature, tf.VarLenFeature, tf.SparseFeature]] # noqa: E501 compression_type=None, # type: str batch_size=128, # type: int shuffle=True, # type: bool num_epochs=1, # type: int shuffle_buffer_size=10000, # type: int shuffle_seed=None, # type: int prefetch_buffer_size=1, # type: int reader_num_threads=1, # type: int parser_num_threads=2, # type: int sloppy_ordering=False, # type: bool drop_final_batch=False # type: bool ): # type: (...) -> tf.data.Dataset if schema_path: feature_spec, _ = cls.parse_schema(schema_path) logger.debug("Will parse features from: `%s`, using features spec: %s", file_pattern, str(feature_spec)) from tensorflow.contrib.data import make_batched_features_dataset reader_args = [compression_type] if compression_type else None dataset = make_batched_features_dataset(file_pattern, batch_size=batch_size, features=feature_spec, reader_args=reader_args, num_epochs=num_epochs, shuffle=shuffle, shuffle_buffer_size=shuffle_buffer_size, shuffle_seed=shuffle_seed, prefetch_buffer_size=prefetch_buffer_size, reader_num_threads=reader_num_threads, parser_num_threads=parser_num_threads, sloppy_ordering=sloppy_ordering, drop_final_batch=drop_final_batch) return dataset