org.nd4j.linalg.dataset.api.DataSet Java Examples
The following examples show how to use
org.nd4j.linalg.dataset.api.DataSet.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: ImageFlatteningDataSetPreProcessor.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public void preProcess(DataSet toPreProcess) { INDArray input = toPreProcess.getFeatures(); if (input.rank() == 2) return; //No op: should usually never happen in a properly configured data pipeline //Assume input is standard rank 4 activations - i.e., CNN image data //First: we require input to be in c order. But c order (as declared in array order) isn't enough; also need strides to be correct if (input.ordering() != 'c' || !Shape.strideDescendingCAscendingF(input)) input = input.dup('c'); val inShape = input.shape(); //[miniBatch,depthOut,outH,outW] val outShape = new long[] {inShape[0], inShape[1] * inShape[2] * inShape[3]}; INDArray reshaped = input.reshape('c', outShape); toPreProcess.setFeatures(reshaped); }
Example #2
Source File: ImageFlatteningDataSetPreProcessor.java From nd4j with Apache License 2.0 | 6 votes |
@Override public void preProcess(DataSet toPreProcess) { INDArray input = toPreProcess.getFeatures(); if (input.rank() == 2) return; //No op: should usually never happen in a properly configured data pipeline //Assume input is standard rank 4 activations - i.e., CNN image data //First: we require input to be in c order. But c order (as declared in array order) isn't enough; also need strides to be correct if (input.ordering() != 'c' || !Shape.strideDescendingCAscendingF(input)) input = input.dup('c'); val inShape = input.shape(); //[miniBatch,depthOut,outH,outW] val outShape = new long[] {inShape[0], inShape[1] * inShape[2] * inShape[3]}; INDArray reshaped = input.reshape('c', outShape); toPreProcess.setFeatures(reshaped); }
Example #3
Source File: TestRecordReaders.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testClassIndexOutsideOfRangeRRDSI() { Collection<Collection<Writable>> c = new ArrayList<>(); c.add(Arrays.<Writable>asList(new DoubleWritable(0.5), new IntWritable(0))); c.add(Arrays.<Writable>asList(new DoubleWritable(1.0), new IntWritable(2))); CollectionRecordReader crr = new CollectionRecordReader(c); RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(crr, 2, 1, 2); try { DataSet ds = iter.next(); fail("Expected exception"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("to one-hot")); } }
Example #4
Source File: ParameterServerTrainer.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public void feedDataSet(@NonNull DataSet dataSet, long time) { // FIXME: this is wrong, and should be fixed. Training should happen within run() loop if (getModel() instanceof ComputationGraph) { ComputationGraph computationGraph = (ComputationGraph) getModel(); computationGraph.fit(dataSet); } else { MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) getModel(); log.info("Calling fit on multi layer network"); multiLayerNetwork.fit(dataSet); } log.info("About to send params in"); //send the updated params parameterServerClient.pushNDArray(getModel().params()); log.info("Sent params"); }
Example #5
Source File: DoubleDQNTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void when_isTerminal_expect_rewardValueAtIdx0() { // Assemble when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { { add(builtTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, true, buildObservation(new double[]{11.0, 22.0}))); } }; DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); assertEquals(1.0, evaluatedQValues.getDouble(0, 0), 0.0001); assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); }
Example #6
Source File: ComputationGraphUtil.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** Convert a DataSet to the equivalent MultiDataSet */ public static MultiDataSet toMultiDataSet(DataSet dataSet) { INDArray f = dataSet.getFeatures(); INDArray l = dataSet.getLabels(); INDArray fMask = dataSet.getFeaturesMaskArray(); INDArray lMask = dataSet.getLabelsMaskArray(); List<Serializable> meta = dataSet.getExampleMetaData(); INDArray[] fNew = f == null ? null : new INDArray[] {f}; INDArray[] lNew = l == null ? null : new INDArray[] {l}; INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null); INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null); org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew); mds.setExampleMetaData(meta); return mds; }
Example #7
Source File: RecordConverter.java From deeplearning4j with Apache License 2.0 | 6 votes |
private static List<List<Writable>> getRegressionWritableMatrix(DataSet dataSet) { List<List<Writable>> writableMatrix = new ArrayList<>(); for (int i = 0; i < dataSet.numExamples(); i++) { List<Writable> writables = toRecord(dataSet.getFeatures().getRow(i)); INDArray labelRow = dataSet.getLabels().getRow(i); for (int j = 0; j < labelRow.shape()[1]; j++) { writables.add(new DoubleWritable(labelRow.getDouble(j))); } writableMatrix.add(writables); } return writableMatrix; }
Example #8
Source File: CropAndResizeDataSetPreProcessor.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * NOTE: The data format must be NHWC */ @Override public void preProcess(DataSet dataSet) { Preconditions.checkNotNull(dataSet, "Encountered null dataSet"); if(dataSet.isEmpty()) { return; } INDArray input = dataSet.getFeatures(); INDArray output = Nd4j.create(LongShapeDescriptor.fromShape(resizedShape, input.dataType()), false); CustomOp op = DynamicCustomOp.builder("crop_and_resize") .addInputs(input, boxes, indices, resize) .addIntegerArguments(method) .addOutputs(output) .build(); Nd4j.getExecutioner().exec(op); dataSet.setFeatures(output); }
Example #9
Source File: CompositeDataSetPreProcessor.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public void preProcess(DataSet dataSet) { Preconditions.checkNotNull(dataSet, "Encountered null dataSet"); if(stopOnEmptyDataSet && dataSet.isEmpty()) { return; } for(DataSetPreProcessor p : preProcessors){ p.preProcess(dataSet); if(stopOnEmptyDataSet && dataSet.isEmpty()) { return; } } }
Example #10
Source File: DefaultCallback.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public void call(DataSet dataSet) { if (dataSet != null) { if (dataSet.getFeatures() != null) Nd4j.getAffinityManager().ensureLocation(dataSet.getFeatures(), AffinityManager.Location.DEVICE); if (dataSet.getLabels() != null) Nd4j.getAffinityManager().ensureLocation(dataSet.getLabels(), AffinityManager.Location.DEVICE); if (dataSet.getFeaturesMaskArray() != null) Nd4j.getAffinityManager().ensureLocation(dataSet.getFeaturesMaskArray(), AffinityManager.Location.DEVICE); if (dataSet.getLabelsMaskArray() != null) Nd4j.getAffinityManager().ensureLocation(dataSet.getLabelsMaskArray(), AffinityManager.Location.DEVICE); } }
Example #11
Source File: RecordConverter.java From DataVec with Apache License 2.0 | 5 votes |
/** * Convert a DataSet to a matrix * @param dataSet the DataSet to convert * @return the matrix for the records */ public static List<List<Writable>> toRecords(DataSet dataSet) { if (isClassificationDataSet(dataSet)) { return getClassificationWritableMatrix(dataSet); } else { return getRegressionWritableMatrix(dataSet); } }
Example #12
Source File: UnderSamplingByMaskingPreProcessor.java From nd4j with Apache License 2.0 | 5 votes |
@Override public void preProcess(DataSet toPreProcess) { INDArray label = toPreProcess.getLabels(); INDArray labelMask = toPreProcess.getLabelsMaskArray(); INDArray sampledMask = adjustMasks(label, labelMask, minorityLabel, targetMinorityDist); toPreProcess.setLabelsMaskArray(sampledMask); }
Example #13
Source File: JointParallelDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Testing relocate * * @throws Exception */ @Test public void testJointIterator3() throws Exception { DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10); DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10); JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RELOCATE) .addSourceIterator(iteratorA).addSourceIterator(iteratorB).build(); int cnt = 0; int example = 0; while (jpdsi.hasNext()) { DataSet ds = jpdsi.next(); assertNotNull("Failed on iteration " + cnt, ds); assertEquals("Failed on iteration " + cnt, (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001); assertEquals("Failed on iteration " + cnt, (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001); cnt++; if (cnt < 200) { if (cnt % 2 == 0) example++; } else example++; } assertEquals(300, cnt); assertEquals(200, example); }
Example #14
Source File: EvaluationToolsTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testRocMultiToHtml() throws Exception { DataSetIterator iter = new IrisDataSetIterator(150, 150); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); NormalizerStandardize ns = new NormalizerStandardize(); DataSet ds = iter.next(); ns.fit(ds); ns.transform(ds); for (int i = 0; i < 30; i++) { net.fit(ds); } for (int numSteps : new int[] {20, 0}) { ROCMultiClass roc = new ROCMultiClass(numSteps); iter.reset(); INDArray f = ds.getFeatures(); INDArray l = ds.getLabels(); INDArray out = net.output(f); roc.eval(l, out); String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica")); // System.out.println(str); } }
Example #15
Source File: JointParallelDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Simple test, checking datasets alignment. They all should have the same data for the same cycle * * * @throws Exception */ @Test public void testJointIterator1() throws Exception { DataSetIterator iteratorA = new SimpleVariableGenerator(119, 100, 32, 100, 10); DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10); JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.STOP_EVERYONE) .addSourceIterator(iteratorA).addSourceIterator(iteratorB).build(); int cnt = 0; int example = 0; while (jpdsi.hasNext()) { DataSet ds = jpdsi.next(); assertNotNull("Failed on iteration " + cnt, ds); // ds.detach(); //ds.migrate(); assertEquals("Failed on iteration " + cnt, (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001); assertEquals("Failed on iteration " + cnt, (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001); cnt++; if (cnt % 2 == 0) example++; } assertEquals(100, example); assertEquals(200, cnt); }
Example #16
Source File: RPTreeTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testRpTreeMaxNodes() throws Exception { DataSetIterator mnist = new MnistDataSetIterator(150,150); RPForest rpTree = new RPForest(4,4,"euclidean"); DataSet d = mnist.next(); NormalizerStandardize normalizerStandardize = new NormalizerStandardize(); normalizerStandardize.fit(d); rpTree.fit(d.getFeatures()); for(RPTree tree : rpTree.getTrees()) { for(RPNode node : tree.getLeaves()) { assertTrue(node.getIndices().size() <= rpTree.getMaxSize()); } } }
Example #17
Source File: DummyBlockDataSetIteratorTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testBlock_1() throws Exception { val simpleIterator = new SimpleVariableGenerator(123, 8, 3, 3, 3); val iterator = new DummyBlockDataSetIterator(simpleIterator); assertTrue(iterator.hasAnything()); val list = new ArrayList<DataSet>(8); var datasets = iterator.next(3); assertNotNull(datasets); assertEquals(3, datasets.length); list.addAll(Arrays.asList(datasets)); datasets = iterator.next(3); assertNotNull(datasets); assertEquals(3, datasets.length); list.addAll(Arrays.asList(datasets)); datasets = iterator.next(3); assertNotNull(datasets); assertEquals(2, datasets.length); list.addAll(Arrays.asList(datasets)); for (int e = 0; e < list.size(); e++) { val dataset = list.get(e); assertEquals(e, (int) dataset.getFeatures().getDouble(0)); assertEquals(e + 0.5, dataset.getLabels().getDouble(0), 1e-3); } }
Example #18
Source File: AbstractDataSetNormalizer.java From nd4j with Apache License 2.0 | 5 votes |
/** * Fit a dataset (only compute based on the statistics from this dataset) * @param dataSet the dataset to compute on */ @Override public void fit(DataSet dataSet) { featureStats = (S) newBuilder().addFeatures(dataSet).build(); if (isFitLabel()) { labelStats = (S) newBuilder().addLabels(dataSet).build(); } }
Example #19
Source File: DummyBlockDataSetIterator.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public DataSet[] next(int maxDatasets) { val list = new ArrayList<DataSet>(maxDatasets); int cnt = 0; while (iterator.hasNext() && cnt < maxDatasets) { list.add(iterator.next()); cnt++; } return list.toArray(new DataSet[list.size()]); }
Example #20
Source File: LossLayer.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Return predicted label names * * @param dataSet to predict * @return the predicted labels for the dataSet */ @Override public List<String> predict(DataSet dataSet) { int[] intRet = predict(dataSet.getFeatures()); List<String> ret = new ArrayList<>(); for (int i : intRet) { ret.add(i, dataSet.getLabelName(i)); } return ret; }
Example #21
Source File: StandardDQNTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { // Assemble List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { { add(buildTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); add(buildTransition(buildObservation(new double[]{3.3, 4.4}), 1, 2.0, false, buildObservation(new double[]{33.0, 44.0}))); add(buildTransition(buildObservation(new double[]{5.5, 6.6}), 0, 3.0, true, buildObservation(new double[]{55.0, 66.0}))); } }; StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); assertEquals((1.0 + 0.5 * 22.0), evaluatedQValues.getDouble(0, 0), 0.0001); assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); assertEquals(3.3, evaluatedQValues.getDouble(1, 0), 0.0001); assertEquals((2.0 + 0.5 * 44.0), evaluatedQValues.getDouble(1, 1), 0.0001); assertEquals(3.0, evaluatedQValues.getDouble(2, 0), 0.0001); // terminal: reward only assertEquals(6.6, evaluatedQValues.getDouble(2, 1), 0.0001); }
Example #22
Source File: TestRecordReaders.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testClassIndexOutsideOfRangeRRMDSI_MultipleReaders() { Collection<Collection<Collection<Writable>>> c1 = new ArrayList<>(); Collection<Collection<Writable>> seq1 = new ArrayList<>(); seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0))); seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0))); c1.add(seq1); Collection<Collection<Writable>> seq2 = new ArrayList<>(); seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0))); seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0))); c1.add(seq2); Collection<Collection<Collection<Writable>>> c2 = new ArrayList<>(); Collection<Collection<Writable>> seq1a = new ArrayList<>(); seq1a.add(Arrays.<Writable>asList(new IntWritable(0))); seq1a.add(Arrays.<Writable>asList(new IntWritable(1))); c2.add(seq1a); Collection<Collection<Writable>> seq2a = new ArrayList<>(); seq2a.add(Arrays.<Writable>asList(new IntWritable(0))); seq2a.add(Arrays.<Writable>asList(new IntWritable(2))); c2.add(seq2a); CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c1); CollectionSequenceRecordReader csrrLabels = new CollectionSequenceRecordReader(c2); DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, csrrLabels, 2, 2); try { DataSet ds = dsi.next(); fail("Expected exception"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("to one-hot")); } }
Example #23
Source File: RGBtoGrayscaleDataSetPreProcessor.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void preProcess(DataSet dataSet) { Preconditions.checkNotNull(dataSet, "Encountered null dataSet"); if(dataSet.isEmpty()) { return; } INDArray originalFeatures = dataSet.getFeatures(); long[] originalShape = originalFeatures.shape(); // result shape is NHW INDArray result = Nd4j.create(originalShape[0], originalShape[2], originalShape[3]); for(long n = 0, numExamples = originalShape[0]; n < numExamples; ++n) { // Extract channels INDArray itemFeatures = originalFeatures.slice(n, 0); // shape is CHW INDArray R = itemFeatures.slice(0, 0); // shape is HW INDArray G = itemFeatures.slice(1, 0); INDArray B = itemFeatures.slice(2, 0); // Convert R.muli(RED_RATIO); G.muli(GREEN_RATIO); B.muli(BLUE_RATIO); R.addi(G).addi(B); result.putSlice((int)n, R); } dataSet.setFeatures(result); }
Example #24
Source File: RPTreeTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testFindSelf() throws Exception { DataSetIterator mnist = new MnistDataSetIterator(100, 6000); NormalizerMinMaxScaler minMaxNormalizer = new NormalizerMinMaxScaler(0, 1); minMaxNormalizer.fit(mnist); DataSet d = mnist.next(); minMaxNormalizer.transform(d.getFeatures()); RPForest rpForest = new RPForest(100, 100, "euclidean"); rpForest.fit(d.getFeatures()); for (int i = 0; i < 10; i++) { INDArray indexes = rpForest.queryAll(d.getFeatures().slice(i), 10); assertEquals(i,indexes.getInt(0)); } }
Example #25
Source File: UnderSamplingByMaskingPreProcessor.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void preProcess(DataSet toPreProcess) { INDArray label = toPreProcess.getLabels(); INDArray labelMask = toPreProcess.getLabelsMaskArray(); INDArray sampledMask = adjustMasks(label, labelMask, minorityLabel, targetMinorityDist); toPreProcess.setLabelsMaskArray(sampledMask); }
Example #26
Source File: BaseOutputLayer.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Return predicted label names * * @param dataSet to predict * @return the predicted labels for the dataSet */ @Override public List<String> predict(DataSet dataSet) { int[] intRet = predict(dataSet.getFeatures()); List<String> ret = new ArrayList<>(); for (int i : intRet) { ret.add(i, dataSet.getLabelName(i)); } return ret; }
Example #27
Source File: SharedTrainingWorker.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public SharedTrainingResult processMinibatch(DataSet dataSet, MultiLayerNetwork network, boolean isLast) { /* We're not really going to use this method for training. Partitions will be mapped to ParallelWorker threads dynamically, wrt thread/device affinity. So plan is simple: we're going to use individual partitions to feed main worker */ throw new UnsupportedOperationException(); }
Example #28
Source File: TransformProcessTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public DataSet transform(Integer input) { return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { input }), null); }
Example #29
Source File: AbstractDataSetNormalizer.java From nd4j with Apache License 2.0 | 4 votes |
/** * Revert the data to what it was before transform * * @param data the dataset to revert back */ @Override public void revert(DataSet data) { revertFeatures(data.getFeatures(), data.getFeaturesMaskArray()); revertLabels(data.getLabels(), data.getLabelsMaskArray()); }
Example #30
Source File: SameDiffOutputLayer.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public List<String> predict(DataSet dataSet) { throw new UnsupportedOperationException("Not supported"); }