Java Code Examples for org.nd4j.evaluation.classification.ROC#calculateAUC()
The following examples show how to use
org.nd4j.evaluation.classification.ROC#calculateAUC() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ROCScoreCalculator.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override protected double finalScore(IEvaluation eval) { switch (type){ case ROC: ROC r = (ROC)eval; return metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR(); case BINARY: ROCBinary r2 = (ROCBinary) eval; return metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAuc(); case MULTICLASS: ROCMultiClass r3 = (ROCMultiClass)eval; return metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR(); default: throw new IllegalStateException("Unknown type: " + type); } }
Example 2
Source File: ROCScoreFunction.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public double score(MultiLayerNetwork net, DataSetIterator iterator) { switch (type){ case ROC: ROC r = net.evaluateROC(iterator); return metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR(); case BINARY: ROCBinary r2 = net.doEvaluation(iterator, new ROCBinary())[0]; return metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAUCPR(); case MULTICLASS: ROCMultiClass r3 = net.evaluateROCMultiClass(iterator); return metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR(); default: throw new RuntimeException("Unknown type: " + type); } }
Example 3
Source File: ROCScoreFunction.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public double score(ComputationGraph net, MultiDataSetIterator iterator) { switch (type){ case ROC: ROC r = net.evaluateROC(iterator); return metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR(); case BINARY: ROCBinary r2 = net.doEvaluation(iterator, new ROCBinary())[0]; return metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAUCPR(); case MULTICLASS: ROCMultiClass r3 = net.evaluateROCMultiClass(iterator, 0); return metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR(); default: throw new RuntimeException("Unknown type: " + type); } }
Example 4
Source File: ROCTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testRocBasicSingleClass() { //1 output here - single probability value (sigmoid) //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) INDArray predictions = Nd4j.create(new double[] {0.001, 0.101, 0.201, 0.301, 0.401, 0.501, 0.601, 0.701, 0.801, 0.901}, new int[] {10, 1}); INDArray actual = Nd4j.create(new double[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}, new int[] {10, 1}); ROC roc = new ROC(10); roc.eval(actual, predictions); RocCurve rocCurve = roc.getRocCurve(); assertEquals(11, rocCurve.getThreshold().length); //0 + 10 steps for (int i = 0; i < 11; i++) { double expThreshold = i / 10.0; assertEquals(expThreshold, rocCurve.getThreshold(i), 1e-5); // System.out.println("t=" + expThreshold + "\t" + v.getFalsePositiveRate() + "\t" + v.getTruePositiveRate()); double efpr = expFPR.get(expThreshold); double afpr = rocCurve.getFalsePositiveRate(i); assertEquals(efpr, afpr, 1e-5); double etpr = expTPR.get(expThreshold); double atpr = rocCurve.getTruePositiveRate(i); assertEquals(etpr, atpr, 1e-5); } //Expect AUC == 1.0 here double auc = roc.calculateAUC(); assertEquals(1.0, auc, 1e-6); }
Example 5
Source File: ROCTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testCompareRocAndRocMultiClass() { Nd4j.getRandom().setSeed(12345); //For 2 class case: ROC and Multi-class ROC should be the same... int nExamples = 200; INDArray predictions = Nd4j.rand(nExamples, 2); INDArray tempSum = predictions.sum(1); predictions.diviColumnVector(tempSum); INDArray labels = Nd4j.create(nExamples, 2); Random r = new Random(12345); for (int i = 0; i < nExamples; i++) { labels.putScalar(i, r.nextInt(2), 1.0); } for (int numSteps : new int[] {30, 0}) { //Steps = 0: exact ROC roc = new ROC(numSteps); roc.eval(labels, predictions); ROCMultiClass rocMultiClass = new ROCMultiClass(numSteps); rocMultiClass.eval(labels, predictions); double auc = roc.calculateAUC(); double auc1 = rocMultiClass.calculateAUC(1); assertEquals(auc, auc1, 1e-6); } }
Example 6
Source File: ROCSerializer.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void serialize(ROC roc, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException { boolean empty = roc.getExampleCount() == 0; if (roc.isExact() && !empty) { //For exact ROC implementation: force AUC and AUPRC calculation, so result can be stored in JSON, such //that we have them once deserialized. //Due to potentially huge size, exact mode doesn't store the original predictions in JSON roc.calculateAUC(); roc.calculateAUCPR(); } jsonGenerator.writeNumberField("thresholdSteps", roc.getThresholdSteps()); jsonGenerator.writeNumberField("countActualPositive", roc.getCountActualPositive()); jsonGenerator.writeNumberField("countActualNegative", roc.getCountActualNegative()); jsonGenerator.writeObjectField("counts", roc.getCounts()); if(!empty) { jsonGenerator.writeNumberField("auc", roc.calculateAUC()); jsonGenerator.writeNumberField("auprc", roc.calculateAUCPR()); } if (roc.isExact() && !empty) { //Store ROC and PR curves only for exact mode... they are redundant + can be calculated again for thresholded mode jsonGenerator.writeObjectField("rocCurve", roc.getRocCurve()); jsonGenerator.writeObjectField("prCurve", roc.getPrecisionRecallCurve()); } jsonGenerator.writeBooleanField("isExact", roc.isExact()); jsonGenerator.writeNumberField("exampleCount", roc.getExampleCount()); jsonGenerator.writeBooleanField("rocRemoveRedundantPts", roc.isRocRemoveRedundantPts()); }
Example 7
Source File: ROCBinaryTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testROCBinary() { //Compare ROCBinary to ROC class DataType dtypeBefore = Nd4j.defaultFloatingPointType(); ROCBinary first30 = null; ROCBinary first0 = null; String sFirst30 = null; String sFirst0 = null; try { for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) { // for (DataType globalDtype : new DataType[]{DataType.HALF}) { Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE); for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { String msg = "globalDtype=" + globalDtype + ", labelPredictionsDtype=" + lpDtype; int nExamples = 50; int nOut = 4; long[] shape = {nExamples, nOut}; for (int thresholdSteps : new int[]{30, 0}) { //0 == exact Nd4j.getRandom().setSeed(12345); INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE, shape), 0.5)).castTo(lpDtype); Nd4j.getRandom().setSeed(12345); INDArray predicted = Nd4j.rand(DataType.DOUBLE, shape).castTo(lpDtype); ROCBinary rb = new ROCBinary(thresholdSteps); for (int xe = 0; xe < 2; xe++) { rb.eval(labels, predicted); //System.out.println(rb.stats()); double eps = lpDtype == DataType.HALF ? 1e-2 : 1e-6; for (int i = 0; i < nOut; i++) { INDArray lCol = labels.getColumn(i, true); INDArray pCol = predicted.getColumn(i, true); ROC r = new ROC(thresholdSteps); r.eval(lCol, pCol); double aucExp = r.calculateAUC(); double auc = rb.calculateAUC(i); assertEquals(msg, aucExp, auc, eps); long apExp = r.getCountActualPositive(); long ap = rb.getCountActualPositive(i); assertEquals(msg, ap, apExp); long anExp = r.getCountActualNegative(); long an = rb.getCountActualNegative(i); assertEquals(anExp, an); PrecisionRecallCurve pExp = r.getPrecisionRecallCurve(); PrecisionRecallCurve p = rb.getPrecisionRecallCurve(i); assertEquals(msg, pExp, p); } String s = rb.stats(); if(thresholdSteps == 0){ if(first0 == null) { first0 = rb; sFirst0 = s; } else if(lpDtype != DataType.HALF) { //Precision issues with FP16 assertEquals(msg, sFirst0, s); assertEquals(first0, rb); } } else { if(first30 == null) { first30 = rb; sFirst30 = s; } else if(lpDtype != DataType.HALF) { //Precision issues with FP16 assertEquals(msg, sFirst30, s); assertEquals(first30, rb); } } // rb.reset(); rb = new ROCBinary(thresholdSteps); } } } } } finally { Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore); } }
Example 8
Source File: ROCTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testRocBasic() { //2 outputs here - probability distribution over classes (softmax) INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601}, {0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}}); INDArray actual = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1}, {0, 1}, {0, 1}}); ROC roc = new ROC(10); roc.eval(actual, predictions); RocCurve rocCurve = roc.getRocCurve(); assertEquals(11, rocCurve.getThreshold().length); //0 + 10 steps for (int i = 0; i < 11; i++) { double expThreshold = i / 10.0; assertEquals(expThreshold, rocCurve.getThreshold(i), 1e-5); // System.out.println("t=" + expThreshold + "\t" + v.getFalsePositiveRate() + "\t" + v.getTruePositiveRate()); double efpr = expFPR.get(expThreshold); double afpr = rocCurve.getFalsePositiveRate(i); assertEquals(efpr, afpr, 1e-5); double etpr = expTPR.get(expThreshold); double atpr = rocCurve.getTruePositiveRate(i); assertEquals(etpr, atpr, 1e-5); } //Expect AUC == 1.0 here double auc = roc.calculateAUC(); assertEquals(1.0, auc, 1e-6); // testing reset now roc.reset(); roc.eval(actual, predictions); auc = roc.calculateAUC(); assertEquals(1.0, auc, 1e-6); }
Example 9
Source File: ROCTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testRoc() { //Previous tests allowed for a perfect classifier with right threshold... INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}}); INDArray prediction = Nd4j.create(new double[][] {{0.199, 0.801}, {0.499, 0.501}, {0.399, 0.601}, {0.799, 0.201}, {0.899, 0.101}}); Map<Double, Double> expTPR = new HashMap<>(); double totalPositives = 2.0; expTPR.put(0.0, 2.0 / totalPositives); //All predicted class 1 -> 2 true positives / 2 total positives expTPR.put(0.1, 2.0 / totalPositives); expTPR.put(0.2, 2.0 / totalPositives); expTPR.put(0.3, 2.0 / totalPositives); expTPR.put(0.4, 2.0 / totalPositives); expTPR.put(0.5, 2.0 / totalPositives); expTPR.put(0.6, 1.0 / totalPositives); //At threshold of 0.6, only 1 of 2 positives are predicted positive expTPR.put(0.7, 1.0 / totalPositives); expTPR.put(0.8, 1.0 / totalPositives); expTPR.put(0.9, 0.0 / totalPositives); //At threshold of 0.9, 0 of 2 positives are predicted positive expTPR.put(1.0, 0.0 / totalPositives); Map<Double, Double> expFPR = new HashMap<>(); double totalNegatives = 3.0; expFPR.put(0.0, 3.0 / totalNegatives); //All predicted class 1 -> 3 false positives / 3 total negatives expFPR.put(0.1, 3.0 / totalNegatives); expFPR.put(0.2, 2.0 / totalNegatives); //At threshold of 0.2: 1 true negative, 2 false positives expFPR.put(0.3, 1.0 / totalNegatives); //At threshold of 0.3: 2 true negative, 1 false positive expFPR.put(0.4, 1.0 / totalNegatives); expFPR.put(0.5, 1.0 / totalNegatives); expFPR.put(0.6, 1.0 / totalNegatives); expFPR.put(0.7, 0.0 / totalNegatives); //At threshold of 0.7: 3 true negatives, 0 false positives expFPR.put(0.8, 0.0 / totalNegatives); expFPR.put(0.9, 0.0 / totalNegatives); expFPR.put(1.0, 0.0 / totalNegatives); int[] expTPs = new int[] {2, 2, 2, 2, 2, 2, 1, 1, 1, 0, 0}; int[] expFPs = new int[] {3, 3, 2, 1, 1, 1, 1, 0, 0, 0, 0}; int[] expFNs = new int[11]; int[] expTNs = new int[11]; for (int i = 0; i < 11; i++) { expFNs[i] = (int) totalPositives - expTPs[i]; expTNs[i] = 5 - expTPs[i] - expFPs[i] - expFNs[i]; } ROC roc = new ROC(10); roc.eval(labels, prediction); RocCurve rocCurve = roc.getRocCurve(); assertEquals(11, rocCurve.getThreshold().length); assertEquals(11, rocCurve.getFpr().length); assertEquals(11, rocCurve.getTpr().length); for (int i = 0; i < 11; i++) { double expThreshold = i / 10.0; assertEquals(expThreshold, rocCurve.getThreshold(i), 1e-5); double efpr = expFPR.get(expThreshold); double afpr = rocCurve.getFalsePositiveRate(i); assertEquals(efpr, afpr, 1e-5); double etpr = expTPR.get(expThreshold); double atpr = rocCurve.getTruePositiveRate(i); assertEquals(etpr, atpr, 1e-5); } //AUC: expected values are based on plotting the ROC curve and manually calculating the area double expAUC = 0.5 * 1.0 / 3.0 + (1 - 1 / 3.0) * 1.0; double actAUC = roc.calculateAUC(); assertEquals(expAUC, actAUC, 1e-6); PrecisionRecallCurve prc = roc.getPrecisionRecallCurve(); for (int i = 0; i < 11; i++) { PrecisionRecallCurve.Confusion c = prc.getConfusionMatrixAtThreshold(i * 0.1); assertEquals(expTPs[i], c.getTpCount()); assertEquals(expFPs[i], c.getFpCount()); assertEquals(expFPs[i], c.getFpCount()); assertEquals(expTNs[i], c.getTnCount()); } }
Example 10
Source File: ROCTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testROCMerging() { int nArrays = 10; int minibatch = 64; int nROCs = 3; for (int steps : new int[] {0, 20}) { //0 steps: exact, 20 steps: thresholded Nd4j.getRandom().setSeed(12345); Random r = new Random(12345); List<ROC> rocList = new ArrayList<>(); for (int i = 0; i < nROCs; i++) { rocList.add(new ROC(steps)); } ROC single = new ROC(steps); for (int i = 0; i < nArrays; i++) { INDArray p = Nd4j.rand(minibatch, 2); p.diviColumnVector(p.sum(1)); INDArray l = Nd4j.zeros(minibatch, 2); for (int j = 0; j < minibatch; j++) { l.putScalar(j, r.nextInt(2), 1.0); } single.eval(l, p); ROC other = rocList.get(i % rocList.size()); other.eval(l, p); } ROC first = rocList.get(0); for (int i = 1; i < nROCs; i++) { first.merge(rocList.get(i)); } double singleAUC = single.calculateAUC(); assertTrue(singleAUC >= 0.0 && singleAUC <= 1.0); assertEquals(singleAUC, first.calculateAUC(), 1e-6); assertEquals(single.getRocCurve(), first.getRocCurve()); } }
Example 11
Source File: ROCTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testROCMerging2() { int nArrays = 10; int minibatch = 64; int exactAllocBlockSize = 10; int nROCs = 3; int steps = 0; //Exact Nd4j.getRandom().setSeed(12345); Random r = new Random(12345); List<ROC> rocList = new ArrayList<>(); for (int i = 0; i < nROCs; i++) { rocList.add(new ROC(steps, true, exactAllocBlockSize)); } ROC single = new ROC(steps); for (int i = 0; i < nArrays; i++) { INDArray p = Nd4j.rand(minibatch, 2); p.diviColumnVector(p.sum(1)); INDArray l = Nd4j.zeros(minibatch, 2); for (int j = 0; j < minibatch; j++) { l.putScalar(j, r.nextInt(2), 1.0); } single.eval(l, p); ROC other = rocList.get(i % rocList.size()); other.eval(l, p); } ROC first = rocList.get(0); for (int i = 1; i < nROCs; i++) { first.merge(rocList.get(i)); } double singleAUC = single.calculateAUC(); assertTrue(singleAUC >= 0.0 && singleAUC <= 1.0); assertEquals(singleAUC, first.calculateAUC(), 1e-6); assertEquals(single.getRocCurve(), first.getRocCurve()); }
Example 12
Source File: ROCTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testRocMerge(){ Nd4j.getRandom().setSeed(12345); ROC roc = new ROC(); ROC roc1 = new ROC(); ROC roc2 = new ROC(); int nOut = 2; Random r = new Random(12345); for( int i=0; i<10; i++ ){ INDArray labels = Nd4j.zeros(3, nOut); for( int j=0; j<3; j++ ){ labels.putScalar(j, r.nextInt(nOut), 1.0 ); } INDArray out = Nd4j.rand(3, nOut); out.diviColumnVector(out.sum(1)); roc.eval(labels, out); if(i % 2 == 0){ roc1.eval(labels, out); } else { roc2.eval(labels, out); } } roc1.calculateAUC(); roc1.calculateAUCPR(); roc2.calculateAUC(); roc2.calculateAUCPR(); roc1.merge(roc2); double aucExp = roc.calculateAUC(); double auprc = roc.calculateAUCPR(); double aucAct = roc1.calculateAUC(); double auprcAct = roc1.calculateAUCPR(); assertEquals(aucExp, aucAct, 1e-6); assertEquals(auprc, auprcAct, 1e-6); }