Java Code Examples for org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator#reset()
The following examples show how to use
org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator#reset() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TrainUtil.java From FancyBing with GNU General Public License v3.0 | 6 votes |
public static double evaluate(Model model, int outputNum, MultiDataSetIterator testData, int topN, int batchSize) { log.info("Evaluate model...."); Evaluation clsEval = new Evaluation(createLabels(outputNum), topN); RegressionEvaluation valueRegEval1 = new RegressionEvaluation(1); int count = 0; long begin = 0; long consume = 0; while(testData.hasNext()){ MultiDataSet ds = testData.next(); begin = System.nanoTime(); INDArray[] output = ((ComputationGraph) model).output(false, ds.getFeatures()); consume += System.nanoTime() - begin; clsEval.eval(ds.getLabels(0), output[0]); valueRegEval1.eval(ds.getLabels(1), output[1]); count++; } String stats = clsEval.stats(); int pos = stats.indexOf("==="); stats = "\n" + stats.substring(pos); log.info(stats); log.info(valueRegEval1.stats()); testData.reset(); log.info("Evaluate time: " + consume + " count: " + (count * batchSize) + " average: " + ((float) consume/(count*batchSize)/1000)); return clsEval.accuracy(); }
Example 2
Source File: AbstractMultiDataSetNormalizer.java From nd4j with Apache License 2.0 | 6 votes |
/** * Fit an iterator * * @param iterator for the data to iterate over */ public void fit(@NonNull MultiDataSetIterator iterator) { List<S.Builder> featureNormBuilders = new ArrayList<>(); List<S.Builder> labelNormBuilders = new ArrayList<>(); iterator.reset(); while (iterator.hasNext()) { MultiDataSet next = iterator.next(); fitPartial(next, featureNormBuilders, labelNormBuilders); } featureStats = buildList(featureNormBuilders); if (isFitLabel()) { labelStats = buildList(labelNormBuilders); } }
Example 3
Source File: AbstractMultiDataSetNormalizer.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * Fit an iterator * * @param iterator for the data to iterate over */ public void fit(@NonNull MultiDataSetIterator iterator) { List<S.Builder> featureNormBuilders = new ArrayList<>(); List<S.Builder> labelNormBuilders = new ArrayList<>(); iterator.reset(); while (iterator.hasNext()) { MultiDataSet next = iterator.next(); fitPartial(next, featureNormBuilders, labelNormBuilders); } featureStats = buildList(featureNormBuilders); if (isFitLabel()) { labelStats = buildList(labelNormBuilders); } }
Example 4
Source File: Main.java From twse-captcha-solver-dl4j with MIT License | 5 votes |
public static void modelPredict(ComputationGraph model, MultiDataSetIterator iterator) { int sumCount = 0; int correctCount = 0; List<String> labelList = Arrays.asList( "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"); while (iterator.hasNext()) { MultiDataSet mds = iterator.next(); INDArray[] output = model.output(mds.getFeatures()); INDArray[] labels = mds.getLabels(); int dataNum = batchSize > output[0].rows() ? output[0].rows() : batchSize; for (int dataIndex = 0; dataIndex < dataNum; dataIndex++) { String reLabel = ""; String peLabel = ""; INDArray preOutput = null; INDArray realLabel = null; for (int digit = 0; digit < 5; digit++) { preOutput = output[digit].getRow(dataIndex); peLabel += labelList.get(Nd4j.argMax(preOutput, 1).getInt(0)); realLabel = labels[digit].getRow(dataIndex); reLabel += labelList.get(Nd4j.argMax(realLabel, 1).getInt(0)); } if (peLabel.equals(reLabel)) { correctCount++; } sumCount++; logger.info( "real image {} prediction {} status {}", reLabel, peLabel, peLabel.equals(reLabel)); } } iterator.reset(); System.out.println( "validate result : sum count =" + sumCount + " correct count=" + correctCount); }
Example 5
Source File: MultiNormalizerHybrid.java From nd4j with Apache License 2.0 | 5 votes |
/** * Iterates over a dataset * accumulating statistics for normalization * * @param iterator the iterator to use for collecting statistics */ @Override public void fit(@NonNull MultiDataSetIterator iterator) { Map<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<>(); Map<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<>(); iterator.reset(); while (iterator.hasNext()) { fitPartial(iterator.next(), inputStatsBuilders, outputStatsBuilders); } inputStats = buildAllStats(inputStatsBuilders); outputStats = buildAllStats(outputStatsBuilders); }
Example 6
Source File: MultiNormalizerHybrid.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Iterates over a dataset * accumulating statistics for normalization * * @param iterator the iterator to use for collecting statistics */ @Override public void fit(@NonNull MultiDataSetIterator iterator) { Map<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<>(); Map<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<>(); iterator.reset(); while (iterator.hasNext()) { fitPartial(iterator.next(), inputStatsBuilders, outputStatsBuilders); } inputStats = buildAllStats(inputStatsBuilders); outputStats = buildAllStats(outputStatsBuilders); }
Example 7
Source File: EarlyTerminationMultiDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testCallstoNextNotAllowed() throws IOException { int terminateAfter = 1; MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples)); EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter); earlyEndIter.next(10); iter.reset(); exception.expect(RuntimeException.class); earlyEndIter.next(10); }
Example 8
Source File: LoaderIteratorTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testMDSLoaderIter(){ for(boolean r : new boolean[]{false, true}) { List<String> l = Arrays.asList("3", "0", "1"); Random rng = r ? new Random(12345) : null; MultiDataSetIterator iter = new MultiDataSetLoaderIterator(l, null, new Loader<MultiDataSet>() { @Override public MultiDataSet load(Source source) throws IOException { INDArray i = Nd4j.scalar(Integer.valueOf(source.getPath())); return new org.nd4j.linalg.dataset.MultiDataSet(i, i); } }, new LocalFileSourceFactory()); int count = 0; int[] exp = {3, 0, 1}; while (iter.hasNext()) { MultiDataSet ds = iter.next(); if(!r) { assertEquals(exp[count], ds.getFeatures()[0].getInt(0)); } count++; } assertEquals(3, count); iter.reset(); assertTrue(iter.hasNext()); } }