Java Code Examples for org.nd4j.linalg.dataset.DataSet#asList()

The following examples show how to use org.nd4j.linalg.dataset.DataSet#asList() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: MultiRegression.java    From dl4j-tutorials with MIT License 6 votes vote down vote up
private static DataSetIterator getTrainingData(int batchSize, Random rand) {
    double [] sum = new double[nSamples];
    double [] input1 = new double[nSamples];
    double [] input2 = new double[nSamples];
    for (int i= 0; i< nSamples; i++) {
        input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        input2[i] =  MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        sum[i] = input1[i] + input2[i];
    }
    INDArray inputNDArray1 = Nd4j.create(input1, new int[]{nSamples,1});
    INDArray inputNDArray2 = Nd4j.create(input2, new int[]{nSamples,1});
    INDArray inputNDArray = Nd4j.hstack(inputNDArray1,inputNDArray2);
    INDArray outPut = Nd4j.create(sum, new int[]{nSamples, 1});
    DataSet dataSet = new DataSet(inputNDArray, outPut);
    List<DataSet> listDs = dataSet.asList();

    return new ListDataSetIterator(listDs,batchSize);
}
 
Example 2
Source File: TestSparkMultiLayerParameterAveraging.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testRunIteration() {

    DataSet dataSet = new IrisDataSetIterator(5, 5).next();
    List<DataSet> list = dataSet.asList();
    JavaRDD<DataSet> data = sc.parallelize(list);

    SparkDl4jMultiLayer sparkNetCopy = new SparkDl4jMultiLayer(sc, getBasicConf(),
                    new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
    MultiLayerNetwork networkCopy = sparkNetCopy.fit(data);

    INDArray expectedParams = networkCopy.params();

    SparkDl4jMultiLayer sparkNet = getBasicNetwork();
    MultiLayerNetwork network = sparkNet.fit(data);
    INDArray actualParams = network.params();

    assertEquals(expectedParams.size(1), actualParams.size(1));
}
 
Example 3
Source File: LinearModel.java    From FederatedAndroidTrainer with MIT License 5 votes vote down vote up
@Override
public String evaluate(FederatedDataSet federatedDataSet) {
    DataSet testData = (DataSet) federatedDataSet.getNativeDataSet();
    List<DataSet> listDs = testData.asList();
    DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);

    return mNetwork.evaluate(iterator).stats();
}
 
Example 4
Source File: MNISTModel.java    From FederatedAndroidTrainer with MIT License 5 votes vote down vote up
@Override
public String evaluate(FederatedDataSet federatedDataSet) {
    DataSet testData = (DataSet) federatedDataSet.getNativeDataSet();
    List<DataSet> listDs = testData.asList();
    DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);

    Evaluation eval = new Evaluation(OUTPUT_NUM); //create an evaluation object with 10 possible classes
    while (iterator.hasNext()) {
        DataSet next = iterator.next();
        INDArray output = model.output(next.getFeatureMatrix()); //get the networks prediction
        eval.eval(next.getLabels(), output); //check the prediction against the true class
    }

    return eval.stats();
}
 
Example 5
Source File: MNISTModel.java    From FederatedAndroidTrainer with MIT License 5 votes vote down vote up
@Override
public void train(FederatedDataSet federatedDataSet) {
    DataSet trainingData = (DataSet) federatedDataSet.getNativeDataSet();
    List<DataSet> listDs = trainingData.asList();
    DataSetIterator mnistTrain = new ListDataSetIterator(listDs, BATCH_SIZE);
    for (int i = 0; i < N_EPOCHS; i++) {
        model.fit(mnistTrain);
    }
}
 
Example 6
Source File: SingleRegression.java    From dl4j-tutorials with MIT License 5 votes vote down vote up
private static DataSetIterator getTrainingData(int batchSize, Random rand) {
    /**
     * 如何构造我们的训练数据
     * 现有的模型主要是有监督学习
     * 我们的训练集必须有  特征+标签
     * 特征-> x
     * 标签->y
     */
    double [] output = new double[nSamples];
    double [] input = new double[nSamples];
    //随机生成0到3之间的x
    //并且构造 y = 0.5x + 0.1
    //a -> 0.5  b ->0.1
    for (int i= 0; i< nSamples; i++) {
        input[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();

        output[i] = 0.5 * input[i] + 0.1;
    }

    /**
     * 我们nSamples条数据
     * 每条数据只有1个x
     */
    INDArray inputNDArray = Nd4j.create(input, new int[]{nSamples,1});

    INDArray outPut = Nd4j.create(output, new int[]{nSamples, 1});

    /**
     * 构造喂给神经网络的数据集
     * DataSet是将  特征+标签  包装成为一个类
     *
     */
    DataSet dataSet = new DataSet(inputNDArray, outPut);
    List<DataSet> listDs = dataSet.asList();

    return new ListDataSetIterator(listDs,batchSize);
}
 
Example 7
Source File: RegressionMathFunctions.java    From dl4j-tutorials with MIT License 5 votes vote down vote up
/** Create a DataSetIterator for training
 * @param x X values
 * @param function Function to evaluate
 * @param batchSize Batch size (number of examples for every call of DataSetIterator.next())
 * @param rng Random number generator (for repeatability)
 */
private static DataSetIterator getTrainingData(final INDArray x, final MathFunction function, final int batchSize, final Random rng) {
    final INDArray y = function.getFunctionValues(x);
    final DataSet allData = new DataSet(x,y);

    final List<DataSet> list = allData.asList();
    Collections.shuffle(list,rng);
    return new ListDataSetIterator(list,batchSize);
}
 
Example 8
Source File: LinearModel.java    From FederatedAndroidTrainer with MIT License 5 votes vote down vote up
@Override
public void train(FederatedDataSet dataSource) {
    DataSet trainingData = (DataSet) dataSource.getNativeDataSet();
    List<DataSet> listDs = trainingData.asList();
    DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);

    //Train the network on the full data set, and evaluate in periodically
    for (int i = 0; i < N_EPOCHS; i++) {
        iterator.reset();
        mNetwork.fit(iterator);
    }
}
 
Example 9
Source File: MLLIbUtilTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMlLibTest() {
    DataSet dataSet = new IrisDataSetIterator(150, 150).next();
    List<DataSet> list = dataSet.asList();
    JavaRDD<DataSet> data = sc.parallelize(list);
    JavaRDD<LabeledPoint> mllLibData = MLLibUtil.fromDataSet(sc, data);
}
 
Example 10
Source File: LearnIrisBackprop.java    From aifh with Apache License 2.0 4 votes vote down vote up
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 0.1;
        int splitTrainNum = (int) (150 * .75);

        int numInputs = 4;
        int numOutputs = 3;
        int numHiddenNodes = 50;

        // Setup training data.
        final InputStream istream = LearnIrisBackprop.class.getResourceAsStream("/iris.csv");
        if( istream==null ) {
            System.out.println("Cannot access data set, make sure the resources are available.");
            System.exit(1);
        }
        final NormalizeDataSet ds = NormalizeDataSet.load(istream);
        final CategoryMap species = ds.encodeOneOfN(4); // species is column 4
        istream.close();

        DataSet next = ds.extractSupervised(0, 4, 4, 3);
        next.shuffle();

        // Training and validation data split
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet trainSet = testAndTrain.getTrain();
        DataSet validationSet = testAndTrain.getTest();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());

        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationSet.numExamples());

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            System.out.println(features + ":Prediction("+findSpecies(labels,species)
                    +"):Actual("+findSpecies(predicted,species)+")" + predicted );
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }
}
 
Example 11
Source File: LearnDigitsDropout.java    From aifh with Apache License 2.0 4 votes vote down vote up
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 100;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }

}
 
Example 12
Source File: LearnDigitsBackprop.java    From aifh with Apache License 2.0 4 votes vote down vote up
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 200;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .regularization(true).dropOut(0.50)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }

}
 
Example 13
Source File: GradientSharingTrainingTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test @Ignore //AB https://github.com/eclipse/deeplearning4j/issues/8985
public void differentNetsTrainingTest() throws Exception {
    int batch = 3;

    File temp = testDir.newFolder();
    DataSet ds = new IrisDataSetIterator(150, 150).next();
    List<DataSet> list = ds.asList();
    Collections.shuffle(list, new Random(12345));
    int pos = 0;
    int dsCount = 0;
    while (pos < list.size()) {
        List<DataSet> l2 = new ArrayList<>();
        for (int i = 0; i < 3 && pos < list.size(); i++) {
            l2.add(list.get(pos++));
        }
        DataSet d = DataSet.merge(l2);
        File f = new File(temp, dsCount++ + ".bin");
        d.save(f);
    }

    INDArray last = null;
    INDArray lastDup = null;
    for (int i = 0; i < 2; i++) {
        System.out.println("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||");
        log.info("Starting: {}", i);

        MultiLayerConfiguration conf;
        if (i == 0) {
            conf = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER)
                    .seed(12345)
                    .list()
                    .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
        } else {
            conf = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER)
                    .seed(12345)
                    .list()
                    .layer(new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build())
                    .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
        }
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();


        //TODO this probably won't work everywhere...
        String controller = Inet4Address.getLocalHost().getHostAddress();
        String networkMask = controller.substring(0, controller.lastIndexOf('.')) + ".0" + "/16";

        VoidConfiguration voidConfiguration = VoidConfiguration.builder()
                .unicastPort(40123) // Should be open for IN/OUT communications on all Spark nodes
                .networkMask(networkMask) // Local network mask
                .controllerAddress(controller)
                .build();
        TrainingMaster tm = new SharedTrainingMaster.Builder(voidConfiguration, 2, new FixedThresholdAlgorithm(1e-4), batch)
                .rngSeed(12345)
                .collectTrainingStats(false)
                .batchSizePerWorker(batch) // Minibatch size for each worker
                .workersPerNode(2) // Workers per node
                .build();


        SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, net, tm);

        //System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));

        String fitPath = "file:///" + temp.getAbsolutePath().replaceAll("\\\\", "/");
        INDArray paramsBefore = net.params().dup();
        for( int j=0; j<3; j++ ) {
            sparkNet.fit(fitPath);
        }

        INDArray paramsAfter = net.params();
        assertNotEquals(paramsBefore, paramsAfter);

        //Also check we don't have any issues
        if(i == 0) {
            last = sparkNet.getNetwork().params();
            lastDup = last.dup();
        } else {
            assertEquals(lastDup, last);
        }
    }
}
 
Example 14
Source File: FileDataSetIterator.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
protected List<DataSet> split(DataSet toSplit) {
    return toSplit.asList();
}
 
Example 15
Source File: TestDataSetIterator.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public TestDataSetIterator(DataSet dataset, int batch) {
    this(dataset.asList(), batch);
}
 
Example 16
Source File: TestDataSetIterator.java    From nd4j with Apache License 2.0 4 votes vote down vote up
public TestDataSetIterator(DataSet dataset, int batch) {
    this(dataset.asList(), batch);
}
 
Example 17
Source File: NormalizeUciData.java    From SKIL_Examples with Apache License 2.0 4 votes vote down vote up
public void run() throws Exception {
    File trainingOutputFile = new File(trainOutputPath);
    File testOutputFile = new File(testOutputPath);

    if (trainingOutputFile.exists() || testOutputFile.exists()) {
        System.out.println(String.format("Warning: overwriting output files (%s, %s)", trainOutputPath, testOutputPath));

        trainingOutputFile.delete();
        testOutputFile.delete();
    }

    System.out.format("downloading from %s\n", downloadUrl);
    System.out.format("writing training output to %s\n", trainOutputPath);
    System.out.format("writing testing output to %s\n", testOutputPath);

    URL url = new URL(downloadUrl);
    String data = IOUtils.toString(url);
    String[] lines = data.split("\n");
    List<INDArray> arrays = new LinkedList<INDArray>();
    List<Integer> labels = new LinkedList<Integer>();

    for (int i=0; i<lines.length; i++) {
        String line = lines[i];
        String[] cols = line.split("\\s+");

        int label = i / 100;
        INDArray array = Nd4j.zeros(1, 60);

        for (int j=0; j<cols.length; j++) {
            Double d = Double.parseDouble(cols[j]);
            array.putScalar(0, j, d);
        }

        arrays.add(array);
        labels.add(label);
    }

    // Shuffle with **known** seed
    Collections.shuffle(arrays, new Random(12345));
    Collections.shuffle(labels, new Random(12345));

    INDArray trainData = Nd4j.zeros(450, 60);
    INDArray testData = Nd4j.zeros(150, 60);

    for (int i=0; i<arrays.size(); i++) {
        INDArray arr = arrays.get(i);

        if (i < 450) { // Training
            trainData.putRow(i, arr);
        } else { // Test
            testData.putRow(i-450, arr);
        }
    }

    DataSet trainDs = new DataSet(trainData, trainData);
    DataSetIterator trainIt = new ListDataSetIterator(trainDs.asList());

    DataSet testDs = new DataSet(testData, testData);
    DataSetIterator testIt = new ListDataSetIterator(testDs.asList());

    // Fit normalizer on training data only!
    DataNormalization normalizer = dataNormalizer.getNormalizer();
    normalizer.fit(trainIt);

    // Print out basic summary stats
    switch (normalizer.getType()) {
        case STANDARDIZE:
            System.out.format("Normalizer - Standardize:\n  mean=%s\n  std= %s\n",
                    ((NormalizerStandardize)normalizer).getMean(),
                    ((NormalizerStandardize)normalizer).getStd());
    }

    // Use same normalizer for both
    trainIt.setPreProcessor(normalizer);
    testIt.setPreProcessor(normalizer);

    String trainOutput = toCsv(trainIt, labels.subList(0, 450), new int[]{1, 60});
    String testOutput = toCsv(testIt, labels.subList(450, 600), new int[]{1, 60});

    FileUtils.write(trainingOutputFile, trainOutput);
    System.out.format("wrote normalized training file to %s\n", trainingOutputFile);

    FileUtils.write(testOutputFile, testOutput);
    System.out.format("wrote normalized test file to %s\n", testOutputFile);

}