Python chainer.iterators.MultiprocessIterator() Examples

The following are 30 code examples of chainer.iterators.MultiprocessIterator(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module chainer.iterators , or try the search function .
Example #1
Source File: cifar1.py    From imgclsmob with MIT License 6 votes vote down vote up
def get_val_data_iterator(dataset_name,
                          batch_size,
                          num_workers):

    if dataset_name == "CIFAR10":
        _, test_ds = cifar.get_cifar10()
    elif dataset_name == "CIFAR100":
        _, test_ds = cifar.get_cifar100()
    elif dataset_name == "SVHN":
        _, test_ds = svhn.get_svhn()
    else:
        raise Exception('Unrecognized dataset: {}'.format(dataset_name))

    val_dataset = test_ds
    val_dataset_len = len(val_dataset)

    val_iterator = iterators.MultiprocessIterator(
        dataset=val_dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=False,
        n_processes=num_workers,
        shared_mem=300000000)

    return val_iterator, val_dataset_len 
Example #2
Source File: cifar1.py    From imgclsmob with MIT License 6 votes vote down vote up
def get_data_iterators(batch_size,
                       num_workers):

    train_dataset = PreprocessedCIFARDataset(train=True)
    train_iterator = iterators.MultiprocessIterator(
        dataset=train_dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=True,
        n_processes=num_workers)

    val_dataset = PreprocessedCIFARDataset(train=False)
    val_iterator = iterators.MultiprocessIterator(
        dataset=val_dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=False,
        n_processes=num_workers)

    return train_iterator, val_iterator 
Example #3
Source File: dataset_utils.py    From imgclsmob with MIT License 6 votes vote down vote up
def get_train_data_source(ds_metainfo,
                          batch_size,
                          num_workers):
    transform = ds_metainfo.train_transform(ds_metainfo=ds_metainfo)
    dataset = ds_metainfo.dataset_class(
        root=ds_metainfo.root_dir_path,
        mode="train",
        transform=transform)
    ds_metainfo.update_from_dataset(dataset)
    iterator = MultiprocessIterator(
        dataset=dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=True,
        n_processes=num_workers,
        shared_mem=300000000)
    return {
        # "transform": transform,
        "iterator": iterator,
        "ds_len": len(dataset)
    } 
Example #4
Source File: dataset_utils.py    From imgclsmob with MIT License 6 votes vote down vote up
def get_test_data_source(ds_metainfo,
                         batch_size,
                         num_workers):
    transform = ds_metainfo.test_transform(ds_metainfo=ds_metainfo)
    dataset = ds_metainfo.dataset_class(
        root=ds_metainfo.root_dir_path,
        mode="test",
        transform=transform)
    ds_metainfo.update_from_dataset(dataset)
    iterator = MultiprocessIterator(
        dataset=dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=False,
        n_processes=num_workers,
        shared_mem=300000000)
    return {
        # "transform": transform,
        "iterator": iterator,
        "ds_len": len(dataset)
    } 
Example #5
Source File: imagenet1k1.py    From imgclsmob with MIT License 6 votes vote down vote up
def get_val_data_iterator(data_dir,
                          batch_size,
                          num_workers,
                          num_classes):

    val_dir_path = os.path.join(data_dir, 'val')
    val_dataset = DirectoryParsingLabelDataset(val_dir_path)
    val_dataset_len = len(val_dataset)
    assert(len(directory_parsing_label_names(val_dir_path)) == num_classes)

    val_iterator = iterators.MultiprocessIterator(
        dataset=val_dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=False,
        n_processes=num_workers,
        shared_mem=300000000)

    return val_iterator, val_dataset_len 
Example #6
Source File: test_multiprocess_iterator.py    From chainer with MIT License 6 votes vote down vote up
def test_iterator_not_repeat_not_even(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.MultiprocessIterator(
            dataset, 2, repeat=False, **self.options)

        self.assertAlmostEqual(it.epoch_detail, 0 / 5)
        self.assertIsNone(it.previous_epoch_detail)
        batch1 = it.next()
        self.assertAlmostEqual(it.epoch_detail, 2 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5)
        batch2 = it.next()
        self.assertAlmostEqual(it.epoch_detail, 4 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5)
        batch3 = it.next()
        self.assertAlmostEqual(it.epoch_detail, 5 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5)
        self.assertRaises(StopIteration, it.next)

        self.assertEqual(len(batch3), 1)
        self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) 
Example #7
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_stalled_getitem(self):
        nth = self.nth
        batch_size = 2
        sleep = 0.5
        timeout = 0.1

        dataset = StallingDataset(nth, sleep)
        it = iterators.MultiprocessIterator(
            dataset, batch_size=batch_size, shuffle=False,
            dataset_timeout=timeout, repeat=False)

        # TimeoutWarning should be issued.
        warning_cls = iterators.MultiprocessIterator.TimeoutWarning
        data = []
        # No warning until the stalling batch
        for i in range(nth // batch_size):
            data.append(it.next())
        # Warning on the stalling batch
        with testing.assert_warns(warning_cls):
            data.append(it.next())
        # Retrieve data until the end
        while True:
            try:
                data.append(it.next())
            except StopIteration:
                break

        # All data must be retrieved
        assert data == [
            dataset.data[i * batch_size: (i+1) * batch_size]
            for i in range((len(dataset) + batch_size - 1) // batch_size)] 
Example #8
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_iterator_pickle_after_init(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)

        self.assertEqual(it.epoch, 0)
        self.assertAlmostEqual(it.epoch_detail, 0 / 6)
        self.assertIsNone(it.previous_epoch_detail)
        batch1 = it.next()
        self.assertEqual(len(batch1), 2)
        self.assertIsInstance(batch1, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 2 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6)
        batch2 = it.next()
        self.assertEqual(len(batch2), 2)
        self.assertIsInstance(batch2, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        pickled_it = pickle.dumps(it)
        it = pickle.loads(pickled_it)

        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        batch3 = it.next()
        self.assertEqual(len(batch3), 2)
        self.assertIsInstance(batch3, list)
        self.assertTrue(it.is_new_epoch)
        self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
        self.assertAlmostEqual(it.epoch_detail, 6 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6) 
Example #9
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_iterator_serialize(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)

        self.assertEqual(it.epoch, 0)
        self.assertAlmostEqual(it.epoch_detail, 0 / 6)
        self.assertIsNone(it.previous_epoch_detail)
        batch1 = it.next()
        self.assertEqual(len(batch1), 2)
        self.assertIsInstance(batch1, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 2 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6)
        batch2 = it.next()
        self.assertEqual(len(batch2), 2)
        self.assertIsInstance(batch2, list)
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        target = dict()
        it.serialize(serializers.DictionarySerializer(target))

        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        it.serialize(serializers.NpzDeserializer(target))
        self.assertFalse(it.is_new_epoch)
        self.assertAlmostEqual(it.epoch_detail, 4 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6)

        batch3 = it.next()
        self.assertEqual(len(batch3), 2)
        self.assertIsInstance(batch3, list)
        self.assertTrue(it.is_new_epoch)
        self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
        self.assertAlmostEqual(it.epoch_detail, 6 / 6)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6) 
Example #10
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_iterator_repeat(self):
        dataset = [1, 2, 3]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        for i in range(3):
            self.assertEqual(it.epoch, i)
            self.assertAlmostEqual(it.epoch_detail, i + 0 / 6)
            if i == 0:
                self.assertIsNone(it.previous_epoch_detail)
            else:
                self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
            batch1 = it.next()
            self.assertEqual(len(batch1), 2)
            self.assertIsInstance(batch1, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, i + 2 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6)
            batch2 = it.next()
            self.assertEqual(len(batch2), 2)
            self.assertIsInstance(batch2, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, i + 4 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6)
            batch3 = it.next()
            self.assertEqual(len(batch3), 2)
            self.assertIsInstance(batch3, list)
            self.assertTrue(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, i + 6 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6)

            self.assertEqual(
                sorted(batch1 + batch2 + batch3), [1, 1, 2, 2, 3, 3]) 
Example #11
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_no_same_indices_order_sampler(self):
        dataset = [1, 2, 3, 4, 5, 6]
        batchsize = 5

        it = iterators.MultiprocessIterator(
            dataset, batchsize,
            order_sampler=_NoSameIndicesOrderSampler(batchsize))
        for _ in range(5):
            batch = it.next()
            self.assertEqual(len(numpy.unique(batch)), batchsize) 
Example #12
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_invalid_order_sampler(self):
        dataset = [1, 2, 3, 4, 5, 6]

        with self.assertRaises(ValueError):
            it = iterators.MultiprocessIterator(
                dataset, 6, shuffle=None,
                order_sampler=_InvalidOrderSampler())
            it.next() 
Example #13
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_finalize_not_deadlock(self):
        dataset = numpy.ones((1000, 1000))
        it = iterators.MultiprocessIterator(dataset, 10, n_processes=4)
        for _ in range(10):
            it.next()

        t = threading.Thread(target=lambda: it.finalize())
        t.daemon = True
        t.start()
        t.join(5)
        deadlock = t.is_alive()

        self.assertFalse(deadlock) 
Example #14
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_reproduce_same_permutation(self):
        dataset = [1, 2, 3, 4, 5, 6]
        order_sampler1 = iterators.ShuffleOrderSampler(
            numpy.random.RandomState(self._seed))
        it1 = iterators.MultiprocessIterator(
            dataset, 6, order_sampler=order_sampler1)
        order_sampler2 = iterators.ShuffleOrderSampler(
            numpy.random.RandomState(self._seed))
        it2 = iterators.MultiprocessIterator(
            dataset, 6, order_sampler=order_sampler2)
        for _ in range(5):
            self.assertEqual(it1.next(), it2.next()) 
Example #15
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_reset_repeat(self):
        dataset = [1, 2, 3, 4]
        it = iterators.MultiprocessIterator(
            dataset, 2, repeat=True, **self.options)

        for trial in range(4):
            batches = sum([it.next() for _ in range(4)], [])
            self.assertEqual(sorted(batches), sorted(2 * dataset))
            it.reset() 
Example #16
Source File: test_iterator_compatibility.py    From chainer with MIT License 5 votes vote down vote up
def test_iterator_compatibilty(self):
        dataset = [1, 2, 3, 4, 5, 6]

        iters = (
            lambda: iterators.SerialIterator(dataset, 2),
            lambda: iterators.MultiprocessIterator(dataset, 2, **self.options),
        )

        for it_before, it_after in itertools.permutations(iters, 2):
            it = it_before()

            self.assertEqual(it.epoch, 0)
            self.assertAlmostEqual(it.epoch_detail, 0 / 6)
            batch1 = it.next()
            self.assertEqual(len(batch1), 2)
            self.assertIsInstance(batch1, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 2 / 6)
            batch2 = it.next()
            self.assertEqual(len(batch2), 2)
            self.assertIsInstance(batch2, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 4 / 6)

            target = dict()
            it.serialize(serializers.DictionarySerializer(target))

            it = it_after()
            it.serialize(serializers.NpzDeserializer(target))
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, 4 / 6)

            batch3 = it.next()
            self.assertEqual(len(batch3), 2)
            self.assertIsInstance(batch3, list)
            self.assertTrue(it.is_new_epoch)
            self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
            self.assertAlmostEqual(it.epoch_detail, 6 / 6) 
Example #17
Source File: evaluator.py    From chainer with MIT License 5 votes vote down vote up
def __init__(self, iterator, target, converter=convert.concat_examples,
                 device=None, eval_hook=None, eval_func=None, **kwargs):
        progress_bar, = argument.parse_kwargs(kwargs, ('progress_bar', False))

        if device is not None:
            device = backend.get_device(device)

        if isinstance(iterator, iterator_module.Iterator):
            iterator = {'main': iterator}
        self._iterators = iterator

        if isinstance(target, link.Link):
            target = {'main': target}
        self._targets = target

        self.converter = converter
        self.device = device
        self.eval_hook = eval_hook
        self.eval_func = eval_func

        self._progress_bar = progress_bar

        for key, iter in six.iteritems(iterator):
            if (isinstance(iter, (iterators.SerialIterator,
                                  iterators.MultiprocessIterator,
                                  iterators.MultithreadIterator)) and
                    getattr(iter, 'repeat', False)):
                msg = 'The `repeat` property of the iterator {} '
                'is set to `True`. Typically, the evaluator sweeps '
                'over iterators until they stop, '
                'but as the property being `True`, this iterator '
                'might not stop and evaluation could go into '
                'an infinite loop. '
                'We recommend to check the configuration '
                'of iterators'.format(key)
                warnings.warn(msg) 
Example #18
Source File: eval_imagenet.py    From chainercv with MIT License 5 votes vote down vote up
def main():
    parser = argparse.ArgumentParser(
        description='Evaluating convnet from ILSVRC2012 dataset')
    parser.add_argument('val', help='Path to root of the validation dataset')
    parser.add_argument('--model', choices=sorted(models.keys()))
    parser.add_argument('--pretrained-model')
    parser.add_argument('--dataset', choices=('imagenet',))
    parser.add_argument('--gpu', type=int, default=-1)
    parser.add_argument('--batchsize', type=int)
    parser.add_argument('--crop', choices=('center', '10'))
    parser.add_argument('--resnet-arch')
    args = parser.parse_args()

    dataset, eval_, model, batchsize = setup(
        args.dataset, args.model, args.pretrained_model, args.batchsize,
        args.val, args.crop, args.resnet_arch)

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()
        model.to_gpu()

    iterator = iterators.MultiprocessIterator(
        dataset, batchsize, repeat=False, shuffle=False,
        n_processes=6, shared_mem=300000000)

    print('Model has been prepared. Evaluation starts.')
    in_values, out_values, rest_values = apply_to_iterator(
        model.predict, iterator, hook=ProgressHook(len(dataset)))
    del in_values

    eval_(out_values, rest_values) 
Example #19
Source File: main.py    From SPReID with MIT License 5 votes vote down vote up
def Evaluation():
    # Creat data generator
    batch_tuple = MultiprocessIterator(
        DataChef.ReID10D(args, args.project_folder + '/evaluation_list/' + args.eval_split + '.txt',
                         image_size=args.scales_tr[0]),
        args.minibatch, n_prefetch=2, n_processes=args.nb_processes, shared_mem=20000000, repeat=False, shuffle=False)
    # Keep the log in history
    history = {args.dataset: {'features': []}}

    for dataBatch in batch_tuple:
        dataBatch = zip(*dataBatch)
        # Prepare batch data
        IMG = np.array_split(np.array(dataBatch[0]), len(Model), axis=0)
        LBL = np.array_split(np.array(dataBatch[1]), len(Model), axis=0)
        # Forward
        for device_id, img, lbl in zip(range(len(Model)), IMG, LBL):
            Model[device_id](img, lbl, args.dataset, train=False)
        # Aggregate reporters from all GPUs
        reporters = []
        for i in range(len(Model)):
            reporters.append(Model[i].reporter)
            Model[i].reporter = {}  # clear reporter
        # History
        for reporter in reporters:
            for k in reporter[args.dataset].keys():
                history[args.dataset][k].append(reporter[args.dataset][k])
    # storing features to an outputfile
    features = np.concatenate(history[args.dataset]['features'], axis=0)
    outfile = '%s/evaluation_features/%s_@%s_%s.csv' % (
    args.project_folder, args.dataset, args.checkpoint, args.eval_split)
    np.savetxt(outfile, features, delimiter=',', fmt='%0.12e') 
Example #20
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_iterator_pickle_new(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)

        self.assertEqual(it.epoch, 0)
        self.assertAlmostEqual(it.epoch_detail, 0 / 6)
        self.assertIsNone(it.previous_epoch_detail)
        pickled_it = pickle.dumps(it)
        it = pickle.loads(pickled_it) 
Example #21
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_unsupported_reset_finalized(self):
        dataset = [1, 2, 3, 4]
        it = iterators.MultiprocessIterator(
            dataset, 2, repeat=False, **self.options)
        it.next()
        it.next()
        it.finalize()
        self.assertRaises(NotImplementedError, it.reset) 
Example #22
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_reset(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.MultiprocessIterator(
            dataset, 2, repeat=False, **self.options)

        for trial in range(4):
            batches = sum([it.next() for _ in range(3)], [])
            self.assertEqual(sorted(batches), dataset)
            for _ in range(2):
                self.assertRaises(StopIteration, it.next)
            it.reset() 
Example #23
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_copy_not_repeat(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.MultiprocessIterator(
            dataset, 2, repeat=False, **self.options)
        copy_it = copy.copy(it)
        batches = sum([it.next() for _ in range(3)], [])
        self.assertEqual(sorted(batches), dataset)
        for _ in range(2):
            self.assertRaises(StopIteration, it.next)
        it = None

        batches = sum([copy_it.next() for _ in range(3)], [])
        self.assertEqual(sorted(batches), dataset)
        for _ in range(2):
            self.assertRaises(StopIteration, copy_it.next) 
Example #24
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_iterator_shuffle_nondivisible(self):
        dataset = list(range(10))
        it = iterators.MultiprocessIterator(
            dataset, 3, **self.options)
        out = sum([it.next() for _ in range(7)], [])
        self.assertNotEqual(out[0:10], out[10:20]) 
Example #25
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_iterator_shuffle_divisible(self):
        dataset = list(range(10))
        it = iterators.MultiprocessIterator(
            dataset, 10, **self.options)
        self.assertNotEqual(it.next(), it.next()) 
Example #26
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_iterator_repeat_not_even(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)

        batches = sum([it.next() for _ in range(5)], [])
        self.assertEqual(sorted(batches), sorted(dataset * 2)) 
Example #27
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_iterator_dict_type(self):
        dataset = [{i: numpy.zeros((10,)) + i} for i in range(6)]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        for i in range(3):
            self.assertEqual(it.epoch, i)
            self.assertAlmostEqual(it.epoch_detail, i)
            if i == 0:
                self.assertIsNone(it.previous_epoch_detail)
            else:
                self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
            batches = {}
            for j in range(3):
                batch = it.next()
                self.assertEqual(len(batch), 2)
                if j != 2:
                    self.assertFalse(it.is_new_epoch)
                else:
                    self.assertTrue(it.is_new_epoch)
                self.assertAlmostEqual(
                    it.epoch_detail, (3 * i + j + 1) * 2 / 6)
                self.assertAlmostEqual(
                    it.previous_epoch_detail, (3 * i + j) * 2 / 6)
                for x in batch:
                    self.assertIsInstance(x, dict)
                    k = tuple(x)[0]
                    v = x[k]
                    self.assertIsInstance(v, numpy.ndarray)
                    batches[k] = v

            self.assertEqual(len(batches), len(dataset))
            for k, v in six.iteritems(batches):
                x = dataset[k][tuple(dataset[k])[0]]
                numpy.testing.assert_allclose(x, v) 
Example #28
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_iterator_tuple_type(self):
        dataset = [(i, numpy.zeros((10,)) + i) for i in range(6)]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        for i in range(3):
            self.assertEqual(it.epoch, i)
            self.assertAlmostEqual(it.epoch_detail, i)
            if i == 0:
                self.assertIsNone(it.previous_epoch_detail)
            else:
                self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
            batches = {}
            for j in range(3):
                batch = it.next()
                self.assertEqual(len(batch), 2)
                if j != 2:
                    self.assertFalse(it.is_new_epoch)
                else:
                    self.assertTrue(it.is_new_epoch)
                self.assertAlmostEqual(
                    it.epoch_detail, (3 * i + j + 1) * 2 / 6)
                self.assertAlmostEqual(
                    it.previous_epoch_detail, (3 * i + j) * 2 / 6)
                for x in batch:
                    self.assertIsInstance(x, tuple)
                    self.assertIsInstance(x[1], numpy.ndarray)
                    batches[x[0]] = x[1]

            self.assertEqual(len(batches), len(dataset))
            for k, v in six.iteritems(batches):
                numpy.testing.assert_allclose(dataset[k][1], v) 
Example #29
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_iterator_list_type(self):
        dataset = [[i, numpy.zeros((10,)) + i] for i in range(6)]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        for i in range(3):
            self.assertEqual(it.epoch, i)
            self.assertAlmostEqual(it.epoch_detail, i)
            if i == 0:
                self.assertIsNone(it.previous_epoch_detail)
            else:
                self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
            batches = {}
            for j in range(3):
                batch = it.next()
                self.assertEqual(len(batch), 2)
                if j != 2:
                    self.assertFalse(it.is_new_epoch)
                else:
                    self.assertTrue(it.is_new_epoch)
                self.assertAlmostEqual(
                    it.epoch_detail, (3 * i + j + 1) * 2 / 6)
                self.assertAlmostEqual(
                    it.previous_epoch_detail, (3 * i + j) * 2 / 6)
                for x in batch:
                    self.assertIsInstance(x, list)
                    self.assertIsInstance(x[1], numpy.ndarray)
                    batches[x[0]] = x[1]

            self.assertEqual(len(batches), len(dataset))
            for k, v in six.iteritems(batches):
                numpy.testing.assert_allclose(dataset[k][1], v) 
Example #30
Source File: test_multiprocess_iterator.py    From chainer with MIT License 5 votes vote down vote up
def test_iterator_repeat(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        for i in range(3):
            self.assertEqual(it.epoch, i)
            self.assertAlmostEqual(it.epoch_detail, i + 0 / 6)
            if i == 0:
                self.assertIsNone(it.previous_epoch_detail)
            else:
                self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
            batch1 = it.next()
            self.assertEqual(len(batch1), 2)
            self.assertIsInstance(batch1, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, i + 2 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6)
            batch2 = it.next()
            self.assertEqual(len(batch2), 2)
            self.assertIsInstance(batch2, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, i + 4 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6)
            batch3 = it.next()
            self.assertEqual(len(batch3), 2)
            self.assertIsInstance(batch3, list)
            self.assertTrue(it.is_new_epoch)
            self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
            self.assertAlmostEqual(it.epoch_detail, i + 6 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6)