Java Code Examples for org.nd4j.linalg.dataset.api.iterator.DataSetIterator#reset()
The following examples show how to use
org.nd4j.linalg.dataset.api.iterator.DataSetIterator#reset() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: AbstractDataSetNormalizer.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * Fit the given model * * @param iterator for the data to iterate over */ @Override public void fit(DataSetIterator iterator) { S.Builder featureNormBuilder = newBuilder(); S.Builder labelNormBuilder = newBuilder(); iterator.reset(); while (iterator.hasNext()) { DataSet next = iterator.next(); featureNormBuilder.addFeatures(next); if (fitLabels) { labelNormBuilder.addLabels(next); } } featureStats = (S) featureNormBuilder.build(); if (fitLabels) { labelStats = (S) labelNormBuilder.build(); } iterator.reset(); }
Example 2
Source File: CachingDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
private void assertCachingDataSetIteratorHasAllTheData(int rows, int inputColumns, int outputColumns, DataSet dataSet, DataSetIterator it, CachingDataSetIterator cachedIt) { cachedIt.reset(); it.reset(); dataSet.setFeatures(Nd4j.zeros(rows, inputColumns)); dataSet.setLabels(Nd4j.ones(rows, outputColumns)); while (it.hasNext()) { assertTrue(cachedIt.hasNext()); DataSet cachedDs = cachedIt.next(); assertEquals(1000.0, cachedDs.getFeatures().sumNumber()); assertEquals(0.0, cachedDs.getLabels().sumNumber()); DataSet ds = it.next(); assertEquals(0.0, ds.getFeatures().sumNumber()); assertEquals(20.0, ds.getLabels().sumNumber()); } assertFalse(cachedIt.hasNext()); assertFalse(it.hasNext()); }
Example 3
Source File: CachingDataSetIteratorTest.java From nd4j with Apache License 2.0 | 6 votes |
private void assertPreProcessingGetsCached(int expectedNumberOfDataSets, DataSetIterator it, CachingDataSetIterator cachedIt, PreProcessor preProcessor) { assertSame(preProcessor, cachedIt.getPreProcessor()); assertSame(preProcessor, it.getPreProcessor()); cachedIt.reset(); it.reset(); while (cachedIt.hasNext()) { cachedIt.next(); } assertEquals(expectedNumberOfDataSets, preProcessor.getCallCount()); cachedIt.reset(); it.reset(); while (cachedIt.hasNext()) { cachedIt.next(); } assertEquals(expectedNumberOfDataSets, preProcessor.getCallCount()); }
Example 4
Source File: TestAsyncIterator.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testInitializeNoNextIter() { DataSetIterator iter = new IrisDataSetIterator(10, 150); while (iter.hasNext()) iter.next(); DataSetIterator async = new AsyncDataSetIterator(iter, 2); assertFalse(iter.hasNext()); assertFalse(async.hasNext()); try { iter.next(); fail("Should have thrown NoSuchElementException"); } catch (Exception e) { //OK } async.reset(); int count = 0; while (async.hasNext()) { async.next(); count++; } assertEquals(150 / 10, count); }
Example 5
Source File: EarlyTerminationDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testCallstoNextNotAllowed() throws IOException { int terminateAfter = 1; DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples); EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); earlyEndIter.next(10); iter.reset(); exception.expect(RuntimeException.class); earlyEndIter.next(10); }
Example 6
Source File: TrainCifar10Model.java From Java-Machine-Learning-for-Computer-Vision with MIT License | 5 votes |
private void testResults(ComputationGraph cifar10, DataSetIterator testIterator, int iEpoch, String modelName) throws IOException { if (iEpoch % TEST_INTERVAL == 0) { Evaluation eval = cifar10.evaluate(testIterator); log.info(eval.stats()); testIterator.reset(); } // TestModels.TestResult test = TestModels.test(cifar10, modelName); // log.info("Test Results >> " + test); }
Example 7
Source File: MultiLayerTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testIterationCountAndPersistence() throws IOException { Nd4j.getRandom().setSeed(123); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) .activation(Activation.TANH).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) .build()) .build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); DataSetIterator iter = new IrisDataSetIterator(50, 150); assertEquals(0, network.getLayerWiseConfigurations().getIterationCount()); network.fit(iter); assertEquals(3, network.getLayerWiseConfigurations().getIterationCount()); iter.reset(); network.fit(iter); assertEquals(6, network.getLayerWiseConfigurations().getIterationCount()); iter.reset(); network.fit(iter.next()); assertEquals(7, network.getLayerWiseConfigurations().getIterationCount()); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(network, baos, true); byte[] asBytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(asBytes); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(bais, true); assertEquals(7, net.getLayerWiseConfigurations().getIterationCount()); }
Example 8
Source File: LinearModel.java From FederatedAndroidTrainer with MIT License | 5 votes |
@Override public void train(FederatedDataSet dataSource) { DataSet trainingData = (DataSet) dataSource.getNativeDataSet(); List<DataSet> listDs = trainingData.asList(); DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE); //Train the network on the full data set, and evaluate in periodically for (int i = 0; i < N_EPOCHS; i++) { iterator.reset(); mNetwork.fit(iterator); } }
Example 9
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testReadingFromStream() throws Exception { for(boolean b : new boolean[]{false, true}) { int batchSize = 1; int labelIndex = 4; int numClasses = 3; InputStream dataFile = Resources.asStream("iris.txt"); RecordReader recordReader = new CSVRecordReader(0, ','); recordReader.initialize(new InputStreamInputSplit(dataFile)); assertTrue(recordReader.hasNext()); assertFalse(recordReader.resetSupported()); DataSetIterator iterator; if(b){ iterator = new RecordReaderDataSetIterator.Builder(recordReader, batchSize) .classification(labelIndex, numClasses) .build(); } else { iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses); } assertFalse(iterator.resetSupported()); int count = 0; while (iterator.hasNext()) { assertNotNull(iterator.next()); count++; } assertEquals(150, count); try { iterator.reset(); fail("Expected exception"); } catch (Exception e) { //expected } } }
Example 10
Source File: MultiRegression.java From dl4j-tutorials with MIT License | 5 votes |
public static void main(String[] args){ //Generate the training data DataSetIterator iterator = getTrainingData(batchSize,rng); //Create the network int numInput = 2; int numOutputs = 1; MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder() .seed(seed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER) .updater(new Sgd(learningRate)) .list() .layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MSE) .activation(Activation.IDENTITY) .nIn(numInput).nOut(numOutputs).build()) .pretrain(false).backprop(true).build() ); net.init(); net.setListeners(new ScoreIterationListener(1)); for( int i=0; i<nEpochs; i++ ){ iterator.reset(); net.fit(iterator); } final INDArray input = Nd4j.create(new double[] { 0.111111, 0.3333333333333 }, new int[] { 1, 2 }); INDArray out = net.output(input, false); System.out.println(out); }
Example 11
Source File: MultiLayerTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testOutput() throws Exception { Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .weightInit(WeightInit.XAVIER).seed(12345L).list() .layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(50).nOut(10).build()) .setInputType(InputType.convolutional(28, 28, 1)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); DataSetIterator fullData = new MnistDataSetIterator(1, 2); net.fit(fullData); fullData.reset(); DataSet expectedSet = fullData.next(2); INDArray expectedOut = net.output(expectedSet.getFeatures(), false); fullData.reset(); INDArray actualOut = net.output(fullData); assertEquals(expectedOut, actualOut); }
Example 12
Source File: Dl4jMlpClassifier.java From wekaDeeplearning4j with GNU General Public License v3.0 | 5 votes |
/** * Get a peak at the features of the {@code iterator}'s first batch using the given instances. * * @return Features of the first batch */ protected INDArray getFirstBatchFeatures(Instances data) throws Exception { final DataSetIterator it = getDataSetIterator(data, CacheMode.NONE); if (!it.hasNext()) { throw new RuntimeException("Iterator was unexpectedly empty."); } final INDArray features = Utils.getNext(it).getFeatures(); it.reset(); return features; }
Example 13
Source File: ManualTests.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testCNNActivations2() throws Exception { int nChannels = 1; int outputNum = 10; int batchSize = 64; int nEpochs = 10; int seed = 123; log.info("Load data...."); DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345); DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345); log.info("Build model...."); MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) .l2(0.0005) .weightInit(WeightInit.XAVIER) .updater(new Nesterovs(0.01, 0.9)).list() .layer(0, new ConvolutionLayer.Builder(5, 5) //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied .nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()) .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) .stride(2, 2).build()) .layer(2, new ConvolutionLayer.Builder(5, 5) //Note that nIn needed be specified in later layers .stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()) .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) .stride(2, 2).build()) .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum).activation(Activation.SOFTMAX).build()) .setInputType(InputType.convolutional(28, 28, nChannels)); MultiLayerConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); /* ParallelWrapper wrapper = new ParallelWrapper.Builder(model) .averagingFrequency(1) .prefetchBuffer(12) .workers(2) .reportScoreAfterAveraging(false) .useLegacyAveraging(false) .build(); */ log.info("Train model...."); model.setListeners(new ConvolutionalIterationListener(1)); //((NativeOpExecutioner) Nd4j.getExecutioner()).getLoop().setOmpNumThreads(8); long timeX = System.currentTimeMillis(); // nEpochs = 2; for (int i = 0; i < nEpochs; i++) { long time1 = System.currentTimeMillis(); model.fit(mnistTrain); //wrapper.fit(mnistTrain); long time2 = System.currentTimeMillis(); log.info("*** Completed epoch {}, Time elapsed: {} ***", i, (time2 - time1)); } long timeY = System.currentTimeMillis(); log.info("Evaluate model...."); Evaluation eval = new Evaluation(outputNum); while (mnistTest.hasNext()) { DataSet ds = mnistTest.next(); INDArray output = model.output(ds.getFeatures(), false); eval.eval(ds.getLabels(), output); } log.info(eval.stats()); mnistTest.reset(); log.info("****************Example finished********************"); }
Example 14
Source File: EvaluationToolsTests.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testRocHtml() { DataSetIterator iter = new IrisDataSetIterator(150, 150); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); NormalizerStandardize ns = new NormalizerStandardize(); DataSet ds = iter.next(); ns.fit(ds); ns.transform(ds); INDArray newLabels = Nd4j.create(150, 2); newLabels.getColumn(0).assign(ds.getLabels().getColumn(0)); newLabels.getColumn(0).addi(ds.getLabels().getColumn(1)); newLabels.getColumn(1).assign(ds.getLabels().getColumn(2)); ds.setLabels(newLabels); for (int i = 0; i < 30; i++) { net.fit(ds); } for (int numSteps : new int[] {20, 0}) { ROC roc = new ROC(numSteps); iter.reset(); INDArray f = ds.getFeatures(); INDArray l = ds.getLabels(); INDArray out = net.output(f); roc.eval(l, out); String str = EvaluationTools.rocChartToHtml(roc); // System.out.println(str); } }
Example 15
Source File: RnnSequenceClassifier.java From wekaDeeplearning4j with GNU General Public License v3.0 | 4 votes |
/** * The method to use when making predictions for test instances. * * @param insts the instances to get predictions for * @return the class probability estimates (if the class is nominal) or the numeric predictions * (if it is numeric) * @throws Exception if something goes wrong at prediction time */ @Override public double[][] distributionsForInstances(Instances insts) throws Exception { log.info("Calc. dist for {} instances", insts.numInstances()); // Do we only have a ZeroR model? if (zeroR != null) { return zeroR.distributionsForInstances(insts); } // Process input data to have the same filters applied as the training data insts = applyFilters(insts); // Get predictions final DataSetIterator it = getDataSetIterator(insts, CacheMode.NONE); double[][] preds = new double[insts.numInstances()][insts.numClasses()]; if (it.resetSupported()) { it.reset(); } int offset = 0; boolean next = it.hasNext(); // Get predictions batch-wise while (next) { final DataSet ds = Utils.getNext(it); final INDArray features = ds.getFeatures(); final INDArray labelsMask = ds.getLabelsMaskArray(); INDArray lastTimeStepIndices; if (labelsMask != null) { lastTimeStepIndices = Nd4j.argMax(labelsMask, 1); } else { lastTimeStepIndices = Nd4j.zeros(features.size(0), 1); } INDArray predBatch = model.outputSingle(features); int currentBatchSize = (int) predBatch.size(0); for (int i = 0; i < currentBatchSize; i++) { int thisTimeSeriesLastIndex = lastTimeStepIndices.getInt(i); INDArray thisExampleProbabilities = predBatch.get( NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(thisTimeSeriesLastIndex)); for (int j = 0; j < insts.numClasses(); j++) { preds[i + offset][j] = thisExampleProbabilities.getDouble(j); } } offset += currentBatchSize; // add batchsize as offset boolean iteratorHasInstancesLeft = offset < insts.numInstances(); next = it.hasNext() || iteratorHasInstancesLeft; } // Fix classes for (int i = 0; i < preds.length; i++) { if (preds[i].length > 1) { weka.core.Utils.normalize(preds[i]); } else { // Rescale numeric classes with the computed coefficients in the initialization phase preds[i][0] = preds[i][0] * x1 + x0; } } return preds; }
Example 16
Source File: RnnSequenceClassifier.java From wekaDeeplearning4j with GNU General Public License v3.0 | 4 votes |
/** * The method to use when making predictions for test instances. * * @param insts the instances to get predictions for * @return the class probability estimates (if the class is nominal) or the numeric predictions * (if it is numeric) * @throws Exception if something goes wrong at prediction time */ @Override public double[][] distributionsForInstances(Instances insts) throws Exception { log.info("Calc. dist for {} instances", insts.numInstances()); // Do we only have a ZeroR model? if (zeroR != null) { return zeroR.distributionsForInstances(insts); } // Process input data to have the same filters applied as the training data insts = applyFilters(insts); // Get predictions final DataSetIterator it = getDataSetIterator(insts, CacheMode.NONE); double[][] preds = new double[insts.numInstances()][insts.numClasses()]; if (it.resetSupported()) { it.reset(); } int offset = 0; boolean next = it.hasNext(); // Get predictions batch-wise while (next) { final DataSet ds = Utils.getNext(it); final INDArray features = ds.getFeatures(); final INDArray labelsMask = ds.getLabelsMaskArray(); INDArray lastTimeStepIndices; if (labelsMask != null) { lastTimeStepIndices = Nd4j.argMax(labelsMask, 1); } else { lastTimeStepIndices = Nd4j.zeros(features.size(0), 1); } INDArray predBatch = model.outputSingle(features); int currentBatchSize = (int) predBatch.size(0); for (int i = 0; i < currentBatchSize; i++) { int thisTimeSeriesLastIndex = lastTimeStepIndices.getInt(i); INDArray thisExampleProbabilities = predBatch.get( NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(thisTimeSeriesLastIndex)); for (int j = 0; j < insts.numClasses(); j++) { preds[i + offset][j] = thisExampleProbabilities.getDouble(j); } } offset += currentBatchSize; // add batchsize as offset boolean iteratorHasInstancesLeft = offset < insts.numInstances(); next = it.hasNext() || iteratorHasInstancesLeft; } // Fix classes for (int i = 0; i < preds.length; i++) { if (preds[i].length > 1) { weka.core.Utils.normalize(preds[i]); } else { // Rescale numeric classes with the computed coefficients in the initialization phase preds[i][0] = preds[i][0] * x1 + x0; } } return preds; }
Example 17
Source File: ParallelInferenceTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
protected void evalClassifcationMultipleThreads(@NonNull ParallelInference inf, @NonNull DataSetIterator iterator, int numThreads) throws Exception { DataSet ds = iterator.next(); log.info("NumColumns: {}", ds.getLabels().columns()); iterator.reset(); Evaluation eval = new Evaluation(ds.getLabels().columns()); final Queue<DataSet> dataSets = new LinkedBlockingQueue<>(); final Queue<Pair<INDArray, INDArray>> outputs = new LinkedBlockingQueue<>(); int cnt = 0; // first of all we'll build datasets while (iterator.hasNext() && cnt < 256) { ds = iterator.next(); dataSets.add(ds); cnt++; } // now we'll build outputs in parallel Thread[] threads = new Thread[numThreads]; for (int i = 0; i < numThreads; i++) { threads[i] = new Thread(new Runnable() { @Override public void run() { DataSet ds; while ((ds = dataSets.poll()) != null) { INDArray output = inf.output(ds); outputs.add(Pair.makePair(ds.getLabels(), output)); } } }); } for (int i = 0; i < numThreads; i++) { threads[i].start(); } for (int i = 0; i < numThreads; i++) { threads[i].join(); } // and now we'll evaluate in single thread once again Pair<INDArray, INDArray> output; while ((output = outputs.poll()) != null) { eval.eval(output.getFirst(), output.getSecond()); } log.info(eval.stats()); }
Example 18
Source File: DL4JSentimentAnalysisExample.java From Java-for-Data-Science with MIT License | 4 votes |
public static void main(String[] args) throws Exception { getModelData(); System.out.println("Total memory = " + Runtime.getRuntime().totalMemory()); int batchSize = 50; int vectorSize = 300; int nEpochs = 5; int truncateReviewsToLength = 300; MultiLayerConfiguration sentimentNN = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .updater(Updater.RMSPROP) .regularization(true).l2(1e-5) .weightInit(WeightInit.XAVIER) .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0) .learningRate(0.0018) .list() .layer(0, new GravesLSTM.Builder().nIn(vectorSize).nOut(200) .activation("softsign").build()) .layer(1, new RnnOutputLayer.Builder().activation("softmax") .lossFunction(LossFunctions.LossFunction.MCXENT).nIn(200).nOut(2).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(sentimentNN); net.init(); net.setListeners(new ScoreIterationListener(1)); WordVectors wordVectors = WordVectorSerializer.loadGoogleModel(new File(GNEWS_VECTORS_PATH), true, false); DataSetIterator trainData = new AsyncDataSetIterator(new SentimentExampleIterator(EXTRACT_DATA_PATH, wordVectors, batchSize, truncateReviewsToLength, true), 1); DataSetIterator testData = new AsyncDataSetIterator(new SentimentExampleIterator(EXTRACT_DATA_PATH, wordVectors, 100, truncateReviewsToLength, false), 1); for (int i = 0; i < nEpochs; i++) { net.fit(trainData); trainData.reset(); Evaluation evaluation = new Evaluation(); while (testData.hasNext()) { DataSet t = testData.next(); INDArray dataFeatures = t.getFeatureMatrix(); INDArray dataLabels = t.getLabels(); INDArray inMask = t.getFeaturesMaskArray(); INDArray outMask = t.getLabelsMaskArray(); INDArray predicted = net.output(dataFeatures, false, inMask, outMask); evaluation.evalTimeSeries(dataLabels, predicted, outMask); } testData.reset(); System.out.println(evaluation.stats()); } }
Example 19
Source File: TrainCifar10Model.java From Java-Machine-Learning-for-Computer-Vision with MIT License | 4 votes |
private void train() throws IOException { ZooModel zooModel = VGG16.builder().build(); ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained(PretrainedType.CIFAR10); log.info(vgg16.summary()); IUpdater iUpdaterWithDefaultConfig = Updater.ADAM.getIUpdaterWithDefaultConfig(); iUpdaterWithDefaultConfig.setLrAndSchedule(0.1, null); FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder() .seed(1234) // .weightInit(WeightInit.XAVIER) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .activation(Activation.RELU) .updater(iUpdaterWithDefaultConfig) .cudnnAlgoMode(ConvolutionLayer.AlgoMode.NO_WORKSPACE) .miniBatch(true) .inferenceWorkspaceMode(WorkspaceMode.ENABLED) .trainingWorkspaceMode(WorkspaceMode.ENABLED) .pretrain(true) .backprop(true) .build(); ComputationGraph cifar10 = new TransferLearning.GraphBuilder(vgg16) .setWorkspaceMode(WorkspaceMode.ENABLED) .fineTuneConfiguration(fineTuneConf) .setInputTypes(InputType.convolutionalFlat(ImageUtils.HEIGHT, ImageUtils.WIDTH, 3)) .removeVertexAndConnections("dense_2_loss") .removeVertexAndConnections("dense_2") .removeVertexAndConnections("dense_1") .removeVertexAndConnections("dropout_1") .removeVertexAndConnections("embeddings") .removeVertexAndConnections("flatten_1") .addLayer("dense_1", new DenseLayer.Builder() .nIn(4096) .nOut(EMBEDDINGS) .activation(Activation.RELU).build(), "block3_pool") .addVertex("embeddings", new L2NormalizeVertex(new int[]{}, 1e-12), "dense_1") .addLayer("lossLayer", new CenterLossOutputLayer.Builder() .lossFunction(LossFunctions.LossFunction.SQUARED_LOSS) .activation(Activation.SOFTMAX).nIn(EMBEDDINGS).nOut(NUM_POSSIBLE_LABELS) .lambda(LAMBDA).alpha(0.9) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).build(), "embeddings") .setOutputs("lossLayer") .build(); log.info(cifar10.summary()); File rootDir = new File("CarTracking/train_from_video_" + NUM_POSSIBLE_LABELS); DataSetIterator dataSetIterator = ImageUtils.createDataSetIterator(rootDir, NUM_POSSIBLE_LABELS, BATCH_SIZE); DataSetIterator testIterator = ImageUtils.createDataSetIterator(rootDir, NUM_POSSIBLE_LABELS, BATCH_SIZE); cifar10.setListeners(new ScoreIterationListener(2)); int iEpoch = I_EPOCH; while (iEpoch < EPOCH_TRAINING) { while (dataSetIterator.hasNext()) { DataSet trainMiniBatchData = null; try { trainMiniBatchData = dataSetIterator.next(); } catch (Exception e) { e.printStackTrace(); } cifar10.fit(trainMiniBatchData); } iEpoch++; String modelName = PREFIX + NUM_POSSIBLE_LABELS + "_epoch_data_e" + EMBEDDINGS + "_b" + BATCH_SIZE + "_" + iEpoch + ".zip"; saveProgress(cifar10, iEpoch, modelName); testResults(cifar10, testIterator, iEpoch, modelName); dataSetIterator.reset(); log.info("iEpoch = " + iEpoch); } }
Example 20
Source File: TestSparkComputationGraph.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test(timeout = 60000L) public void testEvaluationAndRoc() { for( int evalWorkers : new int[]{1, 4, 8}) { DataSetIterator iter = new IrisDataSetIterator(5, 150); //Make a 2-class version of iris: List<DataSet> l = new ArrayList<>(); iter.reset(); while (iter.hasNext()) { DataSet ds = iter.next(); INDArray newL = Nd4j.create(ds.getLabels().size(0), 2); newL.putColumn(0, ds.getLabels().getColumn(0)); newL.putColumn(1, ds.getLabels().getColumn(1)); newL.getColumn(1).addi(ds.getLabels().getColumn(2)); ds.setLabels(newL); l.add(ds); } iter = new ListDataSetIterator<>(l); ComputationGraph cg = getBasicNetIris2Class(); Evaluation e = cg.evaluate(iter); ROC roc = cg.evaluateROC(iter, 32); SparkComputationGraph scg = new SparkComputationGraph(sc, cg, null); scg.setDefaultEvaluationWorkers(evalWorkers); JavaRDD<DataSet> rdd = sc.parallelize(l); rdd = rdd.repartition(20); Evaluation e2 = scg.evaluate(rdd); ROC roc2 = scg.evaluateROC(rdd); assertEquals(e2.accuracy(), e.accuracy(), 1e-3); assertEquals(e2.f1(), e.f1(), 1e-3); assertEquals(e2.getNumRowCounter(), e.getNumRowCounter(), 1e-3); assertEquals(e2.falseNegatives(), e.falseNegatives()); assertEquals(e2.falsePositives(), e.falsePositives()); assertEquals(e2.trueNegatives(), e.trueNegatives()); assertEquals(e2.truePositives(), e.truePositives()); assertEquals(e2.precision(), e.precision(), 1e-3); assertEquals(e2.recall(), e.recall(), 1e-3); assertEquals(e2.getConfusionMatrix(), e.getConfusionMatrix()); assertEquals(roc.calculateAUC(), roc2.calculateAUC(), 1e-5); assertEquals(roc.calculateAUCPR(), roc2.calculateAUCPR(), 1e-5); } }