org.apache.spark.mllib.evaluation.MulticlassMetrics Java Examples
The following examples show how to use
org.apache.spark.mllib.evaluation.MulticlassMetrics.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: JavaLogisticRegressionWithLBFGSExample.java From SparkDemo with MIT License | 5 votes |
public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionWithLBFGSExample"); SparkContext sc = new SparkContext(conf); // $example on$ String path = "data/mllib/sample_libsvm_data.txt"; JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); // Split initial RDD into two... [60% training data, 40% testing data]. JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); JavaRDD<LabeledPoint> training = splits[0].cache(); JavaRDD<LabeledPoint> test = splits[1]; // Run training algorithm to build the model. final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() .setNumClasses(10) .run(training.rdd()); // Compute raw scores on the test set. JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map( new Function<LabeledPoint, Tuple2<Object, Object>>() { public Tuple2<Object, Object> call(LabeledPoint p) { Double prediction = model.predict(p.features()); return new Tuple2<Object, Object>(prediction, p.label()); } } ); // Get evaluation metrics. MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); double accuracy = metrics.accuracy(); System.out.println("Accuracy = " + accuracy); // Save and load model model.save(sc, "target/tmp/javaLogisticRegressionWithLBFGSModel"); LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "target/tmp/javaLogisticRegressionWithLBFGSModel"); // $example off$ sc.stop(); }
Example #2
Source File: MulticlassClassificationEvaluatorByClass.java From ambiverse-nlu with Apache License 2.0 | 4 votes |
@Override public double evaluate(DataFrame dataset) { StructType schema = dataset.schema(); SchemaUtils.checkColumnType(schema, this.getPredictionCol(), DataTypes.DoubleType, ""); SchemaUtils.checkColumnType(schema, this.getLabelCol(), DataTypes.DoubleType, ""); MulticlassMetrics metrics = new MulticlassMetrics(dataset .select(this.getPredictionCol(), this.getLabelCol())); int labelColumn = 0; for(int i=0; i < metrics.labels().length; i++) { if(metrics.labels()[i] == evaluationClass) { labelColumn = i; } } double metric=0d; switch(getMetricName()) { case "f1": metric = metrics.fMeasure(metrics.labels()[labelColumn]); break; case "precision": metric = metrics.precision(metrics.labels()[labelColumn]); break; case "recall": metric = metrics.recall(metrics.labels()[labelColumn]); break; case "weightedPrecision": metric = metrics.weightedPrecision(); break; case "weightedRecall": metric = metrics.weightedRecall(); break; } return metric; }
Example #3
Source File: TrainingSparkRunner.java From ambiverse-nlu with Apache License 2.0 | 4 votes |
private void multiClassEvaluation(DataFrame predictions, String output, TrainingSettings trainingSettings) throws IOException { FileSystem fs = FileSystem.get(new Configuration()); Path evalPath = new Path(output+"multiclass_evaluation_"+trainingSettings.getClassificationMethod()+".txt"); fs.delete(evalPath, true); FSDataOutputStream fsdos = fs.create(evalPath); MulticlassMetrics metrics = new MulticlassMetrics(predictions .select("prediction", "label")); // Confusion matrix Matrix confusion = metrics.confusionMatrix(); IOUtils.write("\nConfusion matrix: \n" + confusion, fsdos); // Overall statistics IOUtils.write("\nPrecision = " + metrics.precision(), fsdos); IOUtils.write("\nRecall = " + metrics.recall(), fsdos); IOUtils.write("\nF1 Score = " + metrics.fMeasure(), fsdos); IOUtils.write("\n\n", fsdos); // Stats by labels for (int i = 0; i < metrics.labels().length; i++) { IOUtils.write(String.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision(metrics.labels()[i])), fsdos); IOUtils.write(String.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])), fsdos); IOUtils.write(String.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])), fsdos); System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision(metrics.labels()[i])); System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); } //Weighted stats IOUtils.write("\nWeighted precision = "+metrics.weightedPrecision(), fsdos); IOUtils.write("\nWeighted recall = "+metrics.weightedRecall(), fsdos); IOUtils.write("\nWeighted F1 score ="+metrics.weightedFMeasure(), fsdos); IOUtils.write("\nWeighted false positive rate = " +metrics.weightedFalsePositiveRate(), fsdos); fsdos.flush(); IOUtils.closeQuietly(fsdos); }
Example #4
Source File: SparkMultiClassClassifier.java From mmtf-spark with Apache License 2.0 | 4 votes |
/** * Dataset must at least contain the following two columns: * label: the class labels * features: feature vector * @param data * @return map with metrics */ public Map<String,String> fit(Dataset<Row> data) { int classCount = (int)data.select(label).distinct().count(); StringIndexerModel labelIndexer = new StringIndexer() .setInputCol(label) .setOutputCol("indexedLabel") .fit(data); // Split the data into training and test sets (30% held out for testing) Dataset<Row>[] splits = data.randomSplit(new double[] {1.0-testFraction, testFraction}, seed); Dataset<Row> trainingData = splits[0]; Dataset<Row> testData = splits[1]; String[] labels = labelIndexer.labels(); System.out.println(); System.out.println("Class\tTrain\tTest"); for (String l: labels) { System.out.println(l + "\t" + trainingData.select(label).filter(label + " = '" + l + "'").count() + "\t" + testData.select(label).filter(label + " = '" + l + "'").count()); } // Set input columns predictor .setLabelCol("indexedLabel") .setFeaturesCol("features"); // Convert indexed labels back to original labels. IndexToString labelConverter = new IndexToString() .setInputCol("prediction") .setOutputCol("predictedLabel") .setLabels(labelIndexer.labels()); // Chain indexers and forest in a Pipeline Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[] {labelIndexer, predictor, labelConverter}); // Train model. This also runs the indexers. PipelineModel model = pipeline.fit(trainingData); // Make predictions. Dataset<Row> predictions = model.transform(testData).cache(); // Display some sample predictions System.out.println(); System.out.println("Sample predictions: " + predictor.getClass().getSimpleName()); predictions.sample(false, 0.1, seed).show(25); predictions = predictions.withColumnRenamed(label, "stringLabel"); predictions = predictions.withColumnRenamed("indexedLabel", label); // collect metrics Dataset<Row> pred = predictions.select("prediction",label); Map<String,String> metrics = new LinkedHashMap<>(); metrics.put("Method", predictor.getClass().getSimpleName()); if (classCount == 2) { BinaryClassificationMetrics b = new BinaryClassificationMetrics(pred); metrics.put("AUC", Float.toString((float)b.areaUnderROC())); } MulticlassMetrics m = new MulticlassMetrics(pred); metrics.put("F", Float.toString((float)m.weightedFMeasure())); metrics.put("Accuracy", Float.toString((float)m.accuracy())); metrics.put("Precision", Float.toString((float)m.weightedPrecision())); metrics.put("Recall", Float.toString((float)m.weightedRecall())); metrics.put("False Positive Rate", Float.toString((float)m.weightedFalsePositiveRate())); metrics.put("True Positive Rate", Float.toString((float)m.weightedTruePositiveRate())); metrics.put("", "\nConfusion Matrix\n" + Arrays.toString(labels) +"\n" + m.confusionMatrix().toString()); return metrics; }
Example #5
Source File: JavaMulticlassClassificationMetricsExample.java From SparkDemo with MIT License | 4 votes |
public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics Example"); SparkContext sc = new SparkContext(conf); // $example on$ String path = "data/mllib/sample_multiclass_classification_data.txt"; JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); // Split initial RDD into two... [60% training data, 40% testing data]. JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L); JavaRDD<LabeledPoint> training = splits[0].cache(); JavaRDD<LabeledPoint> test = splits[1]; // Run training algorithm to build the model. final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() .setNumClasses(3) .run(training.rdd()); // Compute raw scores on the test set. JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map( new Function<LabeledPoint, Tuple2<Object, Object>>() { public Tuple2<Object, Object> call(LabeledPoint p) { Double prediction = model.predict(p.features()); return new Tuple2<Object, Object>(prediction, p.label()); } } ); // Get evaluation metrics. MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); // Confusion matrix Matrix confusion = metrics.confusionMatrix(); System.out.println("Confusion matrix: \n" + confusion); // Overall statistics System.out.println("Accuracy = " + metrics.accuracy()); // Stats by labels for (int i = 0; i < metrics.labels().length; i++) { System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision( metrics.labels()[i])); System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall( metrics.labels()[i])); System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure( metrics.labels()[i])); } //Weighted stats System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); // Save and load model model.save(sc, "target/tmp/LogisticRegressionModel"); LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "target/tmp/LogisticRegressionModel"); // $example off$ }