Python data_utils.get_input_fn() Examples
The following are 4
code examples of data_utils.get_input_fn().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
data_utils
, or try the search function
.
Example #1
Source File: train.py From embedding-as-service with MIT License | 6 votes |
def get_input_fn(split): """doc.""" assert split == "train" batch_size = FLAGS.train_batch_size input_fn, record_info_dict = data_utils.get_input_fn( tfrecord_dir=FLAGS.record_info_dir, split=split, bsz_per_host=batch_size // FLAGS.num_hosts, seq_len=FLAGS.seq_len, reuse_len=FLAGS.reuse_len, bi_data=FLAGS.bi_data, num_hosts=FLAGS.num_hosts, num_core_per_host=FLAGS.num_core_per_host, perm_size=FLAGS.perm_size, mask_alpha=FLAGS.mask_alpha, mask_beta=FLAGS.mask_beta, uncased=FLAGS.uncased, num_passes=FLAGS.num_passes, use_bfloat16=FLAGS.use_bfloat16, num_predict=FLAGS.num_predict) return input_fn, record_info_dict
Example #2
Source File: train.py From xlnet with Apache License 2.0 | 6 votes |
def get_input_fn(split): """doc.""" assert split == "train" batch_size = FLAGS.train_batch_size input_fn, record_info_dict = data_utils.get_input_fn( tfrecord_dir=FLAGS.record_info_dir, split=split, bsz_per_host=batch_size // FLAGS.num_hosts, seq_len=FLAGS.seq_len, reuse_len=FLAGS.reuse_len, bi_data=FLAGS.bi_data, num_hosts=FLAGS.num_hosts, num_core_per_host=FLAGS.num_core_per_host, perm_size=FLAGS.perm_size, mask_alpha=FLAGS.mask_alpha, mask_beta=FLAGS.mask_beta, uncased=FLAGS.uncased, num_passes=FLAGS.num_passes, use_bfloat16=FLAGS.use_bfloat16, num_predict=FLAGS.num_predict) return input_fn, record_info_dict
Example #3
Source File: train.py From embedding-as-service with MIT License | 4 votes |
def main(unused_argv): del unused_argv # Unused tf.logging.set_verbosity(tf.logging.INFO) assert FLAGS.seq_len > 0 assert FLAGS.perm_size > 0 FLAGS.n_token = data_utils.VOCAB_SIZE tf.logging.info("n_token {}".format(FLAGS.n_token)) if not tf.gfile.Exists(FLAGS.model_dir): tf.gfile.MakeDirs(FLAGS.model_dir) # Get train input function train_input_fn, train_record_info_dict = get_input_fn("train") tf.logging.info("num of batches {}".format( train_record_info_dict["num_batch"])) # Get train cache function train_cache_fn = get_cache_fn(FLAGS.mem_len) ##### Get model function model_fn = get_model_fn() ##### Create TPUEstimator # TPU Configuration run_config = model_utils.configure_tpu(FLAGS) # TPU Estimator estimator = tpu_estimator.TPUEstimator( model_fn=model_fn, train_cache_fn=train_cache_fn, use_tpu=FLAGS.use_tpu, config=run_config, params={"track_mean": FLAGS.track_mean}, train_batch_size=FLAGS.train_batch_size, eval_on_tpu=FLAGS.use_tpu) #### Training estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)
Example #4
Source File: train.py From xlnet with Apache License 2.0 | 4 votes |
def main(unused_argv): del unused_argv # Unused tf.logging.set_verbosity(tf.logging.INFO) assert FLAGS.seq_len > 0 assert FLAGS.perm_size > 0 FLAGS.n_token = data_utils.VOCAB_SIZE tf.logging.info("n_token {}".format(FLAGS.n_token)) if not tf.gfile.Exists(FLAGS.model_dir): tf.gfile.MakeDirs(FLAGS.model_dir) # Get train input function train_input_fn, train_record_info_dict = get_input_fn("train") tf.logging.info("num of batches {}".format( train_record_info_dict["num_batch"])) # Get train cache function train_cache_fn = get_cache_fn(FLAGS.mem_len) ##### Get model function model_fn = get_model_fn() ##### Create TPUEstimator # TPU Configuration run_config = model_utils.configure_tpu(FLAGS) # TPU Estimator estimator = tpu_estimator.TPUEstimator( model_fn=model_fn, train_cache_fn=train_cache_fn, use_tpu=FLAGS.use_tpu, config=run_config, params={"track_mean": FLAGS.track_mean}, train_batch_size=FLAGS.train_batch_size, eval_on_tpu=FLAGS.use_tpu) #### Training estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)