Python dataloader.DataLoader() Examples

The following are 7 code examples of dataloader.DataLoader(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module dataloader , or try the search function .
Example #1
Source File: cli.py    From QANet_dureader with MIT License 6 votes vote down vote up
def train(args):
    logger = logging.getLogger("QANet")
    logger.info("====== training ======")

    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)

    dataloader = DataLoader(args.max_p_num, args.max_p_len, args.max_q_len, args.max_ch_len,
                          args.train_files, args.dev_files)

    logger.info('Converting text into ids...')
    dataloader.convert_to_ids(vocab)

    logger.info('Initialize the model...')
    model = Model(vocab, args)

    logger.info('Training the model...')
    model.train(dataloader, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout=args.dropout)

    logger.info('====== Done with model training! ======') 
Example #2
Source File: cli.py    From QANet_dureader with MIT License 6 votes vote down vote up
def evaluate(args):
    logger = logging.getLogger("QANet")
    logger.info("====== evaluating ======")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)

    assert len(args.dev_files) > 0, 'No dev files are provided.'
    dataloader = DataLoader(args.max_p_num, args.max_p_len, args.max_q_len, args.max_ch_len, dev_files=args.dev_files)

    logger.info('Converting text into ids...')
    dataloader.convert_to_ids(vocab)

    logger.info('Restoring the model...')
    model = Model(vocab, args)
    model.restore(args.model_dir, args.algo)
    logger.info('Evaluating the model on dev set...')
    dev_batches = dataloader.next_batch('dev', args.batch_size, vocab.get_word_id(vocab.pad_token), vocab.get_char_id(vocab.pad_token), shuffle=False)

    dev_loss, dev_bleu_rouge = model.evaluate(
        dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted')

    logger.info('Loss on dev set: {}'.format(dev_loss))
    logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
    logger.info('Predicted answers are saved to {}'.format(os.path.join(args.result_dir))) 
Example #3
Source File: cli.py    From QANet_dureader with MIT License 6 votes vote down vote up
def predict(args):
    logger = logging.getLogger("QANet")

    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)

    assert len(args.test_files) > 0, 'No test files are provided.'
    dataloader = DataLoader(args.max_p_num, args.max_p_len, args.max_q_len, args.max_ch_len, 
                          test_files=args.test_files)

    logger.info('Converting text into ids...')
    dataloader.convert_to_ids(vocab)
    logger.info('Restoring the model...')

    model = Model(vocab, args)
    model.restore(args.model_dir, args.algo)
    logger.info('Predicting answers for test set...')
    test_batches = dataloader.next_batch('test', args.batch_size, vocab.get_word_id(vocab.pad_token), vocab.get_char_id(vocab.pad_token), shuffle=False)

    model.evaluate(test_batches,
                      result_dir=args.result_dir, result_prefix='test.predicted') 
Example #4
Source File: main.py    From Fast-SRGAN with MIT License 5 votes vote down vote up
def main():
    # Parse the CLI arguments.
    args = parser.parse_args()

    # create directory for saving trained models.
    if not os.path.exists('models'):
        os.makedirs('models')

    # Create the tensorflow dataset.
    ds = DataLoader(args.image_dir, args.hr_size).dataset(args.batch_size)

    # Initialize the GAN object.
    gan = FastSRGAN(args)

    # Define the directory for saving pretrainig loss tensorboard summary.
    pretrain_summary_writer = tf.summary.create_file_writer('logs/pretrain')

    # Run pre-training.
    pretrain_generator(gan, ds, pretrain_summary_writer)

    # Define the directory for saving the SRGAN training tensorbaord summary.
    train_summary_writer = tf.summary.create_file_writer('logs/train')

    # Run training.
    for _ in range(args.epochs):
        train(gan, ds, args.save_iter, train_summary_writer) 
Example #5
Source File: example.py    From fastNLP with Apache License 2.0 5 votes vote down vote up
def test(model_dict, using_cuda=True):
    if using_cuda:
        net = Net().cuda()
    else:
        net = Net()
    net.load_state_dict(torch.load(model_dict))
    dataset = dataloader.DataLoader("test_set.pkl", batch_size=1, using_cuda=using_cuda)
    count = 0
    for i, batch in enumerate(dataset):
        X = batch["feature"]
        y = batch["class"]
        y_pred, _ = net(X)
        p, idx = torch.max(y_pred.data, dim=1)
        count += torch.sum(torch.eq(idx.cpu(), y.data.cpu()))
    print("accuracy: %f"%(count / dataset.num)) 
Example #6
Source File: test.py    From cubenet with MIT License 4 votes vote down vote up
def test(args):
    print('...Building inputs')
    tf.reset_default_graph()

    print('...Connecting data io and preprocessing')
    with tf.device("/cpu:0"):
        with tf.name_scope("IO"):
            test_data = DataLoader(args.test_file, 'test', args.batch_size,
                                    args.height, args.jitter, shuffle=False)
            args.n_classes = test_data.n_classes
            args.data_size = test_data.data_size
            print("Found {} test examples".format(args.data_size))

            test_iterator = test_data.data.make_initializable_iterator()
            test_inputs, test_targets = test_iterator.get_next()
            test_inputs.set_shape([args.batch_size, args.height, args.width, args.depth, 1])
            test_init_op = test_iterator.make_initializer(test_data.data)
    
    # Outputs
    print('...Constructing model')
    with tf.get_default_graph().as_default(): 
        with tf.variable_scope("model", reuse=False):
            model = GVGG(test_inputs, False, args)
            test_logits = model.pred_logits
            test_preds = tf.nn.softmax(test_logits)

            # Prediction loss
            print("...Building metrics")
            preds = tf.to_int32(tf.argmax(test_preds, 1))
            test_accuracy = tf.contrib.metrics.accuracy(preds, test_targets)
            # HACK: Rotation averaging is brittle.
            preds_rot = tf.to_int32(tf.argmax(tf.reduce_mean(test_preds, 0)))
            test_targets_rot = test_targets[0]
            test_accuracy_rot = tf.contrib.metrics.accuracy(preds_rot, test_targets_rot)
    
    with tf.Session() as sess:
        # Load pretrained model, ignoring final layer
        print('...Restore variables')
        tf.global_variables_initializer().run()
        restorer = tf.train.Saver()
        model_path = tf.train.latest_checkpoint(args.save_dir)
        restorer.restore(sess, model_path)

        accuracies = []
        accuracies_rotavg = []
        print("...Testing")

        sess.run([test_init_op])
        for i in range(args.data_size // args.batch_size):
            tacc, tacc_rotavg = sess.run([test_accuracy, test_accuracy_rot])

            accuracies.append(tacc)
            accuracies_rotavg.append(tacc_rotavg)

            sys.stdout.write("[{} | {}] Running acc: {:0.4f}, Running rot acc: {:0.4f}\r".format(i*args.batch_size, args.data_size, np.mean(accuracies), np.mean(accuracies_rotavg)))
            sys.stdout.flush()
            
        print()
        print("Test accuracy: {:04f}".format(np.mean(accuracies)))
        print("Test accuracy rot avg: {:04f}".format(np.mean(accuracies_rotavg)))
        print() 
Example #7
Source File: cli.py    From QANet_dureader with MIT License 4 votes vote down vote up
def prepro(args):
    logger = logging.getLogger("QANet")
    logger.info("====== preprocessing ======")
    logger.info('Checking the data files...')
    for data_path in args.train_files + args.dev_files + args.test_files:
        assert os.path.exists(data_path), '{} file does not exist.'.format(data_path)

    logger.info('Preparing the directories...')
    for dir_path in [args.vocab_dir, args.model_dir, args.result_dir, args.summary_dir]:
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)

    logger.info('Building vocabulary...')
    dataloader = DataLoader(args.max_p_num, args.max_p_len, args.max_q_len, args.max_ch_len, 
                          args.train_files, args.dev_files, args.test_files)

    vocab = Vocab(lower=True)
    for word in dataloader.word_iter('train'):
        vocab.add_word(word)
        [vocab.add_char(ch) for ch in word]

    unfiltered_vocab_size = vocab.word_size()
    vocab.filter_words_by_cnt(min_cnt=2)
    filtered_num = unfiltered_vocab_size - vocab.word_size()
    logger.info('After filter {} tokens, the final vocab size is {}, char size is{}'.format(filtered_num,
                                                                            vocab.word_size(), vocab.char_size()))

    unfiltered_vocab_char_size = vocab.char_size()
    vocab.filter_chars_by_cnt(min_cnt=2)
    filtered_char_num = unfiltered_vocab_char_size - vocab.char_size()
    logger.info('After filter {} chars, the final char vocab size is {}'.format(filtered_char_num,
                                                                            vocab.char_size()))

    logger.info('Assigning embeddings...')
    if args.pretrained_word_path is not None:
        vocab.load_pretrained_word_embeddings(args.pretrained_word_path)
    else:
        vocab.randomly_init_word_embeddings(args.word_embed_size)
    
    if args.pretrained_char_path is not None:
        vocab.load_pretrained_char_embeddings(args.pretrained_char_path)
    else:
        vocab.randomly_init_char_embeddings(args.char_embed_size)

    logger.info('Saving vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'wb') as fout:
        pickle.dump(vocab, fout)

    logger.info('====== Done with preparing! ======')