Python chainer.iterators.MultiprocessIterator() Examples
The following are 30
code examples of chainer.iterators.MultiprocessIterator().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
chainer.iterators
, or try the search function
.
Example #1
Source File: cifar1.py From imgclsmob with MIT License | 6 votes |
def get_val_data_iterator(dataset_name, batch_size, num_workers): if dataset_name == "CIFAR10": _, test_ds = cifar.get_cifar10() elif dataset_name == "CIFAR100": _, test_ds = cifar.get_cifar100() elif dataset_name == "SVHN": _, test_ds = svhn.get_svhn() else: raise Exception('Unrecognized dataset: {}'.format(dataset_name)) val_dataset = test_ds val_dataset_len = len(val_dataset) val_iterator = iterators.MultiprocessIterator( dataset=val_dataset, batch_size=batch_size, repeat=False, shuffle=False, n_processes=num_workers, shared_mem=300000000) return val_iterator, val_dataset_len
Example #2
Source File: cifar1.py From imgclsmob with MIT License | 6 votes |
def get_data_iterators(batch_size, num_workers): train_dataset = PreprocessedCIFARDataset(train=True) train_iterator = iterators.MultiprocessIterator( dataset=train_dataset, batch_size=batch_size, repeat=False, shuffle=True, n_processes=num_workers) val_dataset = PreprocessedCIFARDataset(train=False) val_iterator = iterators.MultiprocessIterator( dataset=val_dataset, batch_size=batch_size, repeat=False, shuffle=False, n_processes=num_workers) return train_iterator, val_iterator
Example #3
Source File: dataset_utils.py From imgclsmob with MIT License | 6 votes |
def get_train_data_source(ds_metainfo, batch_size, num_workers): transform = ds_metainfo.train_transform(ds_metainfo=ds_metainfo) dataset = ds_metainfo.dataset_class( root=ds_metainfo.root_dir_path, mode="train", transform=transform) ds_metainfo.update_from_dataset(dataset) iterator = MultiprocessIterator( dataset=dataset, batch_size=batch_size, repeat=False, shuffle=True, n_processes=num_workers, shared_mem=300000000) return { # "transform": transform, "iterator": iterator, "ds_len": len(dataset) }
Example #4
Source File: dataset_utils.py From imgclsmob with MIT License | 6 votes |
def get_test_data_source(ds_metainfo, batch_size, num_workers): transform = ds_metainfo.test_transform(ds_metainfo=ds_metainfo) dataset = ds_metainfo.dataset_class( root=ds_metainfo.root_dir_path, mode="test", transform=transform) ds_metainfo.update_from_dataset(dataset) iterator = MultiprocessIterator( dataset=dataset, batch_size=batch_size, repeat=False, shuffle=False, n_processes=num_workers, shared_mem=300000000) return { # "transform": transform, "iterator": iterator, "ds_len": len(dataset) }
Example #5
Source File: imagenet1k1.py From imgclsmob with MIT License | 6 votes |
def get_val_data_iterator(data_dir, batch_size, num_workers, num_classes): val_dir_path = os.path.join(data_dir, 'val') val_dataset = DirectoryParsingLabelDataset(val_dir_path) val_dataset_len = len(val_dataset) assert(len(directory_parsing_label_names(val_dir_path)) == num_classes) val_iterator = iterators.MultiprocessIterator( dataset=val_dataset, batch_size=batch_size, repeat=False, shuffle=False, n_processes=num_workers, shared_mem=300000000) return val_iterator, val_dataset_len
Example #6
Source File: test_multiprocess_iterator.py From chainer with MIT License | 6 votes |
def test_iterator_not_repeat_not_even(self): dataset = [1, 2, 3, 4, 5] it = iterators.MultiprocessIterator( dataset, 2, repeat=False, **self.options) self.assertAlmostEqual(it.epoch_detail, 0 / 5) self.assertIsNone(it.previous_epoch_detail) batch1 = it.next() self.assertAlmostEqual(it.epoch_detail, 2 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5) batch2 = it.next() self.assertAlmostEqual(it.epoch_detail, 4 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5) batch3 = it.next() self.assertAlmostEqual(it.epoch_detail, 5 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5) self.assertRaises(StopIteration, it.next) self.assertEqual(len(batch3), 1) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
Example #7
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_stalled_getitem(self): nth = self.nth batch_size = 2 sleep = 0.5 timeout = 0.1 dataset = StallingDataset(nth, sleep) it = iterators.MultiprocessIterator( dataset, batch_size=batch_size, shuffle=False, dataset_timeout=timeout, repeat=False) # TimeoutWarning should be issued. warning_cls = iterators.MultiprocessIterator.TimeoutWarning data = [] # No warning until the stalling batch for i in range(nth // batch_size): data.append(it.next()) # Warning on the stalling batch with testing.assert_warns(warning_cls): data.append(it.next()) # Retrieve data until the end while True: try: data.append(it.next()) except StopIteration: break # All data must be retrieved assert data == [ dataset.data[i * batch_size: (i+1) * batch_size] for i in range((len(dataset) + batch_size - 1) // batch_size)]
Example #8
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_iterator_pickle_after_init(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.MultiprocessIterator(dataset, 2, **self.options) self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) self.assertIsNone(it.previous_epoch_detail) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) pickled_it = pickle.dumps(it) it = pickle.loads(pickled_it) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
Example #9
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_iterator_serialize(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.MultiprocessIterator(dataset, 2, **self.options) self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) self.assertIsNone(it.previous_epoch_detail) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) target = dict() it.serialize(serializers.DictionarySerializer(target)) it = iterators.MultiprocessIterator(dataset, 2, **self.options) it.serialize(serializers.NpzDeserializer(target)) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
Example #10
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_iterator_repeat(self): dataset = [1, 2, 3] it = iterators.MultiprocessIterator(dataset, 2, **self.options) for i in range(3): self.assertEqual(it.epoch, i) self.assertAlmostEqual(it.epoch_detail, i + 0 / 6) if i == 0: self.assertIsNone(it.previous_epoch_detail) else: self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6) self.assertEqual( sorted(batch1 + batch2 + batch3), [1, 1, 2, 2, 3, 3])
Example #11
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_no_same_indices_order_sampler(self): dataset = [1, 2, 3, 4, 5, 6] batchsize = 5 it = iterators.MultiprocessIterator( dataset, batchsize, order_sampler=_NoSameIndicesOrderSampler(batchsize)) for _ in range(5): batch = it.next() self.assertEqual(len(numpy.unique(batch)), batchsize)
Example #12
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_invalid_order_sampler(self): dataset = [1, 2, 3, 4, 5, 6] with self.assertRaises(ValueError): it = iterators.MultiprocessIterator( dataset, 6, shuffle=None, order_sampler=_InvalidOrderSampler()) it.next()
Example #13
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_finalize_not_deadlock(self): dataset = numpy.ones((1000, 1000)) it = iterators.MultiprocessIterator(dataset, 10, n_processes=4) for _ in range(10): it.next() t = threading.Thread(target=lambda: it.finalize()) t.daemon = True t.start() t.join(5) deadlock = t.is_alive() self.assertFalse(deadlock)
Example #14
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_reproduce_same_permutation(self): dataset = [1, 2, 3, 4, 5, 6] order_sampler1 = iterators.ShuffleOrderSampler( numpy.random.RandomState(self._seed)) it1 = iterators.MultiprocessIterator( dataset, 6, order_sampler=order_sampler1) order_sampler2 = iterators.ShuffleOrderSampler( numpy.random.RandomState(self._seed)) it2 = iterators.MultiprocessIterator( dataset, 6, order_sampler=order_sampler2) for _ in range(5): self.assertEqual(it1.next(), it2.next())
Example #15
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_reset_repeat(self): dataset = [1, 2, 3, 4] it = iterators.MultiprocessIterator( dataset, 2, repeat=True, **self.options) for trial in range(4): batches = sum([it.next() for _ in range(4)], []) self.assertEqual(sorted(batches), sorted(2 * dataset)) it.reset()
Example #16
Source File: test_iterator_compatibility.py From chainer with MIT License | 5 votes |
def test_iterator_compatibilty(self): dataset = [1, 2, 3, 4, 5, 6] iters = ( lambda: iterators.SerialIterator(dataset, 2), lambda: iterators.MultiprocessIterator(dataset, 2, **self.options), ) for it_before, it_after in itertools.permutations(iters, 2): it = it_before() self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 2 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) target = dict() it.serialize(serializers.DictionarySerializer(target)) it = it_after() it.serialize(serializers.NpzDeserializer(target)) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, 6 / 6)
Example #17
Source File: evaluator.py From chainer with MIT License | 5 votes |
def __init__(self, iterator, target, converter=convert.concat_examples, device=None, eval_hook=None, eval_func=None, **kwargs): progress_bar, = argument.parse_kwargs(kwargs, ('progress_bar', False)) if device is not None: device = backend.get_device(device) if isinstance(iterator, iterator_module.Iterator): iterator = {'main': iterator} self._iterators = iterator if isinstance(target, link.Link): target = {'main': target} self._targets = target self.converter = converter self.device = device self.eval_hook = eval_hook self.eval_func = eval_func self._progress_bar = progress_bar for key, iter in six.iteritems(iterator): if (isinstance(iter, (iterators.SerialIterator, iterators.MultiprocessIterator, iterators.MultithreadIterator)) and getattr(iter, 'repeat', False)): msg = 'The `repeat` property of the iterator {} ' 'is set to `True`. Typically, the evaluator sweeps ' 'over iterators until they stop, ' 'but as the property being `True`, this iterator ' 'might not stop and evaluation could go into ' 'an infinite loop. ' 'We recommend to check the configuration ' 'of iterators'.format(key) warnings.warn(msg)
Example #18
Source File: eval_imagenet.py From chainercv with MIT License | 5 votes |
def main(): parser = argparse.ArgumentParser( description='Evaluating convnet from ILSVRC2012 dataset') parser.add_argument('val', help='Path to root of the validation dataset') parser.add_argument('--model', choices=sorted(models.keys())) parser.add_argument('--pretrained-model') parser.add_argument('--dataset', choices=('imagenet',)) parser.add_argument('--gpu', type=int, default=-1) parser.add_argument('--batchsize', type=int) parser.add_argument('--crop', choices=('center', '10')) parser.add_argument('--resnet-arch') args = parser.parse_args() dataset, eval_, model, batchsize = setup( args.dataset, args.model, args.pretrained_model, args.batchsize, args.val, args.crop, args.resnet_arch) if args.gpu >= 0: chainer.cuda.get_device(args.gpu).use() model.to_gpu() iterator = iterators.MultiprocessIterator( dataset, batchsize, repeat=False, shuffle=False, n_processes=6, shared_mem=300000000) print('Model has been prepared. Evaluation starts.') in_values, out_values, rest_values = apply_to_iterator( model.predict, iterator, hook=ProgressHook(len(dataset))) del in_values eval_(out_values, rest_values)
Example #19
Source File: main.py From SPReID with MIT License | 5 votes |
def Evaluation(): # Creat data generator batch_tuple = MultiprocessIterator( DataChef.ReID10D(args, args.project_folder + '/evaluation_list/' + args.eval_split + '.txt', image_size=args.scales_tr[0]), args.minibatch, n_prefetch=2, n_processes=args.nb_processes, shared_mem=20000000, repeat=False, shuffle=False) # Keep the log in history history = {args.dataset: {'features': []}} for dataBatch in batch_tuple: dataBatch = zip(*dataBatch) # Prepare batch data IMG = np.array_split(np.array(dataBatch[0]), len(Model), axis=0) LBL = np.array_split(np.array(dataBatch[1]), len(Model), axis=0) # Forward for device_id, img, lbl in zip(range(len(Model)), IMG, LBL): Model[device_id](img, lbl, args.dataset, train=False) # Aggregate reporters from all GPUs reporters = [] for i in range(len(Model)): reporters.append(Model[i].reporter) Model[i].reporter = {} # clear reporter # History for reporter in reporters: for k in reporter[args.dataset].keys(): history[args.dataset][k].append(reporter[args.dataset][k]) # storing features to an outputfile features = np.concatenate(history[args.dataset]['features'], axis=0) outfile = '%s/evaluation_features/%s_@%s_%s.csv' % ( args.project_folder, args.dataset, args.checkpoint, args.eval_split) np.savetxt(outfile, features, delimiter=',', fmt='%0.12e')
Example #20
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_iterator_pickle_new(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.MultiprocessIterator(dataset, 2, **self.options) self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) self.assertIsNone(it.previous_epoch_detail) pickled_it = pickle.dumps(it) it = pickle.loads(pickled_it)
Example #21
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_unsupported_reset_finalized(self): dataset = [1, 2, 3, 4] it = iterators.MultiprocessIterator( dataset, 2, repeat=False, **self.options) it.next() it.next() it.finalize() self.assertRaises(NotImplementedError, it.reset)
Example #22
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_reset(self): dataset = [1, 2, 3, 4, 5] it = iterators.MultiprocessIterator( dataset, 2, repeat=False, **self.options) for trial in range(4): batches = sum([it.next() for _ in range(3)], []) self.assertEqual(sorted(batches), dataset) for _ in range(2): self.assertRaises(StopIteration, it.next) it.reset()
Example #23
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_copy_not_repeat(self): dataset = [1, 2, 3, 4, 5] it = iterators.MultiprocessIterator( dataset, 2, repeat=False, **self.options) copy_it = copy.copy(it) batches = sum([it.next() for _ in range(3)], []) self.assertEqual(sorted(batches), dataset) for _ in range(2): self.assertRaises(StopIteration, it.next) it = None batches = sum([copy_it.next() for _ in range(3)], []) self.assertEqual(sorted(batches), dataset) for _ in range(2): self.assertRaises(StopIteration, copy_it.next)
Example #24
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_iterator_shuffle_nondivisible(self): dataset = list(range(10)) it = iterators.MultiprocessIterator( dataset, 3, **self.options) out = sum([it.next() for _ in range(7)], []) self.assertNotEqual(out[0:10], out[10:20])
Example #25
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_iterator_shuffle_divisible(self): dataset = list(range(10)) it = iterators.MultiprocessIterator( dataset, 10, **self.options) self.assertNotEqual(it.next(), it.next())
Example #26
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_iterator_repeat_not_even(self): dataset = [1, 2, 3, 4, 5] it = iterators.MultiprocessIterator(dataset, 2, **self.options) batches = sum([it.next() for _ in range(5)], []) self.assertEqual(sorted(batches), sorted(dataset * 2))
Example #27
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_iterator_dict_type(self): dataset = [{i: numpy.zeros((10,)) + i} for i in range(6)] it = iterators.MultiprocessIterator(dataset, 2, **self.options) for i in range(3): self.assertEqual(it.epoch, i) self.assertAlmostEqual(it.epoch_detail, i) if i == 0: self.assertIsNone(it.previous_epoch_detail) else: self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) batches = {} for j in range(3): batch = it.next() self.assertEqual(len(batch), 2) if j != 2: self.assertFalse(it.is_new_epoch) else: self.assertTrue(it.is_new_epoch) self.assertAlmostEqual( it.epoch_detail, (3 * i + j + 1) * 2 / 6) self.assertAlmostEqual( it.previous_epoch_detail, (3 * i + j) * 2 / 6) for x in batch: self.assertIsInstance(x, dict) k = tuple(x)[0] v = x[k] self.assertIsInstance(v, numpy.ndarray) batches[k] = v self.assertEqual(len(batches), len(dataset)) for k, v in six.iteritems(batches): x = dataset[k][tuple(dataset[k])[0]] numpy.testing.assert_allclose(x, v)
Example #28
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_iterator_tuple_type(self): dataset = [(i, numpy.zeros((10,)) + i) for i in range(6)] it = iterators.MultiprocessIterator(dataset, 2, **self.options) for i in range(3): self.assertEqual(it.epoch, i) self.assertAlmostEqual(it.epoch_detail, i) if i == 0: self.assertIsNone(it.previous_epoch_detail) else: self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) batches = {} for j in range(3): batch = it.next() self.assertEqual(len(batch), 2) if j != 2: self.assertFalse(it.is_new_epoch) else: self.assertTrue(it.is_new_epoch) self.assertAlmostEqual( it.epoch_detail, (3 * i + j + 1) * 2 / 6) self.assertAlmostEqual( it.previous_epoch_detail, (3 * i + j) * 2 / 6) for x in batch: self.assertIsInstance(x, tuple) self.assertIsInstance(x[1], numpy.ndarray) batches[x[0]] = x[1] self.assertEqual(len(batches), len(dataset)) for k, v in six.iteritems(batches): numpy.testing.assert_allclose(dataset[k][1], v)
Example #29
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_iterator_list_type(self): dataset = [[i, numpy.zeros((10,)) + i] for i in range(6)] it = iterators.MultiprocessIterator(dataset, 2, **self.options) for i in range(3): self.assertEqual(it.epoch, i) self.assertAlmostEqual(it.epoch_detail, i) if i == 0: self.assertIsNone(it.previous_epoch_detail) else: self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) batches = {} for j in range(3): batch = it.next() self.assertEqual(len(batch), 2) if j != 2: self.assertFalse(it.is_new_epoch) else: self.assertTrue(it.is_new_epoch) self.assertAlmostEqual( it.epoch_detail, (3 * i + j + 1) * 2 / 6) self.assertAlmostEqual( it.previous_epoch_detail, (3 * i + j) * 2 / 6) for x in batch: self.assertIsInstance(x, list) self.assertIsInstance(x[1], numpy.ndarray) batches[x[0]] = x[1] self.assertEqual(len(batches), len(dataset)) for k, v in six.iteritems(batches): numpy.testing.assert_allclose(dataset[k][1], v)
Example #30
Source File: test_multiprocess_iterator.py From chainer with MIT License | 5 votes |
def test_iterator_repeat(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.MultiprocessIterator(dataset, 2, **self.options) for i in range(3): self.assertEqual(it.epoch, i) self.assertAlmostEqual(it.epoch_detail, i + 0 / 6) if i == 0: self.assertIsNone(it.previous_epoch_detail) else: self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, i + 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6)