org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator Java Examples
The following examples show how to use
org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: ImageInstanceIterator.java From wekaDeeplearning4j with GNU General Public License v3.0 | 7 votes |
/** * This method returns the iterator. Scales all intensity values: it divides them by 255. * * @param data the dataset to use * @param seed the seed for the random number generator * @param batchSize the batch size to use * @return the iterator */ @Override public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize) throws Exception { batchSize = Math.min(data.numInstances(), batchSize); validate(data); ImageRecordReader reader = getImageRecordReader(data); // Required for supporting channels-last models (currently only EfficientNet) if (getChannelsLast()) reader.setNchw_channels_first(false); final int labelIndex = 1; // Use explicit label index position final int numPossibleLabels = data.numClasses(); DataSetIterator tmpIter = new RecordReaderDataSetIterator(reader, batchSize, labelIndex, numPossibleLabels); DataNormalization scaler = new ImagePreProcessingScaler(0, 1); scaler.fit(tmpIter); tmpIter.setPreProcessor(scaler); return tmpIter; }
Example #2
Source File: IrisFileDataSource.java From FederatedAndroidTrainer with MIT License | 7 votes |
private void createDataSource() throws IOException, InterruptedException { //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing int numLinesToSkip = 0; String delimiter = ","; RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter); recordReader.initialize(new InputStreamInputSplit(dataFile)); //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2 DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses); DataSet allData = iterator.next(); allData.shuffle(); SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training trainingData = testAndTrain.getTrain(); testData = testAndTrain.getTest(); //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance): DataNormalization normalizer = new NormalizerStandardize(); normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data normalizer.transform(trainingData); //Apply normalization to the training data normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set }
Example #3
Source File: StringToDataSetExportFunction.java From deeplearning4j with Apache License 2.0 | 6 votes |
private void processBatchIfRequired(List<List<Writable>> list, boolean finalRecord) throws Exception { if (list.isEmpty()) return; if (list.size() < batchSize && !finalRecord) return; RecordReader rr = new CollectionRecordReader(list); RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(rr, null, batchSize, labelIndex, labelIndex, numPossibleLabels, -1, regression); DataSet ds = iter.next(); String filename = "dataset_" + uid + "_" + (outputCount++) + ".bin"; URI uri = new URI(outputDir.getPath() + "/" + filename); Configuration c = conf == null ? DefaultHadoopConfig.get() : conf.getValue().getConfiguration(); FileSystem file = FileSystem.get(uri, c); try (FSDataOutputStream out = file.create(new Path(uri))) { ds.save(out); } list.clear(); }
Example #4
Source File: DataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testMnist() throws Exception { ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt"); CSVRecordReader rr = new CSVRecordReader(0, ','); rr.initialize(new FileSplit(cpr.getTempFileFromArchive())); RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10); MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0); while (dsi.hasNext()) { DataSet dsExp = dsi.next(); DataSet dsAct = iter.next(); INDArray fExp = dsExp.getFeatures(); fExp.divi(255); INDArray lExp = dsExp.getLabels(); INDArray fAct = dsAct.getFeatures(); INDArray lAct = dsAct.getLabels(); assertEquals(fExp, fAct.castTo(fExp.dataType())); assertEquals(lExp, lAct.castTo(lExp.dataType())); } assertFalse(iter.hasNext()); }
Example #5
Source File: MultipleEpochsIteratorTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testLoadBatchDataSet() throws Exception { int epochs = 2; RecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile())); DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150, 4, 3); DataSet ds = iter.next(20); assertEquals(20, ds.getFeatures().size(0)); MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds); while (multiIter.hasNext()) { DataSet path = multiIter.next(10); assertNotNull(path); assertEquals(10, path.numExamples(), 0.0); } assertEquals(epochs, multiIter.epochs); }
Example #6
Source File: MultipleEpochsIteratorTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testLoadFullDataSet() throws Exception { int epochs = 3; RecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150); DataSet ds = iter.next(50); assertEquals(50, ds.getFeatures().size(0)); MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds); assertTrue(multiIter.hasNext()); int count = 0; while (multiIter.hasNext()) { DataSet path = multiIter.next(); assertNotNull(path); assertEquals(50, path.numExamples(), 0); count++; } assertEquals(epochs, count); assertEquals(epochs, multiIter.epochs); }
Example #7
Source File: MultipleEpochsIteratorTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testNextAndReset() throws Exception { int epochs = 3; RecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150); MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter); assertTrue(multiIter.hasNext()); while (multiIter.hasNext()) { DataSet path = multiIter.next(); assertFalse(path == null); } assertEquals(epochs, multiIter.epochs); }
Example #8
Source File: ConvolutionLayerSetupTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testLRN() throws Exception { List<String> labels = new ArrayList<>(Arrays.asList("Zico", "Ziwang_Xu")); File dir = testDir.newFolder(); new ClassPathResource("lfwtest/").copyDirectory(dir); String rootDir = dir.getAbsolutePath(); RecordReader reader = new ImageRecordReader(28, 28, 3); reader.initialize(new FileSplit(new File(rootDir))); DataSetIterator recordReader = new RecordReaderDataSetIterator(reader, 10, 1, labels.size()); labels.remove("lfwtest"); NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLRN(); builder.setInputType(InputType.convolutional(28, 28, 3)); MultiLayerConfiguration conf = builder.build(); ConvolutionLayer layer2 = (ConvolutionLayer) conf.getConf(3).getLayer(); assertEquals(6, layer2.getNIn()); }
Example #9
Source File: TestRecordReaders.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testClassIndexOutsideOfRangeRRDSI() { Collection<Collection<Writable>> c = new ArrayList<>(); c.add(Arrays.<Writable>asList(new DoubleWritable(0.5), new IntWritable(0))); c.add(Arrays.<Writable>asList(new DoubleWritable(1.0), new IntWritable(2))); CollectionRecordReader crr = new CollectionRecordReader(c); RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(crr, 2, 1, 2); try { DataSet ds = iter.next(); fail("Expected exception"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("to one-hot")); } }
Example #10
Source File: ImageInstanceIterator.java From wekaDeeplearning4j with GNU General Public License v3.0 | 6 votes |
/** * This method returns the iterator. Scales all intensity values: it divides them by 255. * * @param data the dataset to use * @param seed the seed for the random number generator * @param batchSize the batch size to use * @return the iterator */ @Override public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize) throws Exception { batchSize = Math.min(data.numInstances(), batchSize); validate(data); ImageRecordReader reader = getImageRecordReader(data); // Required for supporting channels-last models (currently only EfficientNet) if (getChannelsLast()) reader.setNchw_channels_first(false); final int labelIndex = 1; // Use explicit label index position final int numPossibleLabels = data.numClasses(); DataSetIterator tmpIter = new RecordReaderDataSetIterator(reader, batchSize, labelIndex, numPossibleLabels); DataNormalization scaler = new ImagePreProcessingScaler(0, 1); scaler.fit(tmpIter); tmpIter.setPreProcessor(scaler); return tmpIter; }
Example #11
Source File: DiabetesFileDataSource.java From FederatedAndroidTrainer with MIT License | 6 votes |
private void createDataSource() throws IOException, InterruptedException { //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing int numLinesToSkip = 0; String delimiter = ","; RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter); recordReader.initialize(new InputStreamInputSplit(dataFile)); //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network int labelIndex = 11; DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true); DataSet allData = iterator.next(); SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training trainingData = testAndTrain.getTrain(); testData = testAndTrain.getTest(); //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance): DataNormalization normalizer = new NormalizerStandardize(); normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data normalizer.transform(trainingData); //Apply normalization to the training data normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set }
Example #12
Source File: DataSetIteratorHelper.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
private static DataSetIteratorSplitter createDataSetSplitter() throws IOException, InterruptedException { final RecordReader recordReader = DataSetIteratorHelper.generateReader(new ClassPathResource("Churn_Modelling.csv").getFile()); final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(recordReader,batchSize) .classification(labelIndex,numClasses) .build(); final DataNormalization dataNormalization = new NormalizerStandardize(); dataNormalization.fit(dataSetIterator); dataSetIterator.setPreProcessor(dataNormalization); final DataSetIteratorSplitter dataSetIteratorSplitter = new DataSetIteratorSplitter(dataSetIterator,1250,0.8); return dataSetIteratorSplitter; }
Example #13
Source File: RecordReaderFileBatchLoader.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public DataSet load(Source source) throws IOException { FileBatch fb = FileBatch.readFromZip(source.getInputStream()); //Wrap file batch in RecordReader //Create RecordReaderDataSetIterator //Return dataset RecordReader rr = new FileBatchRecordReader(recordReader, fb); RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(rr, null, batchSize, labelIndexFrom, labelIndexTo, numPossibleLabels, -1, regression); if (preProcessor != null) { iter.setPreProcessor(preProcessor); } DataSet ds = iter.next(); return ds; }
Example #14
Source File: ModelGenerator.java From arabic-characters-recognition with Apache License 2.0 | 5 votes |
private static DataSetIterator readCSVDataset(String csvFileClasspath, int BATCH_SIZE, int LABEL_INDEX, int numClasses) throws IOException, InterruptedException { RecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(new File(csvFileClasspath))); DataSetIterator iterator = new RecordReaderDataSetIterator(rr, BATCH_SIZE, LABEL_INDEX, numClasses); return iterator; }
Example #15
Source File: DL4JMLModel.java From neo4j-ml-procedures with Apache License 2.0 | 5 votes |
@Override protected Object doPredict(List<String> line) { try { ListStringSplit input = new ListStringSplit(Collections.singletonList(line)); ListStringRecordReader rr = new ListStringRecordReader(); rr.initialize(input); DataSetIterator iterator = new RecordReaderDataSetIterator(rr, 1); DataSet ds = iterator.next(); INDArray prediction = model.output(ds.getFeatures()); DataType outputType = types.get(this.output); switch (outputType) { case _float : return prediction.getDouble(0); case _class: { int numClasses = 2; double max = 0; int maxIndex = -1; for (int i=0;i<numClasses;i++) { if (prediction.getDouble(i) > max) {maxIndex = i; max = prediction.getDouble(i);} } return maxIndex; // return prediction.getInt(0,1); // numberOfClasses } default: throw new IllegalArgumentException("Output type not yet supported "+outputType); } } catch (Exception e) { throw new RuntimeException(e); } }
Example #16
Source File: DataStorage.java From Java-Machine-Learning-for-Computer-Vision with MIT License | 5 votes |
default DataSetIterator getDataSetIterator(InputSplit sample) throws IOException { ImageRecordReader imageRecordReader = new ImageRecordReader(HEIGHT, WIDTH, CHANNELS, LABEL_GENERATOR_MAKER); imageRecordReader.initialize(sample); DataSetIterator iterator = new RecordReaderDataSetIterator(imageRecordReader, BATCH_SIZE, 1, NUM_POSSIBLE_LABELS); iterator.setPreProcessor(new VGG16ImagePreProcessor()); return iterator; }
Example #17
Source File: HyperParameterTuning.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
@Override public Object testData() { try{ DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses); return dataSplit(iterator).getTestIterator(); } catch(Exception e){ throw new RuntimeException(); } }
Example #18
Source File: ImageClassifierAPI.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
public static INDArray generateOutput(File inputFile, String modelFileLocation) throws IOException, InterruptedException { //retrieve the saved model final File modelFile = new File(modelFileLocation); final MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelFile); final RecordReader imageRecordReader = generateReader(inputFile); final ImagePreProcessingScaler normalizerStandardize = ModelSerializer.restoreNormalizerFromFile(modelFile); final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(imageRecordReader,1).build(); normalizerStandardize.fit(dataSetIterator); dataSetIterator.setPreProcessor(normalizerStandardize); return model.output(dataSetIterator); }
Example #19
Source File: CustomerRetentionPredictionApi.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
public static INDArray generateOutput(File inputFile, String modelFilePath) throws IOException, InterruptedException { final File modelFile = new File(modelFilePath); final MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(modelFile); final RecordReader recordReader = generateReader(inputFile); //final INDArray array = RecordConverter.toArray(recordReader.next()); final NormalizerStandardize normalizerStandardize = ModelSerializer.restoreNormalizerFromFile(modelFile); //normalizerStandardize.transform(array); final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(recordReader,1).build(); normalizerStandardize.fit(dataSetIterator); dataSetIterator.setPreProcessor(normalizerStandardize); return network.output(dataSetIterator); }
Example #20
Source File: HyperParameterTuningArbiterUiExample.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
@Override public Object trainData() { try{ DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses); return dataSplit(iterator).getTestIterator(); } catch(Exception e){ throw new RuntimeException(); } }
Example #21
Source File: HyperParameterTuningArbiterUiExample.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
@Override public Object testData() { try{ DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses); return dataSplit(iterator).getTestIterator(); } catch(Exception e){ throw new RuntimeException(); } }
Example #22
Source File: HyperParameterTuning.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
@Override public Object trainData() { try{ DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses); return dataSplit(iterator).getTestIterator(); } catch(Exception e){ throw new RuntimeException(); } }
Example #23
Source File: HyperParameterTuning.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
@Override public Object testData() { try{ DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses); return dataSplit(iterator).getTestIterator(); } catch(Exception e){ throw new RuntimeException(); } }
Example #24
Source File: ImageClassifierAPI.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
public static INDArray generateOutput(File inputFile, String modelFileLocation) throws IOException, InterruptedException { //retrieve the saved model final File modelFile = new File(modelFileLocation); final MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelFile); final RecordReader imageRecordReader = generateReader(inputFile); final ImagePreProcessingScaler normalizerStandardize = ModelSerializer.restoreNormalizerFromFile(modelFile); final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(imageRecordReader,1).build(); normalizerStandardize.fit(dataSetIterator); dataSetIterator.setPreProcessor(normalizerStandardize); return model.output(dataSetIterator); }
Example #25
Source File: DataSetIteratorHelper.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
private static DataSetIteratorSplitter createDataSetSplitter() throws IOException, InterruptedException { final RecordReader recordReader = DataSetIteratorHelper.generateReader(new ClassPathResource("Churn_Modelling.csv").getFile()); final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(recordReader,batchSize) .classification(labelIndex,numClasses) .build(); final DataNormalization dataNormalization = new NormalizerStandardize(); dataNormalization.fit(dataSetIterator); dataSetIterator.setPreProcessor(dataNormalization); final DataSetIteratorSplitter dataSetIteratorSplitter = new DataSetIteratorSplitter(dataSetIterator,1250,0.8); return dataSetIteratorSplitter; }
Example #26
Source File: CustomerRetentionPredictionApi.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
public static INDArray generateOutput(File inputFile, String modelFilePath) throws IOException, InterruptedException { final File modelFile = new File(modelFilePath); final MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(modelFile); final RecordReader recordReader = generateReader(inputFile); //final INDArray array = RecordConverter.toArray(recordReader.next()); final NormalizerStandardize normalizerStandardize = ModelSerializer.restoreNormalizerFromFile(modelFile); //normalizerStandardize.transform(array); final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(recordReader,1).build(); normalizerStandardize.fit(dataSetIterator); dataSetIterator.setPreProcessor(normalizerStandardize); return network.output(dataSetIterator); }
Example #27
Source File: HyperParameterTuningArbiterUiExample.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
@Override public Object trainData() { try{ DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses); return dataSplit(iterator).getTestIterator(); } catch(Exception e){ throw new RuntimeException(); } }
Example #28
Source File: HyperParameterTuningArbiterUiExample.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
@Override public Object testData() { try{ DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses); return dataSplit(iterator).getTestIterator(); } catch(Exception e){ throw new RuntimeException(); } }
Example #29
Source File: HyperParameterTuning.java From Java-Deep-Learning-Cookbook with MIT License | 5 votes |
@Override public Object trainData() { try{ DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses); return dataSplit(iterator).getTestIterator(); } catch(Exception e){ throw new RuntimeException(); } }
Example #30
Source File: ImageUtils.java From Java-Machine-Learning-for-Computer-Vision with MIT License | 5 votes |
public static DataSetIterator createDataSetIterator(File sample,int numLabels,int batchSize) throws IOException { ImageRecordReader imageRecordReader = new ImageRecordReader(HEIGHT, WIDTH, CHANNELS, LABEL_GENERATOR_MAKER); imageRecordReader.initialize(new FileSplit(sample)); DataSetIterator iterator = new RecordReaderDataSetIterator(imageRecordReader, batchSize, 1, numLabels); iterator.setPreProcessor(new CifarImagePreProcessor()); return iterator; }