Java Code Examples for org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator#hasNext()
The following examples show how to use
org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator#hasNext() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TrainUtil.java From FancyBing with GNU General Public License v3.0 | 6 votes |
public static double evaluate(Model model, int outputNum, MultiDataSetIterator testData, int topN, int batchSize) { log.info("Evaluate model...."); Evaluation clsEval = new Evaluation(createLabels(outputNum), topN); RegressionEvaluation valueRegEval1 = new RegressionEvaluation(1); int count = 0; long begin = 0; long consume = 0; while(testData.hasNext()){ MultiDataSet ds = testData.next(); begin = System.nanoTime(); INDArray[] output = ((ComputationGraph) model).output(false, ds.getFeatures()); consume += System.nanoTime() - begin; clsEval.eval(ds.getLabels(0), output[0]); valueRegEval1.eval(ds.getLabels(1), output[1]); count++; } String stats = clsEval.stats(); int pos = stats.indexOf("==="); stats = "\n" + stats.substring(pos); log.info(stats); log.info(valueRegEval1.stats()); testData.reset(); log.info("Evaluate time: " + consume + " count: " + (count * batchSize) + " average: " + ((float) consume/(count*batchSize)/1000)); return clsEval.accuracy(); }
Example 2
Source File: AbstractMultiDataSetNormalizer.java From nd4j with Apache License 2.0 | 6 votes |
/** * Fit an iterator * * @param iterator for the data to iterate over */ public void fit(@NonNull MultiDataSetIterator iterator) { List<S.Builder> featureNormBuilders = new ArrayList<>(); List<S.Builder> labelNormBuilders = new ArrayList<>(); iterator.reset(); while (iterator.hasNext()) { MultiDataSet next = iterator.next(); fitPartial(next, featureNormBuilders, labelNormBuilders); } featureStats = buildList(featureNormBuilders); if (isFitLabel()) { labelStats = buildList(labelNormBuilders); } }
Example 3
Source File: AsyncMultiDataSetIterator.java From deeplearning4j with Apache License 2.0 | 6 votes |
public AsyncMultiDataSetIterator(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue, boolean useWorkspace, DataSetCallback callback, Integer deviceId) { if (queueSize < 2) queueSize = 2; this.callback = callback; this.buffer = queue; this.backedIterator = iterator; this.useWorkspaces = useWorkspace; this.prefetchSize = queueSize; this.workspaceId = "AMDSI_ITER-" + java.util.UUID.randomUUID().toString(); this.deviceId = deviceId; if (iterator.resetSupported() && !iterator.hasNext()) this.backedIterator.reset(); this.thread = new AsyncPrefetchThread(buffer, iterator, terminator, deviceId); thread.setDaemon(true); thread.start(); }
Example 4
Source File: AbstractMultiDataSetNormalizer.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * Fit an iterator * * @param iterator for the data to iterate over */ public void fit(@NonNull MultiDataSetIterator iterator) { List<S.Builder> featureNormBuilders = new ArrayList<>(); List<S.Builder> labelNormBuilders = new ArrayList<>(); iterator.reset(); while (iterator.hasNext()) { MultiDataSet next = iterator.next(); fitPartial(next, featureNormBuilders, labelNormBuilders); } featureStats = buildList(featureNormBuilders); if (isFitLabel()) { labelStats = buildList(labelNormBuilders); } }
Example 5
Source File: ScoreUtil.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * Score based on the loss function * @param model the model to score with * @param testData the test data to score * @param average whether to average the score * for the whole batch or not * @return the score for the given test set */ public static double score(ComputationGraph model, MultiDataSetIterator testData, boolean average) { //TODO: do this properly taking into account division by N, L1/L2 etc double sumScore = 0.0; int totalExamples = 0; while (testData.hasNext()) { MultiDataSet ds = testData.next(); long numExamples = ds.getFeatures(0).size(0); sumScore += numExamples * model.score(ds); totalExamples += numExamples; } if (!average) return sumScore; return sumScore / totalExamples; }
Example 6
Source File: Main.java From twse-captcha-solver-dl4j with MIT License | 5 votes |
public static void modelPredict(ComputationGraph model, MultiDataSetIterator iterator) { int sumCount = 0; int correctCount = 0; List<String> labelList = Arrays.asList( "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"); while (iterator.hasNext()) { MultiDataSet mds = iterator.next(); INDArray[] output = model.output(mds.getFeatures()); INDArray[] labels = mds.getLabels(); int dataNum = batchSize > output[0].rows() ? output[0].rows() : batchSize; for (int dataIndex = 0; dataIndex < dataNum; dataIndex++) { String reLabel = ""; String peLabel = ""; INDArray preOutput = null; INDArray realLabel = null; for (int digit = 0; digit < 5; digit++) { preOutput = output[digit].getRow(dataIndex); peLabel += labelList.get(Nd4j.argMax(preOutput, 1).getInt(0)); realLabel = labels[digit].getRow(dataIndex); reLabel += labelList.get(Nd4j.argMax(realLabel, 1).getInt(0)); } if (peLabel.equals(reLabel)) { correctCount++; } sumCount++; logger.info( "real image {} prediction {} status {}", reLabel, peLabel, peLabel.equals(reLabel)); } } iterator.reset(); System.out.println( "validate result : sum count =" + sumCount + " correct count=" + correctCount); }
Example 7
Source File: MultiNormalizerHybrid.java From nd4j with Apache License 2.0 | 5 votes |
/** * Iterates over a dataset * accumulating statistics for normalization * * @param iterator the iterator to use for collecting statistics */ @Override public void fit(@NonNull MultiDataSetIterator iterator) { Map<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<>(); Map<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<>(); iterator.reset(); while (iterator.hasNext()) { fitPartial(iterator.next(), inputStatsBuilders, outputStatsBuilders); } inputStats = buildAllStats(inputStatsBuilders); outputStats = buildAllStats(outputStatsBuilders); }
Example 8
Source File: MultiNormalizerHybrid.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Iterates over a dataset * accumulating statistics for normalization * * @param iterator the iterator to use for collecting statistics */ @Override public void fit(@NonNull MultiDataSetIterator iterator) { Map<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<>(); Map<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<>(); iterator.reset(); while (iterator.hasNext()) { fitPartial(iterator.next(), inputStatsBuilders, outputStatsBuilders); } inputStats = buildAllStats(inputStatsBuilders); outputStats = buildAllStats(outputStatsBuilders); }
Example 9
Source File: TestComputationGraphNetwork.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test(timeout = 300000) public void testIrisFitMultiDataSetIterator() throws Exception { RecordReader rr = new CSVRecordReader(0, ','); rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr) .addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build(); ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() .updater(new Sgd(0.1)) .graphBuilder().addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(2).nOut(3) .build(), "dense") .setOutputs("out").build(); ComputationGraph cg = new ComputationGraph(config); cg.init(); cg.fit(iter); rr.reset(); iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr).addInput("iris", 0, 3) .addOutputOneHot("iris", 4, 3).build(); while (iter.hasNext()) { cg.fit(iter.next()); } }
Example 10
Source File: RandomDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testMDSI(){ Nd4j.getRandom().setSeed(12345); MultiDataSetIterator iter = new RandomMultiDataSetIterator.Builder(5) .addFeatures(new long[]{3,4}, RandomMultiDataSetIterator.Values.INTEGER_0_100) .addFeatures(new long[]{3,5}, RandomMultiDataSetIterator.Values.BINARY) .addLabels(new long[]{3,6}, RandomMultiDataSetIterator.Values.ZEROS) .build(); int count = 0; while(iter.hasNext()){ count++; MultiDataSet mds = iter.next(); assertEquals(2, mds.numFeatureArrays()); assertEquals(1, mds.numLabelsArrays()); assertArrayEquals(new long[]{3,4}, mds.getFeatures(0).shape()); assertArrayEquals(new long[]{3,5}, mds.getFeatures(1).shape()); assertArrayEquals(new long[]{3,6}, mds.getLabels(0).shape()); assertTrue(mds.getFeatures(0).minNumber().doubleValue() >= 0 && mds.getFeatures(0).maxNumber().doubleValue() <= 100.0 && mds.getFeatures(0).maxNumber().doubleValue() > 2.0); assertTrue(mds.getFeatures(1).minNumber().doubleValue() == 0.0 && mds.getFeatures(1).maxNumber().doubleValue() == 1.0); assertEquals(0.0, mds.getLabels(0).sumNumber().doubleValue(), 0.0); } assertEquals(5, count); }
Example 11
Source File: LoaderIteratorTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testMDSLoaderIter(){ for(boolean r : new boolean[]{false, true}) { List<String> l = Arrays.asList("3", "0", "1"); Random rng = r ? new Random(12345) : null; MultiDataSetIterator iter = new MultiDataSetLoaderIterator(l, null, new Loader<MultiDataSet>() { @Override public MultiDataSet load(Source source) throws IOException { INDArray i = Nd4j.scalar(Integer.valueOf(source.getPath())); return new org.nd4j.linalg.dataset.MultiDataSet(i, i); } }, new LocalFileSourceFactory()); int count = 0; int[] exp = {3, 0, 1}; while (iter.hasNext()) { MultiDataSet ds = iter.next(); if(!r) { assertEquals(exp[count], ds.getFeatures()[0].getInt(0)); } count++; } assertEquals(3, count); iter.reset(); assertTrue(iter.hasNext()); } }
Example 12
Source File: ScoreFlatMapFunctionCGMultiDataSet.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public Iterator<Tuple2<Long, Double>> call(Iterator<MultiDataSet> dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { return Collections.singletonList(new Tuple2<>(0L, 0.0)).iterator(); } MultiDataSetIterator iter = new IteratorMultiDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json)); network.init(); INDArray val = params.value().unsafeDuplication(); //.value() is shared by all executors on single machine -> OK, as params are not changed in score function if (val.length() != network.numParams(false)) throw new IllegalStateException( "Network did not have same number of parameters as the broadcast set parameters"); network.setParams(val); List<Tuple2<Long, Double>> out = new ArrayList<>(); while (iter.hasNext()) { MultiDataSet ds = iter.next(); double score = network.score(ds, false); long numExamples = ds.getFeatures(0).size(0); out.add(new Tuple2<>(numExamples, score * numExamples)); } Nd4j.getExecutioner().commit(); return out.iterator(); }
Example 13
Source File: TestSparkComputationGraph.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testBasic() throws Exception { JavaSparkContext sc = this.sc; RecordReader rr = new CSVRecordReader(0, ','); rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive())); MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr) .addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build(); List<MultiDataSet> list = new ArrayList<>(150); while (iter.hasNext()) list.add(iter.next()); ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() .updater(new Sgd(0.1)) .graphBuilder().addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3) .build(), "dense") .setOutputs("out").build(); ComputationGraph cg = new ComputationGraph(config); cg.init(); TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0); SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm); scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(5))); JavaRDD<MultiDataSet> rdd = sc.parallelize(list); scg.fitMultiDataSet(rdd); //Try: fitting using DataSet DataSetIterator iris = new IrisDataSetIterator(1, 150); List<DataSet> list2 = new ArrayList<>(); while (iris.hasNext()) list2.add(iris.next()); JavaRDD<DataSet> rddDS = sc.parallelize(list2); scg.fit(rddDS); }