org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator Java Examples
The following examples show how to use
org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: MultiRegression.java From dl4j-tutorials with MIT License | 6 votes |
private static DataSetIterator getTrainingData(int batchSize, Random rand) { double [] sum = new double[nSamples]; double [] input1 = new double[nSamples]; double [] input2 = new double[nSamples]; for (int i= 0; i< nSamples; i++) { input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble(); input2[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble(); sum[i] = input1[i] + input2[i]; } INDArray inputNDArray1 = Nd4j.create(input1, new int[]{nSamples,1}); INDArray inputNDArray2 = Nd4j.create(input2, new int[]{nSamples,1}); INDArray inputNDArray = Nd4j.hstack(inputNDArray1,inputNDArray2); INDArray outPut = Nd4j.create(sum, new int[]{nSamples, 1}); DataSet dataSet = new DataSet(inputNDArray, outPut); List<DataSet> listDs = dataSet.asList(); return new ListDataSetIterator(listDs,batchSize); }
Example #2
Source File: DL4JSequenceRecommender.java From inception with Apache License 2.0 | 5 votes |
private MultiLayerNetwork train(List<Sample> aTrainingData, Object2IntMap<String> aTagset) throws IOException { // Configure the neural network MultiLayerNetwork model = createConfiguredNetwork(traits, wordVectors.dimensions()); final int limit = traits.getTrainingSetSizeLimit(); final int batchSize = traits.getBatchSize(); // First vectorizing all sentences and then passing them to the model would consume // huge amounts of memory. Thus, every sentence is vectorized and then immediately // passed on to the model. nextEpoch: for (int epoch = 0; epoch < traits.getnEpochs(); epoch++) { int sentNum = 0; Iterator<Sample> sampleIterator = aTrainingData.iterator(); while (sampleIterator.hasNext()) { List<DataSet> batch = new ArrayList<>(); while (sampleIterator.hasNext() && batch.size() < batchSize && sentNum < limit) { Sample sample = sampleIterator.next(); DataSet trainingData = vectorize(asList(sample), aTagset, true); batch.add(trainingData); sentNum++; } model.fit(new ListDataSetIterator<DataSet>(batch, batch.size())); log.trace("Epoch {}: processed {} of {} sentences", epoch, sentNum, aTrainingData.size()); if (sentNum >= limit) { continue nextEpoch; } } } return model; }
Example #3
Source File: LinearModel.java From FederatedAndroidTrainer with MIT License | 5 votes |
@Override public void train(FederatedDataSet dataSource) { DataSet trainingData = (DataSet) dataSource.getNativeDataSet(); List<DataSet> listDs = trainingData.asList(); DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE); //Train the network on the full data set, and evaluate in periodically for (int i = 0; i < N_EPOCHS; i++) { iterator.reset(); mNetwork.fit(iterator); } }
Example #4
Source File: LinearModel.java From FederatedAndroidTrainer with MIT License | 5 votes |
@Override public String evaluate(FederatedDataSet federatedDataSet) { DataSet testData = (DataSet) federatedDataSet.getNativeDataSet(); List<DataSet> listDs = testData.asList(); DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE); return mNetwork.evaluate(iterator).stats(); }
Example #5
Source File: MNISTModel.java From FederatedAndroidTrainer with MIT License | 5 votes |
@Override public String evaluate(FederatedDataSet federatedDataSet) { DataSet testData = (DataSet) federatedDataSet.getNativeDataSet(); List<DataSet> listDs = testData.asList(); DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE); Evaluation eval = new Evaluation(OUTPUT_NUM); //create an evaluation object with 10 possible classes while (iterator.hasNext()) { DataSet next = iterator.next(); INDArray output = model.output(next.getFeatureMatrix()); //get the networks prediction eval.eval(next.getLabels(), output); //check the prediction against the true class } return eval.stats(); }
Example #6
Source File: MNISTModel.java From FederatedAndroidTrainer with MIT License | 5 votes |
@Override public void train(FederatedDataSet federatedDataSet) { DataSet trainingData = (DataSet) federatedDataSet.getNativeDataSet(); List<DataSet> listDs = trainingData.asList(); DataSetIterator mnistTrain = new ListDataSetIterator(listDs, BATCH_SIZE); for (int i = 0; i < N_EPOCHS; i++) { model.fit(mnistTrain); } }
Example #7
Source File: SingleRegression.java From dl4j-tutorials with MIT License | 5 votes |
private static DataSetIterator getTrainingData(int batchSize, Random rand) { /** * 如何构造我们的训练数据 * 现有的模型主要是有监督学习 * 我们的训练集必须有 特征+标签 * 特征-> x * 标签->y */ double [] output = new double[nSamples]; double [] input = new double[nSamples]; //随机生成0到3之间的x //并且构造 y = 0.5x + 0.1 //a -> 0.5 b ->0.1 for (int i= 0; i< nSamples; i++) { input[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble(); output[i] = 0.5 * input[i] + 0.1; } /** * 我们nSamples条数据 * 每条数据只有1个x */ INDArray inputNDArray = Nd4j.create(input, new int[]{nSamples,1}); INDArray outPut = Nd4j.create(output, new int[]{nSamples, 1}); /** * 构造喂给神经网络的数据集 * DataSet是将 特征+标签 包装成为一个类 * */ DataSet dataSet = new DataSet(inputNDArray, outPut); List<DataSet> listDs = dataSet.asList(); return new ListDataSetIterator(listDs,batchSize); }
Example #8
Source File: RegressionMathFunctions.java From dl4j-tutorials with MIT License | 5 votes |
/** Create a DataSetIterator for training * @param x X values * @param function Function to evaluate * @param batchSize Batch size (number of examples for every call of DataSetIterator.next()) * @param rng Random number generator (for repeatability) */ private static DataSetIterator getTrainingData(final INDArray x, final MathFunction function, final int batchSize, final Random rng) { final INDArray y = function.getFunctionValues(x); final DataSet allData = new DataSet(x,y); final List<DataSet> list = allData.asList(); Collections.shuffle(list,rng); return new ListDataSetIterator(list,batchSize); }
Example #9
Source File: TestCompGraphCNN.java From deeplearning4j with Apache License 2.0 | 5 votes |
protected static DataSetIterator getDS() { List<DataSet> list = new ArrayList<>(5); for (int i = 0; i < 5; i++) { INDArray f = Nd4j.create(1, 32 * 32 * 3); INDArray l = Nd4j.create(1, 10); l.putScalar(i, 1.0); list.add(new DataSet(f, l)); } return new ListDataSetIterator(list, 5); }
Example #10
Source File: NormalizeUciData.java From SKIL_Examples with Apache License 2.0 | 4 votes |
public void run() throws Exception { File trainingOutputFile = new File(trainOutputPath); File testOutputFile = new File(testOutputPath); if (trainingOutputFile.exists() || testOutputFile.exists()) { System.out.println(String.format("Warning: overwriting output files (%s, %s)", trainOutputPath, testOutputPath)); trainingOutputFile.delete(); testOutputFile.delete(); } System.out.format("downloading from %s\n", downloadUrl); System.out.format("writing training output to %s\n", trainOutputPath); System.out.format("writing testing output to %s\n", testOutputPath); URL url = new URL(downloadUrl); String data = IOUtils.toString(url); String[] lines = data.split("\n"); List<INDArray> arrays = new LinkedList<INDArray>(); List<Integer> labels = new LinkedList<Integer>(); for (int i=0; i<lines.length; i++) { String line = lines[i]; String[] cols = line.split("\\s+"); int label = i / 100; INDArray array = Nd4j.zeros(1, 60); for (int j=0; j<cols.length; j++) { Double d = Double.parseDouble(cols[j]); array.putScalar(0, j, d); } arrays.add(array); labels.add(label); } // Shuffle with **known** seed Collections.shuffle(arrays, new Random(12345)); Collections.shuffle(labels, new Random(12345)); INDArray trainData = Nd4j.zeros(450, 60); INDArray testData = Nd4j.zeros(150, 60); for (int i=0; i<arrays.size(); i++) { INDArray arr = arrays.get(i); if (i < 450) { // Training trainData.putRow(i, arr); } else { // Test testData.putRow(i-450, arr); } } DataSet trainDs = new DataSet(trainData, trainData); DataSetIterator trainIt = new ListDataSetIterator(trainDs.asList()); DataSet testDs = new DataSet(testData, testData); DataSetIterator testIt = new ListDataSetIterator(testDs.asList()); // Fit normalizer on training data only! DataNormalization normalizer = dataNormalizer.getNormalizer(); normalizer.fit(trainIt); // Print out basic summary stats switch (normalizer.getType()) { case STANDARDIZE: System.out.format("Normalizer - Standardize:\n mean=%s\n std= %s\n", ((NormalizerStandardize)normalizer).getMean(), ((NormalizerStandardize)normalizer).getStd()); } // Use same normalizer for both trainIt.setPreProcessor(normalizer); testIt.setPreProcessor(normalizer); String trainOutput = toCsv(trainIt, labels.subList(0, 450), new int[]{1, 60}); String testOutput = toCsv(testIt, labels.subList(450, 600), new int[]{1, 60}); FileUtils.write(trainingOutputFile, trainOutput); System.out.format("wrote normalized training file to %s\n", trainingOutputFile); FileUtils.write(testOutputFile, testOutput); System.out.format("wrote normalized test file to %s\n", testOutputFile); }
Example #11
Source File: EvalTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testIris() { // Network config MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42) .updater(new Sgd(1e-6)).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH) .weightInit(WeightInit.XAVIER).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER) .activation(Activation.SOFTMAX).build()) .build(); // Instantiate model MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.addListeners(new ScoreIterationListener(1)); // Train-test split DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSet next = iter.next(); next.shuffle(); SplitTestAndTrain trainTest = next.splitTestAndTrain(5, new Random(42)); // Train DataSet train = trainTest.getTrain(); train.normalizeZeroMeanZeroUnitVariance(); // Test DataSet test = trainTest.getTest(); test.normalizeZeroMeanZeroUnitVariance(); INDArray testFeature = test.getFeatures(); INDArray testLabel = test.getLabels(); // Fitting model model.fit(train); // Get predictions from test feature INDArray testPredictedLabel = model.output(testFeature); // Eval with class number org.nd4j.evaluation.classification.Evaluation eval = new org.nd4j.evaluation.classification.Evaluation(3); //// Specify class num here eval.eval(testLabel, testPredictedLabel); double eval1F1 = eval.f1(); double eval1Acc = eval.accuracy(); // Eval without class number org.nd4j.evaluation.classification.Evaluation eval2 = new org.nd4j.evaluation.classification.Evaluation(); //// No class num eval2.eval(testLabel, testPredictedLabel); double eval2F1 = eval2.f1(); double eval2Acc = eval2.accuracy(); //Assert the two implementations give same f1 and accuracy (since one batch) assertTrue(eval1F1 == eval2F1 && eval1Acc == eval2Acc); org.nd4j.evaluation.classification.Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator<>(Collections.singletonList(test))); checkEvaluationEquality(eval, evalViaMethod); // System.out.println(eval.getConfusionMatrix().toString()); // System.out.println(eval.getConfusionMatrix().toCSV()); // System.out.println(eval.getConfusionMatrix().toHTML()); // System.out.println(eval.confusionToString()); eval.getConfusionMatrix().toString(); eval.getConfusionMatrix().toCSV(); eval.getConfusionMatrix().toHTML(); eval.confusionToString(); }
Example #12
Source File: TestSparkComputationGraph.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test(timeout = 60000L) public void testEvaluationAndRoc() { for( int evalWorkers : new int[]{1, 4, 8}) { DataSetIterator iter = new IrisDataSetIterator(5, 150); //Make a 2-class version of iris: List<DataSet> l = new ArrayList<>(); iter.reset(); while (iter.hasNext()) { DataSet ds = iter.next(); INDArray newL = Nd4j.create(ds.getLabels().size(0), 2); newL.putColumn(0, ds.getLabels().getColumn(0)); newL.putColumn(1, ds.getLabels().getColumn(1)); newL.getColumn(1).addi(ds.getLabels().getColumn(2)); ds.setLabels(newL); l.add(ds); } iter = new ListDataSetIterator<>(l); ComputationGraph cg = getBasicNetIris2Class(); Evaluation e = cg.evaluate(iter); ROC roc = cg.evaluateROC(iter, 32); SparkComputationGraph scg = new SparkComputationGraph(sc, cg, null); scg.setDefaultEvaluationWorkers(evalWorkers); JavaRDD<DataSet> rdd = sc.parallelize(l); rdd = rdd.repartition(20); Evaluation e2 = scg.evaluate(rdd); ROC roc2 = scg.evaluateROC(rdd); assertEquals(e2.accuracy(), e.accuracy(), 1e-3); assertEquals(e2.f1(), e.f1(), 1e-3); assertEquals(e2.getNumRowCounter(), e.getNumRowCounter(), 1e-3); assertEquals(e2.falseNegatives(), e.falseNegatives()); assertEquals(e2.falsePositives(), e.falsePositives()); assertEquals(e2.trueNegatives(), e.trueNegatives()); assertEquals(e2.truePositives(), e.truePositives()); assertEquals(e2.precision(), e.precision(), 1e-3); assertEquals(e2.recall(), e.recall(), 1e-3); assertEquals(e2.getConfusionMatrix(), e.getConfusionMatrix()); assertEquals(roc.calculateAUC(), roc2.calculateAUC(), 1e-5); assertEquals(roc.calculateAUCPR(), roc2.calculateAUCPR(), 1e-5); } }
Example #13
Source File: LearnDigitsBackprop.java From aifh with Apache License 2.0 | 4 votes |
/** * The main method. * @param args Not used. */ public static void main(String[] args) { try { int seed = 43; double learningRate = 1e-2; int nEpochs = 50; int batchSize = 500; // Setup training data. System.out.println("Please wait, reading MNIST training data."); String dir = System.getProperty("user.dir"); MNISTReader trainingReader = MNIST.loadMNIST(dir, true); MNISTReader validationReader = MNIST.loadMNIST(dir, false); DataSet trainingSet = trainingReader.getData(); DataSet validationSet = validationReader.getData(); DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize); DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows()); System.out.println("Training set size: " + trainingReader.getNumImages()); System.out.println("Validation set size: " + validationReader.getNumImages()); System.out.println(trainingSet.get(0).getFeatures().size(1)); System.out.println(validationSet.get(0).getFeatures().size(1)); int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows(); int numOutputs = 10; int numHiddenNodes = 200; // Create neural network. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(1) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(learningRate) .updater(Updater.NESTEROVS).momentum(0.9) .regularization(true).dropOut(0.50) .list(2) .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) .weightInit(WeightInit.XAVIER) .activation("relu") .build()) .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) .weightInit(WeightInit.XAVIER) .activation("softmax") .nIn(numHiddenNodes).nOut(numOutputs).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(1)); // Define when we want to stop training. EarlyStoppingModelSaver saver = new InMemoryModelSaver(); EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder() //.epochTerminationConditions(new MaxEpochsTerminationCondition(10)) .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5)) .evaluateEveryNEpochs(1) .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score .modelSaver(saver) .build(); EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator); // Train and display result. EarlyStoppingResult result = trainer.fit(); System.out.println("Termination reason: " + result.getTerminationReason()); System.out.println("Termination details: " + result.getTerminationDetails()); System.out.println("Total epochs: " + result.getTotalEpochs()); System.out.println("Best epoch number: " + result.getBestModelEpoch()); System.out.println("Score at best epoch: " + result.getBestModelScore()); model = saver.getBestModel(); // Evaluate Evaluation eval = new Evaluation(numOutputs); validationSetIterator.reset(); for (int i = 0; i < validationSet.numExamples(); i++) { DataSet t = validationSet.get(i); INDArray features = t.getFeatureMatrix(); INDArray labels = t.getLabels(); INDArray predicted = model.output(features, false); eval.eval(labels, predicted); } //Print the evaluation statistics System.out.println(eval.stats()); } catch(Exception ex) { ex.printStackTrace(); } }
Example #14
Source File: LearnDigitsDropout.java From aifh with Apache License 2.0 | 4 votes |
/** * The main method. * @param args Not used. */ public static void main(String[] args) { try { int seed = 43; double learningRate = 1e-2; int nEpochs = 50; int batchSize = 500; // Setup training data. System.out.println("Please wait, reading MNIST training data."); String dir = System.getProperty("user.dir"); MNISTReader trainingReader = MNIST.loadMNIST(dir, true); MNISTReader validationReader = MNIST.loadMNIST(dir, false); DataSet trainingSet = trainingReader.getData(); DataSet validationSet = validationReader.getData(); DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize); DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows()); System.out.println("Training set size: " + trainingReader.getNumImages()); System.out.println("Validation set size: " + validationReader.getNumImages()); System.out.println(trainingSet.get(0).getFeatures().size(1)); System.out.println(validationSet.get(0).getFeatures().size(1)); int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows(); int numOutputs = 10; int numHiddenNodes = 100; // Create neural network. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(1) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(learningRate) .updater(Updater.NESTEROVS).momentum(0.9) .list(2) .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) .weightInit(WeightInit.XAVIER) .activation("relu") .build()) .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) .weightInit(WeightInit.XAVIER) .activation("softmax") .nIn(numHiddenNodes).nOut(numOutputs).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(1)); // Define when we want to stop training. EarlyStoppingModelSaver saver = new InMemoryModelSaver(); EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder() //.epochTerminationConditions(new MaxEpochsTerminationCondition(10)) .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5)) .evaluateEveryNEpochs(1) .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score .modelSaver(saver) .build(); EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator); // Train and display result. EarlyStoppingResult result = trainer.fit(); System.out.println("Termination reason: " + result.getTerminationReason()); System.out.println("Termination details: " + result.getTerminationDetails()); System.out.println("Total epochs: " + result.getTotalEpochs()); System.out.println("Best epoch number: " + result.getBestModelEpoch()); System.out.println("Score at best epoch: " + result.getBestModelScore()); model = saver.getBestModel(); // Evaluate Evaluation eval = new Evaluation(numOutputs); validationSetIterator.reset(); for (int i = 0; i < validationSet.numExamples(); i++) { DataSet t = validationSet.get(i); INDArray features = t.getFeatureMatrix(); INDArray labels = t.getLabels(); INDArray predicted = model.output(features, false); eval.eval(labels, predicted); } //Print the evaluation statistics System.out.println(eval.stats()); } catch(Exception ex) { ex.printStackTrace(); } }
Example #15
Source File: LearnIrisBackprop.java From aifh with Apache License 2.0 | 4 votes |
/** * The main method. * @param args Not used. */ public static void main(String[] args) { try { int seed = 43; double learningRate = 0.1; int splitTrainNum = (int) (150 * .75); int numInputs = 4; int numOutputs = 3; int numHiddenNodes = 50; // Setup training data. final InputStream istream = LearnIrisBackprop.class.getResourceAsStream("/iris.csv"); if( istream==null ) { System.out.println("Cannot access data set, make sure the resources are available."); System.exit(1); } final NormalizeDataSet ds = NormalizeDataSet.load(istream); final CategoryMap species = ds.encodeOneOfN(4); // species is column 4 istream.close(); DataSet next = ds.extractSupervised(0, 4, 4, 3); next.shuffle(); // Training and validation data split SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed)); DataSet trainSet = testAndTrain.getTrain(); DataSet validationSet = testAndTrain.getTest(); DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples()); DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationSet.numExamples()); // Create neural network. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(1) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(learningRate) .updater(Updater.NESTEROVS).momentum(0.9) .list(2) .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) .weightInit(WeightInit.XAVIER) .activation("relu") .build()) .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) .weightInit(WeightInit.XAVIER) .activation("softmax") .nIn(numHiddenNodes).nOut(numOutputs).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(1)); // Define when we want to stop training. EarlyStoppingModelSaver saver = new InMemoryModelSaver(); EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder() .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25)) .evaluateEveryNEpochs(1) .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score .modelSaver(saver) .build(); EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator); // Train and display result. EarlyStoppingResult result = trainer.fit(); System.out.println("Termination reason: " + result.getTerminationReason()); System.out.println("Termination details: " + result.getTerminationDetails()); System.out.println("Total epochs: " + result.getTotalEpochs()); System.out.println("Best epoch number: " + result.getBestModelEpoch()); System.out.println("Score at best epoch: " + result.getBestModelScore()); model = saver.getBestModel(); // Evaluate Evaluation eval = new Evaluation(numOutputs); validationSetIterator.reset(); for (int i = 0; i < validationSet.numExamples(); i++) { DataSet t = validationSet.get(i); INDArray features = t.getFeatureMatrix(); INDArray labels = t.getLabels(); INDArray predicted = model.output(features, false); System.out.println(features + ":Prediction("+findSpecies(labels,species) +"):Actual("+findSpecies(predicted,species)+")" + predicted ); eval.eval(labels, predicted); } //Print the evaluation statistics System.out.println(eval.stats()); } catch(Exception ex) { ex.printStackTrace(); } }