Java Code Examples for org.datavec.api.records.reader.SequenceRecordReader#initialize()

The following examples show how to use org.datavec.api.records.reader.SequenceRecordReader#initialize() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: CodecReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testCodecReaderMeta() throws Exception {
    File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
    SequenceRecordReader reader = new CodecRecordReader();
    Configuration conf = new Configuration();
    conf.set(CodecRecordReader.RAVEL, "true");
    conf.set(CodecRecordReader.START_FRAME, "160");
    conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
    conf.set(CodecRecordReader.ROWS, "80");
    conf.set(CodecRecordReader.COLUMNS, "46");
    reader.initialize(new FileSplit(file));
    reader.setConf(conf);
    assertTrue(reader.hasNext());
    List<List<Writable>> record = reader.sequenceRecord();
    assertEquals(500, record.size()); //500 frames

    reader.reset();
    SequenceRecord seqR = reader.nextSequence();
    assertEquals(record, seqR.getSequenceRecord());
    RecordMetaData meta = seqR.getMetaData();
    //        System.out.println(meta);
    assertTrue(meta.getURI().toString().endsWith(file.getName()));

    SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta);
    assertEquals(seqR, fromMeta);
}
 
Example 2
Source File: CodecReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testCodecReader() throws Exception {
    File file = new ClassPathResource("fire_lowres.mp4").getFile();
    SequenceRecordReader reader = new CodecRecordReader();
    Configuration conf = new Configuration();
    conf.set(CodecRecordReader.RAVEL, "true");
    conf.set(CodecRecordReader.START_FRAME, "160");
    conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
    conf.set(CodecRecordReader.ROWS, "80");
    conf.set(CodecRecordReader.COLUMNS, "46");
    reader.initialize(new FileSplit(file));
    reader.setConf(conf);
    assertTrue(reader.hasNext());
    List<List<Writable>> record = reader.sequenceRecord();
    //        System.out.println(record.size());

    Iterator<List<Writable>> it = record.iterator();
    List<Writable> first = it.next();
    //        System.out.println(first);

    //Expected size: 80x46x3
    assertEquals(1, first.size());
    assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length());
}
 
Example 3
Source File: CSVSequenceRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testCsvSeqAndNumberedFileSplit() throws Exception {
    File baseDir = tempDir.newFolder();
    //Simple sanity check unit test
    for (int i = 0; i < 3; i++) {
        new org.nd4j.linalg.io.ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
    }

    //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
    org.nd4j.linalg.io.ClassPathResource resource = new org.nd4j.linalg.io.ClassPathResource("csvsequence_0.txt");
    String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

    while(featureReader.hasNext()){
        featureReader.nextSequence();
    }

}
 
Example 4
Source File: CSVSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testCsvSeqAndNumberedFileSplit() throws Exception {
    File baseDir = tempDir.newFolder();
    //Simple sanity check unit test
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
    }

    //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
    ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
    String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

    while(featureReader.hasNext()){
        featureReader.nextSequence();
    }

}
 
Example 5
Source File: CSVLineSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void test() throws Exception {

    File f = testDir.newFolder();
    File source = new File(f, "temp.csv");
    String str = "a,b,c\n1,2,3,4";
    FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);

    SequenceRecordReader rr = new CSVLineSequenceRecordReader();
    rr.initialize(new FileSplit(source));

    List<List<Writable>> exp0 = Arrays.asList(
            Collections.<Writable>singletonList(new Text("a")),
            Collections.<Writable>singletonList(new Text("b")),
            Collections.<Writable>singletonList(new Text("c")));

    List<List<Writable>> exp1 = Arrays.asList(
            Collections.<Writable>singletonList(new Text("1")),
            Collections.<Writable>singletonList(new Text("2")),
            Collections.<Writable>singletonList(new Text("3")),
            Collections.<Writable>singletonList(new Text("4")));

    for( int i=0; i<3; i++ ) {
        int count = 0;
        while (rr.hasNext()) {
            List<List<Writable>> next = rr.sequenceRecord();
            if (count++ == 0) {
                assertEquals(exp0, next);
            } else {
                assertEquals(exp1, next);
            }
        }

        assertEquals(2, count);

        rr.reset();
    }
}
 
Example 6
Source File: CodecReaderTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Ignore
@Test
public void testNativeViaDataInputStream() throws Exception {

    File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
    SequenceRecordReader reader = new NativeCodecRecordReader();
    Configuration conf = new Configuration();
    conf.set(CodecRecordReader.RAVEL, "true");
    conf.set(CodecRecordReader.START_FRAME, "160");
    conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
    conf.set(CodecRecordReader.ROWS, "80");
    conf.set(CodecRecordReader.COLUMNS, "46");

    Configuration conf2 = new Configuration(conf);

    reader.initialize(new FileSplit(file));
    reader.setConf(conf);
    assertTrue(reader.hasNext());
    List<List<Writable>> expected = reader.sequenceRecord();


    SequenceRecordReader reader2 = new NativeCodecRecordReader();
    reader2.setConf(conf2);

    DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file));
    List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream);

    assertEquals(expected, actual);
}
 
Example 7
Source File: RegexRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testRegexSequenceRecordReaderMeta() throws Exception {
    String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";

    String path = new ClassPathResource("/logtestdata/logtestfile0.txt").getFile().toURI().toString();
    path = path.replace("0", "%d");
    InputSplit is = new NumberedFileInputSplit(path, 0, 1);

    SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
    rr.initialize(is);

    List<List<List<Writable>>> out = new ArrayList<>();
    while (rr.hasNext()) {
        out.add(rr.sequenceRecord());
    }

    assertEquals(2, out.size());
    List<List<List<Writable>>> out2 = new ArrayList<>();
    List<SequenceRecord> out3 = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    rr.reset();
    while (rr.hasNext()) {
        SequenceRecord seqr = rr.nextSequence();
        out2.add(seqr.getSequenceRecord());
        out3.add(seqr);
        meta.add(seqr.getMetaData());
    }

    List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta);

    assertEquals(out, out2);
    assertEquals(out3, fromMeta);
}
 
Example 8
Source File: SameDiffRNNTestCases.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
        public MultiDataSetIterator getEvaluationTestData() throws Exception {
            int miniBatchSize = 10;
            int numLabelClasses = 6;

//            File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile();
//            File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile();
            File featuresDirTest = Files.createTempDir();
            File labelsDirTest = Files.createTempDir();
            Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest);
            Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest);

            SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
            trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
            SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
            trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));

            DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
                    false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

            MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData);

            MultiDataSetPreProcessor pp = multiDataSet -> {
                INDArray l = multiDataSet.getLabels(0);
                l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
                multiDataSet.setLabels(0, l);
                multiDataSet.setLabelsMaskArray(0, null);
            };


            iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));

            return iter;
        }
 
Example 9
Source File: CSVNLinesSequenceRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
    int nLinesPerSequence = 10;

    SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
    seqRR.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    CSVRecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    List<List<List<Writable>>> out = new ArrayList<>();
    while (seqRR.hasNext()) {
        List<List<Writable>> next = seqRR.sequenceRecord();
        out.add(next);
    }

    seqRR.reset();
    List<List<List<Writable>>> out2 = new ArrayList<>();
    List<SequenceRecord> out3 = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    while (seqRR.hasNext()) {
        SequenceRecord seq = seqRR.nextSequence();
        out2.add(seq.getSequenceRecord());
        meta.add(seq.getMetaData());
        out3.add(seq);
    }

    assertEquals(out, out2);

    List<SequenceRecord> out4 = seqRR.loadSequenceFromMetaData(meta);
    assertEquals(out3, out4);
}
 
Example 10
Source File: CSVNLinesSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
    int nLinesPerSequence = 10;

    SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
    seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    CSVRecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    List<List<List<Writable>>> out = new ArrayList<>();
    while (seqRR.hasNext()) {
        List<List<Writable>> next = seqRR.sequenceRecord();
        out.add(next);
    }

    seqRR.reset();
    List<List<List<Writable>>> out2 = new ArrayList<>();
    List<SequenceRecord> out3 = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    while (seqRR.hasNext()) {
        SequenceRecord seq = seqRR.nextSequence();
        out2.add(seq.getSequenceRecord());
        meta.add(seq.getMetaData());
        out3.add(seq);
    }

    assertEquals(out, out2);

    List<SequenceRecord> out4 = seqRR.loadSequenceFromMetaData(meta);
    assertEquals(out3, out4);
}
 
Example 11
Source File: CSVVariableSlidingWindowRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testCSVVariableSlidingWindowRecordReader() throws Exception {
    int maxLinesPerSequence = 3;

    SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence);
    seqRR.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    CSVRecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    int count = 0;
    while (seqRR.hasNext()) {
        List<List<Writable>> next = seqRR.sequenceRecord();

        if(count==maxLinesPerSequence-1) {
            LinkedList<List<Writable>> expected = new LinkedList<>();
            for (int i = 0; i < maxLinesPerSequence; i++) {
                expected.addFirst(rr.next());
            }
            assertEquals(expected, next);

        }
        if(count==maxLinesPerSequence) {
            assertEquals(maxLinesPerSequence, next.size());
        }
        if(count==0) { // first seq should be length 1
            assertEquals(1, next.size());
        }
        if(count>151) { // last seq should be length 1
            assertEquals(1, next.size());
        }

        count++;
    }

    assertEquals(152, count);
}
 
Example 12
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSplittingCSVSequenceMeta() throws Exception {
    //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
    // as standard one-hot output
    //need to manually extract
    File rootDir = temporaryFolder.newFolder();
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
    }

    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2)
                    .addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();

    srrmdsi.setCollectMetaData(true);

    int count = 0;
    while (srrmdsi.hasNext()) {
        MultiDataSet mds = srrmdsi.next();
        MultiDataSet fromMeta = srrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
        assertEquals(mds, fromMeta);
        count++;
    }
    assertEquals(3, count);
}
 
Example 13
Source File: CodecReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Ignore
@Test
public void testNativeCodecReaderMeta() throws Exception {
    File file = new ClassPathResource("fire_lowres.mp4").getFile();
    SequenceRecordReader reader = new NativeCodecRecordReader();
    Configuration conf = new Configuration();
    conf.set(CodecRecordReader.RAVEL, "true");
    conf.set(CodecRecordReader.START_FRAME, "160");
    conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
    conf.set(CodecRecordReader.ROWS, "80");
    conf.set(CodecRecordReader.COLUMNS, "46");
    reader.initialize(new FileSplit(file));
    reader.setConf(conf);
    assertTrue(reader.hasNext());
    List<List<Writable>> record = reader.sequenceRecord();
    assertEquals(500, record.size()); //500 frames

    reader.reset();
    SequenceRecord seqR = reader.nextSequence();
    assertEquals(record, seqR.getSequenceRecord());
    RecordMetaData meta = seqR.getMetaData();
    //        System.out.println(meta);
    assertTrue(meta.getURI().toString().endsWith("fire_lowres.mp4"));

    SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta);
    assertEquals(seqR, fromMeta);
}
 
Example 14
Source File: CodecReaderTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Ignore
@Test
public void testNativeCodecReaderMeta() throws Exception {
    File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile();
    SequenceRecordReader reader = new NativeCodecRecordReader();
    Configuration conf = new Configuration();
    conf.set(CodecRecordReader.RAVEL, "true");
    conf.set(CodecRecordReader.START_FRAME, "160");
    conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
    conf.set(CodecRecordReader.ROWS, "80");
    conf.set(CodecRecordReader.COLUMNS, "46");
    reader.initialize(new FileSplit(file));
    reader.setConf(conf);
    assertTrue(reader.hasNext());
    List<List<Writable>> record = reader.sequenceRecord();
    assertEquals(500, record.size()); //500 frames

    reader.reset();
    SequenceRecord seqR = reader.nextSequence();
    assertEquals(record, seqR.getSequenceRecord());
    RecordMetaData meta = seqR.getMetaData();
    //        System.out.println(meta);
    assertTrue(meta.getURI().toString().endsWith("fire_lowres.mp4"));

    SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta);
    assertEquals(seqR, fromMeta);
}
 
Example 15
Source File: TestPairSequenceRecordReaderBytesFunction.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void test() throws Exception {
    //Goal: combine separate files together into a hadoop sequence file, for later parsing by a SequenceRecordReader
    //For example: use to combine input and labels data from separate files for training a RNN
    JavaSparkContext sc = getContext();

    File f = testDir.newFolder();
    new ClassPathResource("datavec-spark/video/").copyDirectory(f);
    String path = f.getAbsolutePath() + "/*";

    PathToKeyConverter pathConverter = new PathToKeyConverterFilename();
    JavaPairRDD<Text, BytesPairWritable> toWrite =
                    DataVecSparkUtil.combineFilesForSequenceFile(sc, path, path, pathConverter);

    Path p = Files.createTempDirectory("dl4j_rrbytesPairOut");
    p.toFile().deleteOnExit();
    String outPath = p.toString() + "/out";
    new File(outPath).deleteOnExit();
    toWrite.saveAsNewAPIHadoopFile(outPath, Text.class, BytesPairWritable.class, SequenceFileOutputFormat.class);

    //Load back into memory:
    JavaPairRDD<Text, BytesPairWritable> fromSeq = sc.sequenceFile(outPath, Text.class, BytesPairWritable.class);

    SequenceRecordReader srr1 = getReader();
    SequenceRecordReader srr2 = getReader();
    PairSequenceRecordReaderBytesFunction psrbf = new PairSequenceRecordReaderBytesFunction(srr1, srr2);

    JavaRDD<Tuple2<List<List<Writable>>, List<List<Writable>>>> writables = fromSeq.map(psrbf);
    List<Tuple2<List<List<Writable>>, List<List<Writable>>>> fromSequenceFile = writables.collect();

    //Load manually (single copy) and compare:
    InputSplit is = new FileSplit(f, new String[] {"mp4"}, true);
    SequenceRecordReader srr = getReader();
    srr.initialize(is);

    List<List<List<Writable>>> list = new ArrayList<>(4);
    while (srr.hasNext()) {
        list.add(srr.sequenceRecord());
    }

    assertEquals(4, list.size());
    assertEquals(4, fromSequenceFile.size());

    boolean[] found = new boolean[4];
    for (int i = 0; i < 4; i++) {
        int foundIndex = -1;
        Tuple2<List<List<Writable>>, List<List<Writable>>> tuple2 = fromSequenceFile.get(i);
        List<List<Writable>> seq1 = tuple2._1();
        List<List<Writable>> seq2 = tuple2._2();
        assertEquals(seq1, seq2);

        for (int j = 0; j < 4; j++) {
            if (seq1.equals(list.get(j))) {
                if (foundIndex != -1)
                    fail(); //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                foundIndex = j;
                if (found[foundIndex])
                    fail(); //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                found[foundIndex] = true; //mark this one as seen before
            }
        }
    }
    int count = 0;
    for (boolean b : found)
        if (b)
            count++;
    assertEquals(4, count); //Expect all 4 and exactly 4 pairwise matches between spark and local versions

}
 
Example 16
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testVariableLengthTSMeta() throws Exception {
    //need to manually extract
    File rootDir = temporaryFolder.newFolder();
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
    }
    //Set up SequenceRecordReaderDataSetIterators for comparison

    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt");

    //Set up
    SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
    featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
    featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in")
                    .addOutputOneHot("out", 0, 4)
                    .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();

    RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in")
                    .addOutputOneHot("out", 0, 4)
                    .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();

    rrmdsiStart.setCollectMetaData(true);
    rrmdsiEnd.setCollectMetaData(true);

    int count = 0;
    while (rrmdsiStart.hasNext()) {
        MultiDataSet mdsStart = rrmdsiStart.next();
        MultiDataSet mdsEnd = rrmdsiEnd.next();

        MultiDataSet mdsStartFromMeta =
                        rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class));
        MultiDataSet mdsEndFromMeta = rrmdsiEnd.loadFromMetaData(mdsEnd.getExampleMetaData(RecordMetaData.class));

        assertEquals(mdsStart, mdsStartFromMeta);
        assertEquals(mdsEnd, mdsEndFromMeta);

        count++;
    }
    assertFalse(rrmdsiStart.hasNext());
    assertFalse(rrmdsiEnd.hasNext());
    assertEquals(3, count);
}
 
Example 17
Source File: CSVMultiSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testPadding() throws Exception {

    for( int i=0; i<3; i++ ) {

        String seqSep;
        String seqSepRegex;
        switch (i) {
            case 0:
                seqSep = "";
                seqSepRegex = "^$";
                break;
            case 1:
                seqSep = "---";
                seqSepRegex = seqSep;
                break;
            case 2:
                seqSep = "&";
                seqSepRegex = seqSep;
                break;
            default:
                throw new RuntimeException();
        }

        String str = "a,b\n1\nx\n" + seqSep + "\nA\nB\nC";
        File f = testDir.newFile();
        FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);

        SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.PAD, new Text("PAD"));
        seqRR.initialize(new FileSplit(f));


        List<List<Writable>> exp0 = Arrays.asList(
                Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
                Arrays.<Writable>asList(new Text("b"), new Text("PAD"), new Text("PAD")));

        List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));

        assertEquals(exp0, seqRR.sequenceRecord());
        assertEquals(exp1, seqRR.sequenceRecord());
        assertFalse(seqRR.hasNext());

        seqRR.reset();

        assertEquals(exp0, seqRR.sequenceRecord());
        assertEquals(exp1, seqRR.sequenceRecord());
        assertFalse(seqRR.hasNext());
    }
}
 
Example 18
Source File: CSVVariableSlidingWindowRecordReaderTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
    int maxLinesPerSequence = 3;
    int stride = 2;

    SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride);
    seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    CSVRecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    int count = 0;
    while (seqRR.hasNext()) {
        List<List<Writable>> next = seqRR.sequenceRecord();

        if(count==maxLinesPerSequence-1) {
            LinkedList<List<Writable>> expected = new LinkedList<>();
            for(int s = 0; s < stride; s++) {
                expected = new LinkedList<>();
                for (int i = 0; i < maxLinesPerSequence; i++) {
                    expected.addFirst(rr.next());
                }
            }
            assertEquals(expected, next);

        }
        if(count==maxLinesPerSequence) {
            assertEquals(maxLinesPerSequence, next.size());
        }
        if(count==0) { // first seq should be length 2
            assertEquals(2, next.size());
        }
        if(count>151) { // last seq should be length 1
            assertEquals(1, next.size());
        }

        count++;
    }

    assertEquals(76, count);
}
 
Example 19
Source File: CSVMultiSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testEqualLength() throws Exception {

    for( int i=0; i<3; i++ ) {

        String seqSep;
        String seqSepRegex;
        switch (i) {
            case 0:
                seqSep = "";
                seqSepRegex = "^$";
                break;
            case 1:
                seqSep = "---";
                seqSepRegex = seqSep;
                break;
            case 2:
                seqSep = "&";
                seqSepRegex = seqSep;
                break;
            default:
                throw new RuntimeException();
        }

        String str = "a,b\n1,2\nx,y\n" + seqSep + "\nA\nB\nC";
        File f = testDir.newFile();
        FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);

        SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.EQUAL_LENGTH);
        seqRR.initialize(new FileSplit(f));


        List<List<Writable>> exp0 = Arrays.asList(
                Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
                Arrays.<Writable>asList(new Text("b"), new Text("2"), new Text("y")));

        List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));

        assertEquals(exp0, seqRR.sequenceRecord());
        assertEquals(exp1, seqRR.sequenceRecord());
        assertFalse(seqRR.hasNext());

        seqRR.reset();

        assertEquals(exp0, seqRR.sequenceRecord());
        assertEquals(exp1, seqRR.sequenceRecord());
        assertFalse(seqRR.hasNext());
    }
}
 
Example 20
Source File: TestSequenceRecordReaderBytesFunction.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRecordReaderBytesFunction() throws Exception {

    //Local file path
    File f = testDir.newFolder();
    new ClassPathResource("datavec-spark/video/").copyDirectory(f);
    String path = f.getAbsolutePath() + "/*";

    //Load binary data from local file system, convert to a sequence file:
    //Load and convert
    JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
    JavaPairRDD<Text, BytesWritable> filesAsBytes = origData.mapToPair(new FilesAsBytesFunction());
    //Write the sequence file:
    Path p = Files.createTempDirectory("dl4j_rrbytesTest");
    p.toFile().deleteOnExit();
    String outPath = p.toString() + "/out";
    filesAsBytes.saveAsNewAPIHadoopFile(outPath, Text.class, BytesWritable.class, SequenceFileOutputFormat.class);

    //Load data from sequence file, parse via SequenceRecordReader:
    JavaPairRDD<Text, BytesWritable> fromSeqFile = sc.sequenceFile(outPath, Text.class, BytesWritable.class);
    SequenceRecordReader seqRR = new CodecRecordReader();
    Configuration conf = new Configuration();
    conf.set(CodecRecordReader.RAVEL, "true");
    conf.set(CodecRecordReader.START_FRAME, "0");
    conf.set(CodecRecordReader.TOTAL_FRAMES, "25");
    conf.set(CodecRecordReader.ROWS, "64");
    conf.set(CodecRecordReader.COLUMNS, "64");
    Configuration confCopy = new Configuration(conf);
    seqRR.setConf(conf);
    JavaRDD<List<List<Writable>>> dataVecData = fromSeqFile.map(new SequenceRecordReaderBytesFunction(seqRR));



    //Next: do the same thing locally, and compare the results
    InputSplit is = new FileSplit(f, new String[] {"mp4"}, true);
    SequenceRecordReader srr = new CodecRecordReader();
    srr.initialize(is);
    srr.setConf(confCopy);

    List<List<List<Writable>>> list = new ArrayList<>(4);
    while (srr.hasNext()) {
        list.add(srr.sequenceRecord());
    }
    assertEquals(4, list.size());

    List<List<List<Writable>>> fromSequenceFile = dataVecData.collect();

    assertEquals(4, list.size());
    assertEquals(4, fromSequenceFile.size());

    boolean[] found = new boolean[4];
    for (int i = 0; i < 4; i++) {
        int foundIndex = -1;
        List<List<Writable>> collection = fromSequenceFile.get(i);
        for (int j = 0; j < 4; j++) {
            if (collection.equals(list.get(j))) {
                if (foundIndex != -1)
                    fail(); //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                foundIndex = j;
                if (found[foundIndex])
                    fail(); //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                found[foundIndex] = true; //mark this one as seen before
            }
        }
    }
    int count = 0;
    for (boolean b : found)
        if (b)
            count++;
    assertEquals(4, count); //Expect all 4 and exactly 4 pairwise matches between spark and local versions
}