Python preprocessing.shuffle_tf_examples() Examples

The following are 11 code examples of preprocessing.shuffle_tf_examples(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module preprocessing , or try the search function .
Example #1
Source File:    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def test_serialize_round_trip_no_parse(self):
        raw_data = self.create_random_data(10)
        tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

        with tempfile.NamedTemporaryFile() as start_file, \
                tempfile.NamedTemporaryFile() as rewritten_file:
            preprocessing.write_tf_examples(, tfexamples)
            # We want to test that the rewritten, shuffled file contains correctly
            # serialized tf.Examples.
            batch_size = 4
            batches = list(preprocessing.shuffle_tf_examples(
                batch_size, []))
            # 2 batches of 4, 1 incomplete batch of 2.
            self.assertEqual(len(batches), 3)

            # concatenate list of lists into one list
            all_batches = list(itertools.chain.from_iterable(batches))

            for batch in batches:
          , all_batches, serialize=False)

            original_data = self.extract_data(
            recovered_data = self.extract_data(

        # stuff is shuffled, so sort before checking equality
        def sort_key(nparray_tuple): return nparray_tuple[2]
        original_data = sorted(original_data, key=sort_key)
        recovered_data = sorted(recovered_data, key=sort_key)

        self.assertEqualData(original_data, recovered_data) 
Example #2
Source File:    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def test_serialize_round_trip_no_parse(self):
        raw_data = self.create_random_data(10)
        tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

        with tempfile.NamedTemporaryFile() as start_file, \
                tempfile.NamedTemporaryFile() as rewritten_file:
            preprocessing.write_tf_examples(, tfexamples)
            # We want to test that the rewritten, shuffled file contains correctly
            # serialized tf.Examples.
            batch_size = 4
            batches = list(preprocessing.shuffle_tf_examples(
                batch_size, []))
            # 2 batches of 4, 1 incomplete batch of 2.
            self.assertEqual(len(batches), 3)

            # concatenate list of lists into one list
            all_batches = list(itertools.chain.from_iterable(batches))

            for batch in batches:
          , all_batches, serialize=False)

            original_data = self.extract_data(
            recovered_data = self.extract_data(

        # stuff is shuffled, so sort before checking equality
        def sort_key(nparray_tuple): return nparray_tuple[2]
        original_data = sorted(original_data, key=sort_key)
        recovered_data = sorted(recovered_data, key=sort_key)

        self.assertEqualData(original_data, recovered_data) 
Example #3
Source File:    From Gun-Detector with Apache License 2.0 5 votes vote down vote up
def test_serialize_round_trip_no_parse(self):
    raw_data = self.create_random_data(10)
    tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

    with tempfile.NamedTemporaryFile() as start_file, \
        tempfile.NamedTemporaryFile() as rewritten_file:
      preprocessing.write_tf_examples(, tfexamples)
      # We want to test that the rewritten, shuffled file contains correctly
      # serialized tf.Examples.
      batch_size = 4
      batches = list(preprocessing.shuffle_tf_examples(
          1000, batch_size, []))
      # 2 batches of 4, 1 incomplete batch of 2.
      self.assertEqual(len(batches), 3)

      # concatenate list of lists into one list
      all_batches = list(itertools.chain.from_iterable(batches))

      for _ in batches:
  , all_batches, serialize=False)

      original_data = self.extract_data(
      recovered_data = self.extract_data(

    # stuff is shuffled, so sort before checking equality
    def sort_key(nparray_tuple):
      return nparray_tuple[2]
    original_data = sorted(original_data, key=sort_key)
    recovered_data = sorted(recovered_data, key=sort_key)

    self.assertEqualData(original_data, recovered_data) 
Example #4
Source File:    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def test_serialize_round_trip_no_parse(self):
    raw_data = self.create_random_data(10)
    tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

    with tempfile.NamedTemporaryFile() as start_file, \
        tempfile.NamedTemporaryFile() as rewritten_file:
      preprocessing.write_tf_examples(, tfexamples)
      # We want to test that the rewritten, shuffled file contains correctly
      # serialized tf.Examples.
      batch_size = 4
      batches = list(preprocessing.shuffle_tf_examples(
          1000, batch_size, []))
      # 2 batches of 4, 1 incomplete batch of 2.
      self.assertEqual(len(batches), 3)

      # concatenate list of lists into one list
      all_batches = list(itertools.chain.from_iterable(batches))

      for _ in batches:
  , all_batches, serialize=False)

      original_data = self.extract_data(
      recovered_data = self.extract_data(

    # stuff is shuffled, so sort before checking equality
    def sort_key(nparray_tuple):
      return nparray_tuple[2]
    original_data = sorted(original_data, key=sort_key)
    recovered_data = sorted(recovered_data, key=sort_key)

    self.assertEqualData(original_data, recovered_data) 
Example #5
Source File:    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def test_serialize_round_trip_no_parse(self):
    raw_data = self.create_random_data(10)
    tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

    with tempfile.NamedTemporaryFile() as start_file, \
        tempfile.NamedTemporaryFile() as rewritten_file:
      preprocessing.write_tf_examples(, tfexamples)
      # We want to test that the rewritten, shuffled file contains correctly
      # serialized tf.Examples.
      batch_size = 4
      batches = list(preprocessing.shuffle_tf_examples(
          1000, batch_size, []))
      # 2 batches of 4, 1 incomplete batch of 2.
      self.assertEqual(len(batches), 3)

      # concatenate list of lists into one list
      all_batches = list(itertools.chain.from_iterable(batches))

      for _ in batches:
  , all_batches, serialize=False)

      original_data = self.extract_data(
      recovered_data = self.extract_data(

    # stuff is shuffled, so sort before checking equality
    def sort_key(nparray_tuple):
      return nparray_tuple[2]
    original_data = sorted(original_data, key=sort_key)
    recovered_data = sorted(recovered_data, key=sort_key)

    self.assertEqualData(original_data, recovered_data) 
Example #6
Source File:    From training_results_v0.5 with Apache License 2.0 4 votes vote down vote up
def gather(
        input_directory: 'where to look for games'='data/selfplay/',
        output_directory: 'where to put collected games'='data/training_chunks/',
        examples_per_record: 'how many tf.examples to gather in each chunk'=EXAMPLES_PER_RECORD):
    models = [model_dir.strip('/')
              for model_dir in sorted(gfile.ListDirectory(input_directory))[-50:]]
    with timer("Finding existing tfrecords..."):
        model_gamedata = {
            model: gfile.Glob(
                os.path.join(input_directory, model, '*.tfrecord.zz'))
            for model in models
    print("Found %d models" % len(models))
    for model_name, record_files in sorted(model_gamedata.items()):
        print("    %s: %s files" % (model_name, len(record_files)))

    meta_file = os.path.join(output_directory, 'meta.txt')
        with gfile.GFile(meta_file, 'r') as f:
            already_processed = set(
    except tf.errors.NotFoundError:
        already_processed = set()

    num_already_processed = len(already_processed)

    for model_name, record_files in sorted(model_gamedata.items()):
        if set(record_files) <= already_processed:
        print("Gathering files for %s:" % model_name)
        for i, example_batch in enumerate(
                tqdm(preprocessing.shuffle_tf_examples(examples_per_record, record_files))):
            output_record = os.path.join(output_directory,
                                         '{}-{}.tfrecord.zz'.format(model_name, str(i)))
                output_record, example_batch, serialize=False)

    print("Processed %s new files" %
          (len(already_processed) - num_already_processed))
    with gfile.GFile(meta_file, 'w') as f:
Example #7
Source File:    From training_results_v0.5 with Apache License 2.0 4 votes vote down vote up
def gather(
        input_directory: 'where to look for games'='data/selfplay/',
        output_directory: 'where to put collected games'='data/training_chunks/',
        examples_per_record: 'how many tf.examples to gather in each chunk'=EXAMPLES_PER_RECORD):
    models = [model_dir.strip('/')
              for model_dir in sorted(gfile.ListDirectory(input_directory))[-50:]]
    with timer("Finding existing tfrecords..."):
        model_gamedata = {
            model: gfile.Glob(
                os.path.join(input_directory, model, '*.tfrecord.zz'))
            for model in models
    print("Found %d models" % len(models))
    for model_name, record_files in sorted(model_gamedata.items()):
        print("    %s: %s files" % (model_name, len(record_files)))

    meta_file = os.path.join(output_directory, 'meta.txt')
        with gfile.GFile(meta_file, 'r') as f:
            already_processed = set(
    except tf.errors.NotFoundError:
        already_processed = set()

    num_already_processed = len(already_processed)

    for model_name, record_files in sorted(model_gamedata.items()):
        if set(record_files) <= already_processed:
        print("Gathering files for %s:" % model_name)
        for i, example_batch in enumerate(
                tqdm(preprocessing.shuffle_tf_examples(examples_per_record, record_files))):
            output_record = os.path.join(output_directory,
                                         '{}-{}.tfrecord.zz'.format(model_name, str(i)))
                output_record, example_batch, serialize=False)

    print("Processed %s new files" %
          (len(already_processed) - num_already_processed))
    with gfile.GFile(meta_file, 'w') as f:
Example #8
Source File:    From Gun-Detector with Apache License 2.0 4 votes vote down vote up
def gather(selfplay_dir, training_chunk_dir, params):
  """Gather selfplay data into large training chunk.

    selfplay_dir: Where to look for games. Set as 'base_dir/data/selfplay/'.
    training_chunk_dir: where to put collected games. Set as
    params: An object of hyperparameters for the model.
  # Check the selfplay data from the most recent 50 models.
  sorted_model_dirs = sorted(tf.gfile.ListDirectory(selfplay_dir))
  models = [model_dir.strip('/')
            for model_dir in sorted_model_dirs[-params.gather_generation:]]

  with utils.logged_timer('Finding existing tfrecords...'):
    model_gamedata = {
        model: tf.gfile.Glob(
            os.path.join(selfplay_dir, model, '*'+_TF_RECORD_SUFFIX))
        for model in models
  print('Found {} models'.format(len(models)))
  for model_name, record_files in sorted(model_gamedata.items()):
    print('    {}: {} files'.format(model_name, len(record_files)))

  meta_file = os.path.join(training_chunk_dir, 'meta.txt')
    with tf.gfile.GFile(meta_file, 'r') as f:
      already_processed = set(
  except tf.errors.NotFoundError:
    already_processed = set()

  num_already_processed = len(already_processed)

  for model_name, record_files in sorted(model_gamedata.items()):
    if set(record_files) <= already_processed:
    print('Gathering files from {}:'.format(model_name))
    tf_examples = preprocessing.shuffle_tf_examples(
        params.shuffle_buffer_size, params.examples_per_chunk, record_files)
    # tqdm to make the loops show a smart progress meter
    for i, example_batch in enumerate(tf_examples):
      output_record = os.path.join(
          ('{}-{}'+_TF_RECORD_SUFFIX).format(model_name, str(i)))
          output_record, example_batch, serialize=False)

  print('Processed {} new files'.format(
      len(already_processed) - num_already_processed))
  with tf.gfile.GFile(meta_file, 'w') as f:
Example #9
Source File:    From Python-Reinforcement-Learning-Projects with MIT License 4 votes vote down vote up
def aggregate():"Gathering game results")

    os.makedirs(PATHS.TRAINING_CHUNK_DIR, exist_ok=True)
    os.makedirs(PATHS.SELFPLAY_DIR, exist_ok=True)
    models = [model_dir.strip('/')
              for model_dir in sorted(gfile.ListDirectory(PATHS.SELFPLAY_DIR))[-50:]]

    with timer("Finding existing tfrecords..."):
        model_gamedata = {
            model: gfile.Glob(
                os.path.join(PATHS.SELFPLAY_DIR, model, '*.zz'))
            for model in models
        }"Found %d models" % len(models))
    for model_name, record_files in sorted(model_gamedata.items()):"    %s: %s files" % (model_name, len(record_files)))

    meta_file = os.path.join(PATHS.TRAINING_CHUNK_DIR, 'meta.txt')
        with gfile.GFile(meta_file, 'r') as f:
            already_processed = set(
    except tf.errors.NotFoundError:
        already_processed = set()

    num_already_processed = len(already_processed)

    for model_name, record_files in sorted(model_gamedata.items()):
        if set(record_files) <= already_processed:
            continue"Gathering files for %s:" % model_name)
        for i, example_batch in enumerate(
                tqdm(preprocessing.shuffle_tf_examples(GLOBAL_PARAMETER_STORE.EXAMPLES_PER_RECORD, record_files))):
            output_record = os.path.join(PATHS.TRAINING_CHUNK_DIR,
                                         '{}-{}.tfrecord.zz'.format(model_name, str(i)))
                output_record, example_batch, serialize=False)
        already_processed.update(record_files)"Processed %s new files" %
          (len(already_processed) - num_already_processed))
    with gfile.GFile(meta_file, 'w') as f:
Example #10
Source File:    From g-tensorflow-models with Apache License 2.0 4 votes vote down vote up
def gather(selfplay_dir, training_chunk_dir, params):
  """Gather selfplay data into large training chunk.

    selfplay_dir: Where to look for games. Set as 'base_dir/data/selfplay/'.
    training_chunk_dir: where to put collected games. Set as
    params: A MiniGoParams instance of hyperparameters for the model.
  # Check the selfplay data from the most recent 50 models.
  sorted_model_dirs = sorted(tf.gfile.ListDirectory(selfplay_dir))
  models = [model_dir.strip('/')
            for model_dir in sorted_model_dirs[-params.gather_generation:]]

  with utils.logged_timer('Finding existing tfrecords...'):
    model_gamedata = {
        model: tf.gfile.Glob(
            os.path.join(selfplay_dir, model, '*'+_TF_RECORD_SUFFIX))
        for model in models
  print('Found {} models'.format(len(models)))
  for model_name, record_files in sorted(model_gamedata.items()):
    print('    {}: {} files'.format(model_name, len(record_files)))

  meta_file = os.path.join(training_chunk_dir, 'meta.txt')
    with tf.gfile.GFile(meta_file, 'r') as f:
      already_processed = set(
  except tf.errors.NotFoundError:
    already_processed = set()

  num_already_processed = len(already_processed)

  for model_name, record_files in sorted(model_gamedata.items()):
    if set(record_files) <= already_processed:
    print('Gathering files from {}:'.format(model_name))
    tf_examples = preprocessing.shuffle_tf_examples(
        params.shuffle_buffer_size, params.examples_per_chunk, record_files)
    # tqdm to make the loops show a smart progress meter
    for i, example_batch in enumerate(tf_examples):
      output_record = os.path.join(
          ('{}-{}'+_TF_RECORD_SUFFIX).format(model_name, str(i)))
          output_record, example_batch, serialize=False)

  print('Processed {} new files'.format(
      len(already_processed) - num_already_processed))
  with tf.gfile.GFile(meta_file, 'w') as f:
Example #11
Source File:    From multilabel-image-classification-tensorflow with MIT License 4 votes vote down vote up
def gather(selfplay_dir, training_chunk_dir, params):
  """Gather selfplay data into large training chunk.

    selfplay_dir: Where to look for games. Set as 'base_dir/data/selfplay/'.
    training_chunk_dir: where to put collected games. Set as
    params: A MiniGoParams instance of hyperparameters for the model.
  # Check the selfplay data from the most recent 50 models.
  sorted_model_dirs = sorted(tf.gfile.ListDirectory(selfplay_dir))
  models = [model_dir.strip('/')
            for model_dir in sorted_model_dirs[-params.gather_generation:]]

  with utils.logged_timer('Finding existing tfrecords...'):
    model_gamedata = {
        model: tf.gfile.Glob(
            os.path.join(selfplay_dir, model, '*'+_TF_RECORD_SUFFIX))
        for model in models
  print('Found {} models'.format(len(models)))
  for model_name, record_files in sorted(model_gamedata.items()):
    print('    {}: {} files'.format(model_name, len(record_files)))

  meta_file = os.path.join(training_chunk_dir, 'meta.txt')
    with tf.gfile.GFile(meta_file, 'r') as f:
      already_processed = set(
  except tf.errors.NotFoundError:
    already_processed = set()

  num_already_processed = len(already_processed)

  for model_name, record_files in sorted(model_gamedata.items()):
    if set(record_files) <= already_processed:
    print('Gathering files from {}:'.format(model_name))
    tf_examples = preprocessing.shuffle_tf_examples(
        params.shuffle_buffer_size, params.examples_per_chunk, record_files)
    # tqdm to make the loops show a smart progress meter
    for i, example_batch in enumerate(tf_examples):
      output_record = os.path.join(
          ('{}-{}'+_TF_RECORD_SUFFIX).format(model_name, str(i)))
          output_record, example_batch, serialize=False)

  print('Processed {} new files'.format(
      len(already_processed) - num_already_processed))
  with tf.gfile.GFile(meta_file, 'w') as f: