Java Code Examples for org.datavec.api.records.reader.SequenceRecordReader#hasNext()

The following examples show how to use org.datavec.api.records.reader.SequenceRecordReader#hasNext() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: CSVSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testCsvSeqAndNumberedFileSplit() throws Exception {
    File baseDir = tempDir.newFolder();
    //Simple sanity check unit test
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
    }

    //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
    ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
    String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

    while(featureReader.hasNext()){
        featureReader.nextSequence();
    }

}
 
Example 2
Source File: CSVSequenceRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testCsvSeqAndNumberedFileSplit() throws Exception {
    File baseDir = tempDir.newFolder();
    //Simple sanity check unit test
    for (int i = 0; i < 3; i++) {
        new org.nd4j.linalg.io.ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
    }

    //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
    org.nd4j.linalg.io.ClassPathResource resource = new org.nd4j.linalg.io.ClassPathResource("csvsequence_0.txt");
    String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

    while(featureReader.hasNext()){
        featureReader.nextSequence();
    }

}
 
Example 3
Source File: CSVVariableSlidingWindowRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testCSVVariableSlidingWindowRecordReader() throws Exception {
    int maxLinesPerSequence = 3;

    SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence);
    seqRR.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    CSVRecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    int count = 0;
    while (seqRR.hasNext()) {
        List<List<Writable>> next = seqRR.sequenceRecord();

        if(count==maxLinesPerSequence-1) {
            LinkedList<List<Writable>> expected = new LinkedList<>();
            for (int i = 0; i < maxLinesPerSequence; i++) {
                expected.addFirst(rr.next());
            }
            assertEquals(expected, next);

        }
        if(count==maxLinesPerSequence) {
            assertEquals(maxLinesPerSequence, next.size());
        }
        if(count==0) { // first seq should be length 1
            assertEquals(1, next.size());
        }
        if(count>151) { // last seq should be length 1
            assertEquals(1, next.size());
        }

        count++;
    }

    assertEquals(152, count);
}
 
Example 4
Source File: CSVLineSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void test() throws Exception {

    File f = testDir.newFolder();
    File source = new File(f, "temp.csv");
    String str = "a,b,c\n1,2,3,4";
    FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);

    SequenceRecordReader rr = new CSVLineSequenceRecordReader();
    rr.initialize(new FileSplit(source));

    List<List<Writable>> exp0 = Arrays.asList(
            Collections.<Writable>singletonList(new Text("a")),
            Collections.<Writable>singletonList(new Text("b")),
            Collections.<Writable>singletonList(new Text("c")));

    List<List<Writable>> exp1 = Arrays.asList(
            Collections.<Writable>singletonList(new Text("1")),
            Collections.<Writable>singletonList(new Text("2")),
            Collections.<Writable>singletonList(new Text("3")),
            Collections.<Writable>singletonList(new Text("4")));

    for( int i=0; i<3; i++ ) {
        int count = 0;
        while (rr.hasNext()) {
            List<List<Writable>> next = rr.sequenceRecord();
            if (count++ == 0) {
                assertEquals(exp0, next);
            } else {
                assertEquals(exp1, next);
            }
        }

        assertEquals(2, count);

        rr.reset();
    }
}
 
Example 5
Source File: RegexRecordReaderTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testRegexSequenceRecordReaderMeta() throws Exception {
    String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";

    ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
    File f = testDir.newFolder();
    cpr.copyDirectory(f);
    String path = new File(f, "logtestfile%d.txt").getAbsolutePath();

    InputSplit is = new NumberedFileInputSplit(path, 0, 1);

    SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
    rr.initialize(is);

    List<List<List<Writable>>> out = new ArrayList<>();
    while (rr.hasNext()) {
        out.add(rr.sequenceRecord());
    }

    assertEquals(2, out.size());
    List<List<List<Writable>>> out2 = new ArrayList<>();
    List<SequenceRecord> out3 = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    rr.reset();
    while (rr.hasNext()) {
        SequenceRecord seqr = rr.nextSequence();
        out2.add(seqr.getSequenceRecord());
        out3.add(seqr);
        meta.add(seqr.getMetaData());
    }

    List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta);

    assertEquals(out, out2);
    assertEquals(out3, fromMeta);
}
 
Example 6
Source File: CSVNLinesSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
    int nLinesPerSequence = 10;

    SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
    seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    CSVRecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    List<List<List<Writable>>> out = new ArrayList<>();
    while (seqRR.hasNext()) {
        List<List<Writable>> next = seqRR.sequenceRecord();
        out.add(next);
    }

    seqRR.reset();
    List<List<List<Writable>>> out2 = new ArrayList<>();
    List<SequenceRecord> out3 = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    while (seqRR.hasNext()) {
        SequenceRecord seq = seqRR.nextSequence();
        out2.add(seq.getSequenceRecord());
        meta.add(seq.getMetaData());
        out3.add(seq);
    }

    assertEquals(out, out2);

    List<SequenceRecord> out4 = seqRR.loadSequenceFromMetaData(meta);
    assertEquals(out3, out4);
}
 
Example 7
Source File: CSVNLinesSequenceRecordReaderTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testCSVNLinesSequenceRecordReader() throws Exception {
    int nLinesPerSequence = 10;

    SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
    seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    CSVRecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    int count = 0;
    while (seqRR.hasNext()) {
        List<List<Writable>> next = seqRR.sequenceRecord();

        List<List<Writable>> expected = new ArrayList<>();
        for (int i = 0; i < nLinesPerSequence; i++) {
            expected.add(rr.next());
        }

        assertEquals(10, next.size());
        assertEquals(expected, next);

        count++;
    }

    assertEquals(150 / nLinesPerSequence, count);
}
 
Example 8
Source File: CSVVariableSlidingWindowRecordReaderTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testCSVVariableSlidingWindowRecordReader() throws Exception {
    int maxLinesPerSequence = 3;

    SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence);
    seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    CSVRecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    int count = 0;
    while (seqRR.hasNext()) {
        List<List<Writable>> next = seqRR.sequenceRecord();

        if(count==maxLinesPerSequence-1) {
            LinkedList<List<Writable>> expected = new LinkedList<>();
            for (int i = 0; i < maxLinesPerSequence; i++) {
                expected.addFirst(rr.next());
            }
            assertEquals(expected, next);

        }
        if(count==maxLinesPerSequence) {
            assertEquals(maxLinesPerSequence, next.size());
        }
        if(count==0) { // first seq should be length 1
            assertEquals(1, next.size());
        }
        if(count>151) { // last seq should be length 1
            assertEquals(1, next.size());
        }

        count++;
    }

    assertEquals(152, count);
}
 
Example 9
Source File: RecordReaderConverter.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Write all sequences from the specified sequence record reader to the specified sequence record writer.
 * Closes the sequence record writer on completion.
 *
 * @param reader Sequence record reader (source of data)
 * @param writer Sequence record writer (location to write data)
 * @param closeOnCompletion if true: close the record writer once complete, via {@link SequenceRecordWriter#close()}
 * @throws IOException If underlying reader/writer throws an exception
 */
public static void convert(SequenceRecordReader reader, SequenceRecordWriter writer, boolean closeOnCompletion) throws IOException {

    if(!reader.hasNext()){
        throw new UnsupportedOperationException("Cannot convert SequenceRecordReader: reader has no next element");
    }

    while(reader.hasNext()){
        writer.write(reader.sequenceRecord());
    }

    if(closeOnCompletion){
        writer.close();
    }
}
 
Example 10
Source File: AnalyzeLocal.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Get a list of unique values from the specified column of a sequence
 *
 * @param columnName      Name of the column to get unique values from
 * @param schema          Data schema
 * @param sequenceData    Sequence data to get unique values from
 * @return
 */
public static Set<Writable> getUniqueSequence(String columnName, Schema schema,
                                               SequenceRecordReader sequenceData) {
    int colIdx = schema.getIndexOfColumn(columnName);
    Set<Writable> unique = new HashSet<>();
    while(sequenceData.hasNext()){
        List<List<Writable>> next = sequenceData.sequenceRecord();
        for(List<Writable> step : next){
            unique.add(step.get(colIdx));
        }
    }
    return unique;
}
 
Example 11
Source File: CSVNLinesSequenceRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
    int nLinesPerSequence = 10;

    SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
    seqRR.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    CSVRecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    List<List<List<Writable>>> out = new ArrayList<>();
    while (seqRR.hasNext()) {
        List<List<Writable>> next = seqRR.sequenceRecord();
        out.add(next);
    }

    seqRR.reset();
    List<List<List<Writable>>> out2 = new ArrayList<>();
    List<SequenceRecord> out3 = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    while (seqRR.hasNext()) {
        SequenceRecord seq = seqRR.nextSequence();
        out2.add(seq.getSequenceRecord());
        meta.add(seq.getMetaData());
        out3.add(seq);
    }

    assertEquals(out, out2);

    List<SequenceRecord> out4 = seqRR.loadSequenceFromMetaData(meta);
    assertEquals(out3, out4);
}
 
Example 12
Source File: CSVNLinesSequenceRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testCSVNLinesSequenceRecordReader() throws Exception {
    int nLinesPerSequence = 10;

    SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
    seqRR.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    CSVRecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    int count = 0;
    while (seqRR.hasNext()) {
        List<List<Writable>> next = seqRR.sequenceRecord();

        List<List<Writable>> expected = new ArrayList<>();
        for (int i = 0; i < nLinesPerSequence; i++) {
            expected.add(rr.next());
        }

        assertEquals(10, next.size());
        assertEquals(expected, next);

        count++;
    }

    assertEquals(150 / nLinesPerSequence, count);
}
 
Example 13
Source File: RecordReaderConverter.java    From DataVec with Apache License 2.0 5 votes vote down vote up
/**
 * Write all sequences from the specified sequence record reader to the specified sequence record writer.
 * Closes the sequence record writer on completion.
 *
 * @param reader Sequence record reader (source of data)
 * @param writer Sequence record writer (location to write data)
 * @param closeOnCompletion if true: close the record writer once complete, via {@link SequenceRecordWriter#close()}
 * @throws IOException If underlying reader/writer throws an exception
 */
public static void convert(SequenceRecordReader reader, SequenceRecordWriter writer, boolean closeOnCompletion) throws IOException {

    if(!reader.hasNext()){
        throw new UnsupportedOperationException("Cannot convert SequenceRecordReader: reader has no next element");
    }

    while(reader.hasNext()){
        writer.write(reader.sequenceRecord());
    }

    if(closeOnCompletion){
        writer.close();
    }
}
 
Example 14
Source File: TestPairSequenceRecordReaderBytesFunction.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void test() throws Exception {
    //Goal: combine separate files together into a hadoop sequence file, for later parsing by a SequenceRecordReader
    //For example: use to combine input and labels data from separate files for training a RNN
    JavaSparkContext sc = getContext();

    File f = testDir.newFolder();
    new ClassPathResource("datavec-spark/video/").copyDirectory(f);
    String path = f.getAbsolutePath() + "/*";

    PathToKeyConverter pathConverter = new PathToKeyConverterFilename();
    JavaPairRDD<Text, BytesPairWritable> toWrite =
                    DataVecSparkUtil.combineFilesForSequenceFile(sc, path, path, pathConverter);

    Path p = Files.createTempDirectory("dl4j_rrbytesPairOut");
    p.toFile().deleteOnExit();
    String outPath = p.toString() + "/out";
    new File(outPath).deleteOnExit();
    toWrite.saveAsNewAPIHadoopFile(outPath, Text.class, BytesPairWritable.class, SequenceFileOutputFormat.class);

    //Load back into memory:
    JavaPairRDD<Text, BytesPairWritable> fromSeq = sc.sequenceFile(outPath, Text.class, BytesPairWritable.class);

    SequenceRecordReader srr1 = getReader();
    SequenceRecordReader srr2 = getReader();
    PairSequenceRecordReaderBytesFunction psrbf = new PairSequenceRecordReaderBytesFunction(srr1, srr2);

    JavaRDD<Tuple2<List<List<Writable>>, List<List<Writable>>>> writables = fromSeq.map(psrbf);
    List<Tuple2<List<List<Writable>>, List<List<Writable>>>> fromSequenceFile = writables.collect();

    //Load manually (single copy) and compare:
    InputSplit is = new FileSplit(f, new String[] {"mp4"}, true);
    SequenceRecordReader srr = getReader();
    srr.initialize(is);

    List<List<List<Writable>>> list = new ArrayList<>(4);
    while (srr.hasNext()) {
        list.add(srr.sequenceRecord());
    }

    assertEquals(4, list.size());
    assertEquals(4, fromSequenceFile.size());

    boolean[] found = new boolean[4];
    for (int i = 0; i < 4; i++) {
        int foundIndex = -1;
        Tuple2<List<List<Writable>>, List<List<Writable>>> tuple2 = fromSequenceFile.get(i);
        List<List<Writable>> seq1 = tuple2._1();
        List<List<Writable>> seq2 = tuple2._2();
        assertEquals(seq1, seq2);

        for (int j = 0; j < 4; j++) {
            if (seq1.equals(list.get(j))) {
                if (foundIndex != -1)
                    fail(); //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                foundIndex = j;
                if (found[foundIndex])
                    fail(); //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                found[foundIndex] = true; //mark this one as seen before
            }
        }
    }
    int count = 0;
    for (boolean b : found)
        if (b)
            count++;
    assertEquals(4, count); //Expect all 4 and exactly 4 pairwise matches between spark and local versions

}
 
Example 15
Source File: TestSequenceRecordReaderBytesFunction.java    From DataVec with Apache License 2.0 4 votes vote down vote up
@Test
public void testRecordReaderBytesFunction() throws Exception {

    //Local file path
    ClassPathResource cpr = new ClassPathResource("/video/shapes_0.mp4");
    String path = cpr.getFile().getAbsolutePath();
    String folder = path.substring(0, path.length() - 12);
    path = folder + "*";

    //Load binary data from local file system, convert to a sequence file:
    //Load and convert
    JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
    JavaPairRDD<Text, BytesWritable> filesAsBytes = origData.mapToPair(new FilesAsBytesFunction());
    //Write the sequence file:
    Path p = Files.createTempDirectory("dl4j_rrbytesTest");
    p.toFile().deleteOnExit();
    String outPath = p.toString() + "/out";
    filesAsBytes.saveAsNewAPIHadoopFile(outPath, Text.class, BytesWritable.class, SequenceFileOutputFormat.class);

    //Load data from sequence file, parse via SequenceRecordReader:
    JavaPairRDD<Text, BytesWritable> fromSeqFile = sc.sequenceFile(outPath, Text.class, BytesWritable.class);
    SequenceRecordReader seqRR = new CodecRecordReader();
    Configuration conf = new Configuration();
    conf.set(CodecRecordReader.RAVEL, "true");
    conf.set(CodecRecordReader.START_FRAME, "0");
    conf.set(CodecRecordReader.TOTAL_FRAMES, "25");
    conf.set(CodecRecordReader.ROWS, "64");
    conf.set(CodecRecordReader.COLUMNS, "64");
    Configuration confCopy = new Configuration(conf);
    seqRR.setConf(conf);
    JavaRDD<List<List<Writable>>> dataVecData = fromSeqFile.map(new SequenceRecordReaderBytesFunction(seqRR));



    //Next: do the same thing locally, and compare the results
    InputSplit is = new FileSplit(new File(folder), new String[] {"mp4"}, true);
    SequenceRecordReader srr = new CodecRecordReader();
    srr.initialize(is);
    srr.setConf(confCopy);

    List<List<List<Writable>>> list = new ArrayList<>(4);
    while (srr.hasNext()) {
        list.add(srr.sequenceRecord());
    }
    assertEquals(4, list.size());

    List<List<List<Writable>>> fromSequenceFile = dataVecData.collect();

    assertEquals(4, list.size());
    assertEquals(4, fromSequenceFile.size());

    boolean[] found = new boolean[4];
    for (int i = 0; i < 4; i++) {
        int foundIndex = -1;
        List<List<Writable>> collection = fromSequenceFile.get(i);
        for (int j = 0; j < 4; j++) {
            if (collection.equals(list.get(j))) {
                if (foundIndex != -1)
                    fail(); //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                foundIndex = j;
                if (found[foundIndex])
                    fail(); //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                found[foundIndex] = true; //mark this one as seen before
            }
        }
    }
    int count = 0;
    for (boolean b : found)
        if (b)
            count++;
    assertEquals(4, count); //Expect all 4 and exactly 4 pairwise matches between spark and local versions
}
 
Example 16
Source File: TestPairSequenceRecordReaderBytesFunction.java    From DataVec with Apache License 2.0 4 votes vote down vote up
@Test
public void test() throws Exception {
    //Goal: combine separate files together into a hadoop sequence file, for later parsing by a SequenceRecordReader
    //For example: use to combine input and labels data from separate files for training a RNN
    JavaSparkContext sc = getContext();

    ClassPathResource cpr = new ClassPathResource("/video/shapes_0.mp4");
    String path = cpr.getFile().getAbsolutePath();
    String folder = path.substring(0, path.length() - 12);
    path = folder + "*";

    PathToKeyConverter pathConverter = new PathToKeyConverterFilename();
    JavaPairRDD<Text, BytesPairWritable> toWrite =
                    DataVecSparkUtil.combineFilesForSequenceFile(sc, path, path, pathConverter);

    Path p = Files.createTempDirectory("dl4j_rrbytesPairOut");
    p.toFile().deleteOnExit();
    String outPath = p.toString() + "/out";
    new File(outPath).deleteOnExit();
    toWrite.saveAsNewAPIHadoopFile(outPath, Text.class, BytesPairWritable.class, SequenceFileOutputFormat.class);

    //Load back into memory:
    JavaPairRDD<Text, BytesPairWritable> fromSeq = sc.sequenceFile(outPath, Text.class, BytesPairWritable.class);

    SequenceRecordReader srr1 = getReader();
    SequenceRecordReader srr2 = getReader();
    PairSequenceRecordReaderBytesFunction psrbf = new PairSequenceRecordReaderBytesFunction(srr1, srr2);

    JavaRDD<Tuple2<List<List<Writable>>, List<List<Writable>>>> writables = fromSeq.map(psrbf);
    List<Tuple2<List<List<Writable>>, List<List<Writable>>>> fromSequenceFile = writables.collect();

    //Load manually (single copy) and compare:
    InputSplit is = new FileSplit(new File(folder), new String[] {"mp4"}, true);
    SequenceRecordReader srr = getReader();
    srr.initialize(is);

    List<List<List<Writable>>> list = new ArrayList<>(4);
    while (srr.hasNext()) {
        list.add(srr.sequenceRecord());
    }

    assertEquals(4, list.size());
    assertEquals(4, fromSequenceFile.size());

    boolean[] found = new boolean[4];
    for (int i = 0; i < 4; i++) {
        int foundIndex = -1;
        Tuple2<List<List<Writable>>, List<List<Writable>>> tuple2 = fromSequenceFile.get(i);
        List<List<Writable>> seq1 = tuple2._1();
        List<List<Writable>> seq2 = tuple2._2();
        assertEquals(seq1, seq2);

        for (int j = 0; j < 4; j++) {
            if (seq1.equals(list.get(j))) {
                if (foundIndex != -1)
                    fail(); //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                foundIndex = j;
                if (found[foundIndex])
                    fail(); //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                found[foundIndex] = true; //mark this one as seen before
            }
        }
    }
    int count = 0;
    for (boolean b : found)
        if (b)
            count++;
    assertEquals(4, count); //Expect all 4 and exactly 4 pairwise matches between spark and local versions

}
 
Example 17
Source File: TestSequenceRecordReaderBytesFunction.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRecordReaderBytesFunction() throws Exception {

    //Local file path
    File f = testDir.newFolder();
    new ClassPathResource("datavec-spark/video/").copyDirectory(f);
    String path = f.getAbsolutePath() + "/*";

    //Load binary data from local file system, convert to a sequence file:
    //Load and convert
    JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
    JavaPairRDD<Text, BytesWritable> filesAsBytes = origData.mapToPair(new FilesAsBytesFunction());
    //Write the sequence file:
    Path p = Files.createTempDirectory("dl4j_rrbytesTest");
    p.toFile().deleteOnExit();
    String outPath = p.toString() + "/out";
    filesAsBytes.saveAsNewAPIHadoopFile(outPath, Text.class, BytesWritable.class, SequenceFileOutputFormat.class);

    //Load data from sequence file, parse via SequenceRecordReader:
    JavaPairRDD<Text, BytesWritable> fromSeqFile = sc.sequenceFile(outPath, Text.class, BytesWritable.class);
    SequenceRecordReader seqRR = new CodecRecordReader();
    Configuration conf = new Configuration();
    conf.set(CodecRecordReader.RAVEL, "true");
    conf.set(CodecRecordReader.START_FRAME, "0");
    conf.set(CodecRecordReader.TOTAL_FRAMES, "25");
    conf.set(CodecRecordReader.ROWS, "64");
    conf.set(CodecRecordReader.COLUMNS, "64");
    Configuration confCopy = new Configuration(conf);
    seqRR.setConf(conf);
    JavaRDD<List<List<Writable>>> dataVecData = fromSeqFile.map(new SequenceRecordReaderBytesFunction(seqRR));



    //Next: do the same thing locally, and compare the results
    InputSplit is = new FileSplit(f, new String[] {"mp4"}, true);
    SequenceRecordReader srr = new CodecRecordReader();
    srr.initialize(is);
    srr.setConf(confCopy);

    List<List<List<Writable>>> list = new ArrayList<>(4);
    while (srr.hasNext()) {
        list.add(srr.sequenceRecord());
    }
    assertEquals(4, list.size());

    List<List<List<Writable>>> fromSequenceFile = dataVecData.collect();

    assertEquals(4, list.size());
    assertEquals(4, fromSequenceFile.size());

    boolean[] found = new boolean[4];
    for (int i = 0; i < 4; i++) {
        int foundIndex = -1;
        List<List<Writable>> collection = fromSequenceFile.get(i);
        for (int j = 0; j < 4; j++) {
            if (collection.equals(list.get(j))) {
                if (foundIndex != -1)
                    fail(); //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                foundIndex = j;
                if (found[foundIndex])
                    fail(); //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                found[foundIndex] = true; //mark this one as seen before
            }
        }
    }
    int count = 0;
    for (boolean b : found)
        if (b)
            count++;
    assertEquals(4, count); //Expect all 4 and exactly 4 pairwise matches between spark and local versions
}
 
Example 18
Source File: CSVVariableSlidingWindowRecordReaderTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
    int maxLinesPerSequence = 3;
    int stride = 2;

    SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride);
    seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    CSVRecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));

    int count = 0;
    while (seqRR.hasNext()) {
        List<List<Writable>> next = seqRR.sequenceRecord();

        if(count==maxLinesPerSequence-1) {
            LinkedList<List<Writable>> expected = new LinkedList<>();
            for(int s = 0; s < stride; s++) {
                expected = new LinkedList<>();
                for (int i = 0; i < maxLinesPerSequence; i++) {
                    expected.addFirst(rr.next());
                }
            }
            assertEquals(expected, next);

        }
        if(count==maxLinesPerSequence) {
            assertEquals(maxLinesPerSequence, next.size());
        }
        if(count==0) { // first seq should be length 2
            assertEquals(2, next.size());
        }
        if(count>151) { // last seq should be length 1
            assertEquals(1, next.size());
        }

        count++;
    }

    assertEquals(76, count);
}
 
Example 19
Source File: TestCollectionRecordReaders.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testCollectionSequenceRecordReader() throws Exception {

    List<List<List<Writable>>> listOfSequences = new ArrayList<>();

    List<List<Writable>> sequence1 = new ArrayList<>();
    sequence1.add(Arrays.asList((Writable) new IntWritable(0), new IntWritable(1)));
    sequence1.add(Arrays.asList((Writable) new IntWritable(2), new IntWritable(3)));
    listOfSequences.add(sequence1);

    List<List<Writable>> sequence2 = new ArrayList<>();
    sequence2.add(Arrays.asList((Writable) new IntWritable(4), new IntWritable(5)));
    sequence2.add(Arrays.asList((Writable) new IntWritable(6), new IntWritable(7)));
    listOfSequences.add(sequence2);

    SequenceRecordReader seqRR = new CollectionSequenceRecordReader(listOfSequences);
    assertTrue(seqRR.hasNext());

    assertEquals(sequence1, seqRR.sequenceRecord());
    assertEquals(sequence2, seqRR.sequenceRecord());
    assertFalse(seqRR.hasNext());

    seqRR.reset();
    assertEquals(sequence1, seqRR.sequenceRecord());
    assertEquals(sequence2, seqRR.sequenceRecord());
    assertFalse(seqRR.hasNext());

    //Test metadata:
    seqRR.reset();
    List<List<List<Writable>>> out2 = new ArrayList<>();
    List<SequenceRecord> seq = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();

    while (seqRR.hasNext()) {
        SequenceRecord r = seqRR.nextSequence();
        out2.add(r.getSequenceRecord());
        seq.add(r);
        meta.add(r.getMetaData());
    }

    assertEquals(listOfSequences, out2);

    List<SequenceRecord> fromMeta = seqRR.loadSequenceFromMetaData(meta);
    assertEquals(seq, fromMeta);
}
 
Example 20
Source File: CSVVariableSlidingWindowRecordReaderTest.java    From DataVec with Apache License 2.0 4 votes vote down vote up
@Test
public void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
    int maxLinesPerSequence = 3;
    int stride = 2;

    SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride);
    seqRR.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    CSVRecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    int count = 0;
    while (seqRR.hasNext()) {
        List<List<Writable>> next = seqRR.sequenceRecord();

        if(count==maxLinesPerSequence-1) {
            LinkedList<List<Writable>> expected = new LinkedList<>();
            for(int s = 0; s < stride; s++) {
                expected = new LinkedList<>();
                for (int i = 0; i < maxLinesPerSequence; i++) {
                    expected.addFirst(rr.next());
                }
            }
            assertEquals(expected, next);

        }
        if(count==maxLinesPerSequence) {
            assertEquals(maxLinesPerSequence, next.size());
        }
        if(count==0) { // first seq should be length 2
            assertEquals(2, next.size());
        }
        if(count>151) { // last seq should be length 1
            assertEquals(1, next.size());
        }

        count++;
    }

    assertEquals(76, count);
}