org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator Java Examples

The following examples show how to use org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: MultiRegression.java    From dl4j-tutorials with MIT License 6 votes vote down vote up
private static DataSetIterator getTrainingData(int batchSize, Random rand) {
    double [] sum = new double[nSamples];
    double [] input1 = new double[nSamples];
    double [] input2 = new double[nSamples];
    for (int i= 0; i< nSamples; i++) {
        input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        input2[i] =  MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        sum[i] = input1[i] + input2[i];
    }
    INDArray inputNDArray1 = Nd4j.create(input1, new int[]{nSamples,1});
    INDArray inputNDArray2 = Nd4j.create(input2, new int[]{nSamples,1});
    INDArray inputNDArray = Nd4j.hstack(inputNDArray1,inputNDArray2);
    INDArray outPut = Nd4j.create(sum, new int[]{nSamples, 1});
    DataSet dataSet = new DataSet(inputNDArray, outPut);
    List<DataSet> listDs = dataSet.asList();

    return new ListDataSetIterator(listDs,batchSize);
}
 
Example #2
Source File: DL4JSequenceRecommender.java    From inception with Apache License 2.0 5 votes vote down vote up
private MultiLayerNetwork train(List<Sample> aTrainingData, Object2IntMap<String> aTagset)
    throws IOException
{
    // Configure the neural network
    MultiLayerNetwork model = createConfiguredNetwork(traits, wordVectors.dimensions());

    final int limit = traits.getTrainingSetSizeLimit();
    final int batchSize = traits.getBatchSize();

    // First vectorizing all sentences and then passing them to the model would consume
    // huge amounts of memory. Thus, every sentence is vectorized and then immediately
    // passed on to the model.
    nextEpoch: for (int epoch = 0; epoch < traits.getnEpochs(); epoch++) {
        int sentNum = 0;
        Iterator<Sample> sampleIterator = aTrainingData.iterator();
        while (sampleIterator.hasNext()) {
            List<DataSet> batch = new ArrayList<>();
            while (sampleIterator.hasNext() && batch.size() < batchSize && sentNum < limit) {
                Sample sample = sampleIterator.next();
                DataSet trainingData = vectorize(asList(sample), aTagset, true);
                batch.add(trainingData);
                sentNum++;
            }
            
            model.fit(new ListDataSetIterator<DataSet>(batch, batch.size()));
            log.trace("Epoch {}: processed {} of {} sentences", epoch, sentNum,
                    aTrainingData.size());
            
            if (sentNum >= limit) {
                continue nextEpoch;
            }
        }
    }

    return model;
}
 
Example #3
Source File: LinearModel.java    From FederatedAndroidTrainer with MIT License 5 votes vote down vote up
@Override
public void train(FederatedDataSet dataSource) {
    DataSet trainingData = (DataSet) dataSource.getNativeDataSet();
    List<DataSet> listDs = trainingData.asList();
    DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);

    //Train the network on the full data set, and evaluate in periodically
    for (int i = 0; i < N_EPOCHS; i++) {
        iterator.reset();
        mNetwork.fit(iterator);
    }
}
 
Example #4
Source File: LinearModel.java    From FederatedAndroidTrainer with MIT License 5 votes vote down vote up
@Override
public String evaluate(FederatedDataSet federatedDataSet) {
    DataSet testData = (DataSet) federatedDataSet.getNativeDataSet();
    List<DataSet> listDs = testData.asList();
    DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);

    return mNetwork.evaluate(iterator).stats();
}
 
Example #5
Source File: MNISTModel.java    From FederatedAndroidTrainer with MIT License 5 votes vote down vote up
@Override
public String evaluate(FederatedDataSet federatedDataSet) {
    DataSet testData = (DataSet) federatedDataSet.getNativeDataSet();
    List<DataSet> listDs = testData.asList();
    DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);

    Evaluation eval = new Evaluation(OUTPUT_NUM); //create an evaluation object with 10 possible classes
    while (iterator.hasNext()) {
        DataSet next = iterator.next();
        INDArray output = model.output(next.getFeatureMatrix()); //get the networks prediction
        eval.eval(next.getLabels(), output); //check the prediction against the true class
    }

    return eval.stats();
}
 
Example #6
Source File: MNISTModel.java    From FederatedAndroidTrainer with MIT License 5 votes vote down vote up
@Override
public void train(FederatedDataSet federatedDataSet) {
    DataSet trainingData = (DataSet) federatedDataSet.getNativeDataSet();
    List<DataSet> listDs = trainingData.asList();
    DataSetIterator mnistTrain = new ListDataSetIterator(listDs, BATCH_SIZE);
    for (int i = 0; i < N_EPOCHS; i++) {
        model.fit(mnistTrain);
    }
}
 
Example #7
Source File: SingleRegression.java    From dl4j-tutorials with MIT License 5 votes vote down vote up
private static DataSetIterator getTrainingData(int batchSize, Random rand) {
    /**
     * 如何构造我们的训练数据
     * 现有的模型主要是有监督学习
     * 我们的训练集必须有  特征+标签
     * 特征-> x
     * 标签->y
     */
    double [] output = new double[nSamples];
    double [] input = new double[nSamples];
    //随机生成0到3之间的x
    //并且构造 y = 0.5x + 0.1
    //a -> 0.5  b ->0.1
    for (int i= 0; i< nSamples; i++) {
        input[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();

        output[i] = 0.5 * input[i] + 0.1;
    }

    /**
     * 我们nSamples条数据
     * 每条数据只有1个x
     */
    INDArray inputNDArray = Nd4j.create(input, new int[]{nSamples,1});

    INDArray outPut = Nd4j.create(output, new int[]{nSamples, 1});

    /**
     * 构造喂给神经网络的数据集
     * DataSet是将  特征+标签  包装成为一个类
     *
     */
    DataSet dataSet = new DataSet(inputNDArray, outPut);
    List<DataSet> listDs = dataSet.asList();

    return new ListDataSetIterator(listDs,batchSize);
}
 
Example #8
Source File: RegressionMathFunctions.java    From dl4j-tutorials with MIT License 5 votes vote down vote up
/** Create a DataSetIterator for training
 * @param x X values
 * @param function Function to evaluate
 * @param batchSize Batch size (number of examples for every call of DataSetIterator.next())
 * @param rng Random number generator (for repeatability)
 */
private static DataSetIterator getTrainingData(final INDArray x, final MathFunction function, final int batchSize, final Random rng) {
    final INDArray y = function.getFunctionValues(x);
    final DataSet allData = new DataSet(x,y);

    final List<DataSet> list = allData.asList();
    Collections.shuffle(list,rng);
    return new ListDataSetIterator(list,batchSize);
}
 
Example #9
Source File: TestCompGraphCNN.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected static DataSetIterator getDS() {

        List<DataSet> list = new ArrayList<>(5);
        for (int i = 0; i < 5; i++) {
            INDArray f = Nd4j.create(1, 32 * 32 * 3);
            INDArray l = Nd4j.create(1, 10);
            l.putScalar(i, 1.0);
            list.add(new DataSet(f, l));
        }
        return new ListDataSetIterator(list, 5);
    }
 
Example #10
Source File: NormalizeUciData.java    From SKIL_Examples with Apache License 2.0 4 votes vote down vote up
public void run() throws Exception {
    File trainingOutputFile = new File(trainOutputPath);
    File testOutputFile = new File(testOutputPath);

    if (trainingOutputFile.exists() || testOutputFile.exists()) {
        System.out.println(String.format("Warning: overwriting output files (%s, %s)", trainOutputPath, testOutputPath));

        trainingOutputFile.delete();
        testOutputFile.delete();
    }

    System.out.format("downloading from %s\n", downloadUrl);
    System.out.format("writing training output to %s\n", trainOutputPath);
    System.out.format("writing testing output to %s\n", testOutputPath);

    URL url = new URL(downloadUrl);
    String data = IOUtils.toString(url);
    String[] lines = data.split("\n");
    List<INDArray> arrays = new LinkedList<INDArray>();
    List<Integer> labels = new LinkedList<Integer>();

    for (int i=0; i<lines.length; i++) {
        String line = lines[i];
        String[] cols = line.split("\\s+");

        int label = i / 100;
        INDArray array = Nd4j.zeros(1, 60);

        for (int j=0; j<cols.length; j++) {
            Double d = Double.parseDouble(cols[j]);
            array.putScalar(0, j, d);
        }

        arrays.add(array);
        labels.add(label);
    }

    // Shuffle with **known** seed
    Collections.shuffle(arrays, new Random(12345));
    Collections.shuffle(labels, new Random(12345));

    INDArray trainData = Nd4j.zeros(450, 60);
    INDArray testData = Nd4j.zeros(150, 60);

    for (int i=0; i<arrays.size(); i++) {
        INDArray arr = arrays.get(i);

        if (i < 450) { // Training
            trainData.putRow(i, arr);
        } else { // Test
            testData.putRow(i-450, arr);
        }
    }

    DataSet trainDs = new DataSet(trainData, trainData);
    DataSetIterator trainIt = new ListDataSetIterator(trainDs.asList());

    DataSet testDs = new DataSet(testData, testData);
    DataSetIterator testIt = new ListDataSetIterator(testDs.asList());

    // Fit normalizer on training data only!
    DataNormalization normalizer = dataNormalizer.getNormalizer();
    normalizer.fit(trainIt);

    // Print out basic summary stats
    switch (normalizer.getType()) {
        case STANDARDIZE:
            System.out.format("Normalizer - Standardize:\n  mean=%s\n  std= %s\n",
                    ((NormalizerStandardize)normalizer).getMean(),
                    ((NormalizerStandardize)normalizer).getStd());
    }

    // Use same normalizer for both
    trainIt.setPreProcessor(normalizer);
    testIt.setPreProcessor(normalizer);

    String trainOutput = toCsv(trainIt, labels.subList(0, 450), new int[]{1, 60});
    String testOutput = toCsv(testIt, labels.subList(450, 600), new int[]{1, 60});

    FileUtils.write(trainingOutputFile, trainOutput);
    System.out.format("wrote normalized training file to %s\n", trainingOutputFile);

    FileUtils.write(testOutputFile, testOutput);
    System.out.format("wrote normalized test file to %s\n", testOutputFile);

}
 
Example #11
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testIris() {

        // Network config
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()

                        .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42)
                        .updater(new Sgd(1e-6)).list()
                        .layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH)
                                        .weightInit(WeightInit.XAVIER).build())
                        .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
                                        LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER)
                                                        .activation(Activation.SOFTMAX).build())

                        .build();

        // Instantiate model
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.addListeners(new ScoreIterationListener(1));

        // Train-test split
        DataSetIterator iter = new IrisDataSetIterator(150, 150);
        DataSet next = iter.next();
        next.shuffle();
        SplitTestAndTrain trainTest = next.splitTestAndTrain(5, new Random(42));

        // Train
        DataSet train = trainTest.getTrain();
        train.normalizeZeroMeanZeroUnitVariance();

        // Test
        DataSet test = trainTest.getTest();
        test.normalizeZeroMeanZeroUnitVariance();
        INDArray testFeature = test.getFeatures();
        INDArray testLabel = test.getLabels();

        // Fitting model
        model.fit(train);
        // Get predictions from test feature
        INDArray testPredictedLabel = model.output(testFeature);

        // Eval with class number
        org.nd4j.evaluation.classification.Evaluation eval = new org.nd4j.evaluation.classification.Evaluation(3); //// Specify class num here
        eval.eval(testLabel, testPredictedLabel);
        double eval1F1 = eval.f1();
        double eval1Acc = eval.accuracy();

        // Eval without class number
        org.nd4j.evaluation.classification.Evaluation eval2 = new org.nd4j.evaluation.classification.Evaluation(); //// No class num
        eval2.eval(testLabel, testPredictedLabel);
        double eval2F1 = eval2.f1();
        double eval2Acc = eval2.accuracy();

        //Assert the two implementations give same f1 and accuracy (since one batch)
        assertTrue(eval1F1 == eval2F1 && eval1Acc == eval2Acc);

        org.nd4j.evaluation.classification.Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator<>(Collections.singletonList(test)));
        checkEvaluationEquality(eval, evalViaMethod);

//        System.out.println(eval.getConfusionMatrix().toString());
//        System.out.println(eval.getConfusionMatrix().toCSV());
//        System.out.println(eval.getConfusionMatrix().toHTML());
//        System.out.println(eval.confusionToString());

        eval.getConfusionMatrix().toString();
        eval.getConfusionMatrix().toCSV();
        eval.getConfusionMatrix().toHTML();
        eval.confusionToString();
    }
 
Example #12
Source File: TestSparkComputationGraph.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test(timeout = 60000L)
public void testEvaluationAndRoc() {
    for( int evalWorkers : new int[]{1, 4, 8}) {
        DataSetIterator iter = new IrisDataSetIterator(5, 150);

        //Make a 2-class version of iris:
        List<DataSet> l = new ArrayList<>();
        iter.reset();
        while (iter.hasNext()) {
            DataSet ds = iter.next();
            INDArray newL = Nd4j.create(ds.getLabels().size(0), 2);
            newL.putColumn(0, ds.getLabels().getColumn(0));
            newL.putColumn(1, ds.getLabels().getColumn(1));
            newL.getColumn(1).addi(ds.getLabels().getColumn(2));
            ds.setLabels(newL);
            l.add(ds);
        }

        iter = new ListDataSetIterator<>(l);

        ComputationGraph cg = getBasicNetIris2Class();

        Evaluation e = cg.evaluate(iter);
        ROC roc = cg.evaluateROC(iter, 32);


        SparkComputationGraph scg = new SparkComputationGraph(sc, cg, null);
        scg.setDefaultEvaluationWorkers(evalWorkers);


        JavaRDD<DataSet> rdd = sc.parallelize(l);
        rdd = rdd.repartition(20);

        Evaluation e2 = scg.evaluate(rdd);
        ROC roc2 = scg.evaluateROC(rdd);


        assertEquals(e2.accuracy(), e.accuracy(), 1e-3);
        assertEquals(e2.f1(), e.f1(), 1e-3);
        assertEquals(e2.getNumRowCounter(), e.getNumRowCounter(), 1e-3);
        assertEquals(e2.falseNegatives(), e.falseNegatives());
        assertEquals(e2.falsePositives(), e.falsePositives());
        assertEquals(e2.trueNegatives(), e.trueNegatives());
        assertEquals(e2.truePositives(), e.truePositives());
        assertEquals(e2.precision(), e.precision(), 1e-3);
        assertEquals(e2.recall(), e.recall(), 1e-3);
        assertEquals(e2.getConfusionMatrix(), e.getConfusionMatrix());

        assertEquals(roc.calculateAUC(), roc2.calculateAUC(), 1e-5);
        assertEquals(roc.calculateAUCPR(), roc2.calculateAUCPR(), 1e-5);
    }
}
 
Example #13
Source File: LearnDigitsBackprop.java    From aifh with Apache License 2.0 4 votes vote down vote up
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 200;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .regularization(true).dropOut(0.50)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }

}
 
Example #14
Source File: LearnDigitsDropout.java    From aifh with Apache License 2.0 4 votes vote down vote up
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 100;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }

}
 
Example #15
Source File: LearnIrisBackprop.java    From aifh with Apache License 2.0 4 votes vote down vote up
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 0.1;
        int splitTrainNum = (int) (150 * .75);

        int numInputs = 4;
        int numOutputs = 3;
        int numHiddenNodes = 50;

        // Setup training data.
        final InputStream istream = LearnIrisBackprop.class.getResourceAsStream("/iris.csv");
        if( istream==null ) {
            System.out.println("Cannot access data set, make sure the resources are available.");
            System.exit(1);
        }
        final NormalizeDataSet ds = NormalizeDataSet.load(istream);
        final CategoryMap species = ds.encodeOneOfN(4); // species is column 4
        istream.close();

        DataSet next = ds.extractSupervised(0, 4, 4, 3);
        next.shuffle();

        // Training and validation data split
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet trainSet = testAndTrain.getTrain();
        DataSet validationSet = testAndTrain.getTest();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());

        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationSet.numExamples());

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            System.out.println(features + ":Prediction("+findSpecies(labels,species)
                    +"):Actual("+findSpecies(predicted,species)+")" + predicted );
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }
}