org.apache.spark.mllib.util.DataValidators Scala Examples
The following examples show how to use org.apache.spark.mllib.util.DataValidators.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: LogisticRegressionModel.scala From keystone with Apache License 2.0 | 5 votes |
package keystoneml.nodes.learning import breeze.linalg.Vector import org.apache.spark.mllib.classification.{LogisticRegressionModel => MLlibLRM} import org.apache.spark.mllib.linalg.{Vector => MLlibVector} import org.apache.spark.mllib.optimization.{SquaredL2Updater, LogisticGradient, LBFGS} import org.apache.spark.mllib.regression.{GeneralizedLinearAlgorithm, LabeledPoint} import org.apache.spark.mllib.util.DataValidators import org.apache.spark.rdd.RDD import keystoneml.utils.MLlibUtils.breezeVectorToMLlib import keystoneml.workflow.{LabelEstimator, Transformer} import scala.reflect.ClassTag private[this] class LogisticRegressionWithLBFGS(numClasses: Int, numFeaturesValue: Int) extends GeneralizedLinearAlgorithm[MLlibLRM] with Serializable { this.numFeatures = numFeaturesValue override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) override protected val validators = List(multiLabelValidator) require(numClasses > 1) numOfLinearPredictor = numClasses - 1 if (numClasses > 2) { optimizer.setGradient(new LogisticGradient(numClasses)) } private def multiLabelValidator: RDD[LabeledPoint] => Boolean = { data => if (numOfLinearPredictor > 1) { DataValidators.multiLabelValidator(numOfLinearPredictor + 1)(data) } else { DataValidators.binaryLabelValidator(data) } } override protected def createModel(weights: MLlibVector, intercept: Double) = { if (numOfLinearPredictor == 1) { new MLlibLRM(weights, intercept) } else { new MLlibLRM(weights, intercept, numFeatures, numOfLinearPredictor + 1) } } } override def fit(in: RDD[T], labels: RDD[Int]): LogisticRegressionModel[T] = { val labeledPoints = labels.zip(in).map(x => LabeledPoint(x._1, breezeVectorToMLlib(x._2))) val trainer = new LogisticRegressionWithLBFGS(numClasses, numFeatures) trainer.setValidateData(false).optimizer.setNumIterations(numIters).setRegParam(regParam) val model = trainer.run(labeledPoints) new LogisticRegressionModel(model) } }