Java Code Examples for org.nd4j.evaluation.classification.ROC#calculateAUC()

The following examples show how to use org.nd4j.evaluation.classification.ROC#calculateAUC() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ROCScoreCalculator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
protected double finalScore(IEvaluation eval) {
    switch (type){
        case ROC:
            ROC r = (ROC)eval;
            return metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR();
        case BINARY:
            ROCBinary r2 = (ROCBinary) eval;
            return metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAuc();
        case MULTICLASS:
            ROCMultiClass r3 = (ROCMultiClass)eval;
            return metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR();
        default:
            throw new IllegalStateException("Unknown type: " + type);
    }
}
 
Example 2
Source File: ROCScoreFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public double score(MultiLayerNetwork net, DataSetIterator iterator) {
    switch (type){
        case ROC:
            ROC r = net.evaluateROC(iterator);
            return metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR();
        case BINARY:
            ROCBinary r2 = net.doEvaluation(iterator, new ROCBinary())[0];
            return metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAUCPR();
        case MULTICLASS:
            ROCMultiClass r3 = net.evaluateROCMultiClass(iterator);
            return metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR();
        default:
            throw new RuntimeException("Unknown type: " + type);
    }
}
 
Example 3
Source File: ROCScoreFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public double score(ComputationGraph net, MultiDataSetIterator iterator) {
    switch (type){
        case ROC:
            ROC r = net.evaluateROC(iterator);
            return metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR();
        case BINARY:
            ROCBinary r2 = net.doEvaluation(iterator, new ROCBinary())[0];
            return metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAUCPR();
        case MULTICLASS:
            ROCMultiClass r3 = net.evaluateROCMultiClass(iterator, 0);
            return metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR();
        default:
            throw new RuntimeException("Unknown type: " + type);
    }
}
 
Example 4
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testRocBasicSingleClass() {
    //1 output here - single probability value (sigmoid)

    //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
    INDArray predictions =
                    Nd4j.create(new double[] {0.001, 0.101, 0.201, 0.301, 0.401, 0.501, 0.601, 0.701, 0.801, 0.901},
                                    new int[] {10, 1});

    INDArray actual = Nd4j.create(new double[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}, new int[] {10, 1});

    ROC roc = new ROC(10);
    roc.eval(actual, predictions);

    RocCurve rocCurve = roc.getRocCurve();

    assertEquals(11, rocCurve.getThreshold().length); //0 + 10 steps
    for (int i = 0; i < 11; i++) {
        double expThreshold = i / 10.0;
        assertEquals(expThreshold, rocCurve.getThreshold(i), 1e-5);

        //            System.out.println("t=" + expThreshold + "\t" + v.getFalsePositiveRate() + "\t" + v.getTruePositiveRate());

        double efpr = expFPR.get(expThreshold);
        double afpr = rocCurve.getFalsePositiveRate(i);
        assertEquals(efpr, afpr, 1e-5);

        double etpr = expTPR.get(expThreshold);
        double atpr = rocCurve.getTruePositiveRate(i);
        assertEquals(etpr, atpr, 1e-5);
    }

    //Expect AUC == 1.0 here
    double auc = roc.calculateAUC();
    assertEquals(1.0, auc, 1e-6);
}
 
Example 5
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testCompareRocAndRocMultiClass() {
    Nd4j.getRandom().setSeed(12345);

    //For 2 class case: ROC and Multi-class ROC should be the same...
    int nExamples = 200;
    INDArray predictions = Nd4j.rand(nExamples, 2);
    INDArray tempSum = predictions.sum(1);
    predictions.diviColumnVector(tempSum);

    INDArray labels = Nd4j.create(nExamples, 2);
    Random r = new Random(12345);
    for (int i = 0; i < nExamples; i++) {
        labels.putScalar(i, r.nextInt(2), 1.0);
    }

    for (int numSteps : new int[] {30, 0}) { //Steps = 0: exact
        ROC roc = new ROC(numSteps);
        roc.eval(labels, predictions);

        ROCMultiClass rocMultiClass = new ROCMultiClass(numSteps);
        rocMultiClass.eval(labels, predictions);

        double auc = roc.calculateAUC();
        double auc1 = rocMultiClass.calculateAUC(1);

        assertEquals(auc, auc1, 1e-6);
    }
}
 
Example 6
Source File: ROCSerializer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void serialize(ROC roc, JsonGenerator jsonGenerator, SerializerProvider serializerProvider)
                throws IOException {
    boolean empty = roc.getExampleCount() == 0;

    if (roc.isExact() && !empty) {
        //For exact ROC implementation: force AUC and AUPRC calculation, so result can be stored in JSON, such
        //that we have them once deserialized.
        //Due to potentially huge size, exact mode doesn't store the original predictions in JSON
        roc.calculateAUC();
        roc.calculateAUCPR();
    }
    jsonGenerator.writeNumberField("thresholdSteps", roc.getThresholdSteps());
    jsonGenerator.writeNumberField("countActualPositive", roc.getCountActualPositive());
    jsonGenerator.writeNumberField("countActualNegative", roc.getCountActualNegative());
    jsonGenerator.writeObjectField("counts", roc.getCounts());
    if(!empty) {
        jsonGenerator.writeNumberField("auc", roc.calculateAUC());
        jsonGenerator.writeNumberField("auprc", roc.calculateAUCPR());
    }
    if (roc.isExact() && !empty) {
        //Store ROC and PR curves only for exact mode... they are redundant + can be calculated again for thresholded mode
        jsonGenerator.writeObjectField("rocCurve", roc.getRocCurve());
        jsonGenerator.writeObjectField("prCurve", roc.getPrecisionRecallCurve());
    }
    jsonGenerator.writeBooleanField("isExact", roc.isExact());
    jsonGenerator.writeNumberField("exampleCount", roc.getExampleCount());
    jsonGenerator.writeBooleanField("rocRemoveRedundantPts", roc.isRocRemoveRedundantPts());
}
 
Example 7
Source File: ROCBinaryTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testROCBinary() {
        //Compare ROCBinary to ROC class

        DataType dtypeBefore = Nd4j.defaultFloatingPointType();
        ROCBinary first30 = null;
        ROCBinary first0 = null;
        String sFirst30 = null;
        String sFirst0 = null;
        try {
            for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) {
//            for (DataType globalDtype : new DataType[]{DataType.HALF}) {
                Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE);
                for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
                    String msg = "globalDtype=" + globalDtype + ", labelPredictionsDtype=" + lpDtype;

                    int nExamples = 50;
                    int nOut = 4;
                    long[] shape = {nExamples, nOut};

                    for (int thresholdSteps : new int[]{30, 0}) { //0 == exact

                        Nd4j.getRandom().setSeed(12345);
                        INDArray labels =
                                Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE, shape), 0.5)).castTo(lpDtype);

                        Nd4j.getRandom().setSeed(12345);
                        INDArray predicted = Nd4j.rand(DataType.DOUBLE, shape).castTo(lpDtype);

                        ROCBinary rb = new ROCBinary(thresholdSteps);

                        for (int xe = 0; xe < 2; xe++) {
                            rb.eval(labels, predicted);

                            //System.out.println(rb.stats());

                            double eps = lpDtype == DataType.HALF ? 1e-2 : 1e-6;
                            for (int i = 0; i < nOut; i++) {
                                INDArray lCol = labels.getColumn(i, true);
                                INDArray pCol = predicted.getColumn(i, true);


                                ROC r = new ROC(thresholdSteps);
                                r.eval(lCol, pCol);

                                double aucExp = r.calculateAUC();
                                double auc = rb.calculateAUC(i);

                                assertEquals(msg, aucExp, auc, eps);

                                long apExp = r.getCountActualPositive();
                                long ap = rb.getCountActualPositive(i);
                                assertEquals(msg, ap, apExp);

                                long anExp = r.getCountActualNegative();
                                long an = rb.getCountActualNegative(i);
                                assertEquals(anExp, an);

                                PrecisionRecallCurve pExp = r.getPrecisionRecallCurve();
                                PrecisionRecallCurve p = rb.getPrecisionRecallCurve(i);

                                assertEquals(msg, pExp, p);
                            }

                            String s = rb.stats();

                            if(thresholdSteps == 0){
                                if(first0 == null) {
                                    first0 = rb;
                                    sFirst0 = s;
                                } else if(lpDtype != DataType.HALF) {   //Precision issues with FP16
                                    assertEquals(msg, sFirst0, s);
                                    assertEquals(first0, rb);
                                }
                            } else {
                                if(first30 == null) {
                                    first30 = rb;
                                    sFirst30 = s;
                                } else if(lpDtype != DataType.HALF) {   //Precision issues with FP16
                                    assertEquals(msg, sFirst30, s);
                                    assertEquals(first30, rb);
                                }
                            }

//                            rb.reset();
                            rb = new ROCBinary(thresholdSteps);
                        }
                    }
                }
            }
        } finally {
            Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore);
        }
    }
 
Example 8
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRocBasic() {
    //2 outputs here - probability distribution over classes (softmax)
    INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
                    {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
                    {0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});

    INDArray actual = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
                    {0, 1}, {0, 1}});

    ROC roc = new ROC(10);
    roc.eval(actual, predictions);

    RocCurve rocCurve = roc.getRocCurve();

    assertEquals(11, rocCurve.getThreshold().length); //0 + 10 steps
    for (int i = 0; i < 11; i++) {
        double expThreshold = i / 10.0;
        assertEquals(expThreshold, rocCurve.getThreshold(i), 1e-5);

        //            System.out.println("t=" + expThreshold + "\t" + v.getFalsePositiveRate() + "\t" + v.getTruePositiveRate());

        double efpr = expFPR.get(expThreshold);
        double afpr = rocCurve.getFalsePositiveRate(i);
        assertEquals(efpr, afpr, 1e-5);

        double etpr = expTPR.get(expThreshold);
        double atpr = rocCurve.getTruePositiveRate(i);
        assertEquals(etpr, atpr, 1e-5);
    }


    //Expect AUC == 1.0 here
    double auc = roc.calculateAUC();
    assertEquals(1.0, auc, 1e-6);

    // testing reset now
    roc.reset();
    roc.eval(actual, predictions);
    auc = roc.calculateAUC();
    assertEquals(1.0, auc, 1e-6);
}
 
Example 9
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRoc() {
    //Previous tests allowed for a perfect classifier with right threshold...

    INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}});

    INDArray prediction = Nd4j.create(new double[][] {{0.199, 0.801}, {0.499, 0.501}, {0.399, 0.601},
                    {0.799, 0.201}, {0.899, 0.101}});

    Map<Double, Double> expTPR = new HashMap<>();
    double totalPositives = 2.0;
    expTPR.put(0.0, 2.0 / totalPositives); //All predicted class 1 -> 2 true positives / 2 total positives
    expTPR.put(0.1, 2.0 / totalPositives);
    expTPR.put(0.2, 2.0 / totalPositives);
    expTPR.put(0.3, 2.0 / totalPositives);
    expTPR.put(0.4, 2.0 / totalPositives);
    expTPR.put(0.5, 2.0 / totalPositives);
    expTPR.put(0.6, 1.0 / totalPositives); //At threshold of 0.6, only 1 of 2 positives are predicted positive
    expTPR.put(0.7, 1.0 / totalPositives);
    expTPR.put(0.8, 1.0 / totalPositives);
    expTPR.put(0.9, 0.0 / totalPositives); //At threshold of 0.9, 0 of 2 positives are predicted positive
    expTPR.put(1.0, 0.0 / totalPositives);

    Map<Double, Double> expFPR = new HashMap<>();
    double totalNegatives = 3.0;
    expFPR.put(0.0, 3.0 / totalNegatives); //All predicted class 1 -> 3 false positives / 3 total negatives
    expFPR.put(0.1, 3.0 / totalNegatives);
    expFPR.put(0.2, 2.0 / totalNegatives); //At threshold of 0.2: 1 true negative, 2 false positives
    expFPR.put(0.3, 1.0 / totalNegatives); //At threshold of 0.3: 2 true negative, 1 false positive
    expFPR.put(0.4, 1.0 / totalNegatives);
    expFPR.put(0.5, 1.0 / totalNegatives);
    expFPR.put(0.6, 1.0 / totalNegatives);
    expFPR.put(0.7, 0.0 / totalNegatives); //At threshold of 0.7: 3 true negatives, 0 false positives
    expFPR.put(0.8, 0.0 / totalNegatives);
    expFPR.put(0.9, 0.0 / totalNegatives);
    expFPR.put(1.0, 0.0 / totalNegatives);

    int[] expTPs = new int[] {2, 2, 2, 2, 2, 2, 1, 1, 1, 0, 0};
    int[] expFPs = new int[] {3, 3, 2, 1, 1, 1, 1, 0, 0, 0, 0};
    int[] expFNs = new int[11];
    int[] expTNs = new int[11];
    for (int i = 0; i < 11; i++) {
        expFNs[i] = (int) totalPositives - expTPs[i];
        expTNs[i] = 5 - expTPs[i] - expFPs[i] - expFNs[i];
    }

    ROC roc = new ROC(10);
    roc.eval(labels, prediction);

    RocCurve rocCurve = roc.getRocCurve();

    assertEquals(11, rocCurve.getThreshold().length);
    assertEquals(11, rocCurve.getFpr().length);
    assertEquals(11, rocCurve.getTpr().length);

    for (int i = 0; i < 11; i++) {
        double expThreshold = i / 10.0;
        assertEquals(expThreshold, rocCurve.getThreshold(i), 1e-5);

        double efpr = expFPR.get(expThreshold);
        double afpr = rocCurve.getFalsePositiveRate(i);
        assertEquals(efpr, afpr, 1e-5);

        double etpr = expTPR.get(expThreshold);
        double atpr = rocCurve.getTruePositiveRate(i);
        assertEquals(etpr, atpr, 1e-5);
    }

    //AUC: expected values are based on plotting the ROC curve and manually calculating the area
    double expAUC = 0.5 * 1.0 / 3.0 + (1 - 1 / 3.0) * 1.0;
    double actAUC = roc.calculateAUC();

    assertEquals(expAUC, actAUC, 1e-6);

    PrecisionRecallCurve prc = roc.getPrecisionRecallCurve();
    for (int i = 0; i < 11; i++) {
        PrecisionRecallCurve.Confusion c = prc.getConfusionMatrixAtThreshold(i * 0.1);
        assertEquals(expTPs[i], c.getTpCount());
        assertEquals(expFPs[i], c.getFpCount());
        assertEquals(expFPs[i], c.getFpCount());
        assertEquals(expTNs[i], c.getTnCount());
    }
}
 
Example 10
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testROCMerging() {
    int nArrays = 10;
    int minibatch = 64;
    int nROCs = 3;

    for (int steps : new int[] {0, 20}) { //0 steps: exact, 20 steps: thresholded

        Nd4j.getRandom().setSeed(12345);
        Random r = new Random(12345);

        List<ROC> rocList = new ArrayList<>();
        for (int i = 0; i < nROCs; i++) {
            rocList.add(new ROC(steps));
        }

        ROC single = new ROC(steps);
        for (int i = 0; i < nArrays; i++) {
            INDArray p = Nd4j.rand(minibatch, 2);
            p.diviColumnVector(p.sum(1));

            INDArray l = Nd4j.zeros(minibatch, 2);
            for (int j = 0; j < minibatch; j++) {
                l.putScalar(j, r.nextInt(2), 1.0);
            }

            single.eval(l, p);

            ROC other = rocList.get(i % rocList.size());
            other.eval(l, p);
        }

        ROC first = rocList.get(0);
        for (int i = 1; i < nROCs; i++) {
            first.merge(rocList.get(i));
        }

        double singleAUC = single.calculateAUC();
        assertTrue(singleAUC >= 0.0 && singleAUC <= 1.0);
        assertEquals(singleAUC, first.calculateAUC(), 1e-6);

        assertEquals(single.getRocCurve(), first.getRocCurve());
    }
}
 
Example 11
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testROCMerging2() {
    int nArrays = 10;
    int minibatch = 64;
    int exactAllocBlockSize = 10;
    int nROCs = 3;
    int steps = 0;  //Exact

    Nd4j.getRandom().setSeed(12345);
    Random r = new Random(12345);

    List<ROC> rocList = new ArrayList<>();
    for (int i = 0; i < nROCs; i++) {
        rocList.add(new ROC(steps, true, exactAllocBlockSize));
    }

    ROC single = new ROC(steps);
    for (int i = 0; i < nArrays; i++) {
        INDArray p = Nd4j.rand(minibatch, 2);
        p.diviColumnVector(p.sum(1));

        INDArray l = Nd4j.zeros(minibatch, 2);
        for (int j = 0; j < minibatch; j++) {
            l.putScalar(j, r.nextInt(2), 1.0);
        }

        single.eval(l, p);

        ROC other = rocList.get(i % rocList.size());
        other.eval(l, p);
    }

    ROC first = rocList.get(0);
    for (int i = 1; i < nROCs; i++) {
        first.merge(rocList.get(i));
    }

    double singleAUC = single.calculateAUC();
    assertTrue(singleAUC >= 0.0 && singleAUC <= 1.0);
    assertEquals(singleAUC, first.calculateAUC(), 1e-6);

    assertEquals(single.getRocCurve(), first.getRocCurve());
}
 
Example 12
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRocMerge(){
    Nd4j.getRandom().setSeed(12345);

    ROC roc = new ROC();
    ROC roc1 = new ROC();
    ROC roc2 = new ROC();

    int nOut = 2;

    Random r = new Random(12345);
    for( int i=0; i<10; i++ ){
        INDArray labels = Nd4j.zeros(3, nOut);
        for( int j=0; j<3; j++ ){
            labels.putScalar(j, r.nextInt(nOut), 1.0 );
        }
        INDArray out = Nd4j.rand(3, nOut);
        out.diviColumnVector(out.sum(1));

        roc.eval(labels, out);
        if(i % 2 == 0){
            roc1.eval(labels, out);
        } else {
            roc2.eval(labels, out);
        }
    }

    roc1.calculateAUC();
    roc1.calculateAUCPR();
    roc2.calculateAUC();
    roc2.calculateAUCPR();

    roc1.merge(roc2);

    double aucExp = roc.calculateAUC();
    double auprc = roc.calculateAUCPR();

    double aucAct = roc1.calculateAUC();
    double auprcAct = roc1.calculateAUCPR();

    assertEquals(aucExp, aucAct, 1e-6);
    assertEquals(auprc, auprcAct, 1e-6);
}