Java Code Examples for org.nd4j.linalg.dataset.SplitTestAndTrain#getTrain()
The following examples show how to use
org.nd4j.linalg.dataset.SplitTestAndTrain#getTrain() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: IrisFileDataSource.java From FederatedAndroidTrainer with MIT License | 7 votes |
private void createDataSource() throws IOException, InterruptedException { //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing int numLinesToSkip = 0; String delimiter = ","; RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter); recordReader.initialize(new InputStreamInputSplit(dataFile)); //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2 DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses); DataSet allData = iterator.next(); allData.shuffle(); SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training trainingData = testAndTrain.getTrain(); testData = testAndTrain.getTest(); //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance): DataNormalization normalizer = new NormalizerStandardize(); normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data normalizer.transform(trainingData); //Apply normalization to the training data normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set }
Example 2
Source File: DiabetesFileDataSource.java From FederatedAndroidTrainer with MIT License | 6 votes |
private void createDataSource() throws IOException, InterruptedException { //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing int numLinesToSkip = 0; String delimiter = ","; RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter); recordReader.initialize(new InputStreamInputSplit(dataFile)); //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network int labelIndex = 11; DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true); DataSet allData = iterator.next(); SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training trainingData = testAndTrain.getTrain(); testData = testAndTrain.getTest(); //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance): DataNormalization normalizer = new NormalizerStandardize(); normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data normalizer.transform(trainingData); //Apply normalization to the training data normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set }
Example 3
Source File: DeepLearning4J_CSV_Iris_Model.java From kafka-streams-machine-learning-examples with Apache License 2.0 | 4 votes |
public static void main(String[] args) throws Exception { // First: get the dataset using the record reader. CSVRecordReader handles // loading/parsing int numLinesToSkip = 0; char delimiter = ','; RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter); recordReader.initialize(new FileSplit(new ClassPathResource("DL4J_Resources/iris.txt").getFile())); // Second: the RecordReaderDataSetIterator handles conversion to DataSet // objects, ready for use in neural network int labelIndex = 4; // 5 values in each row of the iris.txt CSV: 4 input features followed by an // integer label (class) index. Labels are the 5th value (index 4) in each row int numClasses = 3; // 3 classes (types of iris flowers) in the iris data set. Classes have integer // values 0, 1 or 2 int batchSize = 150; // Iris data set: 150 examples total. We are loading all of them into one // DataSet (not recommended for large data sets) DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses); DataSet allData = iterator.next(); allData.shuffle(); SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); // Use 65% of data for training DataSet trainingData = testAndTrain.getTrain(); DataSet testData = testAndTrain.getTest(); // We need to normalize our data. We'll use NormalizeStandardize (which gives us // mean 0, unit variance): DataNormalization normalizer = new NormalizerStandardize(); normalizer.fit(trainingData); // Collect the statistics (mean/stdev) from the training data. This does not // modify the input data normalizer.transform(trainingData); // Apply normalization to the training data normalizer.transform(testData); // Apply normalization to the test data. This is using statistics calculated // from the *training* set final int numInputs = 4; int outputNum = 3; long seed = 6; log.info("Build model...."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).activation(Activation.TANH) .weightInit(WeightInit.XAVIER).updater(new Sgd(0.1)).l2(1e-4).list() .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3).build()) .layer(1, new DenseLayer.Builder().nIn(3).nOut(3).build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation(Activation.SOFTMAX).nIn(3).nOut(outputNum).build()) .build(); // run the model MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(100)); for (int i = 0; i < 1000; i++) { model.fit(trainingData); } // evaluate the model on the test set Evaluation eval = new Evaluation(3); INDArray input = testData.getFeatures(); INDArray output = model.output(input); System.out.println("INPUT:" + input.toString()); eval.eval(testData.getLabels(), output); log.info(eval.stats()); // Save the model File locationToSave = new File("src/main/resources/generatedModels/DL4J/DL4J_Iris_Model.zip"); // Where to save // the network. // Note: the file // is in .zip // format - can // be opened // externally boolean saveUpdater = true; // Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you // want to train your network more in the future // ModelSerializer.writeModel(model, locationToSave, saveUpdater); // Load the model MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(locationToSave); System.out.println("Saved and loaded parameters are equal: " + model.params().equals(restored.params())); System.out.println("Saved and loaded configurations are equal: " + model.getLayerWiseConfigurations().equals(restored.getLayerWiseConfigurations())); }
Example 4
Source File: EvalTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testIris() { // Network config MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42) .updater(new Sgd(1e-6)).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH) .weightInit(WeightInit.XAVIER).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER) .activation(Activation.SOFTMAX).build()) .build(); // Instantiate model MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.addListeners(new ScoreIterationListener(1)); // Train-test split DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSet next = iter.next(); next.shuffle(); SplitTestAndTrain trainTest = next.splitTestAndTrain(5, new Random(42)); // Train DataSet train = trainTest.getTrain(); train.normalizeZeroMeanZeroUnitVariance(); // Test DataSet test = trainTest.getTest(); test.normalizeZeroMeanZeroUnitVariance(); INDArray testFeature = test.getFeatures(); INDArray testLabel = test.getLabels(); // Fitting model model.fit(train); // Get predictions from test feature INDArray testPredictedLabel = model.output(testFeature); // Eval with class number org.nd4j.evaluation.classification.Evaluation eval = new org.nd4j.evaluation.classification.Evaluation(3); //// Specify class num here eval.eval(testLabel, testPredictedLabel); double eval1F1 = eval.f1(); double eval1Acc = eval.accuracy(); // Eval without class number org.nd4j.evaluation.classification.Evaluation eval2 = new org.nd4j.evaluation.classification.Evaluation(); //// No class num eval2.eval(testLabel, testPredictedLabel); double eval2F1 = eval2.f1(); double eval2Acc = eval2.accuracy(); //Assert the two implementations give same f1 and accuracy (since one batch) assertTrue(eval1F1 == eval2F1 && eval1Acc == eval2Acc); org.nd4j.evaluation.classification.Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator<>(Collections.singletonList(test))); checkEvaluationEquality(eval, evalViaMethod); // System.out.println(eval.getConfusionMatrix().toString()); // System.out.println(eval.getConfusionMatrix().toCSV()); // System.out.println(eval.getConfusionMatrix().toHTML()); // System.out.println(eval.confusionToString()); eval.getConfusionMatrix().toString(); eval.getConfusionMatrix().toCSV(); eval.getConfusionMatrix().toHTML(); eval.confusionToString(); }
Example 5
Source File: IrisClassifier.java From tutorials with MIT License | 4 votes |
public static void main(String[] args) throws IOException, InterruptedException { DataSet allData; try (RecordReader recordReader = new CSVRecordReader(0, ',')) { recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile())); DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 150, FEATURES_COUNT, CLASSES_COUNT); allData = iterator.next(); } allData.shuffle(42); DataNormalization normalizer = new NormalizerStandardize(); normalizer.fit(allData); normalizer.transform(allData); SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); DataSet trainingData = testAndTrain.getTrain(); DataSet testData = testAndTrain.getTest(); MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() .iterations(1000) .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .regularization(true) .learningRate(0.1).l2(0.0001) .list() .layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3) .build()) .layer(1, new DenseLayer.Builder().nIn(3).nOut(3) .build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation(Activation.SOFTMAX) .nIn(3).nOut(CLASSES_COUNT).build()) .backpropType(BackpropType.Standard).pretrain(false) .build(); MultiLayerNetwork model = new MultiLayerNetwork(configuration); model.init(); model.fit(trainingData); INDArray output = model.output(testData.getFeatures()); Evaluation eval = new Evaluation(CLASSES_COUNT); eval.eval(testData.getLabels(), output); System.out.println(eval.stats()); }
Example 6
Source File: LearnIrisBackprop.java From aifh with Apache License 2.0 | 4 votes |
/** * The main method. * @param args Not used. */ public static void main(String[] args) { try { int seed = 43; double learningRate = 0.1; int splitTrainNum = (int) (150 * .75); int numInputs = 4; int numOutputs = 3; int numHiddenNodes = 50; // Setup training data. final InputStream istream = LearnIrisBackprop.class.getResourceAsStream("/iris.csv"); if( istream==null ) { System.out.println("Cannot access data set, make sure the resources are available."); System.exit(1); } final NormalizeDataSet ds = NormalizeDataSet.load(istream); final CategoryMap species = ds.encodeOneOfN(4); // species is column 4 istream.close(); DataSet next = ds.extractSupervised(0, 4, 4, 3); next.shuffle(); // Training and validation data split SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed)); DataSet trainSet = testAndTrain.getTrain(); DataSet validationSet = testAndTrain.getTest(); DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples()); DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationSet.numExamples()); // Create neural network. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(1) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(learningRate) .updater(Updater.NESTEROVS).momentum(0.9) .list(2) .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) .weightInit(WeightInit.XAVIER) .activation("relu") .build()) .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) .weightInit(WeightInit.XAVIER) .activation("softmax") .nIn(numHiddenNodes).nOut(numOutputs).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(1)); // Define when we want to stop training. EarlyStoppingModelSaver saver = new InMemoryModelSaver(); EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder() .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25)) .evaluateEveryNEpochs(1) .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score .modelSaver(saver) .build(); EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator); // Train and display result. EarlyStoppingResult result = trainer.fit(); System.out.println("Termination reason: " + result.getTerminationReason()); System.out.println("Termination details: " + result.getTerminationDetails()); System.out.println("Total epochs: " + result.getTotalEpochs()); System.out.println("Best epoch number: " + result.getBestModelEpoch()); System.out.println("Score at best epoch: " + result.getBestModelScore()); model = saver.getBestModel(); // Evaluate Evaluation eval = new Evaluation(numOutputs); validationSetIterator.reset(); for (int i = 0; i < validationSet.numExamples(); i++) { DataSet t = validationSet.get(i); INDArray features = t.getFeatureMatrix(); INDArray labels = t.getLabels(); INDArray predicted = model.output(features, false); System.out.println(features + ":Prediction("+findSpecies(labels,species) +"):Actual("+findSpecies(predicted,species)+")" + predicted ); eval.eval(labels, predicted); } //Print the evaluation statistics System.out.println(eval.stats()); } catch(Exception ex) { ex.printStackTrace(); } }