Python data_utils.get_input_fn() Examples

The following are 4 code examples of data_utils.get_input_fn(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module data_utils , or try the search function .
Example #1
Source File: train.py    From embedding-as-service with MIT License 6 votes vote down vote up
def get_input_fn(split):
  """doc."""
  assert split == "train"
  batch_size = FLAGS.train_batch_size

  input_fn, record_info_dict = data_utils.get_input_fn(
      tfrecord_dir=FLAGS.record_info_dir,
      split=split,
      bsz_per_host=batch_size // FLAGS.num_hosts,
      seq_len=FLAGS.seq_len,
      reuse_len=FLAGS.reuse_len,
      bi_data=FLAGS.bi_data,
      num_hosts=FLAGS.num_hosts,
      num_core_per_host=FLAGS.num_core_per_host,
      perm_size=FLAGS.perm_size,
      mask_alpha=FLAGS.mask_alpha,
      mask_beta=FLAGS.mask_beta,
      uncased=FLAGS.uncased,
      num_passes=FLAGS.num_passes,
      use_bfloat16=FLAGS.use_bfloat16,
      num_predict=FLAGS.num_predict)

  return input_fn, record_info_dict 
Example #2
Source File: train.py    From xlnet with Apache License 2.0 6 votes vote down vote up
def get_input_fn(split):
  """doc."""
  assert split == "train"
  batch_size = FLAGS.train_batch_size

  input_fn, record_info_dict = data_utils.get_input_fn(
      tfrecord_dir=FLAGS.record_info_dir,
      split=split,
      bsz_per_host=batch_size // FLAGS.num_hosts,
      seq_len=FLAGS.seq_len,
      reuse_len=FLAGS.reuse_len,
      bi_data=FLAGS.bi_data,
      num_hosts=FLAGS.num_hosts,
      num_core_per_host=FLAGS.num_core_per_host,
      perm_size=FLAGS.perm_size,
      mask_alpha=FLAGS.mask_alpha,
      mask_beta=FLAGS.mask_beta,
      uncased=FLAGS.uncased,
      num_passes=FLAGS.num_passes,
      use_bfloat16=FLAGS.use_bfloat16,
      num_predict=FLAGS.num_predict)

  return input_fn, record_info_dict 
Example #3
Source File: train.py    From embedding-as-service with MIT License 4 votes vote down vote up
def main(unused_argv):
  del unused_argv  # Unused

  tf.logging.set_verbosity(tf.logging.INFO)

  assert FLAGS.seq_len > 0
  assert FLAGS.perm_size > 0

  FLAGS.n_token = data_utils.VOCAB_SIZE
  tf.logging.info("n_token {}".format(FLAGS.n_token))

  if not tf.gfile.Exists(FLAGS.model_dir):
    tf.gfile.MakeDirs(FLAGS.model_dir)

  # Get train input function
  train_input_fn, train_record_info_dict = get_input_fn("train")

  tf.logging.info("num of batches {}".format(
      train_record_info_dict["num_batch"]))

  # Get train cache function
  train_cache_fn = get_cache_fn(FLAGS.mem_len)

  ##### Get model function
  model_fn = get_model_fn()

  ##### Create TPUEstimator
  # TPU Configuration
  run_config = model_utils.configure_tpu(FLAGS)

  # TPU Estimator
  estimator = tpu_estimator.TPUEstimator(
      model_fn=model_fn,
      train_cache_fn=train_cache_fn,
      use_tpu=FLAGS.use_tpu,
      config=run_config,
      params={"track_mean": FLAGS.track_mean},
      train_batch_size=FLAGS.train_batch_size,
      eval_on_tpu=FLAGS.use_tpu)

  #### Training
  estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps) 
Example #4
Source File: train.py    From xlnet with Apache License 2.0 4 votes vote down vote up
def main(unused_argv):
  del unused_argv  # Unused

  tf.logging.set_verbosity(tf.logging.INFO)

  assert FLAGS.seq_len > 0
  assert FLAGS.perm_size > 0

  FLAGS.n_token = data_utils.VOCAB_SIZE
  tf.logging.info("n_token {}".format(FLAGS.n_token))

  if not tf.gfile.Exists(FLAGS.model_dir):
    tf.gfile.MakeDirs(FLAGS.model_dir)

  # Get train input function
  train_input_fn, train_record_info_dict = get_input_fn("train")

  tf.logging.info("num of batches {}".format(
      train_record_info_dict["num_batch"]))

  # Get train cache function
  train_cache_fn = get_cache_fn(FLAGS.mem_len)

  ##### Get model function
  model_fn = get_model_fn()

  ##### Create TPUEstimator
  # TPU Configuration
  run_config = model_utils.configure_tpu(FLAGS)

  # TPU Estimator
  estimator = tpu_estimator.TPUEstimator(
      model_fn=model_fn,
      train_cache_fn=train_cache_fn,
      use_tpu=FLAGS.use_tpu,
      config=run_config,
      params={"track_mean": FLAGS.track_mean},
      train_batch_size=FLAGS.train_batch_size,
      eval_on_tpu=FLAGS.use_tpu)

  #### Training
  estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)