org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader Java Examples
The following examples show how to use
org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: CSVSequenceRecordReaderTest.java From DataVec with Apache License 2.0 | 6 votes |
@Test public void testReset() throws Exception { CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); seqReader.initialize(new TestInputSplit()); int nTests = 5; for (int i = 0; i < nTests; i++) { seqReader.reset(); int sequenceCount = 0; while (seqReader.hasNext()) { List<List<Writable>> sequence = seqReader.sequenceRecord(); assertEquals(4, sequence.size()); //4 lines, plus 1 header line Iterator<List<Writable>> timeStepIter = sequence.iterator(); int lineCount = 0; while (timeStepIter.hasNext()) { timeStepIter.next(); lineCount++; } sequenceCount++; assertEquals(4, lineCount); } assertEquals(3, sequenceCount); } }
Example #2
Source File: CSVSequenceRecordReaderTest.java From DataVec with Apache License 2.0 | 6 votes |
@Test public void testCsvSeqAndNumberedFileSplit() throws Exception { File baseDir = tempDir.newFolder(); //Simple sanity check unit test for (int i = 0; i < 3; i++) { new org.nd4j.linalg.io.ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir); } //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator org.nd4j.linalg.io.ClassPathResource resource = new org.nd4j.linalg.io.ClassPathResource("csvsequence_0.txt"); String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath(); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); while(featureReader.hasNext()){ featureReader.nextSequence(); } }
Example #3
Source File: SameDiffRNNTestCases.java From deeplearning4j with Apache License 2.0 | 6 votes |
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception { int miniBatchSize = 10; int numLabelClasses = 6; File featuresDirTrain = Files.createTempDir(); File labelsDirTrain = Files.createTempDir(); Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/features/", featuresDirTrain); Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/labels/", labelsDirTrain); SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData); return iter; }
Example #4
Source File: CSVSequenceRecordReaderTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testReset() throws Exception { CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); seqReader.initialize(new TestInputSplit()); int nTests = 5; for (int i = 0; i < nTests; i++) { seqReader.reset(); int sequenceCount = 0; while (seqReader.hasNext()) { List<List<Writable>> sequence = seqReader.sequenceRecord(); assertEquals(4, sequence.size()); //4 lines, plus 1 header line Iterator<List<Writable>> timeStepIter = sequence.iterator(); int lineCount = 0; while (timeStepIter.hasNext()) { timeStepIter.next(); lineCount++; } sequenceCount++; assertEquals(4, lineCount); } assertEquals(3, sequenceCount); } }
Example #5
Source File: CSVSequenceRecordReaderTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testCsvSeqAndNumberedFileSplit() throws Exception { File baseDir = tempDir.newFolder(); //Simple sanity check unit test for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir); } //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator ClassPathResource resource = new ClassPathResource("csvsequence_0.txt"); String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath(); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); while(featureReader.hasNext()){ featureReader.nextSequence(); } }
Example #6
Source File: RNNTestCases.java From deeplearning4j with Apache License 2.0 | 6 votes |
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception { int miniBatchSize = 10; int numLabelClasses = 6; File featuresDirTrain = Files.createTempDir(); File labelsDirTrain = Files.createTempDir(); new ClassPathResource("dl4j-integration-tests/data/uci_seq/train/features/").copyDirectory(featuresDirTrain); new ClassPathResource("dl4j-integration-tests/data/uci_seq/train/labels/").copyDirectory(labelsDirTrain); SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData); return iter; }
Example #7
Source File: RNNTestCases.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public MultiDataSetIterator getEvaluationTestData() throws Exception { int miniBatchSize = 10; int numLabelClasses = 6; // File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile(); // File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile(); File featuresDirTest = Files.createTempDir(); File labelsDirTest = Files.createTempDir(); new ClassPathResource("dl4j-integration-tests/data/uci_seq/test/features/").copyDirectory(featuresDirTest); new ClassPathResource("dl4j-integration-tests/data/uci_seq/test/labels/").copyDirectory(labelsDirTest); SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData); MultiDataSetPreProcessor pp = multiDataSet -> { INDArray l = multiDataSet.getLabels(0); l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2)-1)); multiDataSet.setLabels(0, l); multiDataSet.setLabelsMaskArray(0, null); }; iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(),pp)); return iter; }
Example #8
Source File: RecordReaderMultiDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testSplittingCSVSequenceMeta() throws Exception { //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt" // as standard one-hot output //need to manually extract File rootDir = temporaryFolder.newFolder(); for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1) .addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2) .addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build(); srrmdsi.setCollectMetaData(true); int count = 0; while (srrmdsi.hasNext()) { MultiDataSet mds = srrmdsi.next(); MultiDataSet fromMeta = srrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class)); assertEquals(mds, fromMeta); count++; } assertEquals(3, count); }
Example #9
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test(expected = ZeroLengthSequenceException.class) public void testSequenceRecordReaderTwoReadersWithEmptyLabelSequenceThrows() throws Exception { SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); File f = Resources.asFile("csvsequence_0.txt"); featureReader.initialize(new FileSplit(f)); labelReader.initialize(new FileSplit(Resources.asFile("empty.txt"))); new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true).next(); }
Example #10
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test(expected = ZeroLengthSequenceException.class) public void testSequenceRecordReaderTwoReadersWithEmptyFeatureSequenceThrows() throws Exception { SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new FileSplit(Resources.asFile("empty.txt"))); labelReader.initialize( new FileSplit(Resources.asFile("csvsequencelabels_0.txt"))); new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true).next(); }
Example #11
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test(expected = ZeroLengthSequenceException.class) public void testSequenceRecordReaderSingleReaderWithEmptySequenceThrows() throws Exception { SequenceRecordReader reader = new CSVSequenceRecordReader(1, ","); reader.initialize(new FileSplit(Resources.asFile("empty.txt"))); new SequenceRecordReaderDataSetIterator(reader, 1, -1, 1, true).next(); }
Example #12
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testSequenceRecordReaderReset() throws Exception { File rootDir = temporaryFolder.newFolder(); //need to manually extract for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i))); FileUtils.copyFile(Resources.asFile(String.format("csvsequencelabels_%d.txt", i)), new File(rootDir, String.format("csvsequencelabels_%d.txt", i))); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); assertEquals(3, iter.inputColumns()); assertEquals(4, iter.totalOutcomes()); int nResets = 5; for (int i = 0; i < nResets; i++) { iter.reset(); int count = 0; while (iter.hasNext()) { DataSet ds = iter.next(); INDArray features = ds.getFeatures(); INDArray labels = ds.getLabels(); assertArrayEquals(new long[] {1, 3, 4}, features.shape()); assertArrayEquals(new long[] {1, 4, 4}, labels.shape()); count++; } assertEquals(3, count); } }
Example #13
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testSequenceRecordReaderMeta() throws Exception { File rootDir = temporaryFolder.newFolder(); //need to manually extract for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i))); FileUtils.copyFile(Resources.asFile(String.format("csvsequencelabels_%d.txt", i)), new File(rootDir, String.format("csvsequencelabels_%d.txt", i))); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); iter.setCollectMetaData(true); assertEquals(3, iter.inputColumns()); assertEquals(4, iter.totalOutcomes()); while (iter.hasNext()) { DataSet ds = iter.next(); List<RecordMetaData> meta = ds.getExampleMetaData(RecordMetaData.class); DataSet fromMeta = iter.loadFromMetaData(meta); assertEquals(ds, fromMeta); } }
Example #14
Source File: CSVSequenceRecordReaderTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testMetaData() throws Exception { CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); seqReader.initialize(new TestInputSplit()); List<List<List<Writable>>> l = new ArrayList<>(); while (seqReader.hasNext()) { List<List<Writable>> sequence = seqReader.sequenceRecord(); assertEquals(4, sequence.size()); //4 lines, plus 1 header line Iterator<List<Writable>> timeStepIter = sequence.iterator(); int lineCount = 0; while (timeStepIter.hasNext()) { timeStepIter.next(); lineCount++; } assertEquals(4, lineCount); l.add(sequence); } List<SequenceRecord> l2 = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>(); seqReader.reset(); while (seqReader.hasNext()) { SequenceRecord sr = seqReader.nextSequence(); l2.add(sr); meta.add(sr.getMetaData()); } assertEquals(3, l2.size()); List<SequenceRecord> fromMeta = seqReader.loadSequenceFromMetaData(meta); for (int i = 0; i < 3; i++) { assertEquals(l.get(i), l2.get(i).getSequenceRecord()); assertEquals(l.get(i), fromMeta.get(i).getSequenceRecord()); } }
Example #15
Source File: CSVSequenceRecordReaderTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void test() throws Exception { CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); seqReader.initialize(new TestInputSplit()); int sequenceCount = 0; while (seqReader.hasNext()) { List<List<Writable>> sequence = seqReader.sequenceRecord(); assertEquals(4, sequence.size()); //4 lines, plus 1 header line Iterator<List<Writable>> timeStepIter = sequence.iterator(); int lineCount = 0; while (timeStepIter.hasNext()) { List<Writable> timeStep = timeStepIter.next(); assertEquals(3, timeStep.size()); Iterator<Writable> lineIter = timeStep.iterator(); int countInLine = 0; while (lineIter.hasNext()) { Writable entry = lineIter.next(); int expValue = 100 * sequenceCount + 10 * lineCount + countInLine; assertEquals(String.valueOf(expValue), entry.toString()); countInLine++; } lineCount++; } sequenceCount++; } }
Example #16
Source File: CSVSequenceRecordReaderTest.java From DataVec with Apache License 2.0 | 5 votes |
@Test public void testMetaData() throws Exception { CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); seqReader.initialize(new TestInputSplit()); List<List<List<Writable>>> l = new ArrayList<>(); while (seqReader.hasNext()) { List<List<Writable>> sequence = seqReader.sequenceRecord(); assertEquals(4, sequence.size()); //4 lines, plus 1 header line Iterator<List<Writable>> timeStepIter = sequence.iterator(); int lineCount = 0; while (timeStepIter.hasNext()) { timeStepIter.next(); lineCount++; } assertEquals(4, lineCount); l.add(sequence); } List<SequenceRecord> l2 = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>(); seqReader.reset(); while (seqReader.hasNext()) { SequenceRecord sr = seqReader.nextSequence(); l2.add(sr); meta.add(sr.getMetaData()); } assertEquals(3, l2.size()); List<SequenceRecord> fromMeta = seqReader.loadSequenceFromMetaData(meta); for (int i = 0; i < 3; i++) { assertEquals(l.get(i), l2.get(i).getSequenceRecord()); assertEquals(l.get(i), fromMeta.get(i).getSequenceRecord()); } }
Example #17
Source File: CSVSequenceRecordReaderTest.java From DataVec with Apache License 2.0 | 5 votes |
@Test public void test() throws Exception { CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); seqReader.initialize(new TestInputSplit()); int sequenceCount = 0; while (seqReader.hasNext()) { List<List<Writable>> sequence = seqReader.sequenceRecord(); assertEquals(4, sequence.size()); //4 lines, plus 1 header line Iterator<List<Writable>> timeStepIter = sequence.iterator(); int lineCount = 0; while (timeStepIter.hasNext()) { List<Writable> timeStep = timeStepIter.next(); assertEquals(3, timeStep.size()); Iterator<Writable> lineIter = timeStep.iterator(); int countInLine = 0; while (lineIter.hasNext()) { Writable entry = lineIter.next(); int expValue = 100 * sequenceCount + 10 * lineCount + countInLine; assertEquals(String.valueOf(expValue), entry.toString()); countInLine++; } lineCount++; } sequenceCount++; } }
Example #18
Source File: SameDiffRNNTestCases.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public MultiDataSetIterator getEvaluationTestData() throws Exception { int miniBatchSize = 10; int numLabelClasses = 6; // File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile(); // File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile(); File featuresDirTest = Files.createTempDir(); File labelsDirTest = Files.createTempDir(); Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest); Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest); SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData); MultiDataSetPreProcessor pp = multiDataSet -> { INDArray l = multiDataSet.getLabels(0); l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1)); multiDataSet.setLabels(0, l); multiDataSet.setLabelsMaskArray(0, null); }; iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp)); return iter; }
Example #19
Source File: TestDataVecDataSetFunctions.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testDataVecSequencePairDataSetFunction() throws Exception { JavaSparkContext sc = getContext(); File f = testDir.newFolder(); ClassPathResource cpr = new ClassPathResource("dl4j-spark/csvsequence/"); cpr.copyDirectory(f); String path = f.getAbsolutePath() + "/*"; PathToKeyConverter pathConverter = new PathToKeyConverterFilename(); JavaPairRDD<Text, BytesPairWritable> toWrite = DataVecSparkUtil.combineFilesForSequenceFile(sc, path, path, pathConverter); Path p = testDir.newFolder("dl4j_testSeqPairFn").toPath(); p.toFile().deleteOnExit(); String outPath = p.toString() + "/out"; new File(outPath).deleteOnExit(); toWrite.saveAsNewAPIHadoopFile(outPath, Text.class, BytesPairWritable.class, SequenceFileOutputFormat.class); //Load from sequence file: JavaPairRDD<Text, BytesPairWritable> fromSeq = sc.sequenceFile(outPath, Text.class, BytesPairWritable.class); SequenceRecordReader srr1 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader srr2 = new CSVSequenceRecordReader(1, ","); PairSequenceRecordReaderBytesFunction psrbf = new PairSequenceRecordReaderBytesFunction(srr1, srr2); JavaRDD<Tuple2<List<List<Writable>>, List<List<Writable>>>> writables = fromSeq.map(psrbf); //Map to DataSet: DataVecSequencePairDataSetFunction pairFn = new DataVecSequencePairDataSetFunction(); JavaRDD<DataSet> data = writables.map(pairFn); List<DataSet> sparkData = data.collect(); //Now: do the same thing locally (SequenceRecordReaderDataSetIterator) and compare String featuresPath = FilenameUtils.concat(f.getAbsolutePath(), "csvsequence_%d.txt"); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true); List<DataSet> localData = new ArrayList<>(3); while (iter.hasNext()) localData.add(iter.next()); assertEquals(3, sparkData.size()); assertEquals(3, localData.size()); for (int i = 0; i < 3; i++) { //Check shapes etc. data sets order may differ for spark vs. local DataSet dsSpark = sparkData.get(i); DataSet dsLocal = localData.get(i); assertNull(dsSpark.getFeaturesMaskArray()); assertNull(dsSpark.getLabelsMaskArray()); INDArray fSpark = dsSpark.getFeatures(); INDArray fLocal = dsLocal.getFeatures(); INDArray lSpark = dsSpark.getLabels(); INDArray lLocal = dsLocal.getLabels(); val s = new long[] {1, 3, 4}; //1 example, 3 values, 3 time steps assertArrayEquals(s, fSpark.shape()); assertArrayEquals(s, fLocal.shape()); assertArrayEquals(s, lSpark.shape()); assertArrayEquals(s, lLocal.shape()); } //Check that results are the same (order not withstanding) boolean[] found = new boolean[3]; for (int i = 0; i < 3; i++) { int foundIndex = -1; DataSet ds = sparkData.get(i); for (int j = 0; j < 3; j++) { if (ds.equals(localData.get(j))) { if (foundIndex != -1) fail(); //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen) foundIndex = j; if (found[foundIndex]) fail(); //One of the other spark values was equal to this one -> suggests duplicates in Spark list found[foundIndex] = true; //mark this one as seen before } } } int count = 0; for (boolean b : found) if (b) count++; assertEquals(3, count); //Expect all 3 and exactly 3 pairwise matches between spark and local versions }
Example #20
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testSequenceRecordReaderMultiRegression() throws Exception { File rootDir = temporaryFolder.newFolder(); //need to manually extract for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i))); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); SequenceRecordReader reader = new CSVSequenceRecordReader(1, ","); reader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(reader, 1, 2, 1, true); assertEquals(1, iter.inputColumns()); assertEquals(2, iter.totalOutcomes()); List<DataSet> dsList = new ArrayList<>(); while (iter.hasNext()) { dsList.add(iter.next()); } assertEquals(3, dsList.size()); //3 files for (int i = 0; i < 3; i++) { DataSet ds = dsList.get(i); INDArray features = ds.getFeatures(); INDArray labels = ds.getLabels(); assertArrayEquals(new long[] {1, 1, 4}, features.shape()); //1 examples, 1 values, 4 time steps assertArrayEquals(new long[] {1, 2, 4}, labels.shape()); INDArray f2d = features.get(point(0), all(), all()).transpose(); INDArray l2d = labels.get(point(0), all(), all()).transpose(); switch (i){ case 0: assertEquals(Nd4j.create(new double[]{0,10,20,30}, new int[]{4,1}).castTo(DataType.FLOAT), f2d); assertEquals(Nd4j.create(new double[][]{{1,2}, {11,12}, {21,22}, {31,32}}).castTo(DataType.FLOAT), l2d); break; case 1: assertEquals(Nd4j.create(new double[]{100,110,120,130}, new int[]{4,1}).castTo(DataType.FLOAT), f2d); assertEquals(Nd4j.create(new double[][]{{101,102}, {111,112}, {121,122}, {131,132}}).castTo(DataType.FLOAT), l2d); break; case 2: assertEquals(Nd4j.create(new double[]{200,210,220,230}, new int[]{4,1}).castTo(DataType.FLOAT), f2d); assertEquals(Nd4j.create(new double[][]{{201,202}, {211,212}, {221,222}, {231,232}}).castTo(DataType.FLOAT), l2d); break; default: throw new RuntimeException(); } } iter.reset(); int count = 0; while (iter.hasNext()) { iter.next(); count++; } assertEquals(3, count); }
Example #21
Source File: UciSequenceDataFetcher.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public CSVSequenceRecordReader getRecordReader(long rngSeed, int[] shape, DataSetType set, ImageTransform transform) { return getRecordReader(rngSeed, set); }
Example #22
Source File: RecordReaderMultiDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testVariableLengthTSMeta() throws Exception { //need to manually extract File rootDir = temporaryFolder.newFolder(); for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); } //Set up SequenceRecordReaderDataSetIterators for comparison String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt"); //Set up SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ","); featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ","); featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1) .addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in") .addOutputOneHot("out", 0, 4) .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build(); RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1) .addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in") .addOutputOneHot("out", 0, 4) .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build(); rrmdsiStart.setCollectMetaData(true); rrmdsiEnd.setCollectMetaData(true); int count = 0; while (rrmdsiStart.hasNext()) { MultiDataSet mdsStart = rrmdsiStart.next(); MultiDataSet mdsEnd = rrmdsiEnd.next(); MultiDataSet mdsStartFromMeta = rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class)); MultiDataSet mdsEndFromMeta = rrmdsiEnd.loadFromMetaData(mdsEnd.getExampleMetaData(RecordMetaData.class)); assertEquals(mdsStart, mdsStartFromMeta); assertEquals(mdsEnd, mdsEndFromMeta); count++; } assertFalse(rrmdsiStart.hasNext()); assertFalse(rrmdsiEnd.hasNext()); assertEquals(3, count); }
Example #23
Source File: RecordReaderMultiDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testVariableLengthTS() throws Exception { //need to manually extract File rootDir = temporaryFolder.newFolder(); for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt"); //Set up SequenceRecordReaderDataSetIterators for comparison SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START); SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2, labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); //Set up SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ","); featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ","); featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1) .addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in") .addOutputOneHot("out", 0, 4) .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build(); RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1) .addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in") .addOutputOneHot("out", 0, 4) .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build(); while (iterAlignStart.hasNext()) { DataSet dsStart = iterAlignStart.next(); DataSet dsEnd = iterAlignEnd.next(); MultiDataSet mdsStart = rrmdsiStart.next(); MultiDataSet mdsEnd = rrmdsiEnd.next(); assertEquals(1, mdsStart.getFeatures().length); assertEquals(1, mdsStart.getLabels().length); //assertEquals(1, mdsStart.getFeaturesMaskArrays().length); //Features data is always longer -> don't need mask arrays for it assertEquals(1, mdsStart.getLabelsMaskArrays().length); assertEquals(1, mdsEnd.getFeatures().length); assertEquals(1, mdsEnd.getLabels().length); //assertEquals(1, mdsEnd.getFeaturesMaskArrays().length); assertEquals(1, mdsEnd.getLabelsMaskArrays().length); assertEquals(dsStart.getFeatures(), mdsStart.getFeatures(0)); assertEquals(dsStart.getLabels(), mdsStart.getLabels(0)); assertEquals(dsStart.getLabelsMaskArray(), mdsStart.getLabelsMaskArray(0)); assertEquals(dsEnd.getFeatures(), mdsEnd.getFeatures(0)); assertEquals(dsEnd.getLabels(), mdsEnd.getLabels(0)); assertEquals(dsEnd.getLabelsMaskArray(), mdsEnd.getLabelsMaskArray(0)); } assertFalse(rrmdsiStart.hasNext()); assertFalse(rrmdsiEnd.hasNext()); }
Example #24
Source File: RecordReaderMultiDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testSplittingCSVSequence() throws Exception { //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt" // as standard one-hot output //need to manually extract File rootDir = temporaryFolder.newFolder(); for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1) .addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2) .addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build(); while (iter.hasNext()) { DataSet ds = iter.next(); INDArray fds = ds.getFeatures(); INDArray lds = ds.getLabels(); MultiDataSet mds = srrmdsi.next(); assertEquals(2, mds.getFeatures().length); assertEquals(1, mds.getLabels().length); assertNull(mds.getFeaturesMaskArrays()); assertNull(mds.getLabelsMaskArrays()); INDArray[] fmds = mds.getFeatures(); INDArray[] lmds = mds.getLabels(); assertNotNull(fmds); assertNotNull(lmds); for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]); for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]); INDArray expIn1 = fds.get(all(), NDArrayIndex.interval(0, 1, true), all()); INDArray expIn2 = fds.get(all(), NDArrayIndex.interval(2, 2, true), all()); assertEquals(expIn1, fmds[0]); assertEquals(expIn2, fmds[1]); assertEquals(lds, lmds[0]); } assertFalse(srrmdsi.hasNext()); }
Example #25
Source File: LstmTimeSeriesExample.java From Java-Deep-Learning-Cookbook with MIT License | 4 votes |
public static void main(String[] args) throws IOException, InterruptedException { if(FEATURE_DIR.equals("{PATH-TO-PHYSIONET-FEATURES}") || LABEL_DIR.equals("{PATH-TO-PHYSIONET-LABELS")){ System.out.println("Please provide proper directory path in place of: PATH-TO-PHYSIONET-FEATURES && PATH-TO-PHYSIONET-LABELS"); throw new FileNotFoundException(); } SequenceRecordReader trainFeaturesReader = new CSVSequenceRecordReader(1, ","); trainFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",0,3199)); SequenceRecordReader trainLabelsReader = new CSVSequenceRecordReader(); trainLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",0,3199)); DataSetIterator trainDataSetIterator = new SequenceRecordReaderDataSetIterator(trainFeaturesReader,trainLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); SequenceRecordReader testFeaturesReader = new CSVSequenceRecordReader(1, ","); testFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",3200,3999)); SequenceRecordReader testLabelsReader = new CSVSequenceRecordReader(); testLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",3200,3999)); DataSetIterator testDataSetIterator = new SequenceRecordReaderDataSetIterator(testFeaturesReader,testLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder() .seed(RANDOM_SEED) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER) .updater(new Adam()) .dropOut(0.9) .graphBuilder() .addInputs("trainFeatures") .setOutputs("predictMortality") .addLayer("L1", new LSTM.Builder() .nIn(86) .nOut(200) .forgetGateBiasInit(1) .activation(Activation.TANH) .build(),"trainFeatures") .addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) .nIn(200).nOut(2).build(),"L1") .build(); ComputationGraph model = new ComputationGraph(configuration); for(int i=0;i<1;i++){ model.fit(trainDataSetIterator); trainDataSetIterator.reset(); } ROC evaluation = new ROC(100); while (testDataSetIterator.hasNext()) { DataSet batch = testDataSetIterator.next(); INDArray[] output = model.output(batch.getFeatures()); evaluation.evalTimeSeries(batch.getLabels(), output[0]); } System.out.println(evaluation.calculateAUC()); System.out.println(evaluation.stats()); }
Example #26
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testSequenceRecordReader() throws Exception { File rootDir = temporaryFolder.newFolder(); //need to manually extract for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i))); FileUtils.copyFile(Resources.asFile(String.format("csvsequencelabels_%d.txt", i)), new File(rootDir, String.format("csvsequencelabels_%d.txt", i))); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); assertEquals(3, iter.inputColumns()); assertEquals(4, iter.totalOutcomes()); List<DataSet> dsList = new ArrayList<>(); while (iter.hasNext()) { dsList.add(iter.next()); } assertEquals(3, dsList.size()); //3 files for (int i = 0; i < 3; i++) { DataSet ds = dsList.get(i); INDArray features = ds.getFeatures(); INDArray labels = ds.getLabels(); assertEquals(1, features.size(0)); //1 example in mini-batch assertEquals(1, labels.size(0)); assertEquals(3, features.size(1)); //3 values per line/time step assertEquals(4, labels.size(1)); //1 value per line, but 4 possible values -> one-hot vector assertEquals(4, features.size(2)); //sequence length = 4 assertEquals(4, labels.size(2)); } //Check features vs. expected: INDArray expF0 = Nd4j.create(1, 3, 4); expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 1, 2})); expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {10, 11, 12})); expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {20, 21, 22})); expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {30, 31, 32})); assertEquals(dsList.get(0).getFeatures(), expF0); INDArray expF1 = Nd4j.create(1, 3, 4); expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {100, 101, 102})); expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {110, 111, 112})); expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {120, 121, 122})); expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {130, 131, 132})); assertEquals(dsList.get(1).getFeatures(), expF1); INDArray expF2 = Nd4j.create(1, 3, 4); expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {200, 201, 202})); expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {210, 211, 212})); expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {220, 221, 222})); expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {230, 231, 232})); assertEquals(dsList.get(2).getFeatures(), expF2); //Check labels vs. expected: INDArray expL0 = Nd4j.create(1, 4, 4); expL0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {1, 0, 0, 0})); expL0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0})); expL0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0, 0, 1, 0})); expL0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {0, 0, 0, 1})); assertEquals(dsList.get(0).getLabels(), expL0); INDArray expL1 = Nd4j.create(1, 4, 4); expL1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 0, 0, 1})); expL1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {0, 0, 1, 0})); expL1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0})); expL1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {1, 0, 0, 0})); assertEquals(dsList.get(1).getLabels(), expL1); INDArray expL2 = Nd4j.create(1, 4, 4); expL2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0})); expL2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {1, 0, 0, 0})); expL2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0, 0, 0, 1})); expL2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {0, 0, 1, 0})); assertEquals(dsList.get(2).getLabels(), expL2); }
Example #27
Source File: TestStreamInputSplit.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testShuffle() throws Exception { File dir = testDir.newFolder(); File f1 = new File(dir, "file1.txt"); File f2 = new File(dir, "file2.txt"); File f3 = new File(dir, "file3.txt"); FileUtils.writeStringToFile(f1, "a,b,c", StandardCharsets.UTF_8); FileUtils.writeStringToFile(f2, "1,2,3", StandardCharsets.UTF_8); FileUtils.writeStringToFile(f3, "x,y,z", StandardCharsets.UTF_8); List<URI> uris = Arrays.asList(f1.toURI(), f2.toURI(), f3.toURI()); CSVSequenceRecordReader rr = new CSVSequenceRecordReader(); TestStreamFunction fn = new TestStreamFunction(); InputSplit is = new StreamInputSplit(uris, fn, new Random(12345)); rr.initialize(is); List<List<List<Writable>>> act = new ArrayList<>(); while (rr.hasNext()) { act.add(rr.sequenceRecord()); } rr.reset(); List<List<List<Writable>>> act2 = new ArrayList<>(); while (rr.hasNext()) { act2.add(rr.sequenceRecord()); } rr.reset(); List<List<List<Writable>>> act3 = new ArrayList<>(); while (rr.hasNext()) { act3.add(rr.sequenceRecord()); } assertEquals(3, act.size()); assertEquals(3, act2.size()); assertEquals(3, act3.size()); /* System.out.println(act); System.out.println("---------"); System.out.println(act2); System.out.println("---------"); System.out.println(act3); */ //Check not the same. With this RNG seed, results are different for first 3 resets assertNotEquals(act, act2); assertNotEquals(act2, act3); assertNotEquals(act, act3); }
Example #28
Source File: TestStreamInputSplit.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testCsvSequenceSimple() throws Exception { File dir = testDir.newFolder(); File f1 = new File(dir, "file1.txt"); File f2 = new File(dir, "file2.txt"); FileUtils.writeStringToFile(f1, "a,b,c\nd,e,f", StandardCharsets.UTF_8); FileUtils.writeStringToFile(f2, "1,2,3", StandardCharsets.UTF_8); List<URI> uris = Arrays.asList(f1.toURI(), f2.toURI()); CSVSequenceRecordReader rr = new CSVSequenceRecordReader(); TestStreamFunction fn = new TestStreamFunction(); InputSplit is = new StreamInputSplit(uris, fn); rr.initialize(is); List<List<List<Writable>>> exp = new ArrayList<>(); exp.add(Arrays.asList( Arrays.<Writable>asList(new Text("a"), new Text("b"), new Text("c")), Arrays.<Writable>asList(new Text("d"), new Text("e"), new Text("f")))); exp.add(Arrays.asList( Arrays.<Writable>asList(new Text("1"), new Text("2"), new Text("3")))); List<List<List<Writable>>> act = new ArrayList<>(); while (rr.hasNext()) { act.add(rr.sequenceRecord()); } assertEquals(exp, act); //Check that the specified stream loading function was used, not the default: assertEquals(uris, fn.calledWithUris); rr.reset(); int count = 0; while(rr.hasNext()) { count++; rr.sequenceRecord(); } assertEquals(2, count); }
Example #29
Source File: FileBatchRecordReaderTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testCsvSequence() throws Exception { //CSV sequence - 3 lines per file, 10 files File baseDir = testDir.newFolder(); List<File> fileList = new ArrayList<>(); for( int i=0; i<10; i++ ){ StringBuilder sb = new StringBuilder(); for( int j=0; j<3; j++ ){ if(j > 0) sb.append("\n"); sb.append("file_" + i + "," + i + "," + j); } File f = new File(baseDir, "origFile" + i + ".csv"); FileUtils.writeStringToFile(f, sb.toString(), StandardCharsets.UTF_8); fileList.add(f); } FileBatch fb = FileBatch.forFiles(fileList); SequenceRecordReader rr = new CSVSequenceRecordReader(); FileBatchSequenceRecordReader fbrr = new FileBatchSequenceRecordReader(rr, fb); for( int test=0; test<3; test++) { for (int i = 0; i < 10; i++) { assertTrue(fbrr.hasNext()); List<List<Writable>> next = fbrr.sequenceRecord(); assertEquals(3, next.size()); int count = 0; for(List<Writable> step : next ){ String s1 = "file_" + i; assertEquals(s1, step.get(0).toString()); assertEquals(String.valueOf(i), step.get(1).toString()); assertEquals(String.valueOf(count++), step.get(2).toString()); } } assertFalse(fbrr.hasNext()); assertTrue(fbrr.resetSupported()); fbrr.reset(); } }
Example #30
Source File: LstmTimeSeriesExample.java From Java-Deep-Learning-Cookbook with MIT License | 4 votes |
public static void main(String[] args) throws IOException, InterruptedException { if(FEATURE_DIR.equals("{PATH-TO-PHYSIONET-FEATURES}") || LABEL_DIR.equals("{PATH-TO-PHYSIONET-LABELS")){ System.out.println("Please provide proper directory path in place of: PATH-TO-PHYSIONET-FEATURES && PATH-TO-PHYSIONET-LABELS"); throw new FileNotFoundException(); } SequenceRecordReader trainFeaturesReader = new CSVSequenceRecordReader(1, ","); trainFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",0,3199)); SequenceRecordReader trainLabelsReader = new CSVSequenceRecordReader(); trainLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",0,3199)); DataSetIterator trainDataSetIterator = new SequenceRecordReaderDataSetIterator(trainFeaturesReader,trainLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); SequenceRecordReader testFeaturesReader = new CSVSequenceRecordReader(1, ","); testFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",3200,3999)); SequenceRecordReader testLabelsReader = new CSVSequenceRecordReader(); testLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",3200,3999)); DataSetIterator testDataSetIterator = new SequenceRecordReaderDataSetIterator(testFeaturesReader,testLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder() .seed(RANDOM_SEED) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER) .updater(new Adam()) .dropOut(0.9) .graphBuilder() .addInputs("trainFeatures") .setOutputs("predictMortality") .addLayer("L1", new LSTM.Builder() .nIn(86) .nOut(200) .forgetGateBiasInit(1) .activation(Activation.TANH) .build(),"trainFeatures") .addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) .nIn(200).nOut(2).build(),"L1") .build(); ComputationGraph model = new ComputationGraph(configuration); for(int i=0;i<1;i++){ model.fit(trainDataSetIterator); trainDataSetIterator.reset(); } ROC evaluation = new ROC(100); while (testDataSetIterator.hasNext()) { DataSet batch = testDataSetIterator.next(); INDArray[] output = model.output(batch.getFeatures()); evaluation.evalTimeSeries(batch.getLabels(), output[0]); } System.out.println(evaluation.calculateAUC()); System.out.println(evaluation.stats()); }