org.nd4j.evaluation.classification.ROC Java Examples

The following examples show how to use org.nd4j.evaluation.classification.ROC. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: ROCScoreFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public double score(ComputationGraph net, MultiDataSetIterator iterator) {
    switch (type){
        case ROC:
            ROC r = net.evaluateROC(iterator);
            return metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR();
        case BINARY:
            ROCBinary r2 = net.doEvaluation(iterator, new ROCBinary())[0];
            return metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAUCPR();
        case MULTICLASS:
            ROCMultiClass r3 = net.evaluateROCMultiClass(iterator, 0);
            return metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR();
        default:
            throw new RuntimeException("Unknown type: " + type);
    }
}
 
Example #2
Source File: ROCScoreFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public double score(MultiLayerNetwork net, DataSetIterator iterator) {
    switch (type){
        case ROC:
            ROC r = net.evaluateROC(iterator);
            return metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR();
        case BINARY:
            ROCBinary r2 = net.doEvaluation(iterator, new ROCBinary())[0];
            return metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAUCPR();
        case MULTICLASS:
            ROCMultiClass r3 = net.evaluateROCMultiClass(iterator);
            return metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR();
        default:
            throw new RuntimeException("Unknown type: " + type);
    }
}
 
Example #3
Source File: ROCScoreCalculator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
protected double finalScore(IEvaluation eval) {
    switch (type){
        case ROC:
            ROC r = (ROC)eval;
            return metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR();
        case BINARY:
            ROCBinary r2 = (ROCBinary) eval;
            return metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAuc();
        case MULTICLASS:
            ROCMultiClass r3 = (ROCMultiClass)eval;
            return metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR();
        default:
            throw new IllegalStateException("Unknown type: " + type);
    }
}
 
Example #4
Source File: ROCArraySerializer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void serialize(ROC[] rocs, JsonGenerator jsonGenerator, SerializerProvider serializerProvider)
                throws IOException, JsonProcessingException {
    jsonGenerator.writeStartArray();
    for (ROC r : rocs) {
        jsonGenerator.writeStartObject();
        jsonGenerator.writeStringField("@class", ROC.class.getName());
        serializer.serialize(r, jsonGenerator, serializerProvider);
        jsonGenerator.writeEndObject();
    }
    jsonGenerator.writeEndArray();
}
 
Example #5
Source File: EvaluationTools.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Given a {@link ROC} instance, render the ROC chart and precision vs. recall charts to a stand-alone HTML file (returned as a String)
 * @param roc  ROC to render
 */
public static String rocChartToHtml(ROC roc) {
    RocCurve rocCurve = roc.getRocCurve();

    Component c = getRocFromPoints(ROC_TITLE, rocCurve, roc.getCountActualPositive(), roc.getCountActualNegative(),
                    roc.calculateAUC(), roc.calculateAUCPR());
    Component c2 = getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, roc.getPrecisionRecallCurve());

    return StaticPageUtil.renderHTML(c, c2);
}
 
Example #6
Source File: MultiLayerNetwork.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Evaluate the network (must be a binary classifier) on the specified data, using the {@link ROC} class
 *
 * @param iterator          Data to evaluate on
 * @param rocThresholdSteps Number of threshold steps to use with {@link ROC} - see that class for details.
 * @return ROC evaluation on the given dataset
 */
public <T extends ROC> T evaluateROC(DataSetIterator iterator, int rocThresholdSteps) {
    Layer outputLayer = getOutputLayer();
    if(getLayerWiseConfigurations().isValidateOutputLayerConfig()){
        OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROC.class);
    }
    return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROC(rocThresholdSteps))[0];
}
 
Example #7
Source File: ROCScoreCalculator.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
protected IEvaluation newEval() {
    switch (type){
        case ROC:
            return new ROC();
        case BINARY:
            return new ROCBinary();
        case MULTICLASS:
            return new ROCMultiClass();
        default:
            throw new IllegalStateException("Unknown type: " + type);
    }
}
 
Example #8
Source File: ROCSerializer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void serializeWithType(ROC value, JsonGenerator gen, SerializerProvider serializers, TypeSerializer typeSer)
                throws IOException {
    typeSer.writeTypePrefixForObject(value, gen);
    serialize(value, gen, serializers);
    typeSer.writeTypeSuffixForObject(value, gen);
}
 
Example #9
Source File: ROCSerializer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void serialize(ROC roc, JsonGenerator jsonGenerator, SerializerProvider serializerProvider)
                throws IOException {
    boolean empty = roc.getExampleCount() == 0;

    if (roc.isExact() && !empty) {
        //For exact ROC implementation: force AUC and AUPRC calculation, so result can be stored in JSON, such
        //that we have them once deserialized.
        //Due to potentially huge size, exact mode doesn't store the original predictions in JSON
        roc.calculateAUC();
        roc.calculateAUCPR();
    }
    jsonGenerator.writeNumberField("thresholdSteps", roc.getThresholdSteps());
    jsonGenerator.writeNumberField("countActualPositive", roc.getCountActualPositive());
    jsonGenerator.writeNumberField("countActualNegative", roc.getCountActualNegative());
    jsonGenerator.writeObjectField("counts", roc.getCounts());
    if(!empty) {
        jsonGenerator.writeNumberField("auc", roc.calculateAUC());
        jsonGenerator.writeNumberField("auprc", roc.calculateAUCPR());
    }
    if (roc.isExact() && !empty) {
        //Store ROC and PR curves only for exact mode... they are redundant + can be calculated again for thresholded mode
        jsonGenerator.writeObjectField("rocCurve", roc.getRocCurve());
        jsonGenerator.writeObjectField("prCurve", roc.getPrecisionRecallCurve());
    }
    jsonGenerator.writeBooleanField("isExact", roc.isExact());
    jsonGenerator.writeNumberField("exampleCount", roc.getExampleCount());
    jsonGenerator.writeBooleanField("rocRemoveRedundantPts", roc.isRocRemoveRedundantPts());
}
 
Example #10
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testPrecisionRecallCurveConfusion() {
    //Sanity check: values calculated from the confusion matrix should match the PR curve values

    for (boolean removeRedundantPts : new boolean[] {true, false}) {
        ROC r = new ROC(0, removeRedundantPts);

        INDArray labels = Nd4j.getExecutioner()
                        .exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE,100, 1), 0.5));
        INDArray probs = Nd4j.rand(100, 1);

        r.eval(labels, probs);

        PrecisionRecallCurve prc = r.getPrecisionRecallCurve();
        int nPoints = prc.numPoints();

        for (int i = 0; i < nPoints; i++) {
            PrecisionRecallCurve.Confusion c = prc.getConfusionMatrixAtPoint(i);
            PrecisionRecallCurve.Point p = c.getPoint();

            int tp = c.getTpCount();
            int fp = c.getFpCount();
            int fn = c.getFnCount();

            double prec = tp / (double) (tp + fp);
            double rec = tp / (double) (tp + fn);

            //Handle edge cases:
            if (tp == 0 && fp == 0) {
                prec = 1.0;
            }

            assertEquals(p.getPrecision(), prec, 1e-6);
            assertEquals(p.getRecall(), rec, 1e-6);
        }
    }
}
 
Example #11
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testCompareRocAndRocMultiClass() {
    Nd4j.getRandom().setSeed(12345);

    //For 2 class case: ROC and Multi-class ROC should be the same...
    int nExamples = 200;
    INDArray predictions = Nd4j.rand(nExamples, 2);
    INDArray tempSum = predictions.sum(1);
    predictions.diviColumnVector(tempSum);

    INDArray labels = Nd4j.create(nExamples, 2);
    Random r = new Random(12345);
    for (int i = 0; i < nExamples; i++) {
        labels.putScalar(i, r.nextInt(2), 1.0);
    }

    for (int numSteps : new int[] {30, 0}) { //Steps = 0: exact
        ROC roc = new ROC(numSteps);
        roc.eval(labels, predictions);

        ROCMultiClass rocMultiClass = new ROCMultiClass(numSteps);
        rocMultiClass.eval(labels, predictions);

        double auc = roc.calculateAUC();
        double auc1 = rocMultiClass.calculateAUC(1);

        assertEquals(auc, auc1, 1e-6);
    }
}
 
Example #12
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testRocBasicSingleClass() {
    //1 output here - single probability value (sigmoid)

    //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
    INDArray predictions =
                    Nd4j.create(new double[] {0.001, 0.101, 0.201, 0.301, 0.401, 0.501, 0.601, 0.701, 0.801, 0.901},
                                    new int[] {10, 1});

    INDArray actual = Nd4j.create(new double[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}, new int[] {10, 1});

    ROC roc = new ROC(10);
    roc.eval(actual, predictions);

    RocCurve rocCurve = roc.getRocCurve();

    assertEquals(11, rocCurve.getThreshold().length); //0 + 10 steps
    for (int i = 0; i < 11; i++) {
        double expThreshold = i / 10.0;
        assertEquals(expThreshold, rocCurve.getThreshold(i), 1e-5);

        //            System.out.println("t=" + expThreshold + "\t" + v.getFalsePositiveRate() + "\t" + v.getTruePositiveRate());

        double efpr = expFPR.get(expThreshold);
        double afpr = rocCurve.getFalsePositiveRate(i);
        assertEquals(efpr, afpr, 1e-5);

        double etpr = expTPR.get(expThreshold);
        double atpr = rocCurve.getTruePositiveRate(i);
        assertEquals(etpr, atpr, 1e-5);
    }

    //Expect AUC == 1.0 here
    double auc = roc.calculateAUC();
    assertEquals(1.0, auc, 1e-6);
}
 
Example #13
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testROCMerging2() {
    int nArrays = 10;
    int minibatch = 64;
    int exactAllocBlockSize = 10;
    int nROCs = 3;
    int steps = 0;  //Exact

    Nd4j.getRandom().setSeed(12345);
    Random r = new Random(12345);

    List<ROC> rocList = new ArrayList<>();
    for (int i = 0; i < nROCs; i++) {
        rocList.add(new ROC(steps, true, exactAllocBlockSize));
    }

    ROC single = new ROC(steps);
    for (int i = 0; i < nArrays; i++) {
        INDArray p = Nd4j.rand(minibatch, 2);
        p.diviColumnVector(p.sum(1));

        INDArray l = Nd4j.zeros(minibatch, 2);
        for (int j = 0; j < minibatch; j++) {
            l.putScalar(j, r.nextInt(2), 1.0);
        }

        single.eval(l, p);

        ROC other = rocList.get(i % rocList.size());
        other.eval(l, p);
    }

    ROC first = rocList.get(0);
    for (int i = 1; i < nROCs; i++) {
        first.merge(rocList.get(i));
    }

    double singleAUC = single.calculateAUC();
    assertTrue(singleAUC >= 0.0 && singleAUC <= 1.0);
    assertEquals(singleAUC, first.calculateAUC(), 1e-6);

    assertEquals(single.getRocCurve(), first.getRocCurve());
}
 
Example #14
Source File: LstmTimeSeriesExample.java    From Java-Deep-Learning-Cookbook with MIT License 4 votes vote down vote up
public static void main(String[] args) throws IOException, InterruptedException {
    if(FEATURE_DIR.equals("{PATH-TO-PHYSIONET-FEATURES}") || LABEL_DIR.equals("{PATH-TO-PHYSIONET-LABELS")){
        System.out.println("Please provide proper directory path in place of: PATH-TO-PHYSIONET-FEATURES && PATH-TO-PHYSIONET-LABELS");
        throw new FileNotFoundException();
    }
    SequenceRecordReader trainFeaturesReader = new CSVSequenceRecordReader(1, ",");
    trainFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",0,3199));
    SequenceRecordReader trainLabelsReader = new CSVSequenceRecordReader();
    trainLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",0,3199));
    DataSetIterator trainDataSetIterator = new SequenceRecordReaderDataSetIterator(trainFeaturesReader,trainLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    SequenceRecordReader testFeaturesReader = new CSVSequenceRecordReader(1, ",");
    testFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",3200,3999));
    SequenceRecordReader testLabelsReader = new CSVSequenceRecordReader();
    testLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",3200,3999));
    DataSetIterator testDataSetIterator = new SequenceRecordReaderDataSetIterator(testFeaturesReader,testLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder()
                                                    .seed(RANDOM_SEED)
                                                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                                                    .weightInit(WeightInit.XAVIER)
                                                    .updater(new Adam())
                                                    .dropOut(0.9)
                                                    .graphBuilder()
                                                    .addInputs("trainFeatures")
                                                    .setOutputs("predictMortality")
                                                    .addLayer("L1", new LSTM.Builder()
                                                                                   .nIn(86)
                                                                                    .nOut(200)
                                                                                    .forgetGateBiasInit(1)
                                                                                    .activation(Activation.TANH)
                                                                                    .build(),"trainFeatures")
                                                    .addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                                                                        .activation(Activation.SOFTMAX)
                                                                                        .nIn(200).nOut(2).build(),"L1")
                                                    .build();

    ComputationGraph model = new ComputationGraph(configuration);

    for(int i=0;i<1;i++){
       model.fit(trainDataSetIterator);
       trainDataSetIterator.reset();
    }
    ROC evaluation = new ROC(100);
    while (testDataSetIterator.hasNext()) {
        DataSet batch = testDataSetIterator.next();
        INDArray[] output = model.output(batch.getFeatures());
        evaluation.evalTimeSeries(batch.getLabels(), output[0]);
    }
    
    System.out.println(evaluation.calculateAUC());
    System.out.println(evaluation.stats());
}
 
Example #15
Source File: ROCBinaryTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testROCBinary() {
        //Compare ROCBinary to ROC class

        DataType dtypeBefore = Nd4j.defaultFloatingPointType();
        ROCBinary first30 = null;
        ROCBinary first0 = null;
        String sFirst30 = null;
        String sFirst0 = null;
        try {
            for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) {
//            for (DataType globalDtype : new DataType[]{DataType.HALF}) {
                Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE);
                for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
                    String msg = "globalDtype=" + globalDtype + ", labelPredictionsDtype=" + lpDtype;

                    int nExamples = 50;
                    int nOut = 4;
                    long[] shape = {nExamples, nOut};

                    for (int thresholdSteps : new int[]{30, 0}) { //0 == exact

                        Nd4j.getRandom().setSeed(12345);
                        INDArray labels =
                                Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE, shape), 0.5)).castTo(lpDtype);

                        Nd4j.getRandom().setSeed(12345);
                        INDArray predicted = Nd4j.rand(DataType.DOUBLE, shape).castTo(lpDtype);

                        ROCBinary rb = new ROCBinary(thresholdSteps);

                        for (int xe = 0; xe < 2; xe++) {
                            rb.eval(labels, predicted);

                            //System.out.println(rb.stats());

                            double eps = lpDtype == DataType.HALF ? 1e-2 : 1e-6;
                            for (int i = 0; i < nOut; i++) {
                                INDArray lCol = labels.getColumn(i, true);
                                INDArray pCol = predicted.getColumn(i, true);


                                ROC r = new ROC(thresholdSteps);
                                r.eval(lCol, pCol);

                                double aucExp = r.calculateAUC();
                                double auc = rb.calculateAUC(i);

                                assertEquals(msg, aucExp, auc, eps);

                                long apExp = r.getCountActualPositive();
                                long ap = rb.getCountActualPositive(i);
                                assertEquals(msg, ap, apExp);

                                long anExp = r.getCountActualNegative();
                                long an = rb.getCountActualNegative(i);
                                assertEquals(anExp, an);

                                PrecisionRecallCurve pExp = r.getPrecisionRecallCurve();
                                PrecisionRecallCurve p = rb.getPrecisionRecallCurve(i);

                                assertEquals(msg, pExp, p);
                            }

                            String s = rb.stats();

                            if(thresholdSteps == 0){
                                if(first0 == null) {
                                    first0 = rb;
                                    sFirst0 = s;
                                } else if(lpDtype != DataType.HALF) {   //Precision issues with FP16
                                    assertEquals(msg, sFirst0, s);
                                    assertEquals(first0, rb);
                                }
                            } else {
                                if(first30 == null) {
                                    first30 = rb;
                                    sFirst30 = s;
                                } else if(lpDtype != DataType.HALF) {   //Precision issues with FP16
                                    assertEquals(msg, sFirst30, s);
                                    assertEquals(first30, rb);
                                }
                            }

//                            rb.reset();
                            rb = new ROCBinary(thresholdSteps);
                        }
                    }
                }
            }
        } finally {
            Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore);
        }
    }
 
Example #16
Source File: TestSparkComputationGraph.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testEvaluationAndRocMDS() {
    for( int evalWorkers : new int[]{1, 4, 8}) {

        DataSetIterator iter = new IrisDataSetIterator(5, 150);

        //Make a 2-class version of iris:
        List<MultiDataSet> l = new ArrayList<>();
        iter.reset();
        while (iter.hasNext()) {
            DataSet ds = iter.next();
            INDArray newL = Nd4j.create(ds.getLabels().size(0), 2);
            newL.putColumn(0, ds.getLabels().getColumn(0));
            newL.putColumn(1, ds.getLabels().getColumn(1));
            newL.getColumn(1).addi(ds.getLabels().getColumn(2));

            MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(ds.getFeatures(), newL);
            l.add(mds);
        }

        MultiDataSetIterator mdsIter = new IteratorMultiDataSetIterator(l.iterator(), 5);

        ComputationGraph cg = getBasicNetIris2Class();

        IEvaluation[] es = cg.doEvaluation(mdsIter, new Evaluation(), new ROC(32));
        Evaluation e = (Evaluation) es[0];
        ROC roc = (ROC) es[1];


        SparkComputationGraph scg = new SparkComputationGraph(sc, cg, null);
        scg.setDefaultEvaluationWorkers(evalWorkers);

        JavaRDD<MultiDataSet> rdd = sc.parallelize(l);
        rdd = rdd.repartition(20);

        IEvaluation[] es2 = scg.doEvaluationMDS(rdd, 5, new Evaluation(), new ROC(32));
        Evaluation e2 = (Evaluation) es2[0];
        ROC roc2 = (ROC) es2[1];


        assertEquals(e2.accuracy(), e.accuracy(), 1e-3);
        assertEquals(e2.f1(), e.f1(), 1e-3);
        assertEquals(e2.getNumRowCounter(), e.getNumRowCounter(), 1e-3);
        assertEquals(e2.falseNegatives(), e.falseNegatives());
        assertEquals(e2.falsePositives(), e.falsePositives());
        assertEquals(e2.trueNegatives(), e.trueNegatives());
        assertEquals(e2.truePositives(), e.truePositives());
        assertEquals(e2.precision(), e.precision(), 1e-3);
        assertEquals(e2.recall(), e.recall(), 1e-3);
        assertEquals(e2.getConfusionMatrix(), e.getConfusionMatrix());

        assertEquals(roc.calculateAUC(), roc2.calculateAUC(), 1e-5);
        assertEquals(roc.calculateAUCPR(), roc2.calculateAUCPR(), 1e-5);
    }
}
 
Example #17
Source File: TestSparkComputationGraph.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test(timeout = 60000L)
public void testEvaluationAndRoc() {
    for( int evalWorkers : new int[]{1, 4, 8}) {
        DataSetIterator iter = new IrisDataSetIterator(5, 150);

        //Make a 2-class version of iris:
        List<DataSet> l = new ArrayList<>();
        iter.reset();
        while (iter.hasNext()) {
            DataSet ds = iter.next();
            INDArray newL = Nd4j.create(ds.getLabels().size(0), 2);
            newL.putColumn(0, ds.getLabels().getColumn(0));
            newL.putColumn(1, ds.getLabels().getColumn(1));
            newL.getColumn(1).addi(ds.getLabels().getColumn(2));
            ds.setLabels(newL);
            l.add(ds);
        }

        iter = new ListDataSetIterator<>(l);

        ComputationGraph cg = getBasicNetIris2Class();

        Evaluation e = cg.evaluate(iter);
        ROC roc = cg.evaluateROC(iter, 32);


        SparkComputationGraph scg = new SparkComputationGraph(sc, cg, null);
        scg.setDefaultEvaluationWorkers(evalWorkers);


        JavaRDD<DataSet> rdd = sc.parallelize(l);
        rdd = rdd.repartition(20);

        Evaluation e2 = scg.evaluate(rdd);
        ROC roc2 = scg.evaluateROC(rdd);


        assertEquals(e2.accuracy(), e.accuracy(), 1e-3);
        assertEquals(e2.f1(), e.f1(), 1e-3);
        assertEquals(e2.getNumRowCounter(), e.getNumRowCounter(), 1e-3);
        assertEquals(e2.falseNegatives(), e.falseNegatives());
        assertEquals(e2.falsePositives(), e.falsePositives());
        assertEquals(e2.trueNegatives(), e.trueNegatives());
        assertEquals(e2.truePositives(), e.truePositives());
        assertEquals(e2.precision(), e.precision(), 1e-3);
        assertEquals(e2.recall(), e.recall(), 1e-3);
        assertEquals(e2.getConfusionMatrix(), e.getConfusionMatrix());

        assertEquals(roc.calculateAUC(), roc2.calculateAUC(), 1e-5);
        assertEquals(roc.calculateAUCPR(), roc2.calculateAUCPR(), 1e-5);
    }
}
 
Example #18
Source File: TestSparkMultiLayerParameterAveraging.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testROC() {

    int nArrays = 100;
    int minibatch = 64;
    int steps = 20;
    int nIn = 5;
    int nOut = 2;
    int layerSize = 10;

    MultiLayerConfiguration conf =
                    new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list()
                                    .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(layerSize).build())
                                    .layer(1, new OutputLayer.Builder().nIn(layerSize).nOut(nOut)
                                                    .activation(Activation.SOFTMAX).lossFunction(
                                                                    LossFunctions.LossFunction.MCXENT)
                                                    .build())
                                    .build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();


    Nd4j.getRandom().setSeed(12345);
    Random r = new Random(12345);

    ROC local = new ROC(steps);
    List<DataSet> dsList = new ArrayList<>();
    for (int i = 0; i < nArrays; i++) {
        INDArray features = Nd4j.rand(minibatch, nIn);

        INDArray p = net.output(features);

        INDArray l = Nd4j.zeros(minibatch, 2);
        for (int j = 0; j < minibatch; j++) {
            l.putScalar(j, r.nextInt(2), 1.0);
        }

        local.eval(l, p);

        dsList.add(new DataSet(features, l));
    }


    SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, net, null);
    JavaRDD<DataSet> rdd = sc.parallelize(dsList);

    ROC sparkROC = sparkNet.evaluateROC(rdd, steps, 32);

    assertEquals(sparkROC.calculateAUC(), sparkROC.calculateAUC(), 1e-6);

    assertEquals(local.getRocCurve(), sparkROC.getRocCurve());
}
 
Example #19
Source File: NewInstanceTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testNewInstances() {
    boolean print = true;
    Nd4j.getRandom().setSeed(12345);

    Evaluation evaluation = new Evaluation();
    EvaluationBinary evaluationBinary = new EvaluationBinary();
    ROC roc = new ROC(2);
    ROCBinary roc2 = new ROCBinary(2);
    ROCMultiClass roc3 = new ROCMultiClass(2);
    RegressionEvaluation regressionEvaluation = new RegressionEvaluation();
    EvaluationCalibration ec = new EvaluationCalibration();


    IEvaluation[] arr = new IEvaluation[] {evaluation, evaluationBinary, roc, roc2, roc3, regressionEvaluation, ec};

    INDArray evalLabel1 = Nd4j.create(10, 3);
    for (int i = 0; i < 10; i++) {
        evalLabel1.putScalar(i, i % 3, 1.0);
    }
    INDArray evalProb1 = Nd4j.rand(10, 3);
    evalProb1.diviColumnVector(evalProb1.sum(1));

    evaluation.eval(evalLabel1, evalProb1);
    roc3.eval(evalLabel1, evalProb1);
    ec.eval(evalLabel1, evalProb1);

    INDArray evalLabel2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 3), 0.5));
    INDArray evalProb2 = Nd4j.rand(10, 3);
    evaluationBinary.eval(evalLabel2, evalProb2);
    roc2.eval(evalLabel2, evalProb2);

    INDArray evalLabel3 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 1), 0.5));
    INDArray evalProb3 = Nd4j.rand(10, 1);
    roc.eval(evalLabel3, evalProb3);

    INDArray reg1 = Nd4j.rand(10, 3);
    INDArray reg2 = Nd4j.rand(10, 3);

    regressionEvaluation.eval(reg1, reg2);

    Evaluation evaluation2 = evaluation.newInstance();
    EvaluationBinary evaluationBinary2 = evaluationBinary.newInstance();
    ROC roc_2 = roc.newInstance();
    ROCBinary roc22 = roc2.newInstance();
    ROCMultiClass roc32 = roc3.newInstance();
    RegressionEvaluation regressionEvaluation2 = regressionEvaluation.newInstance();
    EvaluationCalibration ec2 = ec.newInstance();

    IEvaluation[] arr2 = new IEvaluation[] {evaluation2, evaluationBinary2, roc_2, roc22, roc32, regressionEvaluation2, ec2};

    evaluation2.eval(evalLabel1, evalProb1);
    roc32.eval(evalLabel1, evalProb1);
    ec2.eval(evalLabel1, evalProb1);

    evaluationBinary2.eval(evalLabel2, evalProb2);
    roc22.eval(evalLabel2, evalProb2);

    roc_2.eval(evalLabel3, evalProb3);

    regressionEvaluation2.eval(reg1, reg2);

    for (int i = 0 ; i < arr.length ; i++) {

        IEvaluation e = arr[i];
        IEvaluation e2 = arr2[i];
        assertEquals("Json not equal ", e.toJson(), e2.toJson());
        assertEquals("Stats not equal ", e.stats(), e2.stats());
    }
}
 
Example #20
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRocBasic() {
    //2 outputs here - probability distribution over classes (softmax)
    INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
                    {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
                    {0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});

    INDArray actual = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
                    {0, 1}, {0, 1}});

    ROC roc = new ROC(10);
    roc.eval(actual, predictions);

    RocCurve rocCurve = roc.getRocCurve();

    assertEquals(11, rocCurve.getThreshold().length); //0 + 10 steps
    for (int i = 0; i < 11; i++) {
        double expThreshold = i / 10.0;
        assertEquals(expThreshold, rocCurve.getThreshold(i), 1e-5);

        //            System.out.println("t=" + expThreshold + "\t" + v.getFalsePositiveRate() + "\t" + v.getTruePositiveRate());

        double efpr = expFPR.get(expThreshold);
        double afpr = rocCurve.getFalsePositiveRate(i);
        assertEquals(efpr, afpr, 1e-5);

        double etpr = expTPR.get(expThreshold);
        double atpr = rocCurve.getTruePositiveRate(i);
        assertEquals(etpr, atpr, 1e-5);
    }


    //Expect AUC == 1.0 here
    double auc = roc.calculateAUC();
    assertEquals(1.0, auc, 1e-6);

    // testing reset now
    roc.reset();
    roc.eval(actual, predictions);
    auc = roc.calculateAUC();
    assertEquals(1.0, auc, 1e-6);
}
 
Example #21
Source File: MultiLayerNetwork.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 * @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration
 */
@Deprecated
public <T extends ROC> T evaluateROC(DataSetIterator iterator){
    return evaluateROC(iterator, 0);
}
 
Example #22
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRoc() {
    //Previous tests allowed for a perfect classifier with right threshold...

    INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}});

    INDArray prediction = Nd4j.create(new double[][] {{0.199, 0.801}, {0.499, 0.501}, {0.399, 0.601},
                    {0.799, 0.201}, {0.899, 0.101}});

    Map<Double, Double> expTPR = new HashMap<>();
    double totalPositives = 2.0;
    expTPR.put(0.0, 2.0 / totalPositives); //All predicted class 1 -> 2 true positives / 2 total positives
    expTPR.put(0.1, 2.0 / totalPositives);
    expTPR.put(0.2, 2.0 / totalPositives);
    expTPR.put(0.3, 2.0 / totalPositives);
    expTPR.put(0.4, 2.0 / totalPositives);
    expTPR.put(0.5, 2.0 / totalPositives);
    expTPR.put(0.6, 1.0 / totalPositives); //At threshold of 0.6, only 1 of 2 positives are predicted positive
    expTPR.put(0.7, 1.0 / totalPositives);
    expTPR.put(0.8, 1.0 / totalPositives);
    expTPR.put(0.9, 0.0 / totalPositives); //At threshold of 0.9, 0 of 2 positives are predicted positive
    expTPR.put(1.0, 0.0 / totalPositives);

    Map<Double, Double> expFPR = new HashMap<>();
    double totalNegatives = 3.0;
    expFPR.put(0.0, 3.0 / totalNegatives); //All predicted class 1 -> 3 false positives / 3 total negatives
    expFPR.put(0.1, 3.0 / totalNegatives);
    expFPR.put(0.2, 2.0 / totalNegatives); //At threshold of 0.2: 1 true negative, 2 false positives
    expFPR.put(0.3, 1.0 / totalNegatives); //At threshold of 0.3: 2 true negative, 1 false positive
    expFPR.put(0.4, 1.0 / totalNegatives);
    expFPR.put(0.5, 1.0 / totalNegatives);
    expFPR.put(0.6, 1.0 / totalNegatives);
    expFPR.put(0.7, 0.0 / totalNegatives); //At threshold of 0.7: 3 true negatives, 0 false positives
    expFPR.put(0.8, 0.0 / totalNegatives);
    expFPR.put(0.9, 0.0 / totalNegatives);
    expFPR.put(1.0, 0.0 / totalNegatives);

    int[] expTPs = new int[] {2, 2, 2, 2, 2, 2, 1, 1, 1, 0, 0};
    int[] expFPs = new int[] {3, 3, 2, 1, 1, 1, 1, 0, 0, 0, 0};
    int[] expFNs = new int[11];
    int[] expTNs = new int[11];
    for (int i = 0; i < 11; i++) {
        expFNs[i] = (int) totalPositives - expTPs[i];
        expTNs[i] = 5 - expTPs[i] - expFPs[i] - expFNs[i];
    }

    ROC roc = new ROC(10);
    roc.eval(labels, prediction);

    RocCurve rocCurve = roc.getRocCurve();

    assertEquals(11, rocCurve.getThreshold().length);
    assertEquals(11, rocCurve.getFpr().length);
    assertEquals(11, rocCurve.getTpr().length);

    for (int i = 0; i < 11; i++) {
        double expThreshold = i / 10.0;
        assertEquals(expThreshold, rocCurve.getThreshold(i), 1e-5);

        double efpr = expFPR.get(expThreshold);
        double afpr = rocCurve.getFalsePositiveRate(i);
        assertEquals(efpr, afpr, 1e-5);

        double etpr = expTPR.get(expThreshold);
        double atpr = rocCurve.getTruePositiveRate(i);
        assertEquals(etpr, atpr, 1e-5);
    }

    //AUC: expected values are based on plotting the ROC curve and manually calculating the area
    double expAUC = 0.5 * 1.0 / 3.0 + (1 - 1 / 3.0) * 1.0;
    double actAUC = roc.calculateAUC();

    assertEquals(expAUC, actAUC, 1e-6);

    PrecisionRecallCurve prc = roc.getPrecisionRecallCurve();
    for (int i = 0; i < 11; i++) {
        PrecisionRecallCurve.Confusion c = prc.getConfusionMatrixAtThreshold(i * 0.1);
        assertEquals(expTPs[i], c.getTpCount());
        assertEquals(expFPs[i], c.getFpCount());
        assertEquals(expFPs[i], c.getFpCount());
        assertEquals(expTNs[i], c.getTnCount());
    }
}
 
Example #23
Source File: LstmTimeSeriesExample.java    From Java-Deep-Learning-Cookbook with MIT License 4 votes vote down vote up
public static void main(String[] args) throws IOException, InterruptedException {
    if(FEATURE_DIR.equals("{PATH-TO-PHYSIONET-FEATURES}") || LABEL_DIR.equals("{PATH-TO-PHYSIONET-LABELS")){
        System.out.println("Please provide proper directory path in place of: PATH-TO-PHYSIONET-FEATURES && PATH-TO-PHYSIONET-LABELS");
        throw new FileNotFoundException();
    }
    SequenceRecordReader trainFeaturesReader = new CSVSequenceRecordReader(1, ",");
    trainFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",0,3199));
    SequenceRecordReader trainLabelsReader = new CSVSequenceRecordReader();
    trainLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",0,3199));
    DataSetIterator trainDataSetIterator = new SequenceRecordReaderDataSetIterator(trainFeaturesReader,trainLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    SequenceRecordReader testFeaturesReader = new CSVSequenceRecordReader(1, ",");
    testFeaturesReader.initialize(new NumberedFileInputSplit(FEATURE_DIR+"/%d.csv",3200,3999));
    SequenceRecordReader testLabelsReader = new CSVSequenceRecordReader();
    testLabelsReader.initialize(new NumberedFileInputSplit(LABEL_DIR+"/%d.csv",3200,3999));
    DataSetIterator testDataSetIterator = new SequenceRecordReaderDataSetIterator(testFeaturesReader,testLabelsReader,100,2,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

    ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder()
                                                    .seed(RANDOM_SEED)
                                                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                                                    .weightInit(WeightInit.XAVIER)
                                                    .updater(new Adam())
                                                    .dropOut(0.9)
                                                    .graphBuilder()
                                                    .addInputs("trainFeatures")
                                                    .setOutputs("predictMortality")
                                                    .addLayer("L1", new LSTM.Builder()
                                                                                   .nIn(86)
                                                                                    .nOut(200)
                                                                                    .forgetGateBiasInit(1)
                                                                                    .activation(Activation.TANH)
                                                                                    .build(),"trainFeatures")
                                                    .addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                                                                        .activation(Activation.SOFTMAX)
                                                                                        .nIn(200).nOut(2).build(),"L1")
                                                    .build();

    ComputationGraph model = new ComputationGraph(configuration);

    for(int i=0;i<1;i++){
       model.fit(trainDataSetIterator);
       trainDataSetIterator.reset();
    }
    ROC evaluation = new ROC(100);
    while (testDataSetIterator.hasNext()) {
        DataSet batch = testDataSetIterator.next();
        INDArray[] output = model.output(batch.getFeatures());
        evaluation.evalTimeSeries(batch.getLabels(), output[0]);
    }
    
    System.out.println(evaluation.calculateAUC());
    System.out.println(evaluation.stats());
}
 
Example #24
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRocTimeSeriesNoMasking() {
    //Same as first test...

    //2 outputs here - probability distribution over classes (softmax)
    INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
                    {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
                    {0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});

    INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
                    {0, 1}, {0, 1}});

    INDArray predictions3d = Nd4j.create(2, 2, 5);
    INDArray firstTSp =
                    predictions3d.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
    assertArrayEquals(new long[] {5, 2}, firstTSp.shape());
    firstTSp.assign(predictions2d.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));

    INDArray secondTSp =
                    predictions3d.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
    assertArrayEquals(new long[] {5, 2}, secondTSp.shape());
    secondTSp.assign(predictions2d.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));

    INDArray labels3d = Nd4j.create(2, 2, 5);
    INDArray firstTS = labels3d.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
    assertArrayEquals(new long[] {5, 2}, firstTS.shape());
    firstTS.assign(actual2d.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));

    INDArray secondTS = labels3d.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
    assertArrayEquals(new long[] {5, 2}, secondTS.shape());
    secondTS.assign(actual2d.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));

    for (int steps : new int[] {10, 0}) { //0 steps: exact
        //            System.out.println("Steps: " + steps);
        ROC rocExp = new ROC(steps);
        rocExp.eval(actual2d, predictions2d);

        ROC rocAct = new ROC(steps);
        rocAct.evalTimeSeries(labels3d, predictions3d);

        assertEquals(rocExp.calculateAUC(), rocAct.calculateAUC(), 1e-6);
        assertEquals(rocExp.calculateAUCPR(), rocAct.calculateAUCPR(), 1e-6);

        assertEquals(rocExp.getRocCurve(), rocAct.getRocCurve());
    }
}
 
Example #25
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRocTimeSeriesMasking() {
    //2 outputs here - probability distribution over classes (softmax)
    INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
                    {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
                    {0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});

    INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
                    {0, 1}, {0, 1}});


    //Create time series data... first time series: length 4. Second time series: length 6
    INDArray predictions3d = Nd4j.create(2, 2, 6);
    INDArray tad = predictions3d.tensorAlongDimension(0, 1, 2).transpose();
    tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
                    .assign(predictions2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));

    tad = predictions3d.tensorAlongDimension(1, 1, 2).transpose();
    tad.assign(predictions2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));


    INDArray labels3d = Nd4j.create(2, 2, 6);
    tad = labels3d.tensorAlongDimension(0, 1, 2).transpose();
    tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
                    .assign(actual2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));

    tad = labels3d.tensorAlongDimension(1, 1, 2).transpose();
    tad.assign(actual2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));


    INDArray mask = Nd4j.zeros(2, 6);
    mask.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 4)).assign(1);
    mask.get(NDArrayIndex.point(1), NDArrayIndex.all()).assign(1);


    for (int steps : new int[] {20, 0}) { //0 steps: exact
        ROC rocExp = new ROC(steps);
        rocExp.eval(actual2d, predictions2d);

        ROC rocAct = new ROC(steps);
        rocAct.evalTimeSeries(labels3d, predictions3d, mask);

        assertEquals(rocExp.calculateAUC(), rocAct.calculateAUC(), 1e-6);

        assertEquals(rocExp.getRocCurve(), rocAct.getRocCurve());
    }
}
 
Example #26
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSegmentationBinary(){
    for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
        Nd4j.getRandom().setSeed(12345);
        int mb = 3;
        int h = 3;
        int w = 2;

        //NCHW
        INDArray labels = Nd4j.create(DataType.FLOAT, mb, c, h, w);
        Nd4j.exec(new BernoulliDistribution(labels, 0.5));

        INDArray predictions = Nd4j.rand(DataType.FLOAT, mb, c, h, w);

        ROCBinary e2d = new ROCBinary();
        ROCBinary e4d = new ROCBinary();

        ROC r2d = new ROC();
        e4d.eval(labels, predictions);

        for (int i = 0; i < mb; i++) {
            for (int j = 0; j < h; j++) {
                for (int k = 0; k < w; k++) {
                    INDArray rowLabel = labels.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j), NDArrayIndex.point(k));
                    INDArray rowPredictions = predictions.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j), NDArrayIndex.point(k));
                    rowLabel = rowLabel.reshape(1, rowLabel.length());
                    rowPredictions = rowPredictions.reshape(1, rowLabel.length());

                    e2d.eval(rowLabel, rowPredictions);
                    if(c == 1){
                        r2d.eval(rowLabel, rowPredictions);
                    }
                }
            }
        }

        assertEquals(e2d, e4d);

        if(c == 1){
            ROC r4d = new ROC();
            r4d.eval(labels, predictions);
            assertEquals(r2d, r4d);
        }


        //NHWC, etc
        INDArray lOrig = labels;
        INDArray fOrig = predictions;
        for (int i = 0; i < 4; i++) {
            switch (i) {
                case 0:
                    //CNHW - Never really used
                    labels = lOrig.permute(1, 0, 2, 3).dup();
                    predictions = fOrig.permute(1, 0, 2, 3).dup();
                    break;
                case 1:
                    //NCHW
                    labels = lOrig;
                    predictions = fOrig;
                    break;
                case 2:
                    //NHCW - Never really used...
                    labels = lOrig.permute(0, 2, 1, 3).dup();
                    predictions = fOrig.permute(0, 2, 1, 3).dup();
                    break;
                case 3:
                    //NHWC
                    labels = lOrig.permute(0, 2, 3, 1).dup();
                    predictions = fOrig.permute(0, 2, 3, 1).dup();
                    break;
                default:
                    throw new RuntimeException();
            }

            ROCBinary e = new ROCBinary();
            e.setAxis(i);

            e.eval(labels, predictions);
            assertEquals(e2d, e);

            if(c == 1){
                ROC r2 = new ROC();
                r2.setAxis(i);
                r2.eval(labels, predictions);
                assertEquals(r2d, r2);
            }
        }
    }
}
 
Example #27
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRocMerge(){
    Nd4j.getRandom().setSeed(12345);

    ROC roc = new ROC();
    ROC roc1 = new ROC();
    ROC roc2 = new ROC();

    int nOut = 2;

    Random r = new Random(12345);
    for( int i=0; i<10; i++ ){
        INDArray labels = Nd4j.zeros(3, nOut);
        for( int j=0; j<3; j++ ){
            labels.putScalar(j, r.nextInt(nOut), 1.0 );
        }
        INDArray out = Nd4j.rand(3, nOut);
        out.diviColumnVector(out.sum(1));

        roc.eval(labels, out);
        if(i % 2 == 0){
            roc1.eval(labels, out);
        } else {
            roc2.eval(labels, out);
        }
    }

    roc1.calculateAUC();
    roc1.calculateAUCPR();
    roc2.calculateAUC();
    roc2.calculateAUCPR();

    roc1.merge(roc2);

    double aucExp = roc.calculateAUC();
    double auprc = roc.calculateAUCPR();

    double aucAct = roc1.calculateAUC();
    double auprcAct = roc1.calculateAUCPR();

    assertEquals(aucExp, aucAct, 1e-6);
    assertEquals(auprc, auprcAct, 1e-6);
}
 
Example #28
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testROCMerging() {
    int nArrays = 10;
    int minibatch = 64;
    int nROCs = 3;

    for (int steps : new int[] {0, 20}) { //0 steps: exact, 20 steps: thresholded

        Nd4j.getRandom().setSeed(12345);
        Random r = new Random(12345);

        List<ROC> rocList = new ArrayList<>();
        for (int i = 0; i < nROCs; i++) {
            rocList.add(new ROC(steps));
        }

        ROC single = new ROC(steps);
        for (int i = 0; i < nArrays; i++) {
            INDArray p = Nd4j.rand(minibatch, 2);
            p.diviColumnVector(p.sum(1));

            INDArray l = Nd4j.zeros(minibatch, 2);
            for (int j = 0; j < minibatch; j++) {
                l.putScalar(j, r.nextInt(2), 1.0);
            }

            single.eval(l, p);

            ROC other = rocList.get(i % rocList.size());
            other.eval(l, p);
        }

        ROC first = rocList.get(0);
        for (int i = 1; i < nROCs; i++) {
            first.merge(rocList.get(i));
        }

        double singleAUC = single.calculateAUC();
        assertTrue(singleAUC >= 0.0 && singleAUC <= 1.0);
        assertEquals(singleAUC, first.calculateAUC(), 1e-6);

        assertEquals(single.getRocCurve(), first.getRocCurve());
    }
}
 
Example #29
Source File: ROCTest.java    From deeplearning4j with Apache License 2.0 3 votes vote down vote up
@Test
public void rocExactEdgeCaseReallocation() {

    //Set reallocation block size to say 20, but then evaluate a 100-length array

    ROC roc = new ROC(0, true, 50);

    roc.eval(Nd4j.rand(100, 1), Nd4j.ones(100, 1));

}
 
Example #30
Source File: EvaluationTools.java    From deeplearning4j with Apache License 2.0 2 votes vote down vote up
/**
 * Given a {@link ROC} chart, export the ROC chart and precision vs. recall charts to a stand-alone HTML file
 * @param roc  ROC to export
 * @param file File to export to
 */
public static void exportRocChartsToHtmlFile(ROC roc, File file) throws IOException {
    String rocAsHtml = rocChartToHtml(roc);
    FileUtils.writeStringToFile(file, rocAsHtml);
}