org.apache.spark.ml.util.MLReader Scala Examples
The following examples show how to use org.apache.spark.ml.util.MLReader.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: MultilayerPerceptronClassifierWrapper.scala From drizzle-spark with Apache License 2.0 | 8 votes |
package org.apache.spark.ml.r import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter} import org.apache.spark.sql.{DataFrame, Dataset} private[r] class MultilayerPerceptronClassifierWrapper private ( val pipeline: PipelineModel, val labelCount: Long, val layers: Array[Int], val weights: Array[Double] ) extends MLWritable { def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset) } override def read: MLReader[MultilayerPerceptronClassifierWrapper] = new MultilayerPerceptronClassifierWrapperReader override def load(path: String): MultilayerPerceptronClassifierWrapper = super.load(path) class MultilayerPerceptronClassifierWrapperReader extends MLReader[MultilayerPerceptronClassifierWrapper]{ override def load(path: String): MultilayerPerceptronClassifierWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString val rMetadataStr = sc.textFile(rMetadataPath, 1).first() val rMetadata = parse(rMetadataStr) val labelCount = (rMetadata \ "labelCount").extract[Long] val layers = (rMetadata \ "layers").extract[Array[Int]] val weights = (rMetadata \ "weights").extract[Array[Double]] val pipeline = PipelineModel.load(pipelinePath) new MultilayerPerceptronClassifierWrapper(pipeline, labelCount, layers, weights) } } class MultilayerPerceptronClassifierWrapperWriter(instance: MultilayerPerceptronClassifierWrapper) extends MLWriter { override protected def saveImpl(path: String): Unit = { val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString val rMetadata = ("class" -> instance.getClass.getName) ~ ("labelCount" -> instance.labelCount) ~ ("layers" -> instance.layers.toSeq) ~ ("weights" -> instance.weights.toArray.toSeq) val rMetadataJson: String = compact(render(rMetadata)) sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) instance.pipeline.save(pipelinePath) } } }
Example 2
Source File: RWrappers.scala From drizzle-spark with Apache License 2.0 | 6 votes |
package org.apache.spark.ml.r import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkException import org.apache.spark.ml.util.MLReader private[r] object RWrappers extends MLReader[Object] { override def load(path: String): Object = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val rMetadataStr = sc.textFile(rMetadataPath, 1).first() val rMetadata = parse(rMetadataStr) val className = (rMetadata \ "class").extract[String] className match { case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path) case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" => AFTSurvivalRegressionWrapper.load(path) case "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" => GeneralizedLinearRegressionWrapper.load(path) case "org.apache.spark.ml.r.KMeansWrapper" => KMeansWrapper.load(path) case "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper" => MultilayerPerceptronClassifierWrapper.load(path) case "org.apache.spark.ml.r.LDAWrapper" => LDAWrapper.load(path) case "org.apache.spark.ml.r.IsotonicRegressionWrapper" => IsotonicRegressionWrapper.load(path) case "org.apache.spark.ml.r.GaussianMixtureWrapper" => GaussianMixtureWrapper.load(path) case "org.apache.spark.ml.r.ALSWrapper" => ALSWrapper.load(path) case "org.apache.spark.ml.r.LogisticRegressionWrapper" => LogisticRegressionWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } } }
Example 3
Source File: MathUnary.scala From mleap with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.mleap.feature import ml.combust.mleap.core.feature.{MathUnaryModel, UnaryOperation} import org.apache.hadoop.fs.Path import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, Identifiable, MLReadable, MLReader, MLWritable, MLWriter} import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} import org.apache.spark.sql.functions.udf private val className = classOf[MathUnary].getName override def load(path: String): MathUnary = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).select("operation").head() val operation = data.getAs[String](0) val model = MathUnaryModel(UnaryOperation.forName(operation)) val transformer = new MathUnary(metadata.uid, model) metadata.getAndSetParams(transformer) transformer } } }
Example 4
Source File: RWrappers.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.r import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkException import org.apache.spark.ml.util.MLReader private[r] object RWrappers extends MLReader[Object] { override def load(path: String): Object = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val rMetadataStr = sc.textFile(rMetadataPath, 1).first() val rMetadata = parse(rMetadataStr) val className = (rMetadata \ "class").extract[String] className match { case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path) case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" => AFTSurvivalRegressionWrapper.load(path) case "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" => GeneralizedLinearRegressionWrapper.load(path) case "org.apache.spark.ml.r.KMeansWrapper" => KMeansWrapper.load(path) case "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper" => MultilayerPerceptronClassifierWrapper.load(path) case "org.apache.spark.ml.r.LDAWrapper" => LDAWrapper.load(path) case "org.apache.spark.ml.r.IsotonicRegressionWrapper" => IsotonicRegressionWrapper.load(path) case "org.apache.spark.ml.r.GaussianMixtureWrapper" => GaussianMixtureWrapper.load(path) case "org.apache.spark.ml.r.ALSWrapper" => ALSWrapper.load(path) case "org.apache.spark.ml.r.LogisticRegressionWrapper" => LogisticRegressionWrapper.load(path) case "org.apache.spark.ml.r.RandomForestRegressorWrapper" => RandomForestRegressorWrapper.load(path) case "org.apache.spark.ml.r.RandomForestClassifierWrapper" => RandomForestClassifierWrapper.load(path) case "org.apache.spark.ml.r.GBTRegressorWrapper" => GBTRegressorWrapper.load(path) case "org.apache.spark.ml.r.GBTClassifierWrapper" => GBTClassifierWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } } }
Example 5
Source File: MultilayerPerceptronClassifierWrapper.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.r import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier} import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter} import org.apache.spark.sql.{DataFrame, Dataset} private[r] class MultilayerPerceptronClassifierWrapper private ( val pipeline: PipelineModel ) extends MLWritable { import MultilayerPerceptronClassifierWrapper._ val mlpModel: MultilayerPerceptronClassificationModel = pipeline.stages(1).asInstanceOf[MultilayerPerceptronClassificationModel] val weights: Array[Double] = mlpModel.weights.toArray val layers: Array[Int] = mlpModel.layers def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset) .drop(mlpModel.getFeaturesCol) .drop(mlpModel.getLabelCol) .drop(PREDICTED_LABEL_INDEX_COL) } override def read: MLReader[MultilayerPerceptronClassifierWrapper] = new MultilayerPerceptronClassifierWrapperReader override def load(path: String): MultilayerPerceptronClassifierWrapper = super.load(path) class MultilayerPerceptronClassifierWrapperReader extends MLReader[MultilayerPerceptronClassifierWrapper]{ override def load(path: String): MultilayerPerceptronClassifierWrapper = { implicit val format = DefaultFormats val pipelinePath = new Path(path, "pipeline").toString val pipeline = PipelineModel.load(pipelinePath) new MultilayerPerceptronClassifierWrapper(pipeline) } } class MultilayerPerceptronClassifierWrapperWriter(instance: MultilayerPerceptronClassifierWrapper) extends MLWriter { override protected def saveImpl(path: String): Unit = { val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString val rMetadata = "class" -> instance.getClass.getName val rMetadataJson: String = compact(render(rMetadata)) sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) instance.pipeline.save(pipelinePath) } } }
Example 6
Source File: RWrappers.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.r import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkException import org.apache.spark.ml.util.MLReader private[r] object RWrappers extends MLReader[Object] { override def load(path: String): Object = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val rMetadataStr = sc.textFile(rMetadataPath, 1).first() val rMetadata = parse(rMetadataStr) val className = (rMetadata \ "class").extract[String] className match { case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path) case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" => AFTSurvivalRegressionWrapper.load(path) case "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" => GeneralizedLinearRegressionWrapper.load(path) case "org.apache.spark.ml.r.KMeansWrapper" => KMeansWrapper.load(path) case "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper" => MultilayerPerceptronClassifierWrapper.load(path) case "org.apache.spark.ml.r.LDAWrapper" => LDAWrapper.load(path) case "org.apache.spark.ml.r.IsotonicRegressionWrapper" => IsotonicRegressionWrapper.load(path) case "org.apache.spark.ml.r.GaussianMixtureWrapper" => GaussianMixtureWrapper.load(path) case "org.apache.spark.ml.r.ALSWrapper" => ALSWrapper.load(path) case "org.apache.spark.ml.r.LogisticRegressionWrapper" => LogisticRegressionWrapper.load(path) case "org.apache.spark.ml.r.RandomForestRegressorWrapper" => RandomForestRegressorWrapper.load(path) case "org.apache.spark.ml.r.RandomForestClassifierWrapper" => RandomForestClassifierWrapper.load(path) case "org.apache.spark.ml.r.GBTRegressorWrapper" => GBTRegressorWrapper.load(path) case "org.apache.spark.ml.r.GBTClassifierWrapper" => GBTClassifierWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } } }
Example 7
Source File: MultilayerPerceptronClassifierWrapper.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.r import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier} import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter} import org.apache.spark.sql.{DataFrame, Dataset} private[r] class MultilayerPerceptronClassifierWrapper private ( val pipeline: PipelineModel ) extends MLWritable { import MultilayerPerceptronClassifierWrapper._ val mlpModel: MultilayerPerceptronClassificationModel = pipeline.stages(1).asInstanceOf[MultilayerPerceptronClassificationModel] val weights: Array[Double] = mlpModel.weights.toArray val layers: Array[Int] = mlpModel.layers def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset) .drop(mlpModel.getFeaturesCol) .drop(mlpModel.getLabelCol) .drop(PREDICTED_LABEL_INDEX_COL) } override def read: MLReader[MultilayerPerceptronClassifierWrapper] = new MultilayerPerceptronClassifierWrapperReader override def load(path: String): MultilayerPerceptronClassifierWrapper = super.load(path) class MultilayerPerceptronClassifierWrapperReader extends MLReader[MultilayerPerceptronClassifierWrapper]{ override def load(path: String): MultilayerPerceptronClassifierWrapper = { implicit val format = DefaultFormats val pipelinePath = new Path(path, "pipeline").toString val pipeline = PipelineModel.load(pipelinePath) new MultilayerPerceptronClassifierWrapper(pipeline) } } class MultilayerPerceptronClassifierWrapperWriter(instance: MultilayerPerceptronClassifierWrapper) extends MLWriter { override protected def saveImpl(path: String): Unit = { val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString val rMetadata = "class" -> instance.getClass.getName val rMetadataJson: String = compact(render(rMetadata)) sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) instance.pipeline.save(pipelinePath) } } }
Example 8
Source File: SerializableSparkModel.scala From seahorse with Apache License 2.0 | 5 votes |
package ai.deepsense.deeplang.doperables.serialization import org.apache.spark.ml.Model import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType import ai.deepsense.sparkutils.ML class SerializableSparkModel[M <: Model[M]](val sparkModel: M) extends ML.Model[SerializableSparkModel[M]] with MLWritable { override def copy(extra: ParamMap): SerializableSparkModel[M] = new SerializableSparkModel(sparkModel.copy(extra)) override def write: MLWriter = { sparkModel match { case w: MLWritable => w.write case _ => new DefaultMLWriter(this) } } override def transformDF(dataset: DataFrame): DataFrame = sparkModel.transform(dataset) override def transformSchema(schema: StructType): StructType = sparkModel.transformSchema(schema) override val uid: String = "dc7178fe-b209-44f5-8a74-d3c4dafa0fae" } // This class may seem unused, but it is used reflectively by spark deserialization mechanism object SerializableSparkModel extends MLReadable[SerializableSparkModel[_]] { override def read: MLReader[SerializableSparkModel[_]] = { new DefaultMLReader[SerializableSparkModel[_]]() } }
Example 9
Source File: SparkStageParam.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages import com.salesforce.op.stages.sparkwrappers.generic.SparkWrapperParams import org.apache.hadoop.fs.Path import org.apache.spark.ml.PipelineStage import org.apache.spark.ml.param.{Param, ParamPair, Params} import org.apache.spark.ml.util.{Identifiable, MLReader, MLWritable} import org.apache.spark.util.SparkUtils import org.json4s.JsonAST.{JObject, JValue} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods.{compact, parse, render} import org.json4s.{DefaultFormats, Formats, JString} class SparkStageParam[S <: PipelineStage with Params] ( parent: String, name: String, doc: String, isValid: Option[S] => Boolean ) extends Param[Option[S]](parent, name, doc, isValid) { import SparkStageParam._ override def jsonDecode(jsonStr: String): Option[S] = { val json = parse(jsonStr) val uid = (json \ "uid").extractOpt[String] val path = (json \ "path").extractOpt[String] path -> uid match { case (None, _) | (_, None) | (_, Some(NoUID)) => savePath = None None case (Some(p), Some(stageUid)) => savePath = Option(p) val stagePath = new Path(p, stageUid).toString val className = (json \ "className").extract[String] val cls = SparkUtils.classForName(className) val stage = cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath) Option(stage).map(_.asInstanceOf[S]) } } } object SparkStageParam { implicit val formats: Formats = DefaultFormats val NoClass = "" val NoUID = "" def updateParamsMetadataWithPath(jValue: JValue, path: String): JValue = jValue match { case JObject(pairs) => JObject( pairs.map { case (SparkWrapperParams.SparkStageParamName, j) => SparkWrapperParams.SparkStageParamName -> j.merge(JObject("path" -> JString(path))) case param => param } ) case j => throw new IllegalArgumentException(s"Cannot recognize JSON Spark params metadata: $j") } }
Example 10
Source File: RWrappers.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.r import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkException import org.apache.spark.ml.util.MLReader private[r] object RWrappers extends MLReader[Object] { override def load(path: String): Object = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val rMetadataStr = sc.textFile(rMetadataPath, 1).first() val rMetadata = parse(rMetadataStr) val className = (rMetadata \ "class").extract[String] className match { case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path) case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" => AFTSurvivalRegressionWrapper.load(path) case "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" => GeneralizedLinearRegressionWrapper.load(path) case "org.apache.spark.ml.r.KMeansWrapper" => KMeansWrapper.load(path) case "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper" => MultilayerPerceptronClassifierWrapper.load(path) case "org.apache.spark.ml.r.LDAWrapper" => LDAWrapper.load(path) case "org.apache.spark.ml.r.IsotonicRegressionWrapper" => IsotonicRegressionWrapper.load(path) case "org.apache.spark.ml.r.GaussianMixtureWrapper" => GaussianMixtureWrapper.load(path) case "org.apache.spark.ml.r.ALSWrapper" => ALSWrapper.load(path) case "org.apache.spark.ml.r.LogisticRegressionWrapper" => LogisticRegressionWrapper.load(path) case "org.apache.spark.ml.r.RandomForestRegressorWrapper" => RandomForestRegressorWrapper.load(path) case "org.apache.spark.ml.r.RandomForestClassifierWrapper" => RandomForestClassifierWrapper.load(path) case "org.apache.spark.ml.r.DecisionTreeRegressorWrapper" => DecisionTreeRegressorWrapper.load(path) case "org.apache.spark.ml.r.DecisionTreeClassifierWrapper" => DecisionTreeClassifierWrapper.load(path) case "org.apache.spark.ml.r.GBTRegressorWrapper" => GBTRegressorWrapper.load(path) case "org.apache.spark.ml.r.GBTClassifierWrapper" => GBTClassifierWrapper.load(path) case "org.apache.spark.ml.r.BisectingKMeansWrapper" => BisectingKMeansWrapper.load(path) case "org.apache.spark.ml.r.LinearSVCWrapper" => LinearSVCWrapper.load(path) case "org.apache.spark.ml.r.FPGrowthWrapper" => FPGrowthWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } } }
Example 11
Source File: MultilayerPerceptronClassifierWrapper.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.r import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier} import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter} import org.apache.spark.sql.{DataFrame, Dataset} private[r] class MultilayerPerceptronClassifierWrapper private ( val pipeline: PipelineModel ) extends MLWritable { import MultilayerPerceptronClassifierWrapper._ private val mlpModel: MultilayerPerceptronClassificationModel = pipeline.stages(1).asInstanceOf[MultilayerPerceptronClassificationModel] lazy val weights: Array[Double] = mlpModel.weights.toArray lazy val layers: Array[Int] = mlpModel.layers def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset) .drop(mlpModel.getFeaturesCol) .drop(mlpModel.getLabelCol) .drop(PREDICTED_LABEL_INDEX_COL) } override def read: MLReader[MultilayerPerceptronClassifierWrapper] = new MultilayerPerceptronClassifierWrapperReader override def load(path: String): MultilayerPerceptronClassifierWrapper = super.load(path) class MultilayerPerceptronClassifierWrapperReader extends MLReader[MultilayerPerceptronClassifierWrapper]{ override def load(path: String): MultilayerPerceptronClassifierWrapper = { implicit val format = DefaultFormats val pipelinePath = new Path(path, "pipeline").toString val pipeline = PipelineModel.load(pipelinePath) new MultilayerPerceptronClassifierWrapper(pipeline) } } class MultilayerPerceptronClassifierWrapperWriter(instance: MultilayerPerceptronClassifierWrapper) extends MLWriter { override protected def saveImpl(path: String): Unit = { val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString val rMetadata = "class" -> instance.getClass.getName val rMetadataJson: String = compact(render(rMetadata)) sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) instance.pipeline.save(pipelinePath) } } }
Example 12
Source File: SerializableSparkModel.scala From seahorse-workflow-executor with Apache License 2.0 | 5 votes |
package io.deepsense.deeplang.doperables.serialization import org.apache.spark.ml.Model import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType import io.deepsense.sparkutils.ML class SerializableSparkModel[M <: Model[M]](val sparkModel: M) extends ML.Model[SerializableSparkModel[M]] with MLWritable { override def copy(extra: ParamMap): SerializableSparkModel[M] = new SerializableSparkModel(sparkModel.copy(extra)) override def write: MLWriter = { sparkModel match { case w: MLWritable => w.write case _ => new DefaultMLWriter(this) } } override def transformDF(dataset: DataFrame): DataFrame = sparkModel.transform(dataset) override def transformSchema(schema: StructType): StructType = sparkModel.transformSchema(schema) override val uid: String = "dc7178fe-b209-44f5-8a74-d3c4dafa0fae" } // This class may seem unused, but it is used reflectively by spark deserialization mechanism object SerializableSparkModel extends MLReadable[SerializableSparkModel[_]] { override def read: MLReader[SerializableSparkModel[_]] = { new DefaultMLReader[SerializableSparkModel[_]]() } }