Python tensorflow.VarLenFeature() Examples
The following are 30
code examples of tensorflow.VarLenFeature().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: readers.py From youtube-8m with Apache License 2.0 | 6 votes |
def prepare_reader(self, filename_queue, batch_size=1024): reader = tf.TFRecordReader() _, serialized_examples = reader.read_up_to(filename_queue, batch_size) # set the mapping from the fields to data types in the proto num_features = len(self.feature_names) assert num_features > 0, "self.feature_names is empty!" assert len(self.feature_names) == len(self.feature_sizes), \ "length of feature_names (={}) != length of feature_sizes (={})".format( \ len(self.feature_names), len(self.feature_sizes)) feature_map = {"video_id": tf.FixedLenFeature([], tf.string), "labels": tf.VarLenFeature(tf.int64)} for feature_index in range(num_features): feature_map[self.feature_names[feature_index]] = tf.FixedLenFeature( [self.feature_sizes[feature_index]], tf.float32) features = tf.parse_example(serialized_examples, features=feature_map) labels = tf.sparse_to_indicator(features["labels"], self.num_classes) labels.set_shape([None, self.num_classes]) concatenated_features = tf.concat([ features[feature_name] for feature_name in self.feature_names], 1) return features["video_id"], concatenated_features, labels, tf.ones([tf.shape(serialized_examples)[0]])
Example #2
Source File: dataset_test.py From spotify-tensorflow with Apache License 2.0 | 6 votes |
def _write_test_data(): schema = feature_spec_to_schema({"f0": tf.VarLenFeature(dtype=tf.int64), "f1": tf.VarLenFeature(dtype=tf.int64), "f2": tf.VarLenFeature(dtype=tf.int64)}) batches = [ [1, 4, None], [2, None, None], [3, 5, None], [None, None, None], ] example_proto = [example_pb2.Example(features=feature_pb2.Features(feature={ "f" + str(i): feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=[f])) for i, f in enumerate(batch) if f is not None })) for batch in batches] return DataUtil.write_test_data(example_proto, schema)
Example #3
Source File: movielens.py From cloudml-samples with Apache License 2.0 | 6 votes |
def _make_schema(columns, types, default_values): """Input schema definition. Args: columns: column names for fields appearing in input. types: column types for fields appearing in input. default_values: default values for fields appearing in input. Returns: feature_set dictionary of string to *Feature. """ result = {} assert len(columns) == len(types) assert len(columns) == len(default_values) for c, t, v in zip(columns, types, default_values): if isinstance(t, list): result[c] = tf.VarLenFeature(dtype=t[0]) else: result[c] = tf.FixedLenFeature(shape=[], dtype=t, default_value=v) return dataset_schema.from_feature_spec(result)
Example #4
Source File: word2vec.py From tensorflow_nlp with Apache License 2.0 | 6 votes |
def load_tfrecord(self): opts = self._options file_names = glob(opts.train_dir + '/output.tfrecord') file_queue = tf.train.string_input_producer(file_names, num_epochs=opts.epochs_to_train) reader = tf.TFRecordReader() _, record_string = reader.read(file_queue) features = {'sentence': tf.VarLenFeature(tf.int64)} one_line_example = tf.parse_single_example(record_string, features=features) capacity = PRELOAD_LINES batch_lines = tf.train.batch(one_line_example, batch_size=BATCH_LINES, capacity=capacity, num_threads=opts.io_threads) corpus_slice = batch_lines['sentence'].values return corpus_slice
Example #5
Source File: read_tfrecord.py From 2019-CCF-BDCI-OCR-MCZJ-OCR-IdentificationIDElement with MIT License | 6 votes |
def _extract_features_batch(self, serialized_batch): features = tf.parse_example( serialized_batch, features={'images': tf.FixedLenFeature([], tf.string), 'imagepaths': tf.FixedLenFeature([], tf.string), 'labels': tf.VarLenFeature(tf.int64), }) bs = features['images'].shape[0] images = tf.decode_raw(features['images'], tf.uint8) w, h = tuple(CFG.ARCH.INPUT_SIZE) images = tf.cast(x=images, dtype=tf.float32) #images = tf.subtract(tf.divide(images, 128.0), 1.0) images = tf.reshape(images, [bs, h, -1, CFG.ARCH.INPUT_CHANNELS]) labels = features['labels'] labels = tf.cast(labels, tf.int32) imagepaths = features['imagepaths'] return images, labels, imagepaths
Example #6
Source File: readtf.py From udacity-driving-reader with Apache License 2.0 | 6 votes |
def example_parser(example_serialized): feature_map = { 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 'image/timestamp': tf.FixedLenFeature([], dtype=tf.int64, default_value=-1), 'steer/angle': tf.FixedLenFeature([2], dtype=tf.float32, default_value=[0.0, 0.0]), 'steer/timestamp': tf.FixedLenFeature([2], dtype=tf.int64, default_value=[-1, -1]), #'gps/lat': tf.FixedLenFeature([2], dtype=tf.float32, default_value=[0.0, 0.00]), #'gps/long': tf.FixedLenFeature([2], dtype=tf.float32, default_value=[0.0, 0.0]), #'gps/timestamp': tf.VarLenFeature(tf.int64), } features = tf.parse_single_example(example_serialized, feature_map) image_timestamp = tf.cast(features['image/timestamp'], dtype=tf.int64) steering_angles = features['steer/angle'] steering_timestamps = features['steer/timestamp'] return features['image/encoded'], image_timestamp, steering_angles, steering_timestamps
Example #7
Source File: readers.py From Youtube-8M-WILLOW with Apache License 2.0 | 6 votes |
def prepare_serialized_examples(self, serialized_examples): # set the mapping from the fields to data types in the proto num_features = len(self.feature_names) assert num_features > 0, "self.feature_names is empty!" assert len(self.feature_names) == len(self.feature_sizes), \ "length of feature_names (={}) != length of feature_sizes (={})".format( \ len(self.feature_names), len(self.feature_sizes)) feature_map = {"video_id": tf.FixedLenFeature([], tf.string), "labels": tf.VarLenFeature(tf.int64)} for feature_index in range(num_features): feature_map[self.feature_names[feature_index]] = tf.FixedLenFeature( [self.feature_sizes[feature_index]], tf.float32) features = tf.parse_example(serialized_examples, features=feature_map) labels = tf.sparse_to_indicator(features["labels"], self.num_classes) labels.set_shape([None, self.num_classes]) concatenated_features = tf.concat([ features[feature_name] for feature_name in self.feature_names], 1) return features["video_id"], concatenated_features, labels, tf.ones([tf.shape(serialized_examples)[0]])
Example #8
Source File: ilsvrc12_dataset.py From tf-hrnet with BSD 3-Clause "New" or "Revised" License | 5 votes |
def parse_example_proto(example_serialized): """Parse image buffer, label, and bounding box from the serialized data. Args: * example_serialized: serialized example data Returns: * image_buffer: image buffer label * label: label tensor (not one-hot) * bbox: bounding box tensor """ # parse features from the serialized data feature_map = { 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64, default_value=-1), 'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''), } bbox_keys = ['image/object/bbox/' + x for x in ['xmin', 'ymin', 'xmax', 'ymax']] feature_map.update({key: tf.VarLenFeature(dtype=tf.float32) for key in bbox_keys}) features = tf.parse_single_example(example_serialized, feature_map) # obtain the label and bounding boxes label = tf.cast(features['image/class/label'], dtype=tf.int32) xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0) ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0) xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0) ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0) # Note that we impose an ordering of (y, x) just to make life difficult. bbox = tf.concat(axis=0, values=[ymin, xmin, ymax, xmax]) bbox = tf.expand_dims(bbox, 0) bbox = tf.transpose(bbox, [0, 2, 1]) return features['image/encoded'], label, bbox
Example #9
Source File: generator_utils.py From BERT with Apache License 2.0 | 5 votes |
def tfrecord_iterator(filenames, gzipped=False, example_spec=None): """Yields records from TFRecord files. Args: filenames: list<str>, list of TFRecord filenames to read from. gzipped: bool, whether the TFRecord files are gzip-encoded. example_spec: dict<str feature name, tf.VarLenFeature/tf.FixedLenFeature>, if provided, will parse each record as a tensorflow.Example proto. Yields: Records (or parsed Examples, if example_spec is provided) from files. """ with tf.Graph().as_default(): dataset = tf.data.Dataset.from_tensor_slices(filenames) def _load_records(filename): return tf.data.TFRecordDataset( filename, compression_type=tf.constant("GZIP") if gzipped else None, buffer_size=16 * 1000 * 1000) dataset = dataset.flat_map(_load_records) def _parse_example(ex_ser): return tf.parse_single_example(ex_ser, example_spec) if example_spec: dataset = dataset.map(_parse_example, num_parallel_calls=32) dataset = dataset.prefetch(100) record_it = dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: while True: try: ex = sess.run(record_it) yield ex except tf.errors.OutOfRangeError: break
Example #10
Source File: imagenet_main.py From yolo_v2 with Apache License 2.0 | 5 votes |
def record_parser(value, is_training): """Parse an ImageNet record from `value`.""" keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 'image/class/label': tf.FixedLenFeature([], dtype=tf.int64, default_value=-1), 'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 'image/object/class/label': tf.VarLenFeature(dtype=tf.int64), } parsed = tf.parse_single_example(value, keys_to_features) image = tf.image.decode_image( tf.reshape(parsed['image/encoded'], shape=[]), _NUM_CHANNELS) image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = vgg_preprocessing.preprocess_image( image=image, output_height=_DEFAULT_IMAGE_SIZE, output_width=_DEFAULT_IMAGE_SIZE, is_training=is_training) label = tf.cast( tf.reshape(parsed['image/class/label'], shape=[]), dtype=tf.int32) return image, tf.one_hot(label, _LABEL_CLASSES)
Example #11
Source File: TrainLSP.py From deeppose with GNU General Public License v3.0 | 5 votes |
def read_and_decode(filename_queue): reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) # The serialized example is converted back to actual values. # One needs to describe the format of the objects to be returned features = tf.parse_single_example( serialized_example, features={ # We know the length of both fields. If not the # tf.VarLenFeature could be used 'label': tf.FixedLenFeature([LSPGlobals.TotalLabels], tf.int64), 'image_raw': tf.FixedLenFeature([], tf.string) }) # now return the converted data image_as_vector = tf.decode_raw(features['image_raw'], tf.uint8) image_as_vector.set_shape([LSPGlobals.TotalImageBytes]) image = tf.reshape(image_as_vector, [FLAGS.input_size, FLAGS.input_size, FLAGS.input_depth]) # Convert from [0, 255] -> [-0.5, 0.5] floats. image_float = tf.cast(image, tf.float32) * (1. / 255) - 0.5 # Convert label from a scalar uint8 tensor to an int32 scalar. label = tf.cast(features['label'], tf.int32) return label, image_float
Example #12
Source File: text_problems.py From BERT with Apache License 2.0 | 5 votes |
def example_reading_spec(self): data_fields, data_items_to_decoders = (super(QuestionAndContext2TextProblem, self) .example_reading_spec()) data_fields["context"] = tf.VarLenFeature(tf.int64) return (data_fields, data_items_to_decoders)
Example #13
Source File: input_ops.py From yolo_v2 with Apache License 2.0 | 5 votes |
def parse_example_batch(serialized): """Parses a batch of tf.Example protos. Args: serialized: A 1-D string Tensor; a batch of serialized tf.Example protos. Returns: encode: A SentenceBatch of encode sentences. decode_pre: A SentenceBatch of "previous" sentences to decode. decode_post: A SentenceBatch of "post" sentences to decode. """ features = tf.parse_example( serialized, features={ "encode": tf.VarLenFeature(dtype=tf.int64), "decode_pre": tf.VarLenFeature(dtype=tf.int64), "decode_post": tf.VarLenFeature(dtype=tf.int64), }) def _sparse_to_batch(sparse): ids = tf.sparse_tensor_to_dense(sparse) # Padding with zeroes. mask = tf.sparse_to_dense(sparse.indices, sparse.dense_shape, tf.ones_like(sparse.values, dtype=tf.int32)) return SentenceBatch(ids=ids, mask=mask) output_names = ("encode", "decode_pre", "decode_post") return tuple(_sparse_to_batch(features[x]) for x in output_names)
Example #14
Source File: swivel.py From yolo_v2 with Apache License 2.0 | 5 votes |
def _count_matrix_input(self, filenames, submatrix_rows, submatrix_cols): """Creates ops that read submatrix shards from disk.""" random.shuffle(filenames) filename_queue = tf.train.string_input_producer(filenames) reader = tf.WholeFileReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'global_row': tf.FixedLenFeature([submatrix_rows], dtype=tf.int64), 'global_col': tf.FixedLenFeature([submatrix_cols], dtype=tf.int64), 'sparse_local_row': tf.VarLenFeature(dtype=tf.int64), 'sparse_local_col': tf.VarLenFeature(dtype=tf.int64), 'sparse_value': tf.VarLenFeature(dtype=tf.float32) }) global_row = features['global_row'] global_col = features['global_col'] sparse_local_row = features['sparse_local_row'].values sparse_local_col = features['sparse_local_col'].values sparse_count = features['sparse_value'].values sparse_indices = tf.concat( axis=1, values=[tf.expand_dims(sparse_local_row, 1), tf.expand_dims(sparse_local_col, 1)]) count = tf.sparse_to_dense(sparse_indices, [submatrix_rows, submatrix_cols], sparse_count) return global_row, global_col, count
Example #15
Source File: data_reader_test.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def example_reading_spec(self): data_fields = { "inputs": tf.VarLenFeature(tf.int64), "targets": tf.VarLenFeature(tf.int64), "floats": tf.VarLenFeature(tf.float32), } data_items_to_decoders = None return (data_fields, data_items_to_decoders)
Example #16
Source File: fsns.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def example_reading_spec(self): label_key = "image/unpadded_label" data_fields, data_items_to_decoders = ( super(ImageFSNS, self).example_reading_spec()) data_fields[label_key] = tf.VarLenFeature(tf.int64) data_items_to_decoders[ "targets"] = tf.contrib.slim.tfexample_decoder.Tensor(label_key) return data_fields, data_items_to_decoders
Example #17
Source File: data.py From dket with GNU General Public License v3.0 | 5 votes |
def parse(serialized): """Parse a serialized string into tensors. Arguments: example: a serialized `tf.train.SequenceExample` (like the one returned from the `encode()` method). Returns: a tuple of 4 tensors: `words`: 1D tensor of shape [sentence_length]. `sentence_length`: 0D tesnor (i.e. scalar) representing the sentence length. `formula`: 1D tensor of shape [formula_length]. `formula_length`: a 0D tensor (i.e. scalar) representing the formula length """ features = { SENTENCE_LENGTH_KEY: tf.FixedLenFeature([], tf.int64), FORMULA_LENGTH_KEY: tf.FixedLenFeature([], tf.int64), WORDS_KEY: tf.VarLenFeature(tf.int64), FORMULA_KEY: tf.VarLenFeature(tf.int64), } parsed = tf.parse_single_example( serialized=serialized, features=features) sentence_length = parsed[SENTENCE_LENGTH_KEY] formula_length = parsed[FORMULA_LENGTH_KEY] words = tf.sparse_tensor_to_dense(parsed[WORDS_KEY]) formula = tf.sparse_tensor_to_dense(parsed[FORMULA_KEY]) return words, sentence_length, formula, formula_length
Example #18
Source File: imagenet_input.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def dataset_parser(self, value): """Parses an image and its label from a serialized ResNet-50 TFExample. Args: value: serialized string containing an ImageNet TFExample. Returns: Returns a tuple of (image, label) from the TFExample. """ keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, ''), 'image/format': tf.FixedLenFeature((), tf.string, 'jpeg'), 'image/class/label': tf.FixedLenFeature([], tf.int64, -1), 'image/class/text': tf.FixedLenFeature([], tf.string, ''), 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 'image/object/class/label': tf.VarLenFeature(dtype=tf.int64), } parsed = tf.parse_single_example(value, keys_to_features) image_bytes = tf.reshape(parsed['image/encoded'], shape=[]) image = self.image_preprocessing_fn( image_bytes=image_bytes, is_training=self.is_training, image_size=self.image_size, use_bfloat16=self.use_bfloat16) # Subtract one so that labels are in [0, 1000). label = tf.cast( tf.reshape(parsed['image/class/label'], shape=[]), dtype=tf.int32) - 1 return image, label
Example #19
Source File: gene_expression.py From BERT with Apache License 2.0 | 5 votes |
def example_reading_spec(self): data_fields = { "inputs": tf.VarLenFeature(tf.int64), "targets": tf.VarLenFeature(tf.float32), } data_items_to_decoders = None return (data_fields, data_items_to_decoders)
Example #20
Source File: text_problems.py From BERT with Apache License 2.0 | 5 votes |
def example_reading_spec(self): data_fields = { "inputs": tf.VarLenFeature(tf.int64), "targets": tf.FixedLenFeature([1], tf.int64), } data_items_to_decoders = None return (data_fields, data_items_to_decoders)
Example #21
Source File: hico.py From CVTron with Apache License 2.0 | 5 votes |
def get_split(split_name, dataset_dir, file_pattern=None, reader=None): if split_name not in _SPLITS_TO_SIZES: raise ValueError('split name %s was not recognized.' % split_name) if not file_pattern: file_pattern = _FILE_PATTERN file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # Allowing None in the signature so that dataset_factory can use the default. if reader is None: reader = tf.TFRecordReader # Features in HICO TFRecords keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 'image/class/label': tf.FixedLenFeature([], tf.string), #'image/class/object': tf.VarLenFeature(dtype=tf.string), #'image/class/verb': tf.VarLenFeature(dtype=tf.string), } items_to_handlers = { 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 'label': slim.tfexample_decoder.Tensor('image/class/label'), #'object': slim.tfexample_decoder.Tensor('image/class/object'), #'verb': slim.tfexample_decoder.Tensor('image/class/verb'), } decoder = slim.tfexample_decoder.TFExampleDecoder( keys_to_features, items_to_handlers) labels_to_names = read_label_file(_LABELS_FILENAME) return slim.dataset.Dataset( data_sources=file_pattern, reader=reader, decoder=decoder, num_samples=_SPLITS_TO_SIZES[split_name], items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, num_classes=_NUM_CLASSES, labels_to_names=labels_to_names)
Example #22
Source File: speech_recognition.py From BERT with Apache License 2.0 | 5 votes |
def example_reading_spec(self): data_fields = { "waveforms": tf.VarLenFeature(tf.float32), "targets": tf.VarLenFeature(tf.int64), } data_items_to_decoders = None return data_fields, data_items_to_decoders
Example #23
Source File: problem.py From BERT with Apache License 2.0 | 5 votes |
def example_reading_spec(self): """Define how data is serialized to file and read back. Returns: data_fields: A dictionary mapping data names to its feature type. data_items_to_decoders: A dictionary mapping data names to TF Example decoders, to be used when reading back TF examples from disk. """ data_fields = { "inputs": tf.VarLenFeature(tf.int64), "targets": tf.VarLenFeature(tf.int64) } data_items_to_decoders = None return (data_fields, data_items_to_decoders)
Example #24
Source File: image_utils.py From BERT with Apache License 2.0 | 5 votes |
def example_reading_spec(self): label_key = "image/class/label" data_fields, data_items_to_decoders = ( super(Image2TextProblem, self).example_reading_spec()) data_fields[label_key] = tf.VarLenFeature(tf.int64) data_items_to_decoders[ "targets"] = tf.contrib.slim.tfexample_decoder.Tensor(label_key) return data_fields, data_items_to_decoders
Example #25
Source File: timeseries.py From BERT with Apache License 2.0 | 5 votes |
def example_reading_spec(self): data_fields = { "inputs": tf.VarLenFeature(tf.float32), "targets": tf.VarLenFeature(tf.float32), } data_items_to_decoders = None return (data_fields, data_items_to_decoders)
Example #26
Source File: problem_hparams.py From BERT with Apache License 2.0 | 5 votes |
def example_reading_spec(self): data_fields = { "inputs": tf.VarLenFeature(tf.int64), "audio/sample_count": tf.FixedLenFeature((), tf.int64), "audio/sample_width": tf.FixedLenFeature((), tf.int64), "targets": tf.VarLenFeature(tf.int64), } return data_fields, None
Example #27
Source File: wikisum.py From BERT with Apache License 2.0 | 5 votes |
def example_reading_spec(self): data_fields = { "inputs": tf.VarLenFeature(tf.int64), "targets": tf.VarLenFeature(tf.int64), "section_boundaries": tf.VarLenFeature(tf.int64), } data_items_to_decoders = None return (data_fields, data_items_to_decoders)
Example #28
Source File: translate.py From BERT with Apache License 2.0 | 5 votes |
def example_reading_spec(self): data_fields = {"dist_targets": tf.VarLenFeature(tf.int64)} if self.has_inputs: data_fields["inputs"] = tf.VarLenFeature(tf.int64) # hack: ignoring true targets and putting dist_targets in targets data_items_to_decoders = { "inputs": tf.contrib.slim.tfexample_decoder.Tensor("inputs"), "targets": tf.contrib.slim.tfexample_decoder.Tensor("dist_targets"), } return (data_fields, data_items_to_decoders)
Example #29
Source File: fsns.py From BERT with Apache License 2.0 | 5 votes |
def example_reading_spec(self): label_key = "image/unpadded_label" data_fields, data_items_to_decoders = ( super(ImageFSNS, self).example_reading_spec()) data_fields[label_key] = tf.VarLenFeature(tf.int64) data_items_to_decoders[ "targets"] = tf.contrib.slim.tfexample_decoder.Tensor(label_key) return data_fields, data_items_to_decoders
Example #30
Source File: input_ops.py From DOTA_models with Apache License 2.0 | 5 votes |
def parse_example_batch(serialized): """Parses a batch of tf.Example protos. Args: serialized: A 1-D string Tensor; a batch of serialized tf.Example protos. Returns: encode: A SentenceBatch of encode sentences. decode_pre: A SentenceBatch of "previous" sentences to decode. decode_post: A SentenceBatch of "post" sentences to decode. """ features = tf.parse_example( serialized, features={ "encode": tf.VarLenFeature(dtype=tf.int64), "decode_pre": tf.VarLenFeature(dtype=tf.int64), "decode_post": tf.VarLenFeature(dtype=tf.int64), }) def _sparse_to_batch(sparse): ids = tf.sparse_tensor_to_dense(sparse) # Padding with zeroes. mask = tf.sparse_to_dense(sparse.indices, sparse.dense_shape, tf.ones_like(sparse.values, dtype=tf.int32)) return SentenceBatch(ids=ids, mask=mask) output_names = ("encode", "decode_pre", "decode_post") return tuple(_sparse_to_batch(features[x]) for x in output_names)