Java Code Examples for org.apache.spark.mllib.classification.LogisticRegressionModel#load()
The following examples show how to use
org.apache.spark.mllib.classification.LogisticRegressionModel#load() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: JavaLogisticRegressionWithLBFGSExample.java From SparkDemo with MIT License | 5 votes |
public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionWithLBFGSExample"); SparkContext sc = new SparkContext(conf); // $example on$ String path = "data/mllib/sample_libsvm_data.txt"; JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); // Split initial RDD into two... [60% training data, 40% testing data]. JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); JavaRDD<LabeledPoint> training = splits[0].cache(); JavaRDD<LabeledPoint> test = splits[1]; // Run training algorithm to build the model. final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() .setNumClasses(10) .run(training.rdd()); // Compute raw scores on the test set. JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map( new Function<LabeledPoint, Tuple2<Object, Object>>() { public Tuple2<Object, Object> call(LabeledPoint p) { Double prediction = model.predict(p.features()); return new Tuple2<Object, Object>(prediction, p.label()); } } ); // Get evaluation metrics. MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); double accuracy = metrics.accuracy(); System.out.println("Accuracy = " + accuracy); // Save and load model model.save(sc, "target/tmp/javaLogisticRegressionWithLBFGSModel"); LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "target/tmp/javaLogisticRegressionWithLBFGSModel"); // $example off$ sc.stop(); }
Example 2
Source File: JavaMulticlassClassificationMetricsExample.java From SparkDemo with MIT License | 4 votes |
public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics Example"); SparkContext sc = new SparkContext(conf); // $example on$ String path = "data/mllib/sample_multiclass_classification_data.txt"; JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); // Split initial RDD into two... [60% training data, 40% testing data]. JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L); JavaRDD<LabeledPoint> training = splits[0].cache(); JavaRDD<LabeledPoint> test = splits[1]; // Run training algorithm to build the model. final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() .setNumClasses(3) .run(training.rdd()); // Compute raw scores on the test set. JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map( new Function<LabeledPoint, Tuple2<Object, Object>>() { public Tuple2<Object, Object> call(LabeledPoint p) { Double prediction = model.predict(p.features()); return new Tuple2<Object, Object>(prediction, p.label()); } } ); // Get evaluation metrics. MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); // Confusion matrix Matrix confusion = metrics.confusionMatrix(); System.out.println("Confusion matrix: \n" + confusion); // Overall statistics System.out.println("Accuracy = " + metrics.accuracy()); // Stats by labels for (int i = 0; i < metrics.labels().length; i++) { System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision( metrics.labels()[i])); System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall( metrics.labels()[i])); System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure( metrics.labels()[i])); } //Weighted stats System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); // Save and load model model.save(sc, "target/tmp/LogisticRegressionModel"); LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "target/tmp/LogisticRegressionModel"); // $example off$ }
Example 3
Source File: JavaBinaryClassificationMetricsExample.java From SparkDemo with MIT License | 4 votes |
public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example"); SparkContext sc = new SparkContext(conf); // $example on$ String path = "data/mllib/sample_binary_classification_data.txt"; JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); // Split initial RDD into two... [60% training data, 40% testing data]. JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L); JavaRDD<LabeledPoint> training = splits[0].cache(); JavaRDD<LabeledPoint> test = splits[1]; // Run training algorithm to build the model. final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() .setNumClasses(2) .run(training.rdd()); // Clear the prediction threshold so the model will return probabilities model.clearThreshold(); // Compute raw scores on the test set. JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map( new Function<LabeledPoint, Tuple2<Object, Object>>() { @Override public Tuple2<Object, Object> call(LabeledPoint p) { Double prediction = model.predict(p.features()); return new Tuple2<Object, Object>(prediction, p.label()); } } ); // Get evaluation metrics. BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); // Precision by threshold JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD(); System.out.println("Precision by threshold: " + precision.collect()); // Recall by threshold JavaRDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold().toJavaRDD(); System.out.println("Recall by threshold: " + recall.collect()); // F Score by threshold JavaRDD<Tuple2<Object, Object>> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); System.out.println("F1 Score by threshold: " + f1Score.collect()); JavaRDD<Tuple2<Object, Object>> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); System.out.println("F2 Score by threshold: " + f2Score.collect()); // Precision-recall curve JavaRDD<Tuple2<Object, Object>> prc = metrics.pr().toJavaRDD(); System.out.println("Precision-recall curve: " + prc.collect()); // Thresholds JavaRDD<Double> thresholds = precision.map( new Function<Tuple2<Object, Object>, Double>() { @Override public Double call(Tuple2<Object, Object> t) { return new Double(t._1().toString()); } } ); // ROC Curve JavaRDD<Tuple2<Object, Object>> roc = metrics.roc().toJavaRDD(); System.out.println("ROC curve: " + roc.collect()); // AUPRC System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); // AUROC System.out.println("Area under ROC = " + metrics.areaUnderROC()); // Save and load model model.save(sc, "target/tmp/LogisticRegressionModel"); LogisticRegressionModel.load(sc, "target/tmp/LogisticRegressionModel"); // $example off$ }