Python tensorflow.Example() Examples
The following are 30
code examples of tensorflow.Example().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: run_squad.py From BERT with Apache License 2.0 | 6 votes |
def process_feature(self, feature): """Write a InputFeature to the TFRecordWriter as a tf.train.Example.""" self.num_features += 1 def create_int_feature(values): feature = tf.train.Feature( int64_list=tf.train.Int64List(value=list(values))) return feature features = collections.OrderedDict() features["unique_ids"] = create_int_feature([feature.unique_id]) features["input_ids"] = create_int_feature(feature.input_ids) features["input_mask"] = create_int_feature(feature.input_mask) features["segment_ids"] = create_int_feature(feature.segment_ids) if self.is_training: features["start_positions"] = create_int_feature([feature.start_position]) features["end_positions"] = create_int_feature([feature.end_position]) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) self._writer.write(tf_example.SerializeToString())
Example #2
Source File: run_bert_open_qa_eval.py From XQA with MIT License | 6 votes |
def process_feature(self, feature): """Write a InputFeature to the TFRecordWriter as a tf.train.Example.""" self.num_features += 1 def create_int_feature(values): feature = tf.train.Feature( int64_list=tf.train.Int64List(value=list(values))) return feature features = collections.OrderedDict() features["unique_ids"] = create_int_feature([feature.unique_id]) features["input_ids"] = create_int_feature(feature.input_ids) features["input_mask"] = create_int_feature(feature.input_mask) features["segment_ids"] = create_int_feature(feature.segment_ids) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) self._writer.write(tf_example.SerializeToString())
Example #3
Source File: preprocessor.py From imitation-learning with MIT License | 6 votes |
def write_tfrecord_file(output_filepath, some_h5_files): """Write tf.Examples given a list of h5_files. Args: output_filepath: str some_h5_files: List[str] """ tf_record_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP) writer = tf.python_io.TFRecordWriter(output_filepath, options=tf_record_options) # Read a batch of h5 files for f in some_h5_files: tf_examples = list(read_h5_file(f)) # type: List[tf.Example] # Serialize to string tf_example_strs = map(lambda ex: ex.SerializeToString(), tf_examples) # Write for example_str in tf_example_strs: writer.write(example_str) writer.close()
Example #4
Source File: generator_utils.py From BERT with Apache License 2.0 | 6 votes |
def to_example(dictionary): """Helper: build tf.Example from (string -> int/float/str list) dictionary.""" features = {} for (k, v) in six.iteritems(dictionary): if not v: raise ValueError("Empty generated field: %s" % str((k, v))) if isinstance(v[0], six.integer_types): features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v)) elif isinstance(v[0], float): features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v)) elif isinstance(v[0], six.string_types): if not six.PY2: # Convert in python 3. v = [bytes(x, "utf-8") for x in v] features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v)) elif isinstance(v[0], bytes): features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v)) else: raise ValueError("Value for %s is not a recognized type; v: %s type: %s" % (k, str(v[0]), str(type(v[0])))) return tf.train.Example(features=tf.train.Features(feature=features))
Example #5
Source File: run_classifier_v2.py From wsdm19cup with MIT License | 6 votes |
def file_based_convert_examples_to_features( examples, label_list, max_seq_length, tokenizer, output_file): """Convert a set of `InputExample`s to a TFRecord file.""" writer = tf.python_io.TFRecordWriter(output_file) for (ex_index, example) in enumerate(examples): if ex_index % 10000 == 0: tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) feature = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer) def create_int_feature(values): f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) return f features = collections.OrderedDict() features["input_ids"] = create_int_feature(feature.input_ids) features["input_mask"] = create_int_feature(feature.input_mask) features["segment_ids"] = create_int_feature(feature.segment_ids) features["label_ids"] = create_int_feature([feature.label_id]) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) writer.write(tf_example.SerializeToString())
Example #6
Source File: unittest_utils.py From DOTA_models with Apache License 2.0 | 6 votes |
def create_serialized_example(name_to_values): """Creates a tf.Example proto using a dictionary. It automatically detects type of values and define a corresponding feature. Args: name_to_values: A dictionary. Returns: tf.Example proto. """ example = tf.train.Example() for name, values in name_to_values.items(): feature = example.features.feature[name] if isinstance(values[0], str): add = feature.bytes_list.value.extend elif isinstance(values[0], float): add = feature.float32_list.value.extend elif isinstance(values[0], int): add = feature.int64_list.value.extend else: raise AssertionError('Unsupported type: %s' % type(values[0])) add(values) return example.SerializeToString()
Example #7
Source File: batcher.py From TransferRL with MIT License | 6 votes |
def fill_example_queue(self): """Reads data from file and processes into Examples which are then placed into the example queue.""" input_gen = self.text_generator(data.example_generator(self._data_path, self._single_pass)) while True: try: (article, abstract) = input_gen.next() # read the next example from file. article and abstract are both strings. except StopIteration: # if there are no more examples: tf.logging.info("The example generator for this example queue filling thread has exhausted data.") if self._single_pass: tf.logging.info("single_pass mode is on, so we've finished reading dataset. This thread is stopping.") self._finished_reading = True break else: raise Exception("single_pass mode is off but the example generator is out of data; error.") abstract_sentences = [sent.strip() for sent in data.abstract2sents(abstract)] # Use the <s> and </s> tags in abstract to get a list of sentences. if abstract_sentences is None or len(abstract_sentences) == 0: continue example = Example(article, abstract_sentences, self._vocab, self._hps) # Process into an Example. self._example_queue.put(example) # place the Example in the example queue.
Example #8
Source File: run_squad.py From Extending-Google-BERT-as-Question-and-Answering-model-and-Chatbot with Apache License 2.0 | 6 votes |
def process_feature(self, feature): """Write a InputFeature to the TFRecordWriter as a tf.train.Example.""" self.num_features += 1 def create_int_feature(values): feature = tf.train.Feature( int64_list=tf.train.Int64List(value=list(values))) return feature features = collections.OrderedDict() features["unique_ids"] = create_int_feature([feature.unique_id]) features["input_ids"] = create_int_feature(feature.input_ids) features["input_mask"] = create_int_feature(feature.input_mask) features["segment_ids"] = create_int_feature(feature.segment_ids) if self.is_training: features["start_positions"] = create_int_feature([feature.start_position]) features["end_positions"] = create_int_feature([feature.end_position]) impossible = 0 if feature.is_impossible: impossible = 1 features["is_impossible"] = create_int_feature([impossible]) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) self._writer.write(tf_example.SerializeToString())
Example #9
Source File: run_classifier.py From wsdm19cup with MIT License | 6 votes |
def file_based_convert_examples_to_features( examples, label_list, max_seq_length, tokenizer, output_file): """Convert a set of `InputExample`s to a TFRecord file.""" writer = tf.python_io.TFRecordWriter(output_file) for (ex_index, example) in enumerate(examples): if ex_index % 10000 == 0: tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) feature = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer) def create_int_feature(values): f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) return f features = collections.OrderedDict() features["input_ids"] = create_int_feature(feature.input_ids) features["input_mask"] = create_int_feature(feature.input_mask) features["segment_ids"] = create_int_feature(feature.segment_ids) features["label_ids"] = create_int_feature([feature.label_id]) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) writer.write(tf_example.SerializeToString())
Example #10
Source File: generator_utils.py From fine-lm with MIT License | 6 votes |
def to_example(dictionary): """Helper: build tf.Example from (string -> int/float/str list) dictionary.""" features = {} for (k, v) in six.iteritems(dictionary): if not v: raise ValueError("Empty generated field: %s" % str((k, v))) if isinstance(v[0], six.integer_types): features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v)) elif isinstance(v[0], float): features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v)) elif isinstance(v[0], six.string_types): if not six.PY2: # Convert in python 3. v = [bytes(x, "utf-8") for x in v] features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v)) elif isinstance(v[0], bytes): features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v)) else: raise ValueError("Value for %s is not a recognized type; v: %s type: %s" % (k, str(v[0]), str(type(v[0])))) return tf.train.Example(features=tf.train.Features(feature=features))
Example #11
Source File: audio_records.py From Tensorflow-Audio-Classification with Apache License 2.0 | 6 votes |
def encodes_example(feature, label): """Encodes to TF Example Args: feature: feature to encode label: label corresponding to feature Returns: tf.Example object """ def _bytes_feature(value): """Creates a TensorFlow Record Feature with value as a byte array. """ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _int64_feature(value): """Creates a TensorFlow Record Feature with value as a 64 bit integer. """ return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) features = {AUDIO_FEATURE_NAME: _bytes_feature(feature.tobytes()), AUDIO_LABEL_NAME: _int64_feature(label)} return tf.train.Example(features=tf.train.Features(feature=features))
Example #12
Source File: batcher.py From RLSeq2Seq with MIT License | 6 votes |
def fill_example_queue(self): """Reads data from file and processes into Examples which are then placed into the example queue.""" input_gen = self.text_generator(data.example_generator(self._data_path, self._single_pass)) while True: try: (article, abstract) = input_gen.next() # read the next example from file. article and abstract are both strings. except StopIteration: # if there are no more examples: tf.logging.info("The example generator for this example queue filling thread has exhausted data.") if self._single_pass: tf.logging.info("single_pass mode is on, so we've finished reading dataset. This thread is stopping.") self._finished_reading = True break else: raise Exception("single_pass mode is off but the example generator is out of data; error.") abstract_sentences = [sent.strip() for sent in data.abstract2sents(abstract)] # Use the <s> and </s> tags in abstract to get a list of sentences. example = Example(article, abstract_sentences, self._vocab, self._hps) # Process into an Example. self._example_queue.put(example) # place the Example in the example queue.
Example #13
Source File: eval_oqmrc.py From BERT with Apache License 2.0 | 5 votes |
def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t for name in ["input_ids", "input_mask", "segment_ids"]: example[name] = tf.reshape(example[name], [-1, max_seq_length]) return example
Example #14
Source File: batcher.py From RLSeq2Seq with MIT License | 5 votes |
def fill_batch_queue(self): """Takes Examples out of example queue, sorts them by encoder sequence length, processes into Batches and places them in the batch queue. In decode mode, makes batches that each contain a single example repeated. """ while True: if self._hps.mode != 'decode': # Get bucketing_cache_size-many batches of Examples into a list, then sort inputs = [] for _ in range(self._hps.batch_size * self._bucketing_cache_size): inputs.append(self._example_queue.get()) inputs = sorted(inputs, key=lambda inp: inp.enc_len) # sort by length of encoder sequence # Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue. batches = [] for i in range(0, len(inputs), self._hps.batch_size): batches.append(inputs[i:i + self._hps.batch_size]) if not self._single_pass: shuffle(batches) for b in batches: # each b is a list of Example objects self._batch_queue.put(Batch(b, self._hps, self._vocab)) else: # beam search decode mode ex = self._example_queue.get() b = [ex for _ in range(self._hps.batch_size)] self._batch_queue.put(Batch(b, self._hps, self._vocab))
Example #15
Source File: test_oqmrc_final.py From BERT with Apache License 2.0 | 5 votes |
def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t for name in ["input_ids", "input_mask", "segment_ids"]: example[name] = tf.reshape(example[name], [-1, max_seq_length]) return example
Example #16
Source File: test_wsdm.py From BERT with Apache License 2.0 | 5 votes |
def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example
Example #17
Source File: official_wsdm_order.py From BERT with Apache License 2.0 | 5 votes |
def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example
Example #18
Source File: eval_wsdm_vib_order.py From BERT with Apache License 2.0 | 5 votes |
def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example
Example #19
Source File: run_classifier.py From BERT-Classification-Tutorial with Apache License 2.0 | 5 votes |
def file_based_convert_examples_to_features( examples, label_list, max_seq_length, tokenizer, output_file): """Convert a set of `InputExample`s to a TFRecord file.""" writer = tf.python_io.TFRecordWriter(output_file) label_map = {} for (i, label) in enumerate(sorted(label_list)): label_map[label] = i for (ex_index, example) in enumerate(examples): if ex_index % 10000 == 0: tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) feature = convert_single_example(ex_index, example, label_map, max_seq_length, tokenizer) def create_int_feature(values): f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) return f features = collections.OrderedDict() features["input_ids"] = create_int_feature(feature.input_ids) features["input_mask"] = create_int_feature(feature.input_mask) features["segment_ids"] = create_int_feature(feature.segment_ids) features["label_ids"] = create_int_feature([feature.label_id]) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) writer.write(tf_example.SerializeToString()) return label_map
Example #20
Source File: eval_oqmrc_test.py From BERT with Apache License 2.0 | 5 votes |
def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t for name in ["input_ids", "input_mask", "segment_ids"]: example[name] = tf.reshape(example[name], [-1, max_seq_length]) return example
Example #21
Source File: official_oqmrc_test.py From BERT with Apache License 2.0 | 5 votes |
def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t for name in ["input_ids", "input_mask", "segment_ids"]: example[name] = tf.reshape(example[name], [-1, max_seq_length]) return example
Example #22
Source File: test_oqmrc.py From BERT with Apache License 2.0 | 5 votes |
def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t for name in ["input_ids", "input_mask", "segment_ids"]: example[name] = tf.reshape(example[name], [-1, max_seq_length]) return example
Example #23
Source File: generator_utils.py From BERT with Apache License 2.0 | 5 votes |
def tfrecord_iterator(filenames, gzipped=False, example_spec=None): """Yields records from TFRecord files. Args: filenames: list<str>, list of TFRecord filenames to read from. gzipped: bool, whether the TFRecord files are gzip-encoded. example_spec: dict<str feature name, tf.VarLenFeature/tf.FixedLenFeature>, if provided, will parse each record as a tensorflow.Example proto. Yields: Records (or parsed Examples, if example_spec is provided) from files. """ with tf.Graph().as_default(): dataset = tf.data.Dataset.from_tensor_slices(filenames) def _load_records(filename): return tf.data.TFRecordDataset( filename, compression_type=tf.constant("GZIP") if gzipped else None, buffer_size=16 * 1000 * 1000) dataset = dataset.flat_map(_load_records) def _parse_example(ex_ser): return tf.parse_single_example(ex_ser, example_spec) if example_spec: dataset = dataset.map(_parse_example, num_parallel_calls=32) dataset = dataset.prefetch(100) record_it = dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: while True: try: ex = sess.run(record_it) yield ex except tf.errors.OutOfRangeError: break
Example #24
Source File: run_pretraining.py From clinicalBERT with MIT License | 5 votes |
def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example.""" example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example
Example #25
Source File: distributed_tf_data_utils.py From BERT with Apache License 2.0 | 5 votes |
def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. name_to_features = { "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64), "input_mask": tf.FixedLenFeature([max_seq_length], tf.int64), "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64), "masked_lm_positions": tf.FixedLenFeature([max_predictions_per_seq], tf.int64), "masked_lm_ids": tf.FixedLenFeature([max_predictions_per_seq], tf.int64), "masked_lm_weights": tf.FixedLenFeature([max_predictions_per_seq], tf.float32), "next_sentence_labels": tf.FixedLenFeature([1], tf.int64), } """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example
Example #26
Source File: batcher.py From RLSeq2Seq with MIT License | 5 votes |
def __init__(self, example_list, hps, vocab): """Turns the example_list into a Batch object. Args: example_list: List of Example objects hps: hyperparameters vocab: Vocabulary object """ self.pad_id = vocab.word2id(data.PAD_TOKEN) # id of the PAD token used to pad sequences self.init_encoder_seq(example_list, hps) # initialize the input to the encoder self.init_decoder_seq(example_list, hps) # initialize the input and targets for the decoder self.store_orig_strings(example_list) # store the original strings
Example #27
Source File: create_pet_tf_record.py From ros_people_object_detection_tensorflow with Apache License 2.0 | 5 votes |
def get_class_name_from_filename(file_name): """Gets the class name from a file. Args: file_name: The file name to get the class name from. ie. "american_pit_bull_terrier_105.jpg" Returns: example: The converted tf.Example. """ match = re.match(r'([A-Za-z_]+)(_[0-9]+\.jpg)', file_name, re.I) return match.groups()[0]
Example #28
Source File: run_classifier_v3.py From wsdm19cup with MIT License | 5 votes |
def file_based_convert_examples_to_features( examples, label_list, max_seq_length, tokenizer, output_file): """Convert a set of `InputExample`s to a TFRecord file.""" writer = tf.python_io.TFRecordWriter(output_file) for (ex_index, example) in enumerate(examples): if ex_index % 10000 == 0: tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) feature = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer) def create_int_feature(values): f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) return f def create_float_feature(values): f = tf.train.Feature(float_list=tf.train.FloatList(value=values)) return f features = collections.OrderedDict() features["input_ids"] = create_int_feature(feature.input_ids) features["input_mask"] = create_int_feature(feature.input_mask) features["segment_ids"] = create_int_feature(feature.segment_ids) features["label_ids"] = create_int_feature([feature.label_id]) features["input_extra_nums"] = create_float_feature(feature.input_extra_nums) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) writer.write(tf_example.SerializeToString())
Example #29
Source File: run_clmrc.py From Cross-Lingual-MRC with Apache License 2.0 | 5 votes |
def process_feature(self, feature): """Write a InputFeature to the TFRecordWriter as a tf.train.Example.""" self.num_features += 1 def create_int_feature(values): feature = tf.train.Feature( int64_list=tf.train.Int64List(value=list(values))) return feature features = collections.OrderedDict() features["unique_ids"] = create_int_feature([feature.unique_id]) features["input_ids"] = create_int_feature(feature.input_ids) features["input_mask"] = create_int_feature(feature.input_mask) features["segment_ids"] = create_int_feature(feature.segment_ids) features["input_span_mask"] = create_int_feature(feature.input_span_mask) features["source_input_span_mask"] = create_int_feature(feature.source_input_span_mask) features["source_input_ids"] = create_int_feature(feature.source_input_ids) features["source_input_mask"] = create_int_feature(feature.source_input_mask) features["source_segment_ids"] = create_int_feature(feature.source_segment_ids) #if self.is_training: features["start_positions"] = create_int_feature([feature.start_position]) features["end_positions"] = create_int_feature([feature.end_position]) features["source_start_positions"] = create_int_feature([feature.source_start_position]) features["source_end_positions"] = create_int_feature([feature.source_end_position]) features["output_span_mask"] = create_int_feature(feature.output_span_mask) features["source_output_span_mask"] = create_int_feature(feature.source_output_span_mask) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) self._writer.write(tf_example.SerializeToString())
Example #30
Source File: run_bert_open_qa_eval.py From XQA with MIT License | 5 votes |
def input_fn_builder(input_file, seq_length, drop_remainder): """Creates an `input_fn` closure to be passed to TPUEstimator.""" name_to_features = { "unique_ids": tf.FixedLenFeature([], tf.int64), "input_ids": tf.FixedLenFeature([seq_length], tf.int64), "input_mask": tf.FixedLenFeature([seq_length], tf.int64), "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), } def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example.""" example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example def input_fn(params): """The actual input function.""" batch_size = params["batch_size"] # For training, we want a lot of parallel reading and shuffling. # For eval, we want no shuffling and parallel reading doesn't matter. d = tf.data.TFRecordDataset(input_file) d = d.apply( tf.contrib.data.map_and_batch( lambda record: _decode_record(record, name_to_features), batch_size=batch_size, drop_remainder=drop_remainder)) return d return input_fn