Java Code Examples for org.nd4j.linalg.dataset.api.DataSet#getLabels()

The following examples show how to use org.nd4j.linalg.dataset.api.DataSet#getLabels() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: StandardDQNTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void when_isTerminal_expect_rewardValueAtIdx0() {

    // Assemble
    List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
        {
            add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
                    0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
        }
    };

    StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);

    // Act
    DataSet result = sut.compute(transitions);

    // Assert
    INDArray evaluatedQValues = result.getLabels();
    assertEquals(1.0, evaluatedQValues.getDouble(0, 0), 0.0001);
    assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
}
 
Example 2
Source File: StandardDQNTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {

    // Assemble
    List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
        {
            add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
                    0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
        }
    };

    StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);

    // Act
    DataSet result = sut.compute(transitions);

    // Assert
    INDArray evaluatedQValues = result.getLabels();
    assertEquals(1.0 + 0.5 * 22.0, evaluatedQValues.getDouble(0, 0), 0.0001);
    assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
}
 
Example 3
Source File: DoubleDQNTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void when_isTerminal_expect_rewardValueAtIdx0() {

    // Assemble
    when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);

    List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
        {
            add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
                    0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
        }
    };

    DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);

    // Act
    DataSet result = sut.compute(transitions);

    // Assert
    INDArray evaluatedQValues = result.getLabels();
    assertEquals(1.0, evaluatedQValues.getDouble(0, 0), 0.0001);
    assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
}
 
Example 4
Source File: DoubleDQNTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {

    // Assemble
    when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));

    List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
        {
            add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
                    0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
        }
    };

    DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);

    // Act
    DataSet result = sut.compute(transitions);

    // Assert
    INDArray evaluatedQValues = result.getLabels();
    assertEquals(1.0 + 0.5 * -22.0, evaluatedQValues.getDouble(0, 0), 0.0001);
    assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
}
 
Example 5
Source File: DefaultCallback.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void call(DataSet dataSet) {
    if (dataSet != null) {
        if (dataSet.getFeatures() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getFeatures(), AffinityManager.Location.DEVICE);

        if (dataSet.getLabels() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getLabels(), AffinityManager.Location.DEVICE);

        if (dataSet.getFeaturesMaskArray() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getFeaturesMaskArray(),
                            AffinityManager.Location.DEVICE);

        if (dataSet.getLabelsMaskArray() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getLabelsMaskArray(), AffinityManager.Location.DEVICE);
    }
}
 
Example 6
Source File: ComputationGraphUtil.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/** Convert a DataSet to the equivalent MultiDataSet */
public static MultiDataSet toMultiDataSet(DataSet dataSet) {
    INDArray f = dataSet.getFeatures();
    INDArray l = dataSet.getLabels();
    INDArray fMask = dataSet.getFeaturesMaskArray();
    INDArray lMask = dataSet.getLabelsMaskArray();
    List<Serializable> meta = dataSet.getExampleMetaData();

    INDArray[] fNew = f == null ? null : new INDArray[] {f};
    INDArray[] lNew = l == null ? null : new INDArray[] {l};
    INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null);
    INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null);

    org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew);
    mds.setExampleMetaData(meta);
    return mds;
}
 
Example 7
Source File: UnderSamplingByMaskingPreProcessor.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public void preProcess(DataSet toPreProcess) {
    INDArray label = toPreProcess.getLabels();
    INDArray labelMask = toPreProcess.getLabelsMaskArray();
    INDArray sampledMask = adjustMasks(label, labelMask, minorityLabel, targetMinorityDist);
    toPreProcess.setLabelsMaskArray(sampledMask);
}
 
Example 8
Source File: StandardDQNTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {

    // Assemble
    List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
        {
            add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
                    0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
            add(buildTransition(buildObservation(new double[]{3.3, 4.4}),
                    1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
            add(buildTransition(buildObservation(new double[]{5.5, 6.6}),
                    0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
        }
    };

    StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);

    // Act
    DataSet result = sut.compute(transitions);

    // Assert
    INDArray evaluatedQValues = result.getLabels();
    assertEquals((1.0 + 0.5 * 22.0), evaluatedQValues.getDouble(0, 0), 0.0001);
    assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);

    assertEquals(3.3, evaluatedQValues.getDouble(1, 0), 0.0001);
    assertEquals((2.0 + 0.5 * 44.0), evaluatedQValues.getDouble(1, 1), 0.0001);

    assertEquals(3.0, evaluatedQValues.getDouble(2, 0), 0.0001); // terminal: reward only
    assertEquals(6.6, evaluatedQValues.getDouble(2, 1), 0.0001);

}
 
Example 9
Source File: DoubleDQNTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {

    // Assemble
    when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));

    List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
        {
            add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
                    0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
            add(builtTransition(buildObservation(new double[]{3.3, 4.4}),
                    1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
            add(builtTransition(buildObservation(new double[]{5.5, 6.6}),
                    0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
        }
    };

    DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);

    // Act
    DataSet result = sut.compute(transitions);

    // Assert
    INDArray evaluatedQValues = result.getLabels();
    assertEquals(1.0 + 0.5 * -22.0, evaluatedQValues.getDouble(0, 0), 0.0001);
    assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);

    assertEquals(3.3, evaluatedQValues.getDouble(1, 0), 0.0001);
    assertEquals(2.0 + 0.5 * -44.0, evaluatedQValues.getDouble(1, 1), 0.0001);

    assertEquals(3.0, evaluatedQValues.getDouble(2, 0), 0.0001); // terminal: reward only
    assertEquals(6.6, evaluatedQValues.getDouble(2, 1), 0.0001);

}
 
Example 10
Source File: UnderSamplingByMaskingPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void preProcess(DataSet toPreProcess) {
    INDArray label = toPreProcess.getLabels();
    INDArray labelMask = toPreProcess.getLabelsMaskArray();
    INDArray sampledMask = adjustMasks(label, labelMask, minorityLabel, targetMinorityDist);
    toPreProcess.setLabelsMaskArray(sampledMask);
}
 
Example 11
Source File: LabelLastTimeStepPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void preProcess(DataSet toPreProcess) {

    INDArray label3d = toPreProcess.getLabels();
    Preconditions.checkState(label3d.rank() == 3, "LabelLastTimeStepPreProcessor expects rank 3 labels, got rank %s labels with shape %ndShape", label3d.rank(), label3d);

    INDArray lMask = toPreProcess.getLabelsMaskArray();
    //If no mask: assume that examples for each minibatch are all same length
    INDArray labels2d;
    if(lMask == null){
        labels2d = label3d.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(label3d.size(2)-1)).dup();
    } else {
        //Use the label mask to work out the last time step...
        INDArray lastIndex = BooleanIndexing.lastIndex(lMask, Conditions.greaterThan(0), 1);
        long[] idxs = lastIndex.data().asLong();

        //Now, extract out:
        labels2d = Nd4j.create(DataType.FLOAT, label3d.size(0), label3d.size(1));

        //Now, get and assign the corresponding subsets of 3d activations:
        for (int i = 0; i < idxs.length; i++) {
            long lastStepIdx = idxs[i];
            Preconditions.checkState(lastStepIdx >= 0, "Invalid last time step index: example %s in minibatch is entirely masked out" +
                    " (label mask is all 0s, meaning no label data is present for this example)", i);
            //TODO can optimize using reshape + pullRows
            labels2d.putRow(i, label3d.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(lastStepIdx)));
        }
    }

    toPreProcess.setLabels(labels2d);
    toPreProcess.setLabelsMaskArray(null);  //Remove label mask if present
}
 
Example 12
Source File: ImagePreProcessingScaler.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void preProcess(DataSet toPreProcess) {
    INDArray features = toPreProcess.getFeatures();
    preProcess(features);
    if(fitLabels && toPreProcess.getLabels() != null){
        preProcess(toPreProcess.getLabels());
    }
}
 
Example 13
Source File: EvaluationToolsTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testRocMultiToHtml() throws Exception {
        DataSetIterator iter = new IrisDataSetIterator(150, 150);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list()
                        .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1,
                                        new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
                                                        .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                        .build();
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        NormalizerStandardize ns = new NormalizerStandardize();
        DataSet ds = iter.next();
        ns.fit(ds);
        ns.transform(ds);

        for (int i = 0; i < 30; i++) {
            net.fit(ds);
        }

        for (int numSteps : new int[] {20, 0}) {
            ROCMultiClass roc = new ROCMultiClass(numSteps);
            iter.reset();

            INDArray f = ds.getFeatures();
            INDArray l = ds.getLabels();
            INDArray out = net.output(f);
            roc.eval(l, out);


            String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica"));
//            System.out.println(str);
        }
    }
 
Example 14
Source File: RecordConverter.java    From DataVec with Apache License 2.0 4 votes vote down vote up
private static boolean isClassificationDataSet(DataSet dataSet) {
    INDArray labels = dataSet.getLabels();

    return labels.sum(0, 1).getInt(0) == dataSet.numExamples() && labels.shape()[1] > 1;
}
 
Example 15
Source File: RecordConverter.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
private static boolean isClassificationDataSet(DataSet dataSet) {
    INDArray labels = dataSet.getLabels();

    return labels.sum(0, 1).getInt(0) == dataSet.numExamples() && labels.shape()[1] > 1;
}
 
Example 16
Source File: EvaluationToolsTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRocHtml() {

    DataSetIterator iter = new IrisDataSetIterator(150, 150);

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1,
                                    new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX)
                                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    NormalizerStandardize ns = new NormalizerStandardize();
    DataSet ds = iter.next();
    ns.fit(ds);
    ns.transform(ds);

    INDArray newLabels = Nd4j.create(150, 2);
    newLabels.getColumn(0).assign(ds.getLabels().getColumn(0));
    newLabels.getColumn(0).addi(ds.getLabels().getColumn(1));
    newLabels.getColumn(1).assign(ds.getLabels().getColumn(2));
    ds.setLabels(newLabels);

    for (int i = 0; i < 30; i++) {
        net.fit(ds);
    }

    for (int numSteps : new int[] {20, 0}) {
        ROC roc = new ROC(numSteps);
        iter.reset();

        INDArray f = ds.getFeatures();
        INDArray l = ds.getLabels();
        INDArray out = net.output(f);
        roc.eval(l, out);


        String str = EvaluationTools.rocChartToHtml(roc);
        //            System.out.println(str);
    }
}