Python data_utils.Dataset() Examples
The following are 5
code examples of data_utils.Dataset().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
data_utils
, or try the search function
.
Example #1
Source File: single_lm_train.py From lm with MIT License | 6 votes |
def main(_): hps = LM.get_default_hparams().parse(FLAGS.hpconfig) hps.num_gpus = FLAGS.num_gpus vocab = Vocabulary.from_file("1b_word_vocab.txt") if FLAGS.mode == "train": hps.batch_size = 256 dataset = Dataset(vocab, FLAGS.datadir + "/training-monolingual.tokenized.shuffled/*") run_train(dataset, hps, FLAGS.logdir + "/train", ps_device="/gpu:0") elif FLAGS.mode.startswith("eval_"): if FLAGS.mode.startswith("eval_train"): data_dir = FLAGS.datadir + "/training-monolingual.tokenized.shuffled/*" else: data_dir = FLAGS.datadir + "/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050" dataset = Dataset(vocab, data_dir, deterministic=True) run_eval(dataset, hps, FLAGS.logdir, FLAGS.mode, FLAGS.eval_steps)
Example #2
Source File: data_utils_test.py From lm with MIT License | 6 votes |
def test_dataset(self): vocab = Vocabulary.from_file("testdata/test_vocab.txt") dataset = Dataset(vocab, "testdata/*") def generator(): for i in range(1, 10): yield [0] + list(range(1, i + 1)) + [0] counts = [0] * 10 for seq in generator(): for v in seq: counts[v] += 1 counts2 = [0] * 10 for x, y, w in dataset._iterate(generator(), 2, 4): for v in x.ravel(): counts2[v] += 1 for i in range(1, 10): self.assertEqual(counts[i], counts2[i], "Mismatch at i=%d" % i)
Example #3
Source File: data_utils_test.py From f-lm with MIT License | 6 votes |
def test_dataset(self): vocab = Vocabulary.from_file("testdata/test_vocab.txt") dataset = Dataset(vocab, "testdata/*") def generator(): for i in range(1, 10): yield [0] + list(range(1, i + 1)) + [0] counts = [0] * 10 for seq in generator(): for v in seq: counts[v] += 1 counts2 = [0] * 10 for x, y in dataset._iterate(generator(), 2, 4): for v in x.ravel(): counts2[v] += 1 for i in range(1, 10): self.assertEqual(counts[i], counts2[i], "Mismatch at i=%d. counts[i]=%s, counts2[i]=%s" % (i,counts[i], counts2[i]))
Example #4
Source File: single_lm_train.py From f-lm with MIT License | 5 votes |
def main(_): """ Start either train or eval. Note hardcoded parts of path for training and eval data """ hps = LM.get_default_hparams().parse(FLAGS.hpconfig) hps._set("num_gpus", FLAGS.num_gpus) print('*****HYPER PARAMETERS*****') print(hps) print('**************************') vocab = Vocabulary.from_file(os.path.join(FLAGS.datadir, "1b_word_vocab.txt")) if FLAGS.mode == "train": #hps.batch_size = 256 dataset = Dataset(vocab, os.path.join(FLAGS.datadir, "training-monolingual.tokenized.shuffled/*")) run_train(dataset, hps, os.path.join(FLAGS.logdir, "train"), ps_device="/gpu:0") elif FLAGS.mode.startswith("eval_"): if FLAGS.mode.startswith("eval_train"): data_dir = os.path.join(FLAGS.datadir, "training-monolingual.tokenized.shuffled/*") elif FLAGS.mode.startswith("eval_full"): data_dir = os.path.join(FLAGS.datadir, "heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050") else: data_dir = os.path.join(FLAGS.datadir, "heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050") dataset = Dataset(vocab, data_dir, deterministic=True) run_eval(dataset, hps, FLAGS.logdir, FLAGS.mode, FLAGS.eval_steps) elif FLAGS.mode.startswith("infer"): data_dir = os.path.join(FLAGS.datadir, "heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050") dataset = Dataset(vocab, data_dir, deterministic=True) run_infer(dataset, hps, FLAGS.logdir, FLAGS.mode, vocab)
Example #5
Source File: main.py From dgm_latent_bow with MIT License | 4 votes |
def main(): # configuration config = Config() config.parse_arg(FLAGS) config.setup_path() config.print_arg() # dataset if(config.dataset == 'wikibio'): dset = DatasetTable2text(config) dset.load() config.key_size = len(dset.key2id) else: dset = Dataset(config) dset.build() config.vocab_size = len(dset.word2id) config.dec_start_id = dset.word2id["_GOO"] config.dec_end_id = dset.word2id["_EOS"] config.pad_id = dset.pad_id config.stop_words = dset.stop_words # model if(config.model_name == "seq2seq"): if(config.dataset == 'wikibio'): Model = Seq2seqData2text else: Model = Seq2seq elif(config.model_name == "bow_seq2seq"): Model = BowSeq2seq elif(config.model_name == "vae"): Model = Vae elif(config.model_name == "hierarchical_vae"): Model = Hierarchical_Vae elif(config.model_name == "latent_bow"): if(config.dataset == 'wikibio'): Model = LatentBowData2text else: Model = LatentBow elif(config.model_name == "lm"): Model = LM else: msg = "the model name shoule be in ['seq2seq', 'vae', 'hierarchical_vae', 'latent_low', 'lm'], " msg += "current name: %s" % config.model_name raise Exception(msg) model = Model(config) with tf.variable_scope(config.model_name): model.build() # controller controller = Controller(config) if(config.model_name != "lm"): if("lm" in controller.eval_metrics_list): controller.build_lm(LM, config) controller.train(model, dset) return