org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator Java Examples

The following examples show how to use org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TestRecordReaders.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testClassIndexOutsideOfRangeRRMDSI() {

    Collection<Collection<Collection<Writable>>> c = new ArrayList<>();
    Collection<Collection<Writable>> seq1 = new ArrayList<>();
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(1)));
    c.add(seq1);

    Collection<Collection<Writable>> seq2 = new ArrayList<>();
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(2)));
    c.add(seq2);

    CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c);
    DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, 2, 2, 1);

    try {
        DataSet ds = dsi.next();
        fail("Expected exception");
    } catch (Exception e) {
        assertTrue(e.getMessage(), e.getMessage().contains("to one-hot"));
    }
}
 
Example #2
Source File: RNNTestCases.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception {
    int miniBatchSize = 10;
    int numLabelClasses = 6;

    File featuresDirTrain = Files.createTempDir();
    File labelsDirTrain = Files.createTempDir();
    new ClassPathResource("dl4j-integration-tests/data/uci_seq/train/features/").copyDirectory(featuresDirTrain);
    new ClassPathResource("dl4j-integration-tests/data/uci_seq/train/labels/").copyDirectory(labelsDirTrain);

    SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
    trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
    SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
    trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));

    DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
            false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData);
    return iter;
}
 
Example #3
Source File: SameDiffRNNTestCases.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception {
    int miniBatchSize = 10;
    int numLabelClasses = 6;

    File featuresDirTrain = Files.createTempDir();
    File labelsDirTrain = Files.createTempDir();
    Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/features/", featuresDirTrain);
    Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/labels/", labelsDirTrain);

    SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
    trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
    SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
    trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));

    DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
            false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData);

    return iter;
}
 
Example #4
Source File: TestRecordReaders.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testClassIndexOutsideOfRangeRRMDSI_MultipleReaders() {

    Collection<Collection<Collection<Writable>>> c1 = new ArrayList<>();
    Collection<Collection<Writable>> seq1 = new ArrayList<>();
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    c1.add(seq1);

    Collection<Collection<Writable>> seq2 = new ArrayList<>();
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    c1.add(seq2);

    Collection<Collection<Collection<Writable>>> c2 = new ArrayList<>();
    Collection<Collection<Writable>> seq1a = new ArrayList<>();
    seq1a.add(Arrays.<Writable>asList(new IntWritable(0)));
    seq1a.add(Arrays.<Writable>asList(new IntWritable(1)));
    c2.add(seq1a);

    Collection<Collection<Writable>> seq2a = new ArrayList<>();
    seq2a.add(Arrays.<Writable>asList(new IntWritable(0)));
    seq2a.add(Arrays.<Writable>asList(new IntWritable(2)));
    c2.add(seq2a);

    CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c1);
    CollectionSequenceRecordReader csrrLabels = new CollectionSequenceRecordReader(c2);
    DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, csrrLabels, 2, 2);

    try {
        DataSet ds = dsi.next();
        fail("Expected exception");
    } catch (Exception e) {
        assertTrue(e.getMessage(), e.getMessage().contains("to one-hot"));
    }
}
 
Example #5
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testEvalSplitting2(){
    List<List<Writable>> seqFeatures = new ArrayList<>();
    List<Writable> step = Arrays.<Writable>asList(new FloatWritable(0), new FloatWritable(0), new FloatWritable(0));
    for( int i=0; i<30; i++ ){
        seqFeatures.add(step);
    }
    List<List<Writable>> seqLabels = Collections.singletonList(Collections.<Writable>singletonList(new FloatWritable(0)));

    SequenceRecordReader fsr = new CollectionSequenceRecordReader(Collections.singletonList(seqFeatures));
    SequenceRecordReader lsr = new CollectionSequenceRecordReader(Collections.singletonList(seqLabels));


    DataSetIterator testData = new SequenceRecordReaderDataSetIterator(fsr, lsr, 1, -1, true,
            SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123)
            .list()
            .layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(3).nOut(3).build())
            .layer(1, new RnnOutputLayer.Builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.XENT)
                    .nIn(3).nOut(1).build())
            .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(10).tBPTTBackwardLength(10)
            .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    net.evaluate(testData);
}
 
Example #6
Source File: RNNTestCases.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
        public MultiDataSetIterator getEvaluationTestData() throws Exception {
            int miniBatchSize = 10;
            int numLabelClasses = 6;

//            File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile();
//            File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile();
            File featuresDirTest = Files.createTempDir();
            File labelsDirTest = Files.createTempDir();
            new ClassPathResource("dl4j-integration-tests/data/uci_seq/test/features/").copyDirectory(featuresDirTest);
            new ClassPathResource("dl4j-integration-tests/data/uci_seq/test/labels/").copyDirectory(labelsDirTest);

            SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
            trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
            SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
            trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));

            DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
                    false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

            MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData);

            MultiDataSetPreProcessor pp = multiDataSet -> {
                INDArray l = multiDataSet.getLabels(0);
                l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2)-1));
                multiDataSet.setLabels(0, l);
                multiDataSet.setLabelsMaskArray(0, null);
            };


            iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(),pp));

            return iter;
        }
 
Example #7
Source File: SameDiffRNNTestCases.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
        public MultiDataSetIterator getEvaluationTestData() throws Exception {
            int miniBatchSize = 10;
            int numLabelClasses = 6;

//            File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile();
//            File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile();
            File featuresDirTest = Files.createTempDir();
            File labelsDirTest = Files.createTempDir();
            Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest);
            Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest);

            SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
            trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
            SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
            trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));

            DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
                    false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

            MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData);

            MultiDataSetPreProcessor pp = multiDataSet -> {
                INDArray l = multiDataSet.getLabels(0);
                l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1));
                multiDataSet.setLabels(0, l);
                multiDataSet.setLabelsMaskArray(0, null);
            };


            iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp));

            return iter;
        }
 
Example #8
Source File: LstmTimeSeriesExample.java    From Java-Deep-Learning-Cookbook with MIT License 4 votes vote down vote up
public static void main(String[] args) throws IOException, InterruptedException {
    if(FEATURE_DIR.equals("{PATH-TO-PHYSIONET-FEATURES}") || LABEL_DIR.equals("{PATH-TO-PHYSIONET-LABELS")){
        System.out.println("Please provide proper directory path in place of: PATH-TO-PHYSIONET-FEATURES && PATH-TO-PHYSIONET-LABELS");
        throw new FileNotFoundException();
    }
    SequenceRecordReader trainFeaturesReader = new CSVSequenceRecordReader(1, ",");
    trainFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",0,3199));
    SequenceRecordReader trainLabelsReader = new CSVSequenceRecordReader();
    trainLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",0,3199));
    DataSetIterator trainDataSetIterator = new SequenceRecordReaderDataSetIterator(trainFeaturesReader,trainLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    SequenceRecordReader testFeaturesReader = new CSVSequenceRecordReader(1, ",");
    testFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",3200,3999));
    SequenceRecordReader testLabelsReader = new CSVSequenceRecordReader();
    testLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",3200,3999));
    DataSetIterator testDataSetIterator = new SequenceRecordReaderDataSetIterator(testFeaturesReader,testLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder()
                                                    .seed(RANDOM_SEED)
                                                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                                                    .weightInit(WeightInit.XAVIER)
                                                    .updater(new Adam())
                                                    .dropOut(0.9)
                                                    .graphBuilder()
                                                    .addInputs("trainFeatures")
                                                    .setOutputs("predictMortality")
                                                    .addLayer("L1", new LSTM.Builder()
                                                                                   .nIn(86)
                                                                                    .nOut(200)
                                                                                    .forgetGateBiasInit(1)
                                                                                    .activation(Activation.TANH)
                                                                                    .build(),"trainFeatures")
                                                    .addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                                                                        .activation(Activation.SOFTMAX)
                                                                                        .nIn(200).nOut(2).build(),"L1")
                                                    .build();

    ComputationGraph model = new ComputationGraph(configuration);

    for(int i=0;i<1;i++){
       model.fit(trainDataSetIterator);
       trainDataSetIterator.reset();
    }
    ROC evaluation = new ROC(100);
    while (testDataSetIterator.hasNext()) {
        DataSet batch = testDataSetIterator.next();
        INDArray[] output = model.output(batch.getFeatures());
        evaluation.evalTimeSeries(batch.getLabels(), output[0]);
    }
    
    System.out.println(evaluation.calculateAUC());
    System.out.println(evaluation.stats());
}
 
Example #9
Source File: LstmTimeSeriesExample.java    From Java-Deep-Learning-Cookbook with MIT License 4 votes vote down vote up
public static void main(String[] args) throws IOException, InterruptedException {
    if(FEATURE_DIR.equals("{PATH-TO-PHYSIONET-FEATURES}") || LABEL_DIR.equals("{PATH-TO-PHYSIONET-LABELS")){
        System.out.println("Please provide proper directory path in place of: PATH-TO-PHYSIONET-FEATURES && PATH-TO-PHYSIONET-LABELS");
        throw new FileNotFoundException();
    }
    SequenceRecordReader trainFeaturesReader = new CSVSequenceRecordReader(1, ",");
    trainFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",0,3199));
    SequenceRecordReader trainLabelsReader = new CSVSequenceRecordReader();
    trainLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",0,3199));
    DataSetIterator trainDataSetIterator = new SequenceRecordReaderDataSetIterator(trainFeaturesReader,trainLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    SequenceRecordReader testFeaturesReader = new CSVSequenceRecordReader(1, ",");
    testFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",3200,3999));
    SequenceRecordReader testLabelsReader = new CSVSequenceRecordReader();
    testLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",3200,3999));
    DataSetIterator testDataSetIterator = new SequenceRecordReaderDataSetIterator(testFeaturesReader,testLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder()
                                                    .seed(RANDOM_SEED)
                                                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                                                    .weightInit(WeightInit.XAVIER)
                                                    .updater(new Adam())
                                                    .dropOut(0.9)
                                                    .graphBuilder()
                                                    .addInputs("trainFeatures")
                                                    .setOutputs("predictMortality")
                                                    .addLayer("L1", new LSTM.Builder()
                                                                                   .nIn(86)
                                                                                    .nOut(200)
                                                                                    .forgetGateBiasInit(1)
                                                                                    .activation(Activation.TANH)
                                                                                    .build(),"trainFeatures")
                                                    .addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                                                                        .activation(Activation.SOFTMAX)
                                                                                        .nIn(200).nOut(2).build(),"L1")
                                                    .build();

    ComputationGraph model = new ComputationGraph(configuration);

    for(int i=0;i<1;i++){
       model.fit(trainDataSetIterator);
       trainDataSetIterator.reset();
    }
    ROC evaluation = new ROC(100);
    while (testDataSetIterator.hasNext()) {
        DataSet batch = testDataSetIterator.next();
        INDArray[] output = model.output(batch.getFeatures());
        evaluation.evalTimeSeries(batch.getLabels(), output[0]);
    }
    
    System.out.println(evaluation.calculateAUC());
    System.out.println(evaluation.stats());
}
 
Example #10
Source File: TestDataVecDataSetFunctions.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testDataVecSequencePairDataSetFunction() throws Exception {
    JavaSparkContext sc = getContext();

    File f = testDir.newFolder();
    ClassPathResource cpr = new ClassPathResource("dl4j-spark/csvsequence/");
    cpr.copyDirectory(f);
    String path = f.getAbsolutePath() + "/*";

    PathToKeyConverter pathConverter = new PathToKeyConverterFilename();
    JavaPairRDD<Text, BytesPairWritable> toWrite =
                    DataVecSparkUtil.combineFilesForSequenceFile(sc, path, path, pathConverter);

    Path p = testDir.newFolder("dl4j_testSeqPairFn").toPath();
    p.toFile().deleteOnExit();
    String outPath = p.toString() + "/out";
    new File(outPath).deleteOnExit();
    toWrite.saveAsNewAPIHadoopFile(outPath, Text.class, BytesPairWritable.class, SequenceFileOutputFormat.class);

    //Load from sequence file:
    JavaPairRDD<Text, BytesPairWritable> fromSeq = sc.sequenceFile(outPath, Text.class, BytesPairWritable.class);

    SequenceRecordReader srr1 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader srr2 = new CSVSequenceRecordReader(1, ",");
    PairSequenceRecordReaderBytesFunction psrbf = new PairSequenceRecordReaderBytesFunction(srr1, srr2);
    JavaRDD<Tuple2<List<List<Writable>>, List<List<Writable>>>> writables = fromSeq.map(psrbf);

    //Map to DataSet:
    DataVecSequencePairDataSetFunction pairFn = new DataVecSequencePairDataSetFunction();
    JavaRDD<DataSet> data = writables.map(pairFn);
    List<DataSet> sparkData = data.collect();


    //Now: do the same thing locally (SequenceRecordReaderDataSetIterator) and compare
    String featuresPath = FilenameUtils.concat(f.getAbsolutePath(), "csvsequence_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
                    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true);

    List<DataSet> localData = new ArrayList<>(3);
    while (iter.hasNext())
        localData.add(iter.next());

    assertEquals(3, sparkData.size());
    assertEquals(3, localData.size());

    for (int i = 0; i < 3; i++) {
        //Check shapes etc. data sets order may differ for spark vs. local
        DataSet dsSpark = sparkData.get(i);
        DataSet dsLocal = localData.get(i);

        assertNull(dsSpark.getFeaturesMaskArray());
        assertNull(dsSpark.getLabelsMaskArray());

        INDArray fSpark = dsSpark.getFeatures();
        INDArray fLocal = dsLocal.getFeatures();
        INDArray lSpark = dsSpark.getLabels();
        INDArray lLocal = dsLocal.getLabels();

        val s = new long[] {1, 3, 4}; //1 example, 3 values, 3 time steps
        assertArrayEquals(s, fSpark.shape());
        assertArrayEquals(s, fLocal.shape());
        assertArrayEquals(s, lSpark.shape());
        assertArrayEquals(s, lLocal.shape());
    }


    //Check that results are the same (order not withstanding)
    boolean[] found = new boolean[3];
    for (int i = 0; i < 3; i++) {
        int foundIndex = -1;
        DataSet ds = sparkData.get(i);
        for (int j = 0; j < 3; j++) {
            if (ds.equals(localData.get(j))) {
                if (foundIndex != -1)
                    fail(); //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                foundIndex = j;
                if (found[foundIndex])
                    fail(); //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                found[foundIndex] = true; //mark this one as seen before
            }
        }
    }
    int count = 0;
    for (boolean b : found)
        if (b)
            count++;
    assertEquals(3, count); //Expect all 3 and exactly 3 pairwise matches between spark and local versions
}