Java Code Examples for org.nd4j.linalg.dataset.api.DataSet#getLabels()
The following examples show how to use
org.nd4j.linalg.dataset.api.DataSet#getLabels() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: StandardDQNTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void when_isTerminal_expect_rewardValueAtIdx0() { // Assemble List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { { add(buildTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, true, buildObservation(new double[]{11.0, 22.0}))); } }; StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); assertEquals(1.0, evaluatedQValues.getDouble(0, 0), 0.0001); assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); }
Example 2
Source File: StandardDQNTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { // Assemble List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { { add(buildTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); } }; StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); assertEquals(1.0 + 0.5 * 22.0, evaluatedQValues.getDouble(0, 0), 0.0001); assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); }
Example 3
Source File: DoubleDQNTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void when_isTerminal_expect_rewardValueAtIdx0() { // Assemble when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { { add(builtTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, true, buildObservation(new double[]{11.0, 22.0}))); } }; DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); assertEquals(1.0, evaluatedQValues.getDouble(0, 0), 0.0001); assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); }
Example 4
Source File: DoubleDQNTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { // Assemble when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0)); List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { { add(builtTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); } }; DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); assertEquals(1.0 + 0.5 * -22.0, evaluatedQValues.getDouble(0, 0), 0.0001); assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); }
Example 5
Source File: DefaultCallback.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public void call(DataSet dataSet) { if (dataSet != null) { if (dataSet.getFeatures() != null) Nd4j.getAffinityManager().ensureLocation(dataSet.getFeatures(), AffinityManager.Location.DEVICE); if (dataSet.getLabels() != null) Nd4j.getAffinityManager().ensureLocation(dataSet.getLabels(), AffinityManager.Location.DEVICE); if (dataSet.getFeaturesMaskArray() != null) Nd4j.getAffinityManager().ensureLocation(dataSet.getFeaturesMaskArray(), AffinityManager.Location.DEVICE); if (dataSet.getLabelsMaskArray() != null) Nd4j.getAffinityManager().ensureLocation(dataSet.getLabelsMaskArray(), AffinityManager.Location.DEVICE); } }
Example 6
Source File: ComputationGraphUtil.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** Convert a DataSet to the equivalent MultiDataSet */ public static MultiDataSet toMultiDataSet(DataSet dataSet) { INDArray f = dataSet.getFeatures(); INDArray l = dataSet.getLabels(); INDArray fMask = dataSet.getFeaturesMaskArray(); INDArray lMask = dataSet.getLabelsMaskArray(); List<Serializable> meta = dataSet.getExampleMetaData(); INDArray[] fNew = f == null ? null : new INDArray[] {f}; INDArray[] lNew = l == null ? null : new INDArray[] {l}; INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null); INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null); org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew); mds.setExampleMetaData(meta); return mds; }
Example 7
Source File: UnderSamplingByMaskingPreProcessor.java From nd4j with Apache License 2.0 | 5 votes |
@Override public void preProcess(DataSet toPreProcess) { INDArray label = toPreProcess.getLabels(); INDArray labelMask = toPreProcess.getLabelsMaskArray(); INDArray sampledMask = adjustMasks(label, labelMask, minorityLabel, targetMinorityDist); toPreProcess.setLabelsMaskArray(sampledMask); }
Example 8
Source File: StandardDQNTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { // Assemble List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { { add(buildTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); add(buildTransition(buildObservation(new double[]{3.3, 4.4}), 1, 2.0, false, buildObservation(new double[]{33.0, 44.0}))); add(buildTransition(buildObservation(new double[]{5.5, 6.6}), 0, 3.0, true, buildObservation(new double[]{55.0, 66.0}))); } }; StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); assertEquals((1.0 + 0.5 * 22.0), evaluatedQValues.getDouble(0, 0), 0.0001); assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); assertEquals(3.3, evaluatedQValues.getDouble(1, 0), 0.0001); assertEquals((2.0 + 0.5 * 44.0), evaluatedQValues.getDouble(1, 1), 0.0001); assertEquals(3.0, evaluatedQValues.getDouble(2, 0), 0.0001); // terminal: reward only assertEquals(6.6, evaluatedQValues.getDouble(2, 1), 0.0001); }
Example 9
Source File: DoubleDQNTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { // Assemble when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0)); List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { { add(builtTransition(buildObservation(new double[]{1.1, 2.2}), 0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); add(builtTransition(buildObservation(new double[]{3.3, 4.4}), 1, 2.0, false, buildObservation(new double[]{33.0, 44.0}))); add(builtTransition(buildObservation(new double[]{5.5, 6.6}), 0, 3.0, true, buildObservation(new double[]{55.0, 66.0}))); } }; DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); assertEquals(1.0 + 0.5 * -22.0, evaluatedQValues.getDouble(0, 0), 0.0001); assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); assertEquals(3.3, evaluatedQValues.getDouble(1, 0), 0.0001); assertEquals(2.0 + 0.5 * -44.0, evaluatedQValues.getDouble(1, 1), 0.0001); assertEquals(3.0, evaluatedQValues.getDouble(2, 0), 0.0001); // terminal: reward only assertEquals(6.6, evaluatedQValues.getDouble(2, 1), 0.0001); }
Example 10
Source File: UnderSamplingByMaskingPreProcessor.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void preProcess(DataSet toPreProcess) { INDArray label = toPreProcess.getLabels(); INDArray labelMask = toPreProcess.getLabelsMaskArray(); INDArray sampledMask = adjustMasks(label, labelMask, minorityLabel, targetMinorityDist); toPreProcess.setLabelsMaskArray(sampledMask); }
Example 11
Source File: LabelLastTimeStepPreProcessor.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void preProcess(DataSet toPreProcess) { INDArray label3d = toPreProcess.getLabels(); Preconditions.checkState(label3d.rank() == 3, "LabelLastTimeStepPreProcessor expects rank 3 labels, got rank %s labels with shape %ndShape", label3d.rank(), label3d); INDArray lMask = toPreProcess.getLabelsMaskArray(); //If no mask: assume that examples for each minibatch are all same length INDArray labels2d; if(lMask == null){ labels2d = label3d.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(label3d.size(2)-1)).dup(); } else { //Use the label mask to work out the last time step... INDArray lastIndex = BooleanIndexing.lastIndex(lMask, Conditions.greaterThan(0), 1); long[] idxs = lastIndex.data().asLong(); //Now, extract out: labels2d = Nd4j.create(DataType.FLOAT, label3d.size(0), label3d.size(1)); //Now, get and assign the corresponding subsets of 3d activations: for (int i = 0; i < idxs.length; i++) { long lastStepIdx = idxs[i]; Preconditions.checkState(lastStepIdx >= 0, "Invalid last time step index: example %s in minibatch is entirely masked out" + " (label mask is all 0s, meaning no label data is present for this example)", i); //TODO can optimize using reshape + pullRows labels2d.putRow(i, label3d.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(lastStepIdx))); } } toPreProcess.setLabels(labels2d); toPreProcess.setLabelsMaskArray(null); //Remove label mask if present }
Example 12
Source File: ImagePreProcessingScaler.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void preProcess(DataSet toPreProcess) { INDArray features = toPreProcess.getFeatures(); preProcess(features); if(fitLabels && toPreProcess.getLabels() != null){ preProcess(toPreProcess.getLabels()); } }
Example 13
Source File: EvaluationToolsTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testRocMultiToHtml() throws Exception { DataSetIterator iter = new IrisDataSetIterator(150, 150); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); NormalizerStandardize ns = new NormalizerStandardize(); DataSet ds = iter.next(); ns.fit(ds); ns.transform(ds); for (int i = 0; i < 30; i++) { net.fit(ds); } for (int numSteps : new int[] {20, 0}) { ROCMultiClass roc = new ROCMultiClass(numSteps); iter.reset(); INDArray f = ds.getFeatures(); INDArray l = ds.getLabels(); INDArray out = net.output(f); roc.eval(l, out); String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica")); // System.out.println(str); } }
Example 14
Source File: RecordConverter.java From DataVec with Apache License 2.0 | 4 votes |
private static boolean isClassificationDataSet(DataSet dataSet) { INDArray labels = dataSet.getLabels(); return labels.sum(0, 1).getInt(0) == dataSet.numExamples() && labels.shape()[1] > 1; }
Example 15
Source File: RecordConverter.java From deeplearning4j with Apache License 2.0 | 4 votes |
private static boolean isClassificationDataSet(DataSet dataSet) { INDArray labels = dataSet.getLabels(); return labels.sum(0, 1).getInt(0) == dataSet.numExamples() && labels.shape()[1] > 1; }
Example 16
Source File: EvaluationToolsTests.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testRocHtml() { DataSetIterator iter = new IrisDataSetIterator(150, 150); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); NormalizerStandardize ns = new NormalizerStandardize(); DataSet ds = iter.next(); ns.fit(ds); ns.transform(ds); INDArray newLabels = Nd4j.create(150, 2); newLabels.getColumn(0).assign(ds.getLabels().getColumn(0)); newLabels.getColumn(0).addi(ds.getLabels().getColumn(1)); newLabels.getColumn(1).assign(ds.getLabels().getColumn(2)); ds.setLabels(newLabels); for (int i = 0; i < 30; i++) { net.fit(ds); } for (int numSteps : new int[] {20, 0}) { ROC roc = new ROC(numSteps); iter.reset(); INDArray f = ds.getFeatures(); INDArray l = ds.getLabels(); INDArray out = net.output(f); roc.eval(l, out); String str = EvaluationTools.rocChartToHtml(roc); // System.out.println(str); } }