org.apache.spark.mllib.util.Saveable Scala Examples
The following examples show how to use org.apache.spark.mllib.util.Saveable.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: KMeansModel.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.clustering import scala.collection.JavaConverters._ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} @Since("0.8.0") def computeCost(data: RDD[Vector]): Double = { val centersWithNorm = clusterCentersWithNorm val bcCentersWithNorm = data.context.broadcast(centersWithNorm) data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum() } private def clusterCentersWithNorm: Iterable[VectorWithNorm] = clusterCenters.map(new VectorWithNorm(_)) @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { KMeansModel.SaveLoadV1_0.save(sc, this, path) } override protected def formatVersion: String = "1.0" } @Since("1.4.0") object KMeansModel extends Loader[KMeansModel] { @Since("1.4.0") override def load(sc: SparkContext, path: String): KMeansModel = { KMeansModel.SaveLoadV1_0.load(sc, path) } private case class Cluster(id: Int, point: Vector) private object Cluster { def apply(r: Row): Cluster = { Cluster(r.getInt(0), r.getAs[Vector](1)) } } private[clustering] object SaveLoadV1_0 { private val thisFormatVersion = "1.0" private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) => Cluster(id, point) } spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): KMeansModel = { implicit val formats = DefaultFormats val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] val centroids = spark.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Cluster](centroids.schema) val localCentroids = centroids.rdd.map(Cluster.apply).collect() assert(k == localCentroids.length) new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) } } }