org.datavec.api.split.NumberedFileInputSplit Java Examples
The following examples show how to use
org.datavec.api.split.NumberedFileInputSplit.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SameDiffRNNTestCases.java From deeplearning4j with Apache License 2.0 | 6 votes |
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception { int miniBatchSize = 10; int numLabelClasses = 6; File featuresDirTrain = Files.createTempDir(); File labelsDirTrain = Files.createTempDir(); Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/features/", featuresDirTrain); Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/labels/", labelsDirTrain); SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData); return iter; }
Example #2
Source File: JacksonRecordReaderTest.java From DataVec with Apache License 2.0 | 6 votes |
@Test public void testReadingJson() throws Exception { //Load 3 values from 3 JSON files //stricture: a:value, b:value, c:x:value, c:y:value //And we want to load only a:value, b:value and c:x:value //For first JSON file: all values are present //For second JSON file: b:value is missing //For third JSON file: c:x:value is missing ClassPathResource cpr = new ClassPathResource("json/json_test_0.txt"); String path = cpr.getFile().getAbsolutePath().replace("0", "%d"); InputSplit is = new NumberedFileInputSplit(path, 0, 2); RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory())); rr.initialize(is); testJacksonRecordReader(rr); }
Example #3
Source File: RNNTestCases.java From deeplearning4j with Apache License 2.0 | 6 votes |
protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception { int miniBatchSize = 10; int numLabelClasses = 6; File featuresDirTrain = Files.createTempDir(); File labelsDirTrain = Files.createTempDir(); new ClassPathResource("dl4j-integration-tests/data/uci_seq/train/features/").copyDirectory(featuresDirTrain); new ClassPathResource("dl4j-integration-tests/data/uci_seq/train/labels/").copyDirectory(labelsDirTrain); SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData); return iter; }
Example #4
Source File: CSVSequenceRecordReaderTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testCsvSeqAndNumberedFileSplit() throws Exception { File baseDir = tempDir.newFolder(); //Simple sanity check unit test for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir); } //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator ClassPathResource resource = new ClassPathResource("csvsequence_0.txt"); String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath(); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); while(featureReader.hasNext()){ featureReader.nextSequence(); } }
Example #5
Source File: CSVSequenceRecordReaderTest.java From DataVec with Apache License 2.0 | 6 votes |
@Test public void testCsvSeqAndNumberedFileSplit() throws Exception { File baseDir = tempDir.newFolder(); //Simple sanity check unit test for (int i = 0; i < 3; i++) { new org.nd4j.linalg.io.ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir); } //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator org.nd4j.linalg.io.ClassPathResource resource = new org.nd4j.linalg.io.ClassPathResource("csvsequence_0.txt"); String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath(); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); while(featureReader.hasNext()){ featureReader.nextSequence(); } }
Example #6
Source File: JacksonRecordReaderTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testReadingJson() throws Exception { //Load 3 values from 3 JSON files //stricture: a:value, b:value, c:x:value, c:y:value //And we want to load only a:value, b:value and c:x:value //For first JSON file: all values are present //For second JSON file: b:value is missing //For third JSON file: c:x:value is missing ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); File f = testDir.newFolder(); cpr.copyDirectory(f); String path = new File(f, "json_test_%d.txt").getAbsolutePath(); InputSplit is = new NumberedFileInputSplit(path, 0, 2); RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory())); rr.initialize(is); testJacksonRecordReader(rr); }
Example #7
Source File: JacksonRecordReaderTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testReadingYaml() throws Exception { //Exact same information as JSON format, but in YAML format ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/"); File f = testDir.newFolder(); cpr.copyDirectory(f); String path = new File(f, "yaml_test_%d.txt").getAbsolutePath(); InputSplit is = new NumberedFileInputSplit(path, 0, 2); RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory())); rr.initialize(is); testJacksonRecordReader(rr); }
Example #8
Source File: JacksonRecordReaderTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testReadingXml() throws Exception { //Exact same information as JSON format, but in XML format ClassPathResource cpr = new ClassPathResource("datavec-api/xml/"); File f = testDir.newFolder(); cpr.copyDirectory(f); String path = new File(f, "xml_test_%d.txt").getAbsolutePath(); InputSplit is = new NumberedFileInputSplit(path, 0, 2); RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory())); rr.initialize(is); testJacksonRecordReader(rr); }
Example #9
Source File: JacksonRecordReaderTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testAppendingLabelsMetaData() throws Exception { ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); File f = testDir.newFolder(); cpr.copyDirectory(f); String path = new File(f, "json_test_%d.txt").getAbsolutePath(); InputSplit is = new NumberedFileInputSplit(path, 0, 2); //Insert at the end: RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen()); rr.initialize(is); List<List<Writable>> out = new ArrayList<>(); while (rr.hasNext()) { out.add(rr.next()); } assertEquals(3, out.size()); rr.reset(); List<List<Writable>> out2 = new ArrayList<>(); List<Record> outRecord = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>(); while (rr.hasNext()) { Record r = rr.nextRecord(); out2.add(r.getRecord()); outRecord.add(r); meta.add(r.getMetaData()); } assertEquals(out, out2); List<Record> fromMeta = rr.loadFromMetaData(meta); assertEquals(outRecord, fromMeta); }
Example #10
Source File: SameDiffRNNTestCases.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public MultiDataSetIterator getEvaluationTestData() throws Exception { int miniBatchSize = 10; int numLabelClasses = 6; // File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile(); // File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile(); File featuresDirTest = Files.createTempDir(); File labelsDirTest = Files.createTempDir(); Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest); Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest); SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData); MultiDataSetPreProcessor pp = multiDataSet -> { INDArray l = multiDataSet.getLabels(0); l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1)); multiDataSet.setLabels(0, l); multiDataSet.setLabelsMaskArray(0, null); }; iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp)); return iter; }
Example #11
Source File: RNNTestCases.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public MultiDataSetIterator getEvaluationTestData() throws Exception { int miniBatchSize = 10; int numLabelClasses = 6; // File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile(); // File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile(); File featuresDirTest = Files.createTempDir(); File labelsDirTest = Files.createTempDir(); new ClassPathResource("dl4j-integration-tests/data/uci_seq/test/features/").copyDirectory(featuresDirTest); new ClassPathResource("dl4j-integration-tests/data/uci_seq/test/labels/").copyDirectory(labelsDirTest); SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData); MultiDataSetPreProcessor pp = multiDataSet -> { INDArray l = multiDataSet.getLabels(0); l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2)-1)); multiDataSet.setLabels(0, l); multiDataSet.setLabelsMaskArray(0, null); }; iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(),pp)); return iter; }
Example #12
Source File: RecordReaderMultiDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testSplittingCSVSequenceMeta() throws Exception { //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt" // as standard one-hot output //need to manually extract File rootDir = temporaryFolder.newFolder(); for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1) .addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2) .addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build(); srrmdsi.setCollectMetaData(true); int count = 0; while (srrmdsi.hasNext()) { MultiDataSet mds = srrmdsi.next(); MultiDataSet fromMeta = srrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class)); assertEquals(mds, fromMeta); count++; } assertEquals(3, count); }
Example #13
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testSequenceRecordReaderReset() throws Exception { File rootDir = temporaryFolder.newFolder(); //need to manually extract for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i))); FileUtils.copyFile(Resources.asFile(String.format("csvsequencelabels_%d.txt", i)), new File(rootDir, String.format("csvsequencelabels_%d.txt", i))); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); assertEquals(3, iter.inputColumns()); assertEquals(4, iter.totalOutcomes()); int nResets = 5; for (int i = 0; i < nResets; i++) { iter.reset(); int count = 0; while (iter.hasNext()) { DataSet ds = iter.next(); INDArray features = ds.getFeatures(); INDArray labels = ds.getLabels(); assertArrayEquals(new long[] {1, 3, 4}, features.shape()); assertArrayEquals(new long[] {1, 4, 4}, labels.shape()); count++; } assertEquals(3, count); } }
Example #14
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testSequenceRecordReaderMeta() throws Exception { File rootDir = temporaryFolder.newFolder(); //need to manually extract for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i))); FileUtils.copyFile(Resources.asFile(String.format("csvsequencelabels_%d.txt", i)), new File(rootDir, String.format("csvsequencelabels_%d.txt", i))); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); iter.setCollectMetaData(true); assertEquals(3, iter.inputColumns()); assertEquals(4, iter.totalOutcomes()); while (iter.hasNext()) { DataSet ds = iter.next(); List<RecordMetaData> meta = ds.getExampleMetaData(RecordMetaData.class); DataSet fromMeta = iter.loadFromMetaData(meta); assertEquals(ds, fromMeta); } }
Example #15
Source File: RegexRecordReaderTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testRegexSequenceRecordReaderMeta() throws Exception { String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/"); File f = testDir.newFolder(); cpr.copyDirectory(f); String path = new File(f, "logtestfile%d.txt").getAbsolutePath(); InputSplit is = new NumberedFileInputSplit(path, 0, 1); SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1); rr.initialize(is); List<List<List<Writable>>> out = new ArrayList<>(); while (rr.hasNext()) { out.add(rr.sequenceRecord()); } assertEquals(2, out.size()); List<List<List<Writable>>> out2 = new ArrayList<>(); List<SequenceRecord> out3 = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>(); rr.reset(); while (rr.hasNext()) { SequenceRecord seqr = rr.nextSequence(); out2.add(seqr.getSequenceRecord()); out3.add(seqr); meta.add(seqr.getMetaData()); } List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta); assertEquals(out, out2); assertEquals(out3, fromMeta); }
Example #16
Source File: RegexRecordReaderTest.java From DataVec with Apache License 2.0 | 5 votes |
@Test public void testRegexSequenceRecordReaderMeta() throws Exception { String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; String path = new ClassPathResource("/logtestdata/logtestfile0.txt").getFile().toURI().toString(); path = path.replace("0", "%d"); InputSplit is = new NumberedFileInputSplit(path, 0, 1); SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1); rr.initialize(is); List<List<List<Writable>>> out = new ArrayList<>(); while (rr.hasNext()) { out.add(rr.sequenceRecord()); } assertEquals(2, out.size()); List<List<List<Writable>>> out2 = new ArrayList<>(); List<SequenceRecord> out3 = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>(); rr.reset(); while (rr.hasNext()) { SequenceRecord seqr = rr.nextSequence(); out2.add(seqr.getSequenceRecord()); out3.add(seqr); meta.add(seqr.getMetaData()); } List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta); assertEquals(out, out2); assertEquals(out3, fromMeta); }
Example #17
Source File: RegexRecordReaderTest.java From DataVec with Apache License 2.0 | 5 votes |
@Test public void testRegexSequenceRecordReader() throws Exception { String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; String path = new ClassPathResource("/logtestdata/logtestfile0.txt").getFile().toURI().toString(); path = path.replace("0", "%d"); InputSplit is = new NumberedFileInputSplit(path, 0, 1); SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1); rr.initialize(is); List<List<Writable>> exp0 = new ArrayList<>(); exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!"))); exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!"))); exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!"))); List<List<Writable>> exp1 = new ArrayList<>(); exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), new Text("First entry message!"))); exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), new Text("Second entry message!"))); exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"), new Text("Third entry message!"))); assertEquals(exp0, rr.sequenceRecord()); assertEquals(exp1, rr.sequenceRecord()); assertFalse(rr.hasNext()); //Test resetting: rr.reset(); assertEquals(exp0, rr.sequenceRecord()); assertEquals(exp1, rr.sequenceRecord()); assertFalse(rr.hasNext()); }
Example #18
Source File: JacksonRecordReaderTest.java From DataVec with Apache License 2.0 | 5 votes |
@Test public void testAppendingLabelsMetaData() throws Exception { ClassPathResource cpr = new ClassPathResource("json/json_test_0.txt"); String path = cpr.getFile().getAbsolutePath().replace("0", "%d"); InputSplit is = new NumberedFileInputSplit(path, 0, 2); //Insert at the end: RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen()); rr.initialize(is); List<List<Writable>> out = new ArrayList<>(); while (rr.hasNext()) { out.add(rr.next()); } assertEquals(3, out.size()); rr.reset(); List<List<Writable>> out2 = new ArrayList<>(); List<Record> outRecord = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>(); while (rr.hasNext()) { Record r = rr.nextRecord(); out2.add(r.getRecord()); outRecord.add(r); meta.add(r.getMetaData()); } assertEquals(out, out2); List<Record> fromMeta = rr.loadFromMetaData(meta); assertEquals(outRecord, fromMeta); }
Example #19
Source File: JacksonRecordReaderTest.java From DataVec with Apache License 2.0 | 5 votes |
@Test public void testReadingYaml() throws Exception { //Exact same information as JSON format, but in YAML format ClassPathResource cpr = new ClassPathResource("yaml/yaml_test_0.txt"); String path = cpr.getFile().getAbsolutePath().replace("0", "%d"); InputSplit is = new NumberedFileInputSplit(path, 0, 2); RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory())); rr.initialize(is); testJacksonRecordReader(rr); }
Example #20
Source File: JacksonRecordReaderTest.java From DataVec with Apache License 2.0 | 5 votes |
@Test public void testReadingXml() throws Exception { //Exact same information as JSON format, but in XML format ClassPathResource cpr = new ClassPathResource("xml/xml_test_0.txt"); String path = cpr.getFile().getAbsolutePath().replace("0", "%d"); InputSplit is = new NumberedFileInputSplit(path, 0, 2); RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory())); rr.initialize(is); testJacksonRecordReader(rr); }
Example #21
Source File: RecordReaderMultiDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testSplittingCSVSequence() throws Exception { //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt" // as standard one-hot output //need to manually extract File rootDir = temporaryFolder.newFolder(); for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1) .addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2) .addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build(); while (iter.hasNext()) { DataSet ds = iter.next(); INDArray fds = ds.getFeatures(); INDArray lds = ds.getLabels(); MultiDataSet mds = srrmdsi.next(); assertEquals(2, mds.getFeatures().length); assertEquals(1, mds.getLabels().length); assertNull(mds.getFeaturesMaskArrays()); assertNull(mds.getLabelsMaskArrays()); INDArray[] fmds = mds.getFeatures(); INDArray[] lmds = mds.getLabels(); assertNotNull(fmds); assertNotNull(lmds); for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]); for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]); INDArray expIn1 = fds.get(all(), NDArrayIndex.interval(0, 1, true), all()); INDArray expIn2 = fds.get(all(), NDArrayIndex.interval(2, 2, true), all()); assertEquals(expIn1, fmds[0]); assertEquals(expIn2, fmds[1]); assertEquals(lds, lmds[0]); } assertFalse(srrmdsi.hasNext()); }
Example #22
Source File: LstmTimeSeriesExample.java From Java-Deep-Learning-Cookbook with MIT License | 4 votes |
public static void main(String[] args) throws IOException, InterruptedException { if(FEATURE_DIR.equals("{PATH-TO-PHYSIONET-FEATURES}") || LABEL_DIR.equals("{PATH-TO-PHYSIONET-LABELS")){ System.out.println("Please provide proper directory path in place of: PATH-TO-PHYSIONET-FEATURES && PATH-TO-PHYSIONET-LABELS"); throw new FileNotFoundException(); } SequenceRecordReader trainFeaturesReader = new CSVSequenceRecordReader(1, ","); trainFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",0,3199)); SequenceRecordReader trainLabelsReader = new CSVSequenceRecordReader(); trainLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",0,3199)); DataSetIterator trainDataSetIterator = new SequenceRecordReaderDataSetIterator(trainFeaturesReader,trainLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); SequenceRecordReader testFeaturesReader = new CSVSequenceRecordReader(1, ","); testFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",3200,3999)); SequenceRecordReader testLabelsReader = new CSVSequenceRecordReader(); testLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",3200,3999)); DataSetIterator testDataSetIterator = new SequenceRecordReaderDataSetIterator(testFeaturesReader,testLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder() .seed(RANDOM_SEED) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER) .updater(new Adam()) .dropOut(0.9) .graphBuilder() .addInputs("trainFeatures") .setOutputs("predictMortality") .addLayer("L1", new LSTM.Builder() .nIn(86) .nOut(200) .forgetGateBiasInit(1) .activation(Activation.TANH) .build(),"trainFeatures") .addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) .nIn(200).nOut(2).build(),"L1") .build(); ComputationGraph model = new ComputationGraph(configuration); for(int i=0;i<1;i++){ model.fit(trainDataSetIterator); trainDataSetIterator.reset(); } ROC evaluation = new ROC(100); while (testDataSetIterator.hasNext()) { DataSet batch = testDataSetIterator.next(); INDArray[] output = model.output(batch.getFeatures()); evaluation.evalTimeSeries(batch.getLabels(), output[0]); } System.out.println(evaluation.calculateAUC()); System.out.println(evaluation.stats()); }
Example #23
Source File: NumberedFileInputSplitExample.java From Java-Deep-Learning-Cookbook with MIT License | 4 votes |
public static void main(String[] args) { NumberedFileInputSplit numberedFileInputSplit = new NumberedFileInputSplit("numberedfiles/file%d.txt",1,4); numberedFileInputSplit.locationsIterator().forEachRemaining(System.out::println); }
Example #24
Source File: LstmTimeSeriesExample.java From Java-Deep-Learning-Cookbook with MIT License | 4 votes |
public static void main(String[] args) throws IOException, InterruptedException { if(FEATURE_DIR.equals("{PATH-TO-PHYSIONET-FEATURES}") || LABEL_DIR.equals("{PATH-TO-PHYSIONET-LABELS")){ System.out.println("Please provide proper directory path in place of: PATH-TO-PHYSIONET-FEATURES && PATH-TO-PHYSIONET-LABELS"); throw new FileNotFoundException(); } SequenceRecordReader trainFeaturesReader = new CSVSequenceRecordReader(1, ","); trainFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",0,3199)); SequenceRecordReader trainLabelsReader = new CSVSequenceRecordReader(); trainLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",0,3199)); DataSetIterator trainDataSetIterator = new SequenceRecordReaderDataSetIterator(trainFeaturesReader,trainLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); SequenceRecordReader testFeaturesReader = new CSVSequenceRecordReader(1, ","); testFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",3200,3999)); SequenceRecordReader testLabelsReader = new CSVSequenceRecordReader(); testLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",3200,3999)); DataSetIterator testDataSetIterator = new SequenceRecordReaderDataSetIterator(testFeaturesReader,testLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder() .seed(RANDOM_SEED) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER) .updater(new Adam()) .dropOut(0.9) .graphBuilder() .addInputs("trainFeatures") .setOutputs("predictMortality") .addLayer("L1", new LSTM.Builder() .nIn(86) .nOut(200) .forgetGateBiasInit(1) .activation(Activation.TANH) .build(),"trainFeatures") .addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) .nIn(200).nOut(2).build(),"L1") .build(); ComputationGraph model = new ComputationGraph(configuration); for(int i=0;i<1;i++){ model.fit(trainDataSetIterator); trainDataSetIterator.reset(); } ROC evaluation = new ROC(100); while (testDataSetIterator.hasNext()) { DataSet batch = testDataSetIterator.next(); INDArray[] output = model.output(batch.getFeatures()); evaluation.evalTimeSeries(batch.getLabels(), output[0]); } System.out.println(evaluation.calculateAUC()); System.out.println(evaluation.stats()); }
Example #25
Source File: TestDataVecDataSetFunctions.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testDataVecSequencePairDataSetFunction() throws Exception { JavaSparkContext sc = getContext(); File f = testDir.newFolder(); ClassPathResource cpr = new ClassPathResource("dl4j-spark/csvsequence/"); cpr.copyDirectory(f); String path = f.getAbsolutePath() + "/*"; PathToKeyConverter pathConverter = new PathToKeyConverterFilename(); JavaPairRDD<Text, BytesPairWritable> toWrite = DataVecSparkUtil.combineFilesForSequenceFile(sc, path, path, pathConverter); Path p = testDir.newFolder("dl4j_testSeqPairFn").toPath(); p.toFile().deleteOnExit(); String outPath = p.toString() + "/out"; new File(outPath).deleteOnExit(); toWrite.saveAsNewAPIHadoopFile(outPath, Text.class, BytesPairWritable.class, SequenceFileOutputFormat.class); //Load from sequence file: JavaPairRDD<Text, BytesPairWritable> fromSeq = sc.sequenceFile(outPath, Text.class, BytesPairWritable.class); SequenceRecordReader srr1 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader srr2 = new CSVSequenceRecordReader(1, ","); PairSequenceRecordReaderBytesFunction psrbf = new PairSequenceRecordReaderBytesFunction(srr1, srr2); JavaRDD<Tuple2<List<List<Writable>>, List<List<Writable>>>> writables = fromSeq.map(psrbf); //Map to DataSet: DataVecSequencePairDataSetFunction pairFn = new DataVecSequencePairDataSetFunction(); JavaRDD<DataSet> data = writables.map(pairFn); List<DataSet> sparkData = data.collect(); //Now: do the same thing locally (SequenceRecordReaderDataSetIterator) and compare String featuresPath = FilenameUtils.concat(f.getAbsolutePath(), "csvsequence_%d.txt"); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true); List<DataSet> localData = new ArrayList<>(3); while (iter.hasNext()) localData.add(iter.next()); assertEquals(3, sparkData.size()); assertEquals(3, localData.size()); for (int i = 0; i < 3; i++) { //Check shapes etc. data sets order may differ for spark vs. local DataSet dsSpark = sparkData.get(i); DataSet dsLocal = localData.get(i); assertNull(dsSpark.getFeaturesMaskArray()); assertNull(dsSpark.getLabelsMaskArray()); INDArray fSpark = dsSpark.getFeatures(); INDArray fLocal = dsLocal.getFeatures(); INDArray lSpark = dsSpark.getLabels(); INDArray lLocal = dsLocal.getLabels(); val s = new long[] {1, 3, 4}; //1 example, 3 values, 3 time steps assertArrayEquals(s, fSpark.shape()); assertArrayEquals(s, fLocal.shape()); assertArrayEquals(s, lSpark.shape()); assertArrayEquals(s, lLocal.shape()); } //Check that results are the same (order not withstanding) boolean[] found = new boolean[3]; for (int i = 0; i < 3; i++) { int foundIndex = -1; DataSet ds = sparkData.get(i); for (int j = 0; j < 3; j++) { if (ds.equals(localData.get(j))) { if (foundIndex != -1) fail(); //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen) foundIndex = j; if (found[foundIndex]) fail(); //One of the other spark values was equal to this one -> suggests duplicates in Spark list found[foundIndex] = true; //mark this one as seen before } } } int count = 0; for (boolean b : found) if (b) count++; assertEquals(3, count); //Expect all 3 and exactly 3 pairwise matches between spark and local versions }
Example #26
Source File: RecordReaderMultiDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testVariableLengthTSMeta() throws Exception { //need to manually extract File rootDir = temporaryFolder.newFolder(); for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); } //Set up SequenceRecordReaderDataSetIterators for comparison String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt"); //Set up SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ","); featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ","); featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1) .addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in") .addOutputOneHot("out", 0, 4) .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build(); RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1) .addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in") .addOutputOneHot("out", 0, 4) .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build(); rrmdsiStart.setCollectMetaData(true); rrmdsiEnd.setCollectMetaData(true); int count = 0; while (rrmdsiStart.hasNext()) { MultiDataSet mdsStart = rrmdsiStart.next(); MultiDataSet mdsEnd = rrmdsiEnd.next(); MultiDataSet mdsStartFromMeta = rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class)); MultiDataSet mdsEndFromMeta = rrmdsiEnd.loadFromMetaData(mdsEnd.getExampleMetaData(RecordMetaData.class)); assertEquals(mdsStart, mdsStartFromMeta); assertEquals(mdsEnd, mdsEndFromMeta); count++; } assertFalse(rrmdsiStart.hasNext()); assertFalse(rrmdsiEnd.hasNext()); assertEquals(3, count); }
Example #27
Source File: RecordReaderMultiDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testVariableLengthTS() throws Exception { //need to manually extract File rootDir = temporaryFolder.newFolder(); for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt"); //Set up SequenceRecordReaderDataSetIterators for comparison SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START); SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2, labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); //Set up SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ","); featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ","); featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1) .addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in") .addOutputOneHot("out", 0, 4) .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build(); RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1) .addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in") .addOutputOneHot("out", 0, 4) .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build(); while (iterAlignStart.hasNext()) { DataSet dsStart = iterAlignStart.next(); DataSet dsEnd = iterAlignEnd.next(); MultiDataSet mdsStart = rrmdsiStart.next(); MultiDataSet mdsEnd = rrmdsiEnd.next(); assertEquals(1, mdsStart.getFeatures().length); assertEquals(1, mdsStart.getLabels().length); //assertEquals(1, mdsStart.getFeaturesMaskArrays().length); //Features data is always longer -> don't need mask arrays for it assertEquals(1, mdsStart.getLabelsMaskArrays().length); assertEquals(1, mdsEnd.getFeatures().length); assertEquals(1, mdsEnd.getLabels().length); //assertEquals(1, mdsEnd.getFeaturesMaskArrays().length); assertEquals(1, mdsEnd.getLabelsMaskArrays().length); assertEquals(dsStart.getFeatures(), mdsStart.getFeatures(0)); assertEquals(dsStart.getLabels(), mdsStart.getLabels(0)); assertEquals(dsStart.getLabelsMaskArray(), mdsStart.getLabelsMaskArray(0)); assertEquals(dsEnd.getFeatures(), mdsEnd.getFeatures(0)); assertEquals(dsEnd.getLabels(), mdsEnd.getLabels(0)); assertEquals(dsEnd.getLabelsMaskArray(), mdsEnd.getLabelsMaskArray(0)); } assertFalse(rrmdsiStart.hasNext()); assertFalse(rrmdsiEnd.hasNext()); }
Example #28
Source File: JacksonRecordReaderTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testAppendingLabels() throws Exception { ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); File f = testDir.newFolder(); cpr.copyDirectory(f); String path = new File(f, "json_test_%d.txt").getAbsolutePath(); InputSplit is = new NumberedFileInputSplit(path, 0, 2); //Insert at the end: RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen()); rr.initialize(is); List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"), new IntWritable(0)); assertEquals(exp0, rr.next()); List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"), new IntWritable(1)); assertEquals(exp1, rr.next()); List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"), new IntWritable(2)); assertEquals(exp2, rr.next()); //Insert at position 0: rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen(), 0); rr.initialize(is); exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")); assertEquals(exp0, rr.next()); exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")); assertEquals(exp1, rr.next()); exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")); assertEquals(exp2, rr.next()); }
Example #29
Source File: JacksonRecordReaderTest.java From DataVec with Apache License 2.0 | 4 votes |
@Test public void testAppendingLabels() throws Exception { ClassPathResource cpr = new ClassPathResource("json/json_test_0.txt"); String path = cpr.getFile().getAbsolutePath().replace("0", "%d"); InputSplit is = new NumberedFileInputSplit(path, 0, 2); //Insert at the end: RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen()); rr.initialize(is); List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"), new IntWritable(0)); assertEquals(exp0, rr.next()); List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"), new IntWritable(1)); assertEquals(exp1, rr.next()); List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"), new IntWritable(2)); assertEquals(exp2, rr.next()); //Insert at position 0: rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen(), 0); rr.initialize(is); exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")); assertEquals(exp0, rr.next()); exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")); assertEquals(exp1, rr.next()); exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")); assertEquals(exp2, rr.next()); }
Example #30
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testSequenceRecordReaderMultiRegression() throws Exception { File rootDir = temporaryFolder.newFolder(); //need to manually extract for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i))); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); SequenceRecordReader reader = new CSVSequenceRecordReader(1, ","); reader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(reader, 1, 2, 1, true); assertEquals(1, iter.inputColumns()); assertEquals(2, iter.totalOutcomes()); List<DataSet> dsList = new ArrayList<>(); while (iter.hasNext()) { dsList.add(iter.next()); } assertEquals(3, dsList.size()); //3 files for (int i = 0; i < 3; i++) { DataSet ds = dsList.get(i); INDArray features = ds.getFeatures(); INDArray labels = ds.getLabels(); assertArrayEquals(new long[] {1, 1, 4}, features.shape()); //1 examples, 1 values, 4 time steps assertArrayEquals(new long[] {1, 2, 4}, labels.shape()); INDArray f2d = features.get(point(0), all(), all()).transpose(); INDArray l2d = labels.get(point(0), all(), all()).transpose(); switch (i){ case 0: assertEquals(Nd4j.create(new double[]{0,10,20,30}, new int[]{4,1}).castTo(DataType.FLOAT), f2d); assertEquals(Nd4j.create(new double[][]{{1,2}, {11,12}, {21,22}, {31,32}}).castTo(DataType.FLOAT), l2d); break; case 1: assertEquals(Nd4j.create(new double[]{100,110,120,130}, new int[]{4,1}).castTo(DataType.FLOAT), f2d); assertEquals(Nd4j.create(new double[][]{{101,102}, {111,112}, {121,122}, {131,132}}).castTo(DataType.FLOAT), l2d); break; case 2: assertEquals(Nd4j.create(new double[]{200,210,220,230}, new int[]{4,1}).castTo(DataType.FLOAT), f2d); assertEquals(Nd4j.create(new double[][]{{201,202}, {211,212}, {221,222}, {231,232}}).castTo(DataType.FLOAT), l2d); break; default: throw new RuntimeException(); } } iter.reset(); int count = 0; while (iter.hasNext()) { iter.next(); count++; } assertEquals(3, count); }