org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader Java Examples

The following examples show how to use org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: CSVSequenceRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testReset() throws Exception {
    CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
    seqReader.initialize(new TestInputSplit());

    int nTests = 5;
    for (int i = 0; i < nTests; i++) {
        seqReader.reset();

        int sequenceCount = 0;
        while (seqReader.hasNext()) {
            List<List<Writable>> sequence = seqReader.sequenceRecord();
            assertEquals(4, sequence.size()); //4 lines, plus 1 header line

            Iterator<List<Writable>> timeStepIter = sequence.iterator();
            int lineCount = 0;
            while (timeStepIter.hasNext()) {
                timeStepIter.next();
                lineCount++;
            }
            sequenceCount++;
            assertEquals(4, lineCount);
        }
        assertEquals(3, sequenceCount);
    }
}
 
Example #2
Source File: CSVSequenceRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testCsvSeqAndNumberedFileSplit() throws Exception {
    File baseDir = tempDir.newFolder();
    //Simple sanity check unit test
    for (int i = 0; i < 3; i++) {
        new org.nd4j.linalg.io.ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
    }

    //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
    org.nd4j.linalg.io.ClassPathResource resource = new org.nd4j.linalg.io.ClassPathResource("csvsequence_0.txt");
    String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

    while(featureReader.hasNext()){
        featureReader.nextSequence();
    }

}
 
Example #3
Source File: SameDiffRNNTestCases.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception {
    int miniBatchSize = 10;
    int numLabelClasses = 6;

    File featuresDirTrain = Files.createTempDir();
    File labelsDirTrain = Files.createTempDir();
    Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/features/", featuresDirTrain);
    Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/labels/", labelsDirTrain);

    SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
    trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
    SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
    trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));

    DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
            false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData);

    return iter;
}
 
Example #4
Source File: CSVSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testReset() throws Exception {
    CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
    seqReader.initialize(new TestInputSplit());

    int nTests = 5;
    for (int i = 0; i < nTests; i++) {
        seqReader.reset();

        int sequenceCount = 0;
        while (seqReader.hasNext()) {
            List<List<Writable>> sequence = seqReader.sequenceRecord();
            assertEquals(4, sequence.size()); //4 lines, plus 1 header line

            Iterator<List<Writable>> timeStepIter = sequence.iterator();
            int lineCount = 0;
            while (timeStepIter.hasNext()) {
                timeStepIter.next();
                lineCount++;
            }
            sequenceCount++;
            assertEquals(4, lineCount);
        }
        assertEquals(3, sequenceCount);
    }
}
 
Example #5
Source File: CSVSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testCsvSeqAndNumberedFileSplit() throws Exception {
    File baseDir = tempDir.newFolder();
    //Simple sanity check unit test
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
    }

    //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
    ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
    String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

    while(featureReader.hasNext()){
        featureReader.nextSequence();
    }

}
 
Example #6
Source File: RNNTestCases.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception {
    int miniBatchSize = 10;
    int numLabelClasses = 6;

    File featuresDirTrain = Files.createTempDir();
    File labelsDirTrain = Files.createTempDir();
    new ClassPathResource("dl4j-integration-tests/data/uci_seq/train/features/").copyDirectory(featuresDirTrain);
    new ClassPathResource("dl4j-integration-tests/data/uci_seq/train/labels/").copyDirectory(labelsDirTrain);

    SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
    trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
    SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
    trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));

    DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
            false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData);
    return iter;
}
 
Example #7
Source File: RNNTestCases.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
        public MultiDataSetIterator getEvaluationTestData() throws Exception {
            int miniBatchSize = 10;
            int numLabelClasses = 6;

//            File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile();
//            File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile();
            File featuresDirTest = Files.createTempDir();
            File labelsDirTest = Files.createTempDir();
            new ClassPathResource("dl4j-integration-tests/data/uci_seq/test/features/").copyDirectory(featuresDirTest);
            new ClassPathResource("dl4j-integration-tests/data/uci_seq/test/labels/").copyDirectory(labelsDirTest);

            SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
            trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
            SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
            trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));

            DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
                    false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

            MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData);

            MultiDataSetPreProcessor pp = multiDataSet -> {
                INDArray l = multiDataSet.getLabels(0);
                l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2)-1));
                multiDataSet.setLabels(0, l);
                multiDataSet.setLabelsMaskArray(0, null);
            };


            iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(),pp));

            return iter;
        }
 
Example #8
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSplittingCSVSequenceMeta() throws Exception {
    //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
    // as standard one-hot output
    //need to manually extract
    File rootDir = temporaryFolder.newFolder();
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
    }

    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2)
                    .addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();

    srrmdsi.setCollectMetaData(true);

    int count = 0;
    while (srrmdsi.hasNext()) {
        MultiDataSet mds = srrmdsi.next();
        MultiDataSet fromMeta = srrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
        assertEquals(mds, fromMeta);
        count++;
    }
    assertEquals(3, count);
}
 
Example #9
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test(expected = ZeroLengthSequenceException.class)
public void testSequenceRecordReaderTwoReadersWithEmptyLabelSequenceThrows() throws Exception {
    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");

    File f = Resources.asFile("csvsequence_0.txt");
    featureReader.initialize(new FileSplit(f));
    labelReader.initialize(new FileSplit(Resources.asFile("empty.txt")));

    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true).next();
}
 
Example #10
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test(expected = ZeroLengthSequenceException.class)
public void testSequenceRecordReaderTwoReadersWithEmptyFeatureSequenceThrows() throws Exception {
    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");

    featureReader.initialize(new FileSplit(Resources.asFile("empty.txt")));
    labelReader.initialize(
                    new FileSplit(Resources.asFile("csvsequencelabels_0.txt")));

    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true).next();
}
 
Example #11
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test(expected = ZeroLengthSequenceException.class)
public void testSequenceRecordReaderSingleReaderWithEmptySequenceThrows() throws Exception {
    SequenceRecordReader reader = new CSVSequenceRecordReader(1, ",");
    reader.initialize(new FileSplit(Resources.asFile("empty.txt")));

    new SequenceRecordReaderDataSetIterator(reader, 1, -1, 1, true).next();
}
 
Example #12
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSequenceRecordReaderReset() throws Exception {
    File rootDir = temporaryFolder.newFolder();
    //need to manually extract
    for (int i = 0; i < 3; i++) {
        FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i)));
        FileUtils.copyFile(Resources.asFile(String.format("csvsequencelabels_%d.txt", i)), new File(rootDir, String.format("csvsequencelabels_%d.txt", i)));
    }
    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
                    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);

    assertEquals(3, iter.inputColumns());
    assertEquals(4, iter.totalOutcomes());

    int nResets = 5;
    for (int i = 0; i < nResets; i++) {
        iter.reset();
        int count = 0;
        while (iter.hasNext()) {
            DataSet ds = iter.next();
            INDArray features = ds.getFeatures();
            INDArray labels = ds.getLabels();
            assertArrayEquals(new long[] {1, 3, 4}, features.shape());
            assertArrayEquals(new long[] {1, 4, 4}, labels.shape());
            count++;
        }
        assertEquals(3, count);
    }
}
 
Example #13
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSequenceRecordReaderMeta() throws Exception {
    File rootDir = temporaryFolder.newFolder();
    //need to manually extract
    for (int i = 0; i < 3; i++) {
        FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i)));
        FileUtils.copyFile(Resources.asFile(String.format("csvsequencelabels_%d.txt", i)), new File(rootDir, String.format("csvsequencelabels_%d.txt", i)));
    }
    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
                    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);

    iter.setCollectMetaData(true);

    assertEquals(3, iter.inputColumns());
    assertEquals(4, iter.totalOutcomes());

    while (iter.hasNext()) {
        DataSet ds = iter.next();
        List<RecordMetaData> meta = ds.getExampleMetaData(RecordMetaData.class);
        DataSet fromMeta = iter.loadFromMetaData(meta);

        assertEquals(ds, fromMeta);
    }
}
 
Example #14
Source File: CSVSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMetaData() throws Exception {
    CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
    seqReader.initialize(new TestInputSplit());

    List<List<List<Writable>>> l = new ArrayList<>();
    while (seqReader.hasNext()) {
        List<List<Writable>> sequence = seqReader.sequenceRecord();
        assertEquals(4, sequence.size()); //4 lines, plus 1 header line

        Iterator<List<Writable>> timeStepIter = sequence.iterator();
        int lineCount = 0;
        while (timeStepIter.hasNext()) {
            timeStepIter.next();
            lineCount++;
        }
        assertEquals(4, lineCount);

        l.add(sequence);
    }

    List<SequenceRecord> l2 = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    seqReader.reset();
    while (seqReader.hasNext()) {
        SequenceRecord sr = seqReader.nextSequence();
        l2.add(sr);
        meta.add(sr.getMetaData());
    }
    assertEquals(3, l2.size());

    List<SequenceRecord> fromMeta = seqReader.loadSequenceFromMetaData(meta);
    for (int i = 0; i < 3; i++) {
        assertEquals(l.get(i), l2.get(i).getSequenceRecord());
        assertEquals(l.get(i), fromMeta.get(i).getSequenceRecord());
    }
}
 
Example #15
Source File: CSVSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void test() throws Exception {

    CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
    seqReader.initialize(new TestInputSplit());

    int sequenceCount = 0;
    while (seqReader.hasNext()) {
        List<List<Writable>> sequence = seqReader.sequenceRecord();
        assertEquals(4, sequence.size()); //4 lines, plus 1 header line

        Iterator<List<Writable>> timeStepIter = sequence.iterator();
        int lineCount = 0;
        while (timeStepIter.hasNext()) {
            List<Writable> timeStep = timeStepIter.next();
            assertEquals(3, timeStep.size());
            Iterator<Writable> lineIter = timeStep.iterator();
            int countInLine = 0;
            while (lineIter.hasNext()) {
                Writable entry = lineIter.next();
                int expValue = 100 * sequenceCount + 10 * lineCount + countInLine;
                assertEquals(String.valueOf(expValue), entry.toString());
                countInLine++;
            }
            lineCount++;
        }
        sequenceCount++;
    }
}
 
Example #16
Source File: CSVSequenceRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testMetaData() throws Exception {
    CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
    seqReader.initialize(new TestInputSplit());

    List<List<List<Writable>>> l = new ArrayList<>();
    while (seqReader.hasNext()) {
        List<List<Writable>> sequence = seqReader.sequenceRecord();
        assertEquals(4, sequence.size()); //4 lines, plus 1 header line

        Iterator<List<Writable>> timeStepIter = sequence.iterator();
        int lineCount = 0;
        while (timeStepIter.hasNext()) {
            timeStepIter.next();
            lineCount++;
        }
        assertEquals(4, lineCount);

        l.add(sequence);
    }

    List<SequenceRecord> l2 = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    seqReader.reset();
    while (seqReader.hasNext()) {
        SequenceRecord sr = seqReader.nextSequence();
        l2.add(sr);
        meta.add(sr.getMetaData());
    }
    assertEquals(3, l2.size());

    List<SequenceRecord> fromMeta = seqReader.loadSequenceFromMetaData(meta);
    for (int i = 0; i < 3; i++) {
        assertEquals(l.get(i), l2.get(i).getSequenceRecord());
        assertEquals(l.get(i), fromMeta.get(i).getSequenceRecord());
    }
}
 
Example #17
Source File: CSVSequenceRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void test() throws Exception {

    CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
    seqReader.initialize(new TestInputSplit());

    int sequenceCount = 0;
    while (seqReader.hasNext()) {
        List<List<Writable>> sequence = seqReader.sequenceRecord();
        assertEquals(4, sequence.size()); //4 lines, plus 1 header line

        Iterator<List<Writable>> timeStepIter = sequence.iterator();
        int lineCount = 0;
        while (timeStepIter.hasNext()) {
            List<Writable> timeStep = timeStepIter.next();
            assertEquals(3, timeStep.size());
            Iterator<Writable> lineIter = timeStep.iterator();
            int countInLine = 0;
            while (lineIter.hasNext()) {
                Writable entry = lineIter.next();
                int expValue = 100 * sequenceCount + 10 * lineCount + countInLine;
                assertEquals(String.valueOf(expValue), entry.toString());
                countInLine++;
            }
            lineCount++;
        }
        sequenceCount++;
    }
}
 
Example #18
Source File: SameDiffRNNTestCases.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
        public MultiDataSetIterator getEvaluationTestData() throws Exception {
            int miniBatchSize = 10;
            int numLabelClasses = 6;

//            File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile();
//            File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile();
            File featuresDirTest = Files.createTempDir();
            File labelsDirTest = Files.createTempDir();
            Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest);
            Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest);

            SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
            trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
            SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
            trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));

            DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
                    false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

            MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData);

            MultiDataSetPreProcessor pp = multiDataSet -> {
                INDArray l = multiDataSet.getLabels(0);
                l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
                multiDataSet.setLabels(0, l);
                multiDataSet.setLabelsMaskArray(0, null);
            };


            iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));

            return iter;
        }
 
Example #19
Source File: TestDataVecDataSetFunctions.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testDataVecSequencePairDataSetFunction() throws Exception {
    JavaSparkContext sc = getContext();

    File f = testDir.newFolder();
    ClassPathResource cpr = new ClassPathResource("dl4j-spark/csvsequence/");
    cpr.copyDirectory(f);
    String path = f.getAbsolutePath() + "/*";

    PathToKeyConverter pathConverter = new PathToKeyConverterFilename();
    JavaPairRDD<Text, BytesPairWritable> toWrite =
                    DataVecSparkUtil.combineFilesForSequenceFile(sc, path, path, pathConverter);

    Path p = testDir.newFolder("dl4j_testSeqPairFn").toPath();
    p.toFile().deleteOnExit();
    String outPath = p.toString() + "/out";
    new File(outPath).deleteOnExit();
    toWrite.saveAsNewAPIHadoopFile(outPath, Text.class, BytesPairWritable.class, SequenceFileOutputFormat.class);

    //Load from sequence file:
    JavaPairRDD<Text, BytesPairWritable> fromSeq = sc.sequenceFile(outPath, Text.class, BytesPairWritable.class);

    SequenceRecordReader srr1 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader srr2 = new CSVSequenceRecordReader(1, ",");
    PairSequenceRecordReaderBytesFunction psrbf = new PairSequenceRecordReaderBytesFunction(srr1, srr2);
    JavaRDD<Tuple2<List<List<Writable>>, List<List<Writable>>>> writables = fromSeq.map(psrbf);

    //Map to DataSet:
    DataVecSequencePairDataSetFunction pairFn = new DataVecSequencePairDataSetFunction();
    JavaRDD<DataSet> data = writables.map(pairFn);
    List<DataSet> sparkData = data.collect();


    //Now: do the same thing locally (SequenceRecordReaderDataSetIterator) and compare
    String featuresPath = FilenameUtils.concat(f.getAbsolutePath(), "csvsequence_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
                    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true);

    List<DataSet> localData = new ArrayList<>(3);
    while (iter.hasNext())
        localData.add(iter.next());

    assertEquals(3, sparkData.size());
    assertEquals(3, localData.size());

    for (int i = 0; i < 3; i++) {
        //Check shapes etc. data sets order may differ for spark vs. local
        DataSet dsSpark = sparkData.get(i);
        DataSet dsLocal = localData.get(i);

        assertNull(dsSpark.getFeaturesMaskArray());
        assertNull(dsSpark.getLabelsMaskArray());

        INDArray fSpark = dsSpark.getFeatures();
        INDArray fLocal = dsLocal.getFeatures();
        INDArray lSpark = dsSpark.getLabels();
        INDArray lLocal = dsLocal.getLabels();

        val s = new long[] {1, 3, 4}; //1 example, 3 values, 3 time steps
        assertArrayEquals(s, fSpark.shape());
        assertArrayEquals(s, fLocal.shape());
        assertArrayEquals(s, lSpark.shape());
        assertArrayEquals(s, lLocal.shape());
    }


    //Check that results are the same (order not withstanding)
    boolean[] found = new boolean[3];
    for (int i = 0; i < 3; i++) {
        int foundIndex = -1;
        DataSet ds = sparkData.get(i);
        for (int j = 0; j < 3; j++) {
            if (ds.equals(localData.get(j))) {
                if (foundIndex != -1)
                    fail(); //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                foundIndex = j;
                if (found[foundIndex])
                    fail(); //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                found[foundIndex] = true; //mark this one as seen before
            }
        }
    }
    int count = 0;
    for (boolean b : found)
        if (b)
            count++;
    assertEquals(3, count); //Expect all 3 and exactly 3 pairwise matches between spark and local versions
}
 
Example #20
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSequenceRecordReaderMultiRegression() throws Exception {
    File rootDir = temporaryFolder.newFolder();
    //need to manually extract
    for (int i = 0; i < 3; i++) {
        FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i)));
    }
    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");

    SequenceRecordReader reader = new CSVSequenceRecordReader(1, ",");
    reader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
            new SequenceRecordReaderDataSetIterator(reader, 1, 2, 1, true);

    assertEquals(1, iter.inputColumns());
    assertEquals(2, iter.totalOutcomes());

    List<DataSet> dsList = new ArrayList<>();
    while (iter.hasNext()) {
        dsList.add(iter.next());
    }

    assertEquals(3, dsList.size()); //3 files
    for (int i = 0; i < 3; i++) {
        DataSet ds = dsList.get(i);
        INDArray features = ds.getFeatures();
        INDArray labels = ds.getLabels();
        assertArrayEquals(new long[] {1, 1, 4}, features.shape()); //1 examples, 1 values, 4 time steps
        assertArrayEquals(new long[] {1, 2, 4}, labels.shape());

        INDArray f2d = features.get(point(0), all(), all()).transpose();
        INDArray l2d = labels.get(point(0), all(), all()).transpose();

        switch (i){
            case 0:
                assertEquals(Nd4j.create(new double[]{0,10,20,30}, new int[]{4,1}).castTo(DataType.FLOAT), f2d);
                assertEquals(Nd4j.create(new double[][]{{1,2}, {11,12}, {21,22}, {31,32}}).castTo(DataType.FLOAT), l2d);
                break;
            case 1:
                assertEquals(Nd4j.create(new double[]{100,110,120,130}, new int[]{4,1}).castTo(DataType.FLOAT), f2d);
                assertEquals(Nd4j.create(new double[][]{{101,102}, {111,112}, {121,122}, {131,132}}).castTo(DataType.FLOAT), l2d);
                break;
            case 2:
                assertEquals(Nd4j.create(new double[]{200,210,220,230}, new int[]{4,1}).castTo(DataType.FLOAT), f2d);
                assertEquals(Nd4j.create(new double[][]{{201,202}, {211,212}, {221,222}, {231,232}}).castTo(DataType.FLOAT), l2d);
                break;
            default:
                throw new RuntimeException();
        }
    }


    iter.reset();
    int count = 0;
    while (iter.hasNext()) {
        iter.next();
        count++;
    }
    assertEquals(3, count);
}
 
Example #21
Source File: UciSequenceDataFetcher.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public CSVSequenceRecordReader getRecordReader(long rngSeed, int[] shape, DataSetType set, ImageTransform transform) {
    return getRecordReader(rngSeed, set);
}
 
Example #22
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testVariableLengthTSMeta() throws Exception {
    //need to manually extract
    File rootDir = temporaryFolder.newFolder();
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
    }
    //Set up SequenceRecordReaderDataSetIterators for comparison

    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt");

    //Set up
    SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
    featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
    featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in")
                    .addOutputOneHot("out", 0, 4)
                    .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();

    RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in")
                    .addOutputOneHot("out", 0, 4)
                    .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();

    rrmdsiStart.setCollectMetaData(true);
    rrmdsiEnd.setCollectMetaData(true);

    int count = 0;
    while (rrmdsiStart.hasNext()) {
        MultiDataSet mdsStart = rrmdsiStart.next();
        MultiDataSet mdsEnd = rrmdsiEnd.next();

        MultiDataSet mdsStartFromMeta =
                        rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class));
        MultiDataSet mdsEndFromMeta = rrmdsiEnd.loadFromMetaData(mdsEnd.getExampleMetaData(RecordMetaData.class));

        assertEquals(mdsStart, mdsStartFromMeta);
        assertEquals(mdsEnd, mdsEndFromMeta);

        count++;
    }
    assertFalse(rrmdsiStart.hasNext());
    assertFalse(rrmdsiEnd.hasNext());
    assertEquals(3, count);
}
 
Example #23
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testVariableLengthTS() throws Exception {
    //need to manually extract
    File rootDir = temporaryFolder.newFolder();
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
    }

    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt");

    //Set up SequenceRecordReaderDataSetIterators for comparison

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader,
                    labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START);

    SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2,
                    labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);


    //Set up
    SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
    featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
    featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in")
                    .addOutputOneHot("out", 0, 4)
                    .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();

    RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in")
                    .addOutputOneHot("out", 0, 4)
                    .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();


    while (iterAlignStart.hasNext()) {
        DataSet dsStart = iterAlignStart.next();
        DataSet dsEnd = iterAlignEnd.next();

        MultiDataSet mdsStart = rrmdsiStart.next();
        MultiDataSet mdsEnd = rrmdsiEnd.next();

        assertEquals(1, mdsStart.getFeatures().length);
        assertEquals(1, mdsStart.getLabels().length);
        //assertEquals(1, mdsStart.getFeaturesMaskArrays().length); //Features data is always longer -> don't need mask arrays for it
        assertEquals(1, mdsStart.getLabelsMaskArrays().length);

        assertEquals(1, mdsEnd.getFeatures().length);
        assertEquals(1, mdsEnd.getLabels().length);
        //assertEquals(1, mdsEnd.getFeaturesMaskArrays().length);
        assertEquals(1, mdsEnd.getLabelsMaskArrays().length);


        assertEquals(dsStart.getFeatures(), mdsStart.getFeatures(0));
        assertEquals(dsStart.getLabels(), mdsStart.getLabels(0));
        assertEquals(dsStart.getLabelsMaskArray(), mdsStart.getLabelsMaskArray(0));

        assertEquals(dsEnd.getFeatures(), mdsEnd.getFeatures(0));
        assertEquals(dsEnd.getLabels(), mdsEnd.getLabels(0));
        assertEquals(dsEnd.getLabelsMaskArray(), mdsEnd.getLabelsMaskArray(0));
    }
    assertFalse(rrmdsiStart.hasNext());
    assertFalse(rrmdsiEnd.hasNext());
}
 
Example #24
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSplittingCSVSequence() throws Exception {
    //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
    // as standard one-hot output
    //need to manually extract
    File rootDir = temporaryFolder.newFolder();
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
    }

    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
                    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);

    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2)
                    .addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();

    while (iter.hasNext()) {
        DataSet ds = iter.next();
        INDArray fds = ds.getFeatures();
        INDArray lds = ds.getLabels();

        MultiDataSet mds = srrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(1, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();

        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++)
            assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++)
            assertNotNull(lmds[i]);

        INDArray expIn1 = fds.get(all(), NDArrayIndex.interval(0, 1, true), all());
        INDArray expIn2 = fds.get(all(), NDArrayIndex.interval(2, 2, true), all());

        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(lds, lmds[0]);
    }
    assertFalse(srrmdsi.hasNext());
}
 
Example #25
Source File: LstmTimeSeriesExample.java    From Java-Deep-Learning-Cookbook with MIT License 4 votes vote down vote up
public static void main(String[] args) throws IOException, InterruptedException {
    if(FEATURE_DIR.equals("{PATH-TO-PHYSIONET-FEATURES}") || LABEL_DIR.equals("{PATH-TO-PHYSIONET-LABELS")){
        System.out.println("Please provide proper directory path in place of: PATH-TO-PHYSIONET-FEATURES && PATH-TO-PHYSIONET-LABELS");
        throw new FileNotFoundException();
    }
    SequenceRecordReader trainFeaturesReader = new CSVSequenceRecordReader(1, ",");
    trainFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",0,3199));
    SequenceRecordReader trainLabelsReader = new CSVSequenceRecordReader();
    trainLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",0,3199));
    DataSetIterator trainDataSetIterator = new SequenceRecordReaderDataSetIterator(trainFeaturesReader,trainLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    SequenceRecordReader testFeaturesReader = new CSVSequenceRecordReader(1, ",");
    testFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",3200,3999));
    SequenceRecordReader testLabelsReader = new CSVSequenceRecordReader();
    testLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",3200,3999));
    DataSetIterator testDataSetIterator = new SequenceRecordReaderDataSetIterator(testFeaturesReader,testLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder()
                                                    .seed(RANDOM_SEED)
                                                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                                                    .weightInit(WeightInit.XAVIER)
                                                    .updater(new Adam())
                                                    .dropOut(0.9)
                                                    .graphBuilder()
                                                    .addInputs("trainFeatures")
                                                    .setOutputs("predictMortality")
                                                    .addLayer("L1", new LSTM.Builder()
                                                                                   .nIn(86)
                                                                                    .nOut(200)
                                                                                    .forgetGateBiasInit(1)
                                                                                    .activation(Activation.TANH)
                                                                                    .build(),"trainFeatures")
                                                    .addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                                                                        .activation(Activation.SOFTMAX)
                                                                                        .nIn(200).nOut(2).build(),"L1")
                                                    .build();

    ComputationGraph model = new ComputationGraph(configuration);

    for(int i=0;i<1;i++){
       model.fit(trainDataSetIterator);
       trainDataSetIterator.reset();
    }
    ROC evaluation = new ROC(100);
    while (testDataSetIterator.hasNext()) {
        DataSet batch = testDataSetIterator.next();
        INDArray[] output = model.output(batch.getFeatures());
        evaluation.evalTimeSeries(batch.getLabels(), output[0]);
    }
    
    System.out.println(evaluation.calculateAUC());
    System.out.println(evaluation.stats());
}
 
Example #26
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSequenceRecordReader() throws Exception {
    File rootDir = temporaryFolder.newFolder();
    //need to manually extract
    for (int i = 0; i < 3; i++) {
        FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i)));
        FileUtils.copyFile(Resources.asFile(String.format("csvsequencelabels_%d.txt", i)), new File(rootDir, String.format("csvsequencelabels_%d.txt", i)));
    }
    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
                    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);

    assertEquals(3, iter.inputColumns());
    assertEquals(4, iter.totalOutcomes());

    List<DataSet> dsList = new ArrayList<>();
    while (iter.hasNext()) {
        dsList.add(iter.next());
    }

    assertEquals(3, dsList.size()); //3 files
    for (int i = 0; i < 3; i++) {
        DataSet ds = dsList.get(i);
        INDArray features = ds.getFeatures();
        INDArray labels = ds.getLabels();
        assertEquals(1, features.size(0)); //1 example in mini-batch
        assertEquals(1, labels.size(0));
        assertEquals(3, features.size(1)); //3 values per line/time step
        assertEquals(4, labels.size(1)); //1 value per line, but 4 possible values -> one-hot vector
        assertEquals(4, features.size(2)); //sequence length = 4
        assertEquals(4, labels.size(2));
    }

    //Check features vs. expected:
    INDArray expF0 = Nd4j.create(1, 3, 4);
    expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 1, 2}));
    expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {10, 11, 12}));
    expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {20, 21, 22}));
    expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {30, 31, 32}));
    assertEquals(dsList.get(0).getFeatures(), expF0);

    INDArray expF1 = Nd4j.create(1, 3, 4);
    expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {100, 101, 102}));
    expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {110, 111, 112}));
    expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {120, 121, 122}));
    expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {130, 131, 132}));
    assertEquals(dsList.get(1).getFeatures(), expF1);

    INDArray expF2 = Nd4j.create(1, 3, 4);
    expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {200, 201, 202}));
    expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {210, 211, 212}));
    expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {220, 221, 222}));
    expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {230, 231, 232}));
    assertEquals(dsList.get(2).getFeatures(), expF2);

    //Check labels vs. expected:
    INDArray expL0 = Nd4j.create(1, 4, 4);
    expL0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {1, 0, 0, 0}));
    expL0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0}));
    expL0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0, 0, 1, 0}));
    expL0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {0, 0, 0, 1}));
    assertEquals(dsList.get(0).getLabels(), expL0);

    INDArray expL1 = Nd4j.create(1, 4, 4);
    expL1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 0, 0, 1}));
    expL1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {0, 0, 1, 0}));
    expL1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0}));
    expL1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {1, 0, 0, 0}));
    assertEquals(dsList.get(1).getLabels(), expL1);

    INDArray expL2 = Nd4j.create(1, 4, 4);
    expL2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0}));
    expL2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {1, 0, 0, 0}));
    expL2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0, 0, 0, 1}));
    expL2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {0, 0, 1, 0}));
    assertEquals(dsList.get(2).getLabels(), expL2);
}
 
Example #27
Source File: TestStreamInputSplit.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testShuffle() throws Exception {
    File dir = testDir.newFolder();
    File f1 = new File(dir, "file1.txt");
    File f2 = new File(dir, "file2.txt");
    File f3 = new File(dir, "file3.txt");

    FileUtils.writeStringToFile(f1, "a,b,c", StandardCharsets.UTF_8);
    FileUtils.writeStringToFile(f2, "1,2,3", StandardCharsets.UTF_8);
    FileUtils.writeStringToFile(f3, "x,y,z", StandardCharsets.UTF_8);

    List<URI> uris = Arrays.asList(f1.toURI(), f2.toURI(), f3.toURI());

    CSVSequenceRecordReader rr = new CSVSequenceRecordReader();

    TestStreamFunction fn = new TestStreamFunction();
    InputSplit is = new StreamInputSplit(uris, fn, new Random(12345));
    rr.initialize(is);

    List<List<List<Writable>>> act = new ArrayList<>();
    while (rr.hasNext()) {
        act.add(rr.sequenceRecord());
    }

    rr.reset();
    List<List<List<Writable>>> act2 = new ArrayList<>();
    while (rr.hasNext()) {
        act2.add(rr.sequenceRecord());
    }

    rr.reset();
    List<List<List<Writable>>> act3 = new ArrayList<>();
    while (rr.hasNext()) {
        act3.add(rr.sequenceRecord());
    }

    assertEquals(3, act.size());
    assertEquals(3, act2.size());
    assertEquals(3, act3.size());

    /*
    System.out.println(act);
    System.out.println("---------");
    System.out.println(act2);
    System.out.println("---------");
    System.out.println(act3);
    */

    //Check not the same. With this RNG seed, results are different for first 3 resets
    assertNotEquals(act, act2);
    assertNotEquals(act2, act3);
    assertNotEquals(act, act3);
}
 
Example #28
Source File: TestStreamInputSplit.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testCsvSequenceSimple() throws Exception {

    File dir = testDir.newFolder();
    File f1 = new File(dir, "file1.txt");
    File f2 = new File(dir, "file2.txt");

    FileUtils.writeStringToFile(f1, "a,b,c\nd,e,f", StandardCharsets.UTF_8);
    FileUtils.writeStringToFile(f2, "1,2,3", StandardCharsets.UTF_8);

    List<URI> uris = Arrays.asList(f1.toURI(), f2.toURI());

    CSVSequenceRecordReader rr = new CSVSequenceRecordReader();

    TestStreamFunction fn = new TestStreamFunction();
    InputSplit is = new StreamInputSplit(uris, fn);
    rr.initialize(is);

    List<List<List<Writable>>> exp = new ArrayList<>();
    exp.add(Arrays.asList(
            Arrays.<Writable>asList(new Text("a"), new Text("b"), new Text("c")),
            Arrays.<Writable>asList(new Text("d"), new Text("e"), new Text("f"))));
    exp.add(Arrays.asList(
            Arrays.<Writable>asList(new Text("1"), new Text("2"), new Text("3"))));

    List<List<List<Writable>>> act = new ArrayList<>();
    while (rr.hasNext()) {
        act.add(rr.sequenceRecord());
    }

    assertEquals(exp, act);

    //Check that the specified stream loading function was used, not the default:
    assertEquals(uris, fn.calledWithUris);

    rr.reset();
    int count = 0;
    while(rr.hasNext()) {
        count++;
        rr.sequenceRecord();
    }
    assertEquals(2, count);
}
 
Example #29
Source File: FileBatchRecordReaderTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testCsvSequence() throws Exception {
    //CSV sequence - 3 lines per file, 10 files
    File baseDir = testDir.newFolder();

    List<File> fileList = new ArrayList<>();
    for( int i=0; i<10; i++ ){
        StringBuilder sb = new StringBuilder();
        for( int j=0; j<3; j++ ){
            if(j > 0)
                sb.append("\n");
            sb.append("file_" + i + "," + i + "," + j);
        }
        File f = new File(baseDir, "origFile" + i + ".csv");
        FileUtils.writeStringToFile(f, sb.toString(), StandardCharsets.UTF_8);
        fileList.add(f);
    }

    FileBatch fb = FileBatch.forFiles(fileList);
    SequenceRecordReader rr = new CSVSequenceRecordReader();
    FileBatchSequenceRecordReader fbrr = new FileBatchSequenceRecordReader(rr, fb);


    for( int test=0; test<3; test++) {
        for (int i = 0; i < 10; i++) {
            assertTrue(fbrr.hasNext());
            List<List<Writable>> next = fbrr.sequenceRecord();
            assertEquals(3, next.size());
            int count = 0;
            for(List<Writable> step : next ){
                String s1 = "file_" + i;
                assertEquals(s1, step.get(0).toString());
                assertEquals(String.valueOf(i), step.get(1).toString());
                assertEquals(String.valueOf(count++), step.get(2).toString());
            }
        }
        assertFalse(fbrr.hasNext());
        assertTrue(fbrr.resetSupported());
        fbrr.reset();
    }
}
 
Example #30
Source File: LstmTimeSeriesExample.java    From Java-Deep-Learning-Cookbook with MIT License 4 votes vote down vote up
public static void main(String[] args) throws IOException, InterruptedException {
    if(FEATURE_DIR.equals("{PATH-TO-PHYSIONET-FEATURES}") || LABEL_DIR.equals("{PATH-TO-PHYSIONET-LABELS")){
        System.out.println("Please provide proper directory path in place of: PATH-TO-PHYSIONET-FEATURES && PATH-TO-PHYSIONET-LABELS");
        throw new FileNotFoundException();
    }
    SequenceRecordReader trainFeaturesReader = new CSVSequenceRecordReader(1, ",");
    trainFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",0,3199));
    SequenceRecordReader trainLabelsReader = new CSVSequenceRecordReader();
    trainLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",0,3199));
    DataSetIterator trainDataSetIterator = new SequenceRecordReaderDataSetIterator(trainFeaturesReader,trainLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    SequenceRecordReader testFeaturesReader = new CSVSequenceRecordReader(1, ",");
    testFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",3200,3999));
    SequenceRecordReader testLabelsReader = new CSVSequenceRecordReader();
    testLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",3200,3999));
    DataSetIterator testDataSetIterator = new SequenceRecordReaderDataSetIterator(testFeaturesReader,testLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder()
                                                    .seed(RANDOM_SEED)
                                                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                                                    .weightInit(WeightInit.XAVIER)
                                                    .updater(new Adam())
                                                    .dropOut(0.9)
                                                    .graphBuilder()
                                                    .addInputs("trainFeatures")
                                                    .setOutputs("predictMortality")
                                                    .addLayer("L1", new LSTM.Builder()
                                                                                   .nIn(86)
                                                                                    .nOut(200)
                                                                                    .forgetGateBiasInit(1)
                                                                                    .activation(Activation.TANH)
                                                                                    .build(),"trainFeatures")
                                                    .addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                                                                        .activation(Activation.SOFTMAX)
                                                                                        .nIn(200).nOut(2).build(),"L1")
                                                    .build();

    ComputationGraph model = new ComputationGraph(configuration);

    for(int i=0;i<1;i++){
       model.fit(trainDataSetIterator);
       trainDataSetIterator.reset();
    }
    ROC evaluation = new ROC(100);
    while (testDataSetIterator.hasNext()) {
        DataSet batch = testDataSetIterator.next();
        INDArray[] output = model.output(batch.getFeatures());
        evaluation.evalTimeSeries(batch.getLabels(), output[0]);
    }
    
    System.out.println(evaluation.calculateAUC());
    System.out.println(evaluation.stats());
}