org.nd4j.linalg.dataset.api.iterator.DataSetIterator Java Examples

The following examples show how to use org.nd4j.linalg.dataset.api.iterator.DataSetIterator. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: ScoreUtil.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Score the given test data
 * with the given multi layer network
 * @param model model to use
 * @param testData the test data to test with
 * @param average whether to average the score or not
 * @return the score for the given test data given the model
 */
public static double score(MultiLayerNetwork model, DataSetIterator testData, boolean average) {
    //TODO: do this properly taking into account division by N, L1/L2 etc
    double sumScore = 0.0;
    int totalExamples = 0;
    while (testData.hasNext()) {
        DataSet ds = testData.next();
        int numExamples = ds.numExamples();

        sumScore += numExamples * model.score(ds);
        totalExamples += numExamples;
    }

    if (!average)
        return sumScore;
    return sumScore / totalExamples;
}
 
Example #2
Source File: CnnTextFilesEmbeddingInstanceIteratorTest.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Test getDataSetIterator
 */
@Test
public void testGetIteratorNominalClass() throws Exception {
  final Instances data = DatasetLoader.loadAngerMetaClassification();
  final int batchSize = 1;
  final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);

  Set<Integer> labels = new HashSet<>();
  for (int i = 0; i < data.size(); i++) {
    Instance inst = data.get(i);

    int label = Integer.parseInt(inst.stringValue(data.classIndex()));
    final DataSet next = Utils.getNext(it);
    int itLabel = next.getLabels().argMax().getInt(0);
    Assert.assertEquals(label, itLabel);
    labels.add(label);
  }
  final Set<Integer> collect =
      it.getLabels().stream().map(s -> Double.valueOf(s).intValue()).collect(Collectors.toSet());
  Assert.assertEquals(2, labels.size());
  Assert.assertTrue(labels.containsAll(collect));
  Assert.assertTrue(collect.containsAll(labels));
}
 
Example #3
Source File: TestAsyncIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testInitializeNoNextIter() {

    DataSetIterator iter = new IrisDataSetIterator(10, 150);
    while (iter.hasNext())
        iter.next();

    DataSetIterator async = new AsyncDataSetIterator(iter, 2);

    assertFalse(iter.hasNext());
    assertFalse(async.hasNext());
    try {
        iter.next();
        fail("Should have thrown NoSuchElementException");
    } catch (Exception e) {
        //OK
    }

    async.reset();
    int count = 0;
    while (async.hasNext()) {
        async.next();
        count++;
    }
    assertEquals(150 / 10, count);
}
 
Example #4
Source File: RnnTextFilesEmbeddingInstanceIterator.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize)
    throws InvalidInputDataException, IOException {
  validate(data);
  initWordVectors();
  final LabeledSentenceProvider sentenceProvider = getSentenceProvider(data);
  return new RnnTextEmbeddingDataSetIterator(
      data,
      wordVectors,
      tokenizerFactory,
      tokenPreProcess,
      stopwords,
      sentenceProvider,
      batchSize,
      truncateLength);
}
 
Example #5
Source File: MultiRegression.java    From dl4j-tutorials with MIT License 6 votes vote down vote up
private static DataSetIterator getTrainingData(int batchSize, Random rand) {
    double [] sum = new double[nSamples];
    double [] input1 = new double[nSamples];
    double [] input2 = new double[nSamples];
    for (int i= 0; i< nSamples; i++) {
        input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        input2[i] =  MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        sum[i] = input1[i] + input2[i];
    }
    INDArray inputNDArray1 = Nd4j.create(input1, new int[]{nSamples,1});
    INDArray inputNDArray2 = Nd4j.create(input2, new int[]{nSamples,1});
    INDArray inputNDArray = Nd4j.hstack(inputNDArray1,inputNDArray2);
    INDArray outPut = Nd4j.create(sum, new int[]{nSamples, 1});
    DataSet dataSet = new DataSet(inputNDArray, outPut);
    List<DataSet> listDs = dataSet.asList();

    return new ListDataSetIterator(listDs,batchSize);
}
 
Example #6
Source File: Dl4jMlpClassifier.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Initialize early stopping with the given data
 *
 * @param data Data
 * @return Augmented data - if early stopping applies, return train set without validation set
 */
protected Instances initEarlyStopping(Instances data) throws Exception {
  // Split train/validation
  double valSplit = earlyStopping.getValidationSetPercentage();
  Instances trainData;
  Instances valData;
  if (useEarlyStopping()) {
    // Split in train and validation
    Instances[] insts = splitTrainVal(data, valSplit);
    trainData = insts[0];
    valData = insts[1];
    validateSplit(trainData, valData);
    DataSetIterator valIterator =
        getDataSetIterator(valData, cacheMode, "val");
    earlyStopping.init(valIterator);
  } else {
    // Keep the full data
    trainData = data;
  }

  return trainData;
}
 
Example #7
Source File: MultiLayerTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testPredict() throws Exception {

    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER).seed(12345L).list()
                    .layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build())
                    .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                    .activation(Activation.SOFTMAX).nIn(50).nOut(10).build())
                    .setInputType(InputType.convolutional(28, 28, 1)).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    DataSetIterator ds = new MnistDataSetIterator(10, 10);
    net.fit(ds);

    DataSetIterator testDs = new MnistDataSetIterator(1, 1);
    DataSet testData = testDs.next();
    testData.setLabelNames(Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"));
    String actualLables = testData.getLabelName(0);
    List<String> prediction = net.predict(testData);
    assertTrue(actualLables != null);
    assertTrue(prediction.get(0) != null);
}
 
Example #8
Source File: CnnWord2VecSentenceClassificationExample.java    From Java-Deep-Learning-Cookbook with MIT License 6 votes vote down vote up
private static DataSetIterator getDataSetIterator(boolean isTraining, WordVectors wordVectors, int minibatchSize,
                                                  int maxSentenceLength, Random rng ){
    String path = FilenameUtils.concat(DATA_PATH, (isTraining ? "aclImdb/train/" : "aclImdb/test/"));
    String positiveBaseDir = FilenameUtils.concat(path, "pos");
    String negativeBaseDir = FilenameUtils.concat(path, "neg");

    File filePositive = new File(positiveBaseDir);
    File fileNegative = new File(negativeBaseDir);

    Map<String,List<File>> reviewFilesMap = new HashMap<>();
    reviewFilesMap.put("Positive", Arrays.asList(filePositive.listFiles()));
    reviewFilesMap.put("Negative", Arrays.asList(fileNegative.listFiles()));

    LabeledSentenceProvider sentenceProvider = new FileLabeledSentenceProvider(reviewFilesMap, rng);

    return new CnnSentenceDataSetIterator.Builder(CnnSentenceDataSetIterator.Format.CNN2D)
            .sentenceProvider(sentenceProvider)
            .wordVectors(wordVectors)
            .minibatchSize(minibatchSize)
            .maxSentenceLength(maxSentenceLength)
            .useNormalizedWordVectors(false)
            .build();
}
 
Example #9
Source File: Dl4jMlpClassifier.java    From wekaDeeplearning4j with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Initialize early stopping with the given data
 *
 * @param data Data
 * @return Augmented data - if early stopping applies, return train set without validation set
 */
protected Instances initEarlyStopping(Instances data) throws Exception {
  // Split train/validation
  double valSplit = earlyStopping.getValidationSetPercentage();
  Instances trainData;
  Instances valData;
  if (useEarlyStopping()) {
    // Split in train and validation
    Instances[] insts = splitTrainVal(data, valSplit);
    trainData = insts[0];
    valData = insts[1];
    validateSplit(trainData, valData);
    DataSetIterator valIterator =
        getDataSetIterator(valData, cacheMode, "val");
    earlyStopping.init(valIterator);
  } else {
    // Keep the full data
    trainData = data;
  }

  return trainData;
}
 
Example #10
Source File: ImageInstanceIteratorTest.java    From wekaDeeplearning4j with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Counts the number of iterations an {@see ImageInstanceIterator}
 *
 * @param data Instances to iterate
 * @param imgIter ImageInstanceIterator to be tested
 * @param seed Seed
 * @param batchsize Size of the batch which is returned in {@see DataSetIterator#next}
 * @return Number of iterations
 */
private int countIterations(
    Instances data, ImageInstanceIterator imgIter, int seed, int batchsize) throws Exception {
  DataSetIterator it = imgIter.getDataSetIterator(data, seed, batchsize);
  int count = 0;
  while (it.hasNext()) {
    count++;
    DataSet dataset = Utils.getNext(it);
  }
  return count;
}
 
Example #11
Source File: HyperParameterTuningArbiterUiExample.java    From Java-Deep-Learning-Cookbook with MIT License 5 votes vote down vote up
@Override
public Object testData() {
    try{
        DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses);
        return dataSplit(iterator).getTestIterator();
    }
    catch(Exception e){
        throw new RuntimeException();
    }
}
 
Example #12
Source File: TestBasic.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public DataSetIterator trainData(Map<String, Object> dataParameters) {
    try {
        if (dataParameters == null || dataParameters.isEmpty()) {
            return new MnistDataSetIterator(64, 10000, false, true, true, 123);
        }
        if (dataParameters.containsKey("batchsize")) {
            int b = (Integer) dataParameters.get("batchsize");
            return new MnistDataSetIterator(b, 10000, false, true, true, 123);
        }
        return new MnistDataSetIterator(64, 10000, false, true, true, 123);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
 
Example #13
Source File: BatchNormalizationTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testBatchNorm() throws Exception {

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(12345)
            .updater(new Adam(1e-3))
            .activation(Activation.TANH)
            .list()
            .layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(2, 2).build())
            .layer(new BatchNormalization())
            .layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(2, 2).build())
            .layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nOut(10).build())
            .setInputType(InputType.convolutionalFlat(28, 28, 1))
            .build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, true, 12345), 10);

    net.fit(iter);

    MultiLayerNetwork net2 = new TransferLearning.Builder(net)
            .fineTuneConfiguration(FineTuneConfiguration.builder()
                    .updater(new AdaDelta())
                    .build())
            .removeOutputLayer()
            .addLayer(new BatchNormalization.Builder().nOut(3380).build())
            .addLayer(new OutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(3380).nOut(10).build())
            .build();

    net2.fit(iter);
}
 
Example #14
Source File: TestParallelEarlyStopping.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testBadTuning() {
    //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition

    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(1.0)) //Intentionally huge LR
                    .weightInit(WeightInit.XAVIER).list()
                    .layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(10, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(5000))
                                    .iterationTerminationConditions(
                                                    new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES),
                                                    new MaxScoreIterationTerminationCondition(10)) //Initial score is ~2.5
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
                                    .build();

    IEarlyStoppingTrainer<MultiLayerNetwork> trainer =
                    new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 2, 2, 1);
    EarlyStoppingResult result = trainer.fit();

    assertTrue(result.getTotalEpochs() < 5);
    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition,
                    result.getTerminationReason());
    String expDetails = new MaxScoreIterationTerminationCondition(10).toString();
    assertEquals(expDetails, result.getTerminationDetails());

    assertTrue(result.getBestModelEpoch() <= 0);
    assertNotNull(result.getBestModel());
}
 
Example #15
Source File: HyperParameterTuning.java    From Java-Deep-Learning-Cookbook with MIT License 5 votes vote down vote up
@Override
public Object trainData() {
    try{
        DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses);
        return dataSplit(iterator).getTestIterator();
    }
    catch(Exception e){
        throw new RuntimeException();
    }
}
 
Example #16
Source File: HyperParameterTuningArbiterUiExample.java    From Java-Deep-Learning-Cookbook with MIT License 5 votes vote down vote up
public DataSetIteratorSplitter dataSplit(DataSetIterator iterator) throws IOException, InterruptedException {
    DataNormalization dataNormalization = new NormalizerStandardize();
    dataNormalization.fit(iterator);
    iterator.setPreProcessor(dataNormalization);
    DataSetIteratorSplitter splitter = new DataSetIteratorSplitter(iterator,1000,0.8);
    return splitter;
}
 
Example #17
Source File: CnnTextEmbeddingInstanceIterator.java    From wekaDeeplearning4j with GNU General Public License v3.0 5 votes vote down vote up
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize) {
  initialize();
  LabeledSentenceProvider clsp = getSentenceProvider(data);
  return new CnnSentenceDataSetIterator.Builder()
      .stopwords(stopwords)
      .wordVectors(wordVectors)
      .tokenizerFactory(tokenizerFactory.getBackend())
      .sentenceProvider(clsp)
      .minibatchSize(batchSize)
      .maxSentenceLength(truncateLength)
      .useNormalizedWordVectors(false)
      .sentencesAlongHeight(true)
      .build();
}
 
Example #18
Source File: HyperParameterTuningArbiterUiExample.java    From Java-Deep-Learning-Cookbook with MIT License 5 votes vote down vote up
@Override
public Object testData() {
    try{
        DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses);
        return dataSplit(iterator).getTestIterator();
    }
    catch(Exception e){
        throw new RuntimeException();
    }
}
 
Example #19
Source File: SparkADSI.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace,
                DataSetCallback callback, Integer deviceId) {
    this();

    if (queueSize < 2)
        queueSize = 2;

    this.deviceId = deviceId;
    this.callback = callback;
    this.useWorkspace = useWorkspace;
    this.buffer = queue;
    this.prefetchSize = queueSize;
    this.backedIterator = iterator;
    this.workspaceId = "SADSI_ITER-" + java.util.UUID.randomUUID().toString();

    if (iterator.resetSupported())
        this.backedIterator.reset();

    context = TaskContext.get();

    this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null, Nd4j.getAffinityManager().getDeviceForCurrentThread());

    /**
     * We want to ensure, that background thread will have the same thread->device affinity, as master thread
     */

    thread.setDaemon(true);
    thread.start();
}
 
Example #20
Source File: TestRecordReaders.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testClassIndexOutsideOfRangeRRMDSI_MultipleReaders() {

    Collection<Collection<Collection<Writable>>> c1 = new ArrayList<>();
    Collection<Collection<Writable>> seq1 = new ArrayList<>();
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    c1.add(seq1);

    Collection<Collection<Writable>> seq2 = new ArrayList<>();
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    c1.add(seq2);

    Collection<Collection<Collection<Writable>>> c2 = new ArrayList<>();
    Collection<Collection<Writable>> seq1a = new ArrayList<>();
    seq1a.add(Arrays.<Writable>asList(new IntWritable(0)));
    seq1a.add(Arrays.<Writable>asList(new IntWritable(1)));
    c2.add(seq1a);

    Collection<Collection<Writable>> seq2a = new ArrayList<>();
    seq2a.add(Arrays.<Writable>asList(new IntWritable(0)));
    seq2a.add(Arrays.<Writable>asList(new IntWritable(2)));
    c2.add(seq2a);

    CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c1);
    CollectionSequenceRecordReader csrrLabels = new CollectionSequenceRecordReader(c2);
    DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, csrrLabels, 2, 2);

    try {
        DataSet ds = dsi.next();
        fail("Expected exception");
    } catch (Exception e) {
        assertTrue(e.getMessage(), e.getMessage().contains("to one-hot"));
    }
}
 
Example #21
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testClassificationScoreFunctionSimple() throws Exception {

    for(Evaluation.Metric metric : Evaluation.Metric.values()) {
        log.info("Metric: " + metric);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .list()
                .layer(new DenseLayer.Builder().nIn(784).nOut(32).build())
                .layer(new OutputLayer.Builder().nIn(32).nOut(10).activation(Activation.SOFTMAX).build())
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

        List<DataSet> l = new ArrayList<>();
        for( int i=0; i<10; i++ ){
            DataSet ds = iter.next();
            l.add(ds);
        }

        iter = new ExistingDataSetIterator(l);

        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
        EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                        .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                        .iterationTerminationConditions(
                                new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                        .scoreCalculator(new ClassificationScoreCalculator(metric, iter)).modelSaver(saver)
                        .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
        EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

        assertNotNull(result.getBestModel());
    }
}
 
Example #22
Source File: TestVertxUI.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testUI_VAE() throws Exception {
    //Variational autoencoder - for unsupervised layerwise pretraining

    StatsStorage ss = new InMemoryStatsStorage();

    UIServer uiServer = UIServer.getInstance();
    uiServer.attach(ss);

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(1e-5))
                    .list().layer(0,
                                    new VariationalAutoencoder.Builder().nIn(4).nOut(3).encoderLayerSizes(10, 11)
                                                    .decoderLayerSizes(12, 13).weightInit(WeightInit.XAVIER)
                                                    .pzxActivationFunction(Activation.IDENTITY)
                                                    .reconstructionDistribution(
                                                                    new GaussianReconstructionDistribution())
                                                    .activation(Activation.LEAKYRELU).build())
                    .layer(1, new VariationalAutoencoder.Builder().nIn(3).nOut(3).encoderLayerSizes(7)
                                    .decoderLayerSizes(8).weightInit(WeightInit.XAVIER)
                                    .pzxActivationFunction(Activation.IDENTITY)
                                    .reconstructionDistribution(new GaussianReconstructionDistribution())
                                    .activation(Activation.LEAKYRELU).build())
                    .layer(2, new OutputLayer.Builder().nIn(3).nOut(3).build())
                    .build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));

    DataSetIterator iter = new IrisDataSetIterator(150, 150);

    for (int i = 0; i < 50; i++) {
        net.fit(iter);
        Thread.sleep(100);
    }

}
 
Example #23
Source File: SingleRegression.java    From dl4j-tutorials with MIT License 5 votes vote down vote up
private static DataSetIterator getTrainingData(int batchSize, Random rand) {
    /**
     * 如何构造我们的训练数据
     * 现有的模型主要是有监督学习
     * 我们的训练集必须有  特征+标签
     * 特征-> x
     * 标签->y
     */
    double [] output = new double[nSamples];
    double [] input = new double[nSamples];
    //随机生成0到3之间的x
    //并且构造 y = 0.5x + 0.1
    //a -> 0.5  b ->0.1
    for (int i= 0; i< nSamples; i++) {
        input[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();

        output[i] = 0.5 * input[i] + 0.1;
    }

    /**
     * 我们nSamples条数据
     * 每条数据只有1个x
     */
    INDArray inputNDArray = Nd4j.create(input, new int[]{nSamples,1});

    INDArray outPut = Nd4j.create(output, new int[]{nSamples, 1});

    /**
     * 构造喂给神经网络的数据集
     * DataSet是将  特征+标签  包装成为一个类
     *
     */
    DataSet dataSet = new DataSet(inputNDArray, outPut);
    List<DataSet> listDs = dataSet.asList();

    return new ListDataSetIterator(listDs,batchSize);
}
 
Example #24
Source File: TestEarlyStoppingCompGraph.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testListeners() {
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in")
                    .addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3)
                            .activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
                    .setOutputs("0").build();
    ComputationGraph net = new ComputationGraph(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>()
                    .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                    .iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                    .scoreCalculator(new DataSetLossCalculatorCG(irisIter, true)).modelSaver(saver).build();

    LoggingEarlyStoppingListener listener = new LoggingEarlyStoppingListener();

    IEarlyStoppingTrainer trainer = new EarlyStoppingGraphTrainer(esConf, net, irisIter, listener);

    trainer.fit();

    assertEquals(1, listener.onStartCallCount);
    assertEquals(5, listener.onEpochCallCount);
    assertEquals(1, listener.onCompletionCallCount);
}
 
Example #25
Source File: TestModels.java    From Java-Machine-Learning-for-Computer-Vision with MIT License 5 votes vote down vote up
private static void showTrainingPrecision(ComputationGraph vgg16, String classesNumber) throws IOException {
    File[] carTrackings = new File("CarTracking").listFiles();
    for (File carTracking : carTrackings) {
        if (carTracking.getName().contains(classesNumber)) {
            DataSetIterator dataSetIterator = ImageUtils.createDataSetIterator(carTracking,
                    Integer.parseInt(classesNumber), 64);
            Evaluation eval = vgg16.evaluate(dataSetIterator);
            log.info(eval.stats());
        }
    }
}
 
Example #26
Source File: TestScoreFunctions.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public Object testData(Map<String, Object> dataParameters) {
    try {
        DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
        iter.setPreProcessor(new PreProc(rocType));
        return iter;
    } catch (IOException e){
        throw new RuntimeException(e);
    }
}
 
Example #27
Source File: DataSetIteratorSplitter.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public List<DataSetIterator> getIterators() {
    List<DataSetIterator> retVal = new ArrayList<>();
    int partN = 0;
    int bottom = 0;
    for (final int split : splits) {
            ScrollableDataSetIterator partIterator =
                    new ScrollableDataSetIterator(partN++, backedIterator, counter, resetPending, firstTrain,
                            new int[]{bottom,split});
            bottom += split;
            retVal.add(partIterator);
    }
    return retVal;
}
 
Example #28
Source File: TestEarlyStopping.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testEarlyStoppingListeners() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list()
            .layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
                    .activation(Activation.SOFTMAX)
                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
            .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);

    TestListener tl = new TestListener();
    net.setListeners(tl);

    DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
            new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                    .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                    .iterationTerminationConditions(
                            new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                    .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
                    .build();

    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter);

    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

    assertEquals(5, tl.countEpochStart);
    assertEquals(5, tl.countEpochEnd);
    assertEquals(5 * 150/50, tl.iterCount);

    assertEquals(4, tl.maxEpochStart);
    assertEquals(4, tl.maxEpochEnd);
}
 
Example #29
Source File: TestFailureListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Ignore
@Test
public void testFailureRandom_OR(){

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .updater(new Adam(1e-4))
            .list()
            .layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
            .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    String username = System.getProperty("user.name");
    assertNotNull(username);
    assertFalse(username.isEmpty());

    net.setListeners(new FailureTestingListener(
            FailureTestingListener.FailureMode.SYSTEM_EXIT_1,
            new FailureTestingListener.Or(
                    new FailureTestingListener.IterationEpochTrigger(false, 10000),
                    new FailureTestingListener.RandomProb(FailureTestingListener.CallType.ANY, 0.02))
            ));

    DataSetIterator iter = new IrisDataSetIterator(5,150);

    net.fit(iter);
}
 
Example #30
Source File: LinearModel.java    From FederatedAndroidTrainer with MIT License 5 votes vote down vote up
@Override
public void train(FederatedDataSet dataSource) {
    DataSet trainingData = (DataSet) dataSource.getNativeDataSet();
    List<DataSet> listDs = trainingData.asList();
    DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);

    //Train the network on the full data set, and evaluate in periodically
    for (int i = 0; i < N_EPOCHS; i++) {
        iterator.reset();
        mNetwork.fit(iterator);
    }
}