org.apache.spark.ml.classification.OneVsRest Scala Examples
The following examples show how to use org.apache.spark.ml.classification.OneVsRest.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: OneVsRestParitySpec.scala From mleap with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.parity.classification import org.apache.spark.ml.{Pipeline, Transformer} import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler} import org.apache.spark.ml.parity.SparkParityBase import org.apache.spark.sql.DataFrame class OneVsRestParitySpec extends SparkParityBase { override val dataset: DataFrame = baseDataset.select("fico_score_group_fnl", "dti") override val sparkTransformer: Transformer = new Pipeline().setStages(Array(new StringIndexer(). setInputCol("fico_score_group_fnl"). setOutputCol("fico_index"), new VectorAssembler(). setInputCols(Array("fico_index", "dti")). setOutputCol("features"), new OneVsRest().setClassifier(new LogisticRegression()). setLabelCol("fico_index"). setFeaturesCol("features"). setPredictionCol("prediction"))).fit(dataset) override val unserializedParams = Set("stringOrderType", "classifier", "labelCol") }
Example 2
Source File: Iris.scala From spark-gp with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.classification.examples import org.apache.spark.ml.classification.{GaussianProcessClassifier, OneVsRest} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} import org.apache.spark.sql.SparkSession object Iris extends App { val name = "Iris" val spark = SparkSession.builder().appName(name).master("local[4]").getOrCreate() import spark.sqlContext.implicits._ val name2indx = Map("Iris-versicolor" -> 0, "Iris-setosa" -> 1, "Iris-virginica" -> 2) val dataset = spark.read.format("csv").load("data/iris.csv").rdd.map(row => { val features = Vectors.dense(Array("_c0", "_c1", "_c2", "_c3") .map(col => row.getAs[String](col).toDouble)) val label = name2indx(row.getAs[String]("_c4")) LabeledPoint(label, features) }).toDF val gp = new GaussianProcessClassifier().setDatasetSizeForExpert(20).setActiveSetSize(30) val ovr = new OneVsRest().setClassifier(gp) val cv = new CrossValidator() .setEstimator(ovr) .setEvaluator(new MulticlassClassificationEvaluator().setMetricName("accuracy")) .setEstimatorParamMaps(new ParamGridBuilder().build()) .setNumFolds(10) println("Accuracy: " + cv.fit(dataset).avgMetrics.toList) }
Example 3
Source File: OnevsRest.scala From Apache-Spark-2x-Machine-Learning-Cookbook with MIT License | 5 votes |
package spark.ml.cookbook.chapter5 import org.apache.spark.sql.SparkSession import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator object OnevsRest { def main(args: Array[String]): Unit = { import org.apache.log4j.Logger import org.apache.log4j.Level Logger.getLogger("org").setLevel(Level.ERROR) Logger.getLogger("akka").setLevel(Level.ERROR) val spark = SparkSession .builder .master("local[*]") .appName("MLP") .config("spark.sql.warehouse.dir", ".") .getOrCreate() val data = spark.read.format("libsvm") .load("../data/sparkml2/chapter5/iris.scale.txt") data.show(false) val Array(train, test) = data.randomSplit(Array(0.8, 0.2), seed = System.currentTimeMillis()) // logistic regression classifier val lrc = new LogisticRegression() .setMaxIter(15) .setTol(1E-3) .setFitIntercept(true) val ovr = new OneVsRest().setClassifier(lrc) val ovrModel = ovr.fit(train) val predictions = ovrModel.transform(test) predictions.show(false) val eval = new MulticlassClassificationEvaluator() .setMetricName("accuracy") // compute the classification error on test data. val accuracy = eval.evaluate(predictions) println("Accuracy: " + eval.evaluate(predictions)) } }