Python fuel.schemes.SequentialScheme() Examples

The following are 26 code examples of fuel.schemes.SequentialScheme(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module fuel.schemes , or try the search function .
Example #1
Source File: test_mnist.py    From fuel with MIT License 6 votes vote down vote up
def test_mnist_train():
    skip_if_not_available(datasets=['mnist.hdf5'])

    dataset = MNIST(('train',), load_in_memory=False)
    handle = dataset.open()
    data, labels = dataset.get_data(handle, slice(0, 10))
    assert data.dtype == 'uint8'
    assert data.shape == (10, 1, 28, 28)
    assert labels.shape == (10, 1)
    known = numpy.array([0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253,
                         253, 253, 253, 253, 225, 172, 253, 242, 195,  64, 0,
                         0, 0, 0])
    assert_allclose(data[0][0][6], known)
    assert labels[0][0] == 5
    assert dataset.num_examples == 60000
    dataset.close(handle)

    stream = DataStream.default_stream(
        dataset, iteration_scheme=SequentialScheme(10, 10))
    data = next(stream.get_epoch_iterator())[0]
    assert data.min() >= 0.0 and data.max() <= 1.0
    assert data.dtype == config.floatX 
Example #2
Source File: wsj0.py    From DaNet-Tensorflow with MIT License 6 votes vote down vote up
def epoch(self, subset, batch_size, shuffle=False):
        dataset = self.subset[subset]
        handle = dataset.open()
        dset_size = self.h5file.attrs['split'][
            dict(train=0, valid=1, test=2)[subset]][3]
        indices = np.arange(
            ((dset_size + batch_size - 1) // batch_size)*batch_size)
        indices %= dset_size
        if shuffle:
            np.random.shuffle(indices)
        req_itor = SequentialScheme(
            examples=indices, batch_size=batch_size).get_request_iterator()
        for req in req_itor:
            data_pt = dataset.get_data(handle, req)
            max_len = max(map(len, data_pt[0]))
            spectra_li = [utils.random_zeropad(
                x, max_len - len(x), axis=-2)
                for x in data_pt[0]]
            spectra = np.stack(spectra_li)
            yield (spectra,)
        dataset.close(handle) 
Example #3
Source File: test_mnist.py    From fuel with MIT License 6 votes vote down vote up
def test_mnist_test():
    skip_if_not_available(datasets=['mnist.hdf5'])

    dataset = MNIST(('test',), load_in_memory=False)
    handle = dataset.open()
    data, labels = dataset.get_data(handle, slice(0, 10))
    assert data.dtype == 'uint8'
    assert data.shape == (10, 1, 28, 28)
    assert labels.shape == (10, 1)
    known = numpy.array([0, 0, 0, 0, 0, 0, 84, 185, 159, 151, 60, 36, 0, 0, 0,
                         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    assert_allclose(data[0][0][7], known)
    assert labels[0][0] == 7
    assert dataset.num_examples == 10000
    dataset.close(handle)

    stream = DataStream.default_stream(
        dataset, iteration_scheme=SequentialScheme(10, 10))
    data = next(stream.get_epoch_iterator())[0]
    assert data.min() >= 0.0 and data.max() <= 1.0
    assert data.dtype == config.floatX 
Example #4
Source File: load.py    From iGAN with MIT License 6 votes vote down vote up
def load_imgs_seq(ntrain=None, ntest=None, batch_size=128, data_file=None):
    t = time()
    print('LOADING DATASET...')
    path = os.path.join(data_file)
    tr_data = H5PYDataset(path, which_sets=('train',))
    te_data = H5PYDataset(path, which_sets=('test',))

    if ntrain is None:
        ntrain = tr_data.num_examples
    if ntest is None:
        ntest = te_data.num_examples

    tr_scheme = SequentialScheme(examples=ntrain, batch_size=batch_size)
    tr_stream = DataStream(tr_data, iteration_scheme=tr_scheme)

    te_scheme = SequentialScheme(examples=ntest, batch_size=batch_size)
    te_stream = DataStream(te_data, iteration_scheme=te_scheme)

    print('name = %s, ntrain = %d, ntest = %d' % (data_file, ntrain, ntest))
    print('%.2f seconds to load data' % (time() - t))

    return tr_data, te_data, tr_stream, te_stream, ntrain, ntest 
Example #5
Source File: load.py    From dcgan_code with MIT License 6 votes vote down vote up
def faces(ntrain=None, nval=None, ntest=None, batch_size=128):
    path = os.path.join(data_dir, 'faces_364293_128px.hdf5')
    tr_data = H5PYDataset(path, which_sets=('train',))
    te_data = H5PYDataset(path, which_sets=('test',))

    if ntrain is None:
        ntrain = tr_data.num_examples
    if ntest is None:
        ntest = te_data.num_examples
    if nval is None:
        nval = te_data.num_examples

    tr_scheme = ShuffledScheme(examples=ntrain, batch_size=batch_size)
    tr_stream = DataStream(tr_data, iteration_scheme=tr_scheme)

    te_scheme = SequentialScheme(examples=ntest, batch_size=batch_size)
    te_stream = DataStream(te_data, iteration_scheme=te_scheme)

    val_scheme = SequentialScheme(examples=nval, batch_size=batch_size)
    val_stream = DataStream(tr_data, iteration_scheme=val_scheme)
    return tr_data, te_data, tr_stream, val_stream, te_stream 
Example #6
Source File: test_mnist.py    From attention-lvcsr with MIT License 6 votes vote down vote up
def test_mnist_train():
    skip_if_not_available(datasets=['mnist.hdf5'])

    dataset = MNIST(('train',), load_in_memory=False)
    handle = dataset.open()
    data, labels = dataset.get_data(handle, slice(0, 10))
    assert data.dtype == 'uint8'
    assert data.shape == (10, 1, 28, 28)
    assert labels.shape == (10, 1)
    known = numpy.array([0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253,
                         253, 253, 253, 253, 225, 172, 253, 242, 195,  64, 0,
                         0, 0, 0])
    assert_allclose(data[0][0][6], known)
    assert labels[0][0] == 5
    assert dataset.num_examples == 60000
    dataset.close(handle)

    stream = DataStream.default_stream(
        dataset, iteration_scheme=SequentialScheme(10, 10))
    data = next(stream.get_epoch_iterator())[0]
    assert data.min() >= 0.0 and data.max() <= 1.0
    assert data.dtype == config.floatX 
Example #7
Source File: test_mnist.py    From attention-lvcsr with MIT License 6 votes vote down vote up
def test_mnist_test():
    skip_if_not_available(datasets=['mnist.hdf5'])

    dataset = MNIST(('test',), load_in_memory=False)
    handle = dataset.open()
    data, labels = dataset.get_data(handle, slice(0, 10))
    assert data.dtype == 'uint8'
    assert data.shape == (10, 1, 28, 28)
    assert labels.shape == (10, 1)
    known = numpy.array([0, 0, 0, 0, 0, 0, 84, 185, 159, 151, 60, 36, 0, 0, 0,
                         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    assert_allclose(data[0][0][7], known)
    assert labels[0][0] == 7
    assert dataset.num_examples == 10000
    dataset.close(handle)

    stream = DataStream.default_stream(
        dataset, iteration_scheme=SequentialScheme(10, 10))
    data = next(stream.get_epoch_iterator())[0]
    assert data.min() >= 0.0 and data.max() <= 1.0
    assert data.dtype == config.floatX 
Example #8
Source File: test_streams.py    From fuel with MIT License 5 votes vote down vote up
def test_axis_labels_on_produces_batches(self):
        dataset = IndexableDataset(numpy.eye(2))
        axis_labels = {'data': ('batch', 'features')}
        dataset.axis_labels = axis_labels
        stream = DataStream(dataset, iteration_scheme=SequentialScheme(2, 2))
        assert_equal(stream.axis_labels, axis_labels) 
Example #9
Source File: test_predict.py    From blocks-extras with MIT License 5 votes vote down vote up
def test_predict():
    tempfile_path = os.path.join(gettempdir(), 'test_predict.npz')

    # set up mock datastream
    source = [[1], [2], [3], [4]]
    dataset = IndexableDataset(OrderedDict([('input', source)]))
    scheme = SequentialScheme(dataset.num_examples, batch_size=2)
    data_stream = DataStream(dataset, iteration_scheme=scheme)

    # simulate small "network" that increments the input by 1
    input_tensor = tensor.matrix('input')
    output_tensor = input_tensor + 1
    output_tensor.name = 'output_tensor'

    main_loop = MockMainLoop(extensions=[
        PredictDataStream(data_stream=data_stream,
                          variables=[output_tensor],
                          path=tempfile_path,
                          after_training=True),
        FinishAfter(after_n_epochs=1)
    ])
    main_loop.run()

    # assert resulting prediction is saved
    prediction = numpy.load(tempfile_path)
    assert numpy.all(prediction[output_tensor.name] == numpy.array(source) + 1)

    try:
        os.remove(tempfile_path)
    except:
        pass 
Example #10
Source File: test_dogs_vs_cats.py    From fuel with MIT License 5 votes vote down vote up
def _test_dataset():
    train = DogsVsCats(('train',))
    assert train.num_examples == 25000
    assert_raises(ValueError, DogsVsCats, ('valid',))

    test = DogsVsCats(('test',))
    stream = DataStream.default_stream(
        test, iteration_scheme=SequentialScheme(10, 10))
    data = next(stream.get_epoch_iterator())[0][0]
    assert data.dtype.kind == 'f' 
Example #11
Source File: test_sequences.py    From fuel with MIT License 5 votes vote down vote up
def test_ngram_stream_raises_error_on_batch_stream():
    sentences = [list(numpy.random.randint(10, size=sentence_length))
                 for sentence_length in [3, 5, 7]]
    stream = DataStream(
        IndexableDataset(sentences), iteration_scheme=SequentialScheme(3, 1))
    assert_raises(ValueError, NGrams, 4, stream) 
Example #12
Source File: test_hdf5.py    From fuel with MIT License 5 votes vote down vote up
def test_data_stream_pickling(self):
        stream = DataStream(H5PYDataset(self.h5file, which_sets=('train',)),
                            iteration_scheme=SequentialScheme(100, 10))
        cPickle.loads(cPickle.dumps(stream))
        stream.close() 
Example #13
Source File: test_cifar10.py    From fuel with MIT License 5 votes vote down vote up
def test_cifar10():
    train = CIFAR10(('train',), load_in_memory=False)
    assert train.num_examples == 50000
    handle = train.open()
    features, targets = train.get_data(handle, slice(49990, 50000))
    assert features.shape == (10, 3, 32, 32)
    assert targets.shape == (10, 1)
    train.close(handle)

    test = CIFAR10(('test',), load_in_memory=False)
    handle = test.open()
    features, targets = test.get_data(handle, slice(0, 10))
    assert features.shape == (10, 3, 32, 32)
    assert targets.shape == (10, 1)
    assert features.dtype == numpy.uint8
    assert targets.dtype == numpy.uint8
    test.close(handle)

    stream = DataStream.default_stream(
        test, iteration_scheme=SequentialScheme(10, 10))
    data = next(stream.get_epoch_iterator())[0]
    assert data.min() >= 0.0 and data.max() <= 1.0
    assert data.dtype == config.floatX

    assert_raises(ValueError, CIFAR10, ('valid',))

    assert_raises(ValueError, CIFAR10,
                  ('train',), subset=slice(50000, 60000)) 
Example #14
Source File: test_server.py    From fuel with MIT License 5 votes vote down vote up
def get_stream():
    return DataStream(
        MNIST(('train',)), iteration_scheme=SequentialScheme(1500, 500)) 
Example #15
Source File: test_cifar100.py    From fuel with MIT License 5 votes vote down vote up
def test_cifar100():
    train = CIFAR100(('train',), load_in_memory=False)
    assert train.num_examples == 50000
    handle = train.open()
    coarse_labels, features, fine_labels = train.get_data(handle,
                                                          slice(49990, 50000))

    assert features.shape == (10, 3, 32, 32)
    assert coarse_labels.shape == (10, 1)
    assert fine_labels.shape == (10, 1)
    train.close(handle)

    test = CIFAR100(('test',), load_in_memory=False)
    handle = test.open()
    coarse_labels, features, fine_labels = test.get_data(handle,
                                                         slice(0, 10))

    assert features.shape == (10, 3, 32, 32)
    assert coarse_labels.shape == (10, 1)
    assert fine_labels.shape == (10, 1)

    assert features.dtype == numpy.uint8
    assert coarse_labels.dtype == numpy.uint8
    assert fine_labels.dtype == numpy.uint8

    test.close(handle)

    stream = DataStream.default_stream(
        test, iteration_scheme=SequentialScheme(10, 10))
    data = next(stream.get_epoch_iterator())[1]

    assert data.min() >= 0.0 and data.max() <= 1.0
    assert data.dtype == config.floatX

    assert_raises(ValueError, CIFAR100, ('valid',)) 
Example #16
Source File: test_aggregation.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_mean_aggregator():
    num_examples = 4
    batch_size = 2

    features = numpy.array([[0, 3],
                           [2, 9],
                           [2, 4],
                           [5, 1]], dtype=theano.config.floatX)

    dataset = IndexableDataset(OrderedDict([('features', features)]))

    data_stream = DataStream(dataset,
                             iteration_scheme=SequentialScheme(num_examples,
                                                               batch_size))

    x = tensor.matrix('features')
    y = (x**2).mean(axis=0)
    y.name = 'y'
    z = y.sum()
    z.name = 'z'

    y.tag.aggregation_scheme = Mean(y, 1.)
    z.tag.aggregation_scheme = Mean(z, 1.)

    assert_allclose(DatasetEvaluator([y]).evaluate(data_stream)['y'],
                    numpy.array([8.25, 26.75], dtype=theano.config.floatX))
    assert_allclose(DatasetEvaluator([z]).evaluate(data_stream)['z'],
                    numpy.array([35], dtype=theano.config.floatX)) 
Example #17
Source File: dataset.py    From kerosene with MIT License 5 votes vote down vote up
def fuel_data_to_list(fuel_data, shuffle):
    if(shuffle):
        scheme = ShuffledScheme(fuel_data.num_examples, fuel_data.num_examples)
    else:
        scheme = SequentialScheme(fuel_data.num_examples, fuel_data.num_examples)
    fuel_data_stream = DataStream.default_stream(fuel_data, iteration_scheme=scheme)
    return next(fuel_data_stream.get_epoch_iterator()) 
Example #18
Source File: test_hdf5.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_data_stream_pickling(self):
        stream = DataStream(H5PYDataset(self.h5file, which_sets=('train',)),
                            iteration_scheme=SequentialScheme(100, 10))
        cPickle.loads(cPickle.dumps(stream))
        stream.close() 
Example #19
Source File: test_cifar10.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_cifar10():
    train = CIFAR10(('train',), load_in_memory=False)
    assert train.num_examples == 50000
    handle = train.open()
    features, targets = train.get_data(handle, slice(49990, 50000))
    assert features.shape == (10, 3, 32, 32)
    assert targets.shape == (10, 1)
    train.close(handle)

    test = CIFAR10(('test',), load_in_memory=False)
    handle = test.open()
    features, targets = test.get_data(handle, slice(0, 10))
    assert features.shape == (10, 3, 32, 32)
    assert targets.shape == (10, 1)
    assert features.dtype == numpy.uint8
    assert targets.dtype == numpy.uint8
    test.close(handle)

    stream = DataStream.default_stream(
        test, iteration_scheme=SequentialScheme(10, 10))
    data = next(stream.get_epoch_iterator())[0]
    assert data.min() >= 0.0 and data.max() <= 1.0
    assert data.dtype == config.floatX

    assert_raises(ValueError, CIFAR10, ('valid',))

    assert_raises(ValueError, CIFAR10,
                  ('train',), subset=slice(50000, 60000)) 
Example #20
Source File: test_text.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_ngram_stream_raises_error_on_batch_stream():
    sentences = [list(numpy.random.randint(10, size=sentence_length))
                 for sentence_length in [3, 5, 7]]
    stream = DataStream(
        IndexableDataset(sentences), iteration_scheme=SequentialScheme(3, 1))
    assert_raises(ValueError, NGrams, 4, stream) 
Example #21
Source File: test_serialization.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_in_memory():
    skip_if_not_available(datasets=['mnist.hdf5'])
    # Load MNIST and get two batches
    mnist = MNIST(('train',), load_in_memory=True)
    data_stream = DataStream(mnist, iteration_scheme=SequentialScheme(
        examples=mnist.num_examples, batch_size=256))
    epoch = data_stream.get_epoch_iterator()
    for i, (features, targets) in enumerate(epoch):
        if i == 1:
            break
    handle = mnist.open()
    known_features, _ = mnist.get_data(handle, slice(256, 512))
    mnist.close(handle)
    assert numpy.all(features == known_features)

    # Pickle the epoch and make sure that the data wasn't dumped
    with tempfile.NamedTemporaryFile(delete=False) as f:
        filename = f.name
        cPickle.dump(epoch, f)
    assert os.path.getsize(filename) < 1024 * 1024  # Less than 1MB

    # Reload the epoch and make sure that the state was maintained
    del epoch
    with open(filename, 'rb') as f:
        epoch = cPickle.load(f)
    features, targets = next(epoch)
    handle = mnist.open()
    known_features, _ = mnist.get_data(handle, slice(512, 768))
    mnist.close(handle)
    assert numpy.all(features == known_features) 
Example #22
Source File: test_cifar100.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_cifar100():
    train = CIFAR100(('train',), load_in_memory=False)
    assert train.num_examples == 50000
    handle = train.open()
    coarse_labels, features, fine_labels = train.get_data(handle,
                                                          slice(49990, 50000))

    assert features.shape == (10, 3, 32, 32)
    assert coarse_labels.shape == (10, 1)
    assert fine_labels.shape == (10, 1)
    train.close(handle)

    test = CIFAR100(('test',), load_in_memory=False)
    handle = test.open()
    coarse_labels, features, fine_labels = test.get_data(handle,
                                                         slice(0, 10))

    assert features.shape == (10, 3, 32, 32)
    assert coarse_labels.shape == (10, 1)
    assert fine_labels.shape == (10, 1)

    assert features.dtype == numpy.uint8
    assert coarse_labels.dtype == numpy.uint8
    assert fine_labels.dtype == numpy.uint8

    test.close(handle)

    stream = DataStream.default_stream(
        test, iteration_scheme=SequentialScheme(10, 10))
    data = next(stream.get_epoch_iterator())[1]

    assert data.min() >= 0.0 and data.max() <= 1.0
    assert data.dtype == config.floatX

    assert_raises(ValueError, CIFAR100, ('valid',)) 
Example #23
Source File: test_streams.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_axis_labels_on_produces_batches(self):
        dataset = IndexableDataset(numpy.eye(2))
        axis_labels = {'data': ('batch', 'features')}
        dataset.axis_labels = axis_labels
        stream = DataStream(dataset, iteration_scheme=SequentialScheme(2, 2))
        assert_equal(stream.axis_labels, axis_labels) 
Example #24
Source File: fuel_helper.py    From plat with MIT License 4 votes vote down vote up
def create_streams(train_set, valid_set, test_set, training_batch_size,
                   monitoring_batch_size):
    """Creates data streams for training and monitoring.

    Parameters
    ----------
    train_set : :class:`fuel.datasets.Dataset`
        Training set.
    valid_set : :class:`fuel.datasets.Dataset`
        Validation set.
    test_set : :class:`fuel.datasets.Dataset`
        Test set.
    monitoring_batch_size : int
        Batch size for monitoring.
    include_targets : bool
        If ``True``, use both features and targets. If ``False``, use
        features only.

    Returns
    -------
    rval : tuple of data streams
        Data streams for the main loop, the training set monitor,
        the validation set monitor and the test set monitor.

    """
    main_loop_stream = DataStream.default_stream(
        dataset=train_set,
        iteration_scheme=ShuffledScheme(
            train_set.num_examples, training_batch_size))
    train_monitor_stream = DataStream.default_stream(
        dataset=train_set,
        iteration_scheme=ShuffledScheme(
            train_set.num_examples, monitoring_batch_size))
    valid_monitor_stream = DataStream.default_stream(
        dataset=valid_set,
        iteration_scheme=SequentialScheme(
            valid_set.num_examples, monitoring_batch_size))
    test_monitor_stream = DataStream.default_stream(
        dataset=test_set,
        iteration_scheme=SequentialScheme(
            test_set.num_examples, monitoring_batch_size))

    return (main_loop_stream, train_monitor_stream, valid_monitor_stream,
            test_monitor_stream) 
Example #25
Source File: __init__.py    From blocks-examples with MIT License 4 votes vote down vote up
def main(save_to, num_epochs):
    mlp = MLP([Tanh(), Softmax()], [784, 100, 10],
              weights_init=IsotropicGaussian(0.01),
              biases_init=Constant(0))
    mlp.initialize()
    x = tensor.matrix('features')
    y = tensor.lmatrix('targets')
    probs = mlp.apply(x)
    cost = CategoricalCrossEntropy().apply(y.flatten(), probs)
    error_rate = MisclassificationRate().apply(y.flatten(), probs)

    cg = ComputationGraph([cost])
    W1, W2 = VariableFilter(roles=[WEIGHT])(cg.variables)
    cost = cost + .00005 * (W1 ** 2).sum() + .00005 * (W2 ** 2).sum()
    cost.name = 'final_cost'

    mnist_train = MNIST(("train",))
    mnist_test = MNIST(("test",))

    algorithm = GradientDescent(
        cost=cost, parameters=cg.parameters,
        step_rule=Scale(learning_rate=0.1))
    extensions = [Timing(),
                  FinishAfter(after_n_epochs=num_epochs),
                  DataStreamMonitoring(
                      [cost, error_rate],
                      Flatten(
                          DataStream.default_stream(
                              mnist_test,
                              iteration_scheme=SequentialScheme(
                                  mnist_test.num_examples, 500)),
                          which_sources=('features',)),
                      prefix="test"),
                  TrainingDataMonitoring(
                      [cost, error_rate,
                       aggregation.mean(algorithm.total_gradient_norm)],
                      prefix="train",
                      after_epoch=True),
                  Checkpoint(save_to),
                  Printing()]

    if BLOCKS_EXTRAS_AVAILABLE:
        extensions.append(Plot(
            'MNIST example',
            channels=[
                ['test_final_cost',
                 'test_misclassificationrate_apply_error_rate'],
                ['train_total_gradient_norm']]))

    main_loop = MainLoop(
        algorithm,
        Flatten(
            DataStream.default_stream(
                mnist_train,
                iteration_scheme=SequentialScheme(
                    mnist_train.num_examples, 50)),
            which_sources=('features',)),
        model=Model(cost),
        extensions=extensions)

    main_loop.run() 
Example #26
Source File: data_provider.py    From deep_metric_learning with MIT License 4 votes vote down vote up
def get_streams(batch_size=50, dataset='cars196', method='n_pairs_mc',
                crop_size=224, load_in_memory=False):
    '''
    args:
        batch_size (int):
            number of examples per batch
        dataset (str):
            specify the dataset from 'cars196', 'cub200_2011', 'products'.
        method (str or fuel.schemes.IterationScheme):
            batch construction method. Specify 'n_pairs_mc', 'clustering', or
            a subclass of IterationScheme that has constructor such as
            `__init__(self, batch_size, dataset_train)` .
        crop_size (int or tuple of ints):
            height and width of the cropped image.
    '''
    dataset_class = get_dataset_class(dataset)
    dataset_train = dataset_class(['train'], load_in_memory=load_in_memory)
    dataset_test = dataset_class(['test'], load_in_memory=load_in_memory)

    if not isinstance(crop_size, tuple):
        crop_size = (crop_size, crop_size)

    if method == 'n_pairs_mc':
        labels = dataset_class(
            ['train'], sources=['targets'], load_in_memory=True).data_sources
        scheme = NPairLossScheme(labels, batch_size)
    elif method == 'clustering':
        scheme = EpochwiseShuffledInfiniteScheme(
            dataset_train.num_examples, batch_size)
    elif issubclass(method, IterationScheme):
        scheme = method(batch_size, dataset=dataset_train)
    else:
        raise ValueError("`method` must be 'n_pairs_mc' or 'clustering' "
                         "or subclass of IterationScheme.")
    stream = DataStream(dataset_train, iteration_scheme=scheme)
    stream_train = RandomFixedSizeCrop(stream, which_sources=('images',),
                                       random_lr_flip=True,
                                       window_shape=crop_size)

    stream_train_eval = RandomFixedSizeCrop(DataStream(
        dataset_train, iteration_scheme=SequentialScheme(
            dataset_train.num_examples, batch_size)),
        which_sources=('images',), center_crop=True, window_shape=crop_size)
    stream_test = RandomFixedSizeCrop(DataStream(
        dataset_test, iteration_scheme=SequentialScheme(
            dataset_test.num_examples, batch_size)),
        which_sources=('images',), center_crop=True, window_shape=crop_size)

    return stream_train, stream_train_eval, stream_test