Python fuel.streams.DataStream() Examples

The following are 30 code examples of fuel.streams.DataStream(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module fuel.streams , or try the search function .
Example #1
Source File: load.py    From iGAN with MIT License 6 votes vote down vote up
def load_imgs_seq(ntrain=None, ntest=None, batch_size=128, data_file=None):
    t = time()
    print('LOADING DATASET...')
    path = os.path.join(data_file)
    tr_data = H5PYDataset(path, which_sets=('train',))
    te_data = H5PYDataset(path, which_sets=('test',))

    if ntrain is None:
        ntrain = tr_data.num_examples
    if ntest is None:
        ntest = te_data.num_examples

    tr_scheme = SequentialScheme(examples=ntrain, batch_size=batch_size)
    tr_stream = DataStream(tr_data, iteration_scheme=tr_scheme)

    te_scheme = SequentialScheme(examples=ntest, batch_size=batch_size)
    te_stream = DataStream(te_data, iteration_scheme=te_scheme)

    print('name = %s, ntrain = %d, ntest = %d' % (data_file, ntrain, ntest))
    print('%.2f seconds to load data' % (time() - t))

    return tr_data, te_data, tr_stream, te_stream, ntrain, ntest 
Example #2
Source File: timit.py    From CTC-LSTM with Apache License 2.0 6 votes vote down vote up
def setup_datastream(path, batch_size, sort_batch_count, valid=False):
    A = numpy.load(os.path.join(path, ('valid_x_raw.npy' if valid else 'train_x_raw.npy')))
    B = numpy.load(os.path.join(path, ('valid_phn.npy' if valid else 'train_phn.npy')))
    C = numpy.load(os.path.join(path, ('valid_seq_to_phn.npy' if valid else 'train_seq_to_phn.npy')))

    D = [B[x[0]:x[1], 2] for x in C]

    ds = IndexableDataset({'input': A, 'output': D})
    stream = DataStream(ds, iteration_scheme=ShuffledExampleScheme(len(A)))

    stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * sort_batch_count))
    comparison = _balanced_batch_helper(stream.sources.index('input'))
    stream = Mapping(stream, SortMapping(comparison))
    stream = Unpack(stream)

    stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size, num_examples=len(A)))
    stream = Padding(stream, mask_sources=['input', 'output'])

    return ds, stream 
Example #3
Source File: stream.py    From dl4mt-multi with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def get_dev_streams(config):
    """Setup development set stream if necessary."""
    dev_streams = {}
    for cg in config['cgs']:
        if 'val_sets' in config and cg in config['val_sets']:
            logger.info('Building development stream for cg:[{}]'.format(cg))
            eid = p_(cg)[0]
            dev_file = config['val_sets'][cg]

            # Get dictionary and fix EOS
            dictionary = cPickle.load(open(config['src_vocabs'][eid]))
            dictionary['<S>'] = 0
            dictionary['<UNK>'] = config['unk_id']
            dictionary['</S>'] = config['src_eos_idxs'][eid]

            # Get as a text file and convert it into a stream
            dev_dataset = TextFile([dev_file], dictionary, None)
            dev_streams[cg] = DataStream(dev_dataset)
    return dev_streams 
Example #4
Source File: data.py    From DeepMind-Teaching-Machines-to-Read-and-Comprehend with MIT License 6 votes vote down vote up
def setup_datastream(path, vocab_file, config):
    ds = QADataset(path, vocab_file, config.n_entities, need_sep_token=config.concat_ctx_and_question)
    it = QAIterator(path, shuffle=config.shuffle_questions)

    stream = DataStream(ds, iteration_scheme=it)

    if config.concat_ctx_and_question:
        stream = ConcatCtxAndQuestion(stream, config.concat_question_before, ds.reverse_vocab['<SEP>'])

    # Sort sets of multiple batches to make batches of similar sizes
    stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size * config.sort_batch_count))
    comparison = _balanced_batch_helper(stream.sources.index('question' if config.concat_ctx_and_question else 'context'))
    stream = Mapping(stream, SortMapping(comparison))
    stream = Unpack(stream)

    stream = Batch(stream, iteration_scheme=ConstantScheme(config.batch_size))
    stream = Padding(stream, mask_sources=['context', 'question', 'candidates'], mask_dtype='int32')

    return ds, stream 
Example #5
Source File: load.py    From dcgan_code with MIT License 6 votes vote down vote up
def faces(ntrain=None, nval=None, ntest=None, batch_size=128):
    path = os.path.join(data_dir, 'faces_364293_128px.hdf5')
    tr_data = H5PYDataset(path, which_sets=('train',))
    te_data = H5PYDataset(path, which_sets=('test',))

    if ntrain is None:
        ntrain = tr_data.num_examples
    if ntest is None:
        ntest = te_data.num_examples
    if nval is None:
        nval = te_data.num_examples

    tr_scheme = ShuffledScheme(examples=ntrain, batch_size=batch_size)
    tr_stream = DataStream(tr_data, iteration_scheme=tr_scheme)

    te_scheme = SequentialScheme(examples=ntest, batch_size=batch_size)
    te_stream = DataStream(te_data, iteration_scheme=te_scheme)

    val_scheme = SequentialScheme(examples=nval, batch_size=batch_size)
    val_stream = DataStream(tr_data, iteration_scheme=val_scheme)
    return tr_data, te_data, tr_stream, val_stream, te_stream 
Example #6
Source File: load.py    From iGAN with MIT License 6 votes vote down vote up
def load_imgs(ntrain=None, ntest=None, batch_size=128, data_file=None):
    t = time()
    print('LOADING DATASET...')
    path = os.path.join(data_file)
    tr_data = H5PYDataset(path, which_sets=('train',))
    te_data = H5PYDataset(path, which_sets=('test',))

    if ntrain is None:
        ntrain = tr_data.num_examples
    else:
        ntrain = min(ntrain, tr_data.num_examples)

    if ntest is None:
        ntest = te_data.num_examples
    else:
        ntest = min(ntest, te_data.num_examples)
    print('name = %s, ntrain = %d, ntest = %d' % (data_file, ntrain, ntest))

    tr_scheme = ShuffledScheme(examples=ntrain, batch_size=batch_size)
    tr_stream = DataStream(tr_data, iteration_scheme=tr_scheme)

    te_scheme = ShuffledScheme(examples=ntest, batch_size=batch_size)
    te_stream = DataStream(te_data, iteration_scheme=te_scheme)
    print('%.2f secs to load data' % (time() - t))
    return tr_data, te_data, tr_stream, te_stream, ntrain, ntest 
Example #7
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def setUp(self):
        data = range(10)
        self.stream = Batch(
            DataStream(IterableDataset(data)),
            iteration_scheme=ConstantScheme(2))
        data_np = numpy.arange(10)
        self.stream_np = Batch(
            DataStream(IterableDataset(data_np)),
            iteration_scheme=ConstantScheme(2)) 
Example #8
Source File: test_streams.py    From fuel with MIT License 5 votes vote down vote up
def test_axis_labels_on_produces_examples(self):
        axis_labels = {'data': ('batch', 'features')}
        self.dataset.axis_labels = axis_labels
        stream = DataStream(self.dataset)
        assert_equal(stream.axis_labels, {'data': ('features',)}) 
Example #9
Source File: test_streams.py    From fuel with MIT License 5 votes vote down vote up
def test_no_axis_labels(self):
        stream = DataStream(self.dataset)
        assert stream.axis_labels is None 
Example #10
Source File: test_streams.py    From fuel with MIT License 5 votes vote down vote up
def test_sources_setter(self):
        stream = DataStream(self.dataset)
        stream.sources = ('features',)
        assert_equal(stream.sources, ('features',)) 
Example #11
Source File: test_datasets.py    From fuel with MIT License 5 votes vote down vote up
def test_sources_selection():
    features = [5, 6, 7, 1]
    targets = [1, 0, 1, 1]
    stream = DataStream(IterableDataset(OrderedDict(
        [('features', features), ('targets', targets)])))
    assert list(stream.get_epoch_iterator()) == list(zip(features, targets))

    stream = DataStream(IterableDataset(
        {'features': features, 'targets': targets},
        sources=('targets',)))
    assert list(stream.get_epoch_iterator()) == list(zip(targets)) 
Example #12
Source File: test_streams.py    From fuel with MIT License 5 votes vote down vote up
def test_axis_labels_on_produces_batches(self):
        dataset = IndexableDataset(numpy.eye(2))
        axis_labels = {'data': ('batch', 'features')}
        dataset.axis_labels = axis_labels
        stream = DataStream(dataset, iteration_scheme=SequentialScheme(2, 2))
        assert_equal(stream.axis_labels, axis_labels) 
Example #13
Source File: test_text.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_text():
    # Test word level and epochs.
    with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
        sentences1 = f.name
        f.write("This is a sentence\n")
        f.write("This another one")
    with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
        sentences2 = f.name
        f.write("More sentences\n")
        f.write("The last one")
    dictionary = {'<UNK>': 0, '</S>': 1, 'this': 2, 'a': 3, 'one': 4}
    text_data = TextFile(files=[sentences1, sentences2],
                         dictionary=dictionary, bos_token=None,
                         preprocess=lower)
    stream = DataStream(text_data)
    epoch = stream.get_epoch_iterator()
    assert len(list(epoch)) == 4
    epoch = stream.get_epoch_iterator()
    for sentence in zip(range(3), epoch):
        pass
    f = BytesIO()
    cPickle.dump(epoch, f)
    sentence = next(epoch)
    f.seek(0)
    epoch = cPickle.load(f)
    assert next(epoch) == sentence
    assert_raises(StopIteration, next, epoch)

    # Test character level.
    dictionary = dict([(chr(ord('a') + i), i) for i in range(26)] +
                      [(' ', 26)] + [('<S>', 27)] +
                      [('</S>', 28)] + [('<UNK>', 29)])
    text_data = TextFile(files=[sentences1, sentences2],
                         dictionary=dictionary, preprocess=lower,
                         level="character")
    sentence = next(DataStream(text_data).get_epoch_iterator())[0]
    assert sentence[:3] == [27, 19, 7]
    assert sentence[-3:] == [2, 4, 28] 
Example #14
Source File: test_serialization.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_in_memory():
    skip_if_not_available(datasets=['mnist.hdf5'])
    # Load MNIST and get two batches
    mnist = MNIST(('train',), load_in_memory=True)
    data_stream = DataStream(mnist, iteration_scheme=SequentialScheme(
        examples=mnist.num_examples, batch_size=256))
    epoch = data_stream.get_epoch_iterator()
    for i, (features, targets) in enumerate(epoch):
        if i == 1:
            break
    handle = mnist.open()
    known_features, _ = mnist.get_data(handle, slice(256, 512))
    mnist.close(handle)
    assert numpy.all(features == known_features)

    # Pickle the epoch and make sure that the data wasn't dumped
    with tempfile.NamedTemporaryFile(delete=False) as f:
        filename = f.name
        cPickle.dump(epoch, f)
    assert os.path.getsize(filename) < 1024 * 1024  # Less than 1MB

    # Reload the epoch and make sure that the state was maintained
    del epoch
    with open(filename, 'rb') as f:
        epoch = cPickle.load(f)
    features, targets = next(epoch)
    handle = mnist.open()
    known_features, _ = mnist.get_data(handle, slice(512, 768))
    mnist.close(handle)
    assert numpy.all(features == known_features) 
Example #15
Source File: test_server.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def get_stream():
    return DataStream(
        MNIST(('train',)), iteration_scheme=SequentialScheme(1500, 500)) 
Example #16
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_batchwise(self):
        stream = DataStream(
            dataset=self.dataset, iteration_scheme=SequentialScheme(2, 2))
        decoded_stream = ToBytes(stream)
        assert_equal([self.string_data],
                     [s for s, in decoded_stream.get_epoch_iterator()]) 
Example #17
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_examplewise(self):
        stream = DataStream(
            dataset=self.dataset, iteration_scheme=SequentialExampleScheme(2))
        decoded_stream = ToBytes(stream)
        assert_equal(self.string_data,
                     [s for s, in decoded_stream.get_epoch_iterator()]) 
Example #18
Source File: test_transformers.py    From fuel with MIT License 5 votes vote down vote up
def test_mapping(self):
        stream = DataStream(IterableDataset(self.data))
        transformer = Mapping(stream, lambda d: ([2 * i for i in d[0]],))
        assert_equal(list(transformer.get_epoch_iterator()),
                     list(zip([[2, 4, 6], [4, 6, 2], [6, 4, 2]]))) 
Example #19
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def setUp(self):
        self.stream = DataStream(
            IndexableDataset(
                OrderedDict([('X', numpy.ones((4, 2, 2))),
                             ('y', numpy.array([0, 1, 0, 1]))]),
                axis_labels={'X': ('batch', 'width', 'height'),
                             'y': ('batch',)}),
            iteration_scheme=SequentialScheme(4, 2))
        self.transformer = Rename(
            self.stream, {'X': 'features', 'y': 'targets'}) 
Example #20
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def setUp(self):
        stream = DataStream(IterableDataset(range(100)))
        self.transformer = Mapping(stream, lambda x: (x[0] + 1,)) 
Example #21
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_value_error_on_different_stream_output_type(self):
        spanish_stream = DataStream(IndexableDataset(['Hola mundo!']),
                                    iteration_scheme=SequentialScheme(2, 2))
        assert_raises(ValueError, Merge, self.streams + (spanish_stream,),
                      ('english', 'french', 'spanish')) 
Example #22
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def setUp(self):
        self.streams = (
            DataStream(IterableDataset(['Hello world!'])),
            DataStream(IterableDataset(['Bonjour le monde!'])))
        self.batch_streams = (
            Batch(DataStream(IterableDataset(['Hello world!', 'Hi!'])),
                  iteration_scheme=ConstantScheme(2)),
            Batch(DataStream(IterableDataset(['Bonjour le monde!', 'Salut!'])),
                  iteration_scheme=ConstantScheme(2)))
        self.transformer = Merge(
            self.streams, ('english', 'french'))
        self.batch_transformer = Merge(
            self.batch_streams, ('english', 'french')) 
Example #23
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_value_error_on_example_stream(self):
        stream = DataStream(
            IterableDataset(
                dict(features=[[1], [2, 3]], targets=[[4, 5, 6], [7]])))
        assert_raises(ValueError, Padding, stream) 
Example #24
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_mask_sources(self):
        transformer = Padding(Batch(
            DataStream(
                IterableDataset(
                    OrderedDict([
                        ('features', [[1], [2, 3]]),
                        ('targets', [[4, 5, 6], [7]])]))),
            ConstantScheme(2)),
            mask_sources=('features',))
        assert_equal(len(next(transformer.get_epoch_iterator())), 3) 
Example #25
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_mask_dtype(self):
        transformer = Padding(Batch(
            DataStream(
                IterableDataset(
                    dict(features=[[1], [2, 3]], targets=[[4, 5, 6], [7]]))),
            ConstantScheme(2)),
            mask_dtype='uint8')
        assert_equal(
            str(next(transformer.get_epoch_iterator())[1].dtype), 'uint8') 
Example #26
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_two_sources(self):
        transformer = Padding(Batch(
            DataStream(
                IterableDataset(
                    dict(features=[[1], [2, 3]], targets=[[4, 5, 6], [7]]))),
            ConstantScheme(2)))
        assert len(next(transformer.get_epoch_iterator())) == 4 
Example #27
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_2d_sequences_error_on_unequal_shapes(self):
        stream = Batch(
            DataStream(
                IterableDataset([numpy.ones((3, 4)), 2 * numpy.ones((2, 3))])),
            ConstantScheme(2))
        assert_raises(ValueError, next, Padding(stream).get_epoch_iterator()) 
Example #28
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_2d_sequences(self):
        stream = Batch(
            DataStream(
                IterableDataset([numpy.ones((3, 4)), 2 * numpy.ones((2, 4))])),
            ConstantScheme(2))
        it = Padding(stream).get_epoch_iterator()
        data, mask = next(it)
        assert data.shape == (2, 3, 4)
        assert (data[0, :, :] == 1).all()
        assert (data[1, :2, :] == 2).all()
        assert (mask == numpy.array([[1, 1, 1], [1, 1, 0]])).all() 
Example #29
Source File: test_transformers.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_value_error_on_example_stream(self):
        stream = DataStream(
            IterableDataset(
                dict(features=[[1], [2, 3]],
                     targets=[[4, 5, 6], [7]])))
        assert_raises(ValueError, Unpack, stream) 
Example #30
Source File: test_transformers.py    From fuel with MIT License 5 votes vote down vote up
def test_mapping_accepts_list_or_dict(self):
        def mapping(d):
            return [2 * i for i in d[0]],
        stream = DataStream(IterableDataset(self.data))
        assert_raises(ValueError,
                      lambda: Mapping(stream, mapping, mapping_accepts=int))