org.apache.spark.ml.util.MLReader Scala Examples

The following examples show how to use org.apache.spark.ml.util.MLReader. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: MultilayerPerceptronClassifierWrapper.scala    From drizzle-spark   with Apache License 2.0 8 votes vote down vote up
package org.apache.spark.ml.r

import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.{DataFrame, Dataset}

private[r] class MultilayerPerceptronClassifierWrapper private (
    val pipeline: PipelineModel,
    val labelCount: Long,
    val layers: Array[Int],
    val weights: Array[Double]
  ) extends MLWritable {

  def transform(dataset: Dataset[_]): DataFrame = {
    pipeline.transform(dataset)
  }

  
  override def read: MLReader[MultilayerPerceptronClassifierWrapper] =
    new MultilayerPerceptronClassifierWrapperReader

  override def load(path: String): MultilayerPerceptronClassifierWrapper = super.load(path)

  class MultilayerPerceptronClassifierWrapperReader
    extends MLReader[MultilayerPerceptronClassifierWrapper]{

    override def load(path: String): MultilayerPerceptronClassifierWrapper = {
      implicit val format = DefaultFormats
      val rMetadataPath = new Path(path, "rMetadata").toString
      val pipelinePath = new Path(path, "pipeline").toString

      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
      val rMetadata = parse(rMetadataStr)
      val labelCount = (rMetadata \ "labelCount").extract[Long]
      val layers = (rMetadata \ "layers").extract[Array[Int]]
      val weights = (rMetadata \ "weights").extract[Array[Double]]

      val pipeline = PipelineModel.load(pipelinePath)
      new MultilayerPerceptronClassifierWrapper(pipeline, labelCount, layers, weights)
    }
  }

  class MultilayerPerceptronClassifierWrapperWriter(instance: MultilayerPerceptronClassifierWrapper)
    extends MLWriter {

    override protected def saveImpl(path: String): Unit = {
      val rMetadataPath = new Path(path, "rMetadata").toString
      val pipelinePath = new Path(path, "pipeline").toString

      val rMetadata = ("class" -> instance.getClass.getName) ~
        ("labelCount" -> instance.labelCount) ~
        ("layers" -> instance.layers.toSeq) ~
        ("weights" -> instance.weights.toArray.toSeq)
      val rMetadataJson: String = compact(render(rMetadata))
      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)

      instance.pipeline.save(pipelinePath)
    }
  }
} 
Example 2
Source File: RWrappers.scala    From drizzle-spark   with Apache License 2.0 6 votes vote down vote up
package org.apache.spark.ml.r

import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkException
import org.apache.spark.ml.util.MLReader


private[r] object RWrappers extends MLReader[Object] {

  override def load(path: String): Object = {
    implicit val format = DefaultFormats
    val rMetadataPath = new Path(path, "rMetadata").toString
    val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
    val rMetadata = parse(rMetadataStr)
    val className = (rMetadata \ "class").extract[String]
    className match {
      case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path)
      case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" =>
        AFTSurvivalRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" =>
        GeneralizedLinearRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.KMeansWrapper" =>
        KMeansWrapper.load(path)
      case "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper" =>
        MultilayerPerceptronClassifierWrapper.load(path)
      case "org.apache.spark.ml.r.LDAWrapper" =>
        LDAWrapper.load(path)
      case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
        IsotonicRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
        GaussianMixtureWrapper.load(path)
      case "org.apache.spark.ml.r.ALSWrapper" =>
        ALSWrapper.load(path)
      case "org.apache.spark.ml.r.LogisticRegressionWrapper" =>
        LogisticRegressionWrapper.load(path)
      case _ =>
        throw new SparkException(s"SparkR read.ml does not support load $className")
    }
  }
} 
Example 3
Source File: MathUnary.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.mleap.feature

import ml.combust.mleap.core.feature.{MathUnaryModel, UnaryOperation}
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, Identifiable, MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}
import org.apache.spark.sql.functions.udf


    private val className = classOf[MathUnary].getName

    override def load(path: String): MathUnary = {
      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

      val dataPath = new Path(path, "data").toString

      val data = sparkSession.read.parquet(dataPath).select("operation").head()
      val operation = data.getAs[String](0)

      val model = MathUnaryModel(UnaryOperation.forName(operation))
      val transformer = new MathUnary(metadata.uid, model)

      metadata.getAndSetParams(transformer)
      transformer
    }
  }

} 
Example 4
Source File: RWrappers.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.r

import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkException
import org.apache.spark.ml.util.MLReader


private[r] object RWrappers extends MLReader[Object] {

  override def load(path: String): Object = {
    implicit val format = DefaultFormats
    val rMetadataPath = new Path(path, "rMetadata").toString
    val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
    val rMetadata = parse(rMetadataStr)
    val className = (rMetadata \ "class").extract[String]
    className match {
      case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path)
      case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" =>
        AFTSurvivalRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" =>
        GeneralizedLinearRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.KMeansWrapper" =>
        KMeansWrapper.load(path)
      case "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper" =>
        MultilayerPerceptronClassifierWrapper.load(path)
      case "org.apache.spark.ml.r.LDAWrapper" =>
        LDAWrapper.load(path)
      case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
        IsotonicRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
        GaussianMixtureWrapper.load(path)
      case "org.apache.spark.ml.r.ALSWrapper" =>
        ALSWrapper.load(path)
      case "org.apache.spark.ml.r.LogisticRegressionWrapper" =>
        LogisticRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.RandomForestRegressorWrapper" =>
        RandomForestRegressorWrapper.load(path)
      case "org.apache.spark.ml.r.RandomForestClassifierWrapper" =>
        RandomForestClassifierWrapper.load(path)
      case "org.apache.spark.ml.r.GBTRegressorWrapper" =>
        GBTRegressorWrapper.load(path)
      case "org.apache.spark.ml.r.GBTClassifierWrapper" =>
        GBTClassifierWrapper.load(path)
      case _ =>
        throw new SparkException(s"SparkR read.ml does not support load $className")
    }
  }
} 
Example 5
Source File: MultilayerPerceptronClassifierWrapper.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.r

import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier}
import org.apache.spark.ml.feature.{IndexToString, RFormula}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.r.RWrapperUtils._
import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.{DataFrame, Dataset}

private[r] class MultilayerPerceptronClassifierWrapper private (
    val pipeline: PipelineModel
  ) extends MLWritable {

  import MultilayerPerceptronClassifierWrapper._

  val mlpModel: MultilayerPerceptronClassificationModel =
    pipeline.stages(1).asInstanceOf[MultilayerPerceptronClassificationModel]

  val weights: Array[Double] = mlpModel.weights.toArray
  val layers: Array[Int] = mlpModel.layers

  def transform(dataset: Dataset[_]): DataFrame = {
    pipeline.transform(dataset)
      .drop(mlpModel.getFeaturesCol)
      .drop(mlpModel.getLabelCol)
      .drop(PREDICTED_LABEL_INDEX_COL)
  }

  
  override def read: MLReader[MultilayerPerceptronClassifierWrapper] =
    new MultilayerPerceptronClassifierWrapperReader

  override def load(path: String): MultilayerPerceptronClassifierWrapper = super.load(path)

  class MultilayerPerceptronClassifierWrapperReader
    extends MLReader[MultilayerPerceptronClassifierWrapper]{

    override def load(path: String): MultilayerPerceptronClassifierWrapper = {
      implicit val format = DefaultFormats
      val pipelinePath = new Path(path, "pipeline").toString

      val pipeline = PipelineModel.load(pipelinePath)
      new MultilayerPerceptronClassifierWrapper(pipeline)
    }
  }

  class MultilayerPerceptronClassifierWrapperWriter(instance: MultilayerPerceptronClassifierWrapper)
    extends MLWriter {

    override protected def saveImpl(path: String): Unit = {
      val rMetadataPath = new Path(path, "rMetadata").toString
      val pipelinePath = new Path(path, "pipeline").toString

      val rMetadata = "class" -> instance.getClass.getName
      val rMetadataJson: String = compact(render(rMetadata))
      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)

      instance.pipeline.save(pipelinePath)
    }
  }
} 
Example 6
Source File: RWrappers.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.r

import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkException
import org.apache.spark.ml.util.MLReader


private[r] object RWrappers extends MLReader[Object] {

  override def load(path: String): Object = {
    implicit val format = DefaultFormats
    val rMetadataPath = new Path(path, "rMetadata").toString
    val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
    val rMetadata = parse(rMetadataStr)
    val className = (rMetadata \ "class").extract[String]
    className match {
      case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path)
      case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" =>
        AFTSurvivalRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" =>
        GeneralizedLinearRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.KMeansWrapper" =>
        KMeansWrapper.load(path)
      case "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper" =>
        MultilayerPerceptronClassifierWrapper.load(path)
      case "org.apache.spark.ml.r.LDAWrapper" =>
        LDAWrapper.load(path)
      case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
        IsotonicRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
        GaussianMixtureWrapper.load(path)
      case "org.apache.spark.ml.r.ALSWrapper" =>
        ALSWrapper.load(path)
      case "org.apache.spark.ml.r.LogisticRegressionWrapper" =>
        LogisticRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.RandomForestRegressorWrapper" =>
        RandomForestRegressorWrapper.load(path)
      case "org.apache.spark.ml.r.RandomForestClassifierWrapper" =>
        RandomForestClassifierWrapper.load(path)
      case "org.apache.spark.ml.r.GBTRegressorWrapper" =>
        GBTRegressorWrapper.load(path)
      case "org.apache.spark.ml.r.GBTClassifierWrapper" =>
        GBTClassifierWrapper.load(path)
      case _ =>
        throw new SparkException(s"SparkR read.ml does not support load $className")
    }
  }
} 
Example 7
Source File: MultilayerPerceptronClassifierWrapper.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.r

import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier}
import org.apache.spark.ml.feature.{IndexToString, RFormula}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.r.RWrapperUtils._
import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.{DataFrame, Dataset}

private[r] class MultilayerPerceptronClassifierWrapper private (
    val pipeline: PipelineModel
  ) extends MLWritable {

  import MultilayerPerceptronClassifierWrapper._

  val mlpModel: MultilayerPerceptronClassificationModel =
    pipeline.stages(1).asInstanceOf[MultilayerPerceptronClassificationModel]

  val weights: Array[Double] = mlpModel.weights.toArray
  val layers: Array[Int] = mlpModel.layers

  def transform(dataset: Dataset[_]): DataFrame = {
    pipeline.transform(dataset)
      .drop(mlpModel.getFeaturesCol)
      .drop(mlpModel.getLabelCol)
      .drop(PREDICTED_LABEL_INDEX_COL)
  }

  
  override def read: MLReader[MultilayerPerceptronClassifierWrapper] =
    new MultilayerPerceptronClassifierWrapperReader

  override def load(path: String): MultilayerPerceptronClassifierWrapper = super.load(path)

  class MultilayerPerceptronClassifierWrapperReader
    extends MLReader[MultilayerPerceptronClassifierWrapper]{

    override def load(path: String): MultilayerPerceptronClassifierWrapper = {
      implicit val format = DefaultFormats
      val pipelinePath = new Path(path, "pipeline").toString

      val pipeline = PipelineModel.load(pipelinePath)
      new MultilayerPerceptronClassifierWrapper(pipeline)
    }
  }

  class MultilayerPerceptronClassifierWrapperWriter(instance: MultilayerPerceptronClassifierWrapper)
    extends MLWriter {

    override protected def saveImpl(path: String): Unit = {
      val rMetadataPath = new Path(path, "rMetadata").toString
      val pipelinePath = new Path(path, "pipeline").toString

      val rMetadata = "class" -> instance.getClass.getName
      val rMetadataJson: String = compact(render(rMetadata))
      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)

      instance.pipeline.save(pipelinePath)
    }
  }
} 
Example 8
Source File: SerializableSparkModel.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperables.serialization

import org.apache.spark.ml.Model
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType

import ai.deepsense.sparkutils.ML

class SerializableSparkModel[M <: Model[M]](val sparkModel: M)
  extends ML.Model[SerializableSparkModel[M]]
  with MLWritable {

  override def copy(extra: ParamMap): SerializableSparkModel[M] =
    new SerializableSparkModel(sparkModel.copy(extra))

  override def write: MLWriter = {
    sparkModel match {
      case w: MLWritable => w.write
      case _ => new DefaultMLWriter(this)
    }
  }

  override def transformDF(dataset: DataFrame): DataFrame = sparkModel.transform(dataset)

  override def transformSchema(schema: StructType): StructType = sparkModel.transformSchema(schema)

  override val uid: String = "dc7178fe-b209-44f5-8a74-d3c4dafa0fae"
}

// This class may seem unused, but it is used reflectively by spark deserialization mechanism
object SerializableSparkModel extends MLReadable[SerializableSparkModel[_]] {
  override def read: MLReader[SerializableSparkModel[_]] = {
    new DefaultMLReader[SerializableSparkModel[_]]()
  }
} 
Example 9
Source File: SparkStageParam.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages

import com.salesforce.op.stages.sparkwrappers.generic.SparkWrapperParams
import org.apache.hadoop.fs.Path
import org.apache.spark.ml.PipelineStage
import org.apache.spark.ml.param.{Param, ParamPair, Params}
import org.apache.spark.ml.util.{Identifiable, MLReader, MLWritable}
import org.apache.spark.util.SparkUtils
import org.json4s.JsonAST.{JObject, JValue}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, parse, render}
import org.json4s.{DefaultFormats, Formats, JString}

class SparkStageParam[S <: PipelineStage with Params]
(
  parent: String,
  name: String,
  doc: String,
  isValid: Option[S] => Boolean
) extends Param[Option[S]](parent, name, doc, isValid) {

  import SparkStageParam._

  
  override def jsonDecode(jsonStr: String): Option[S] = {
    val json = parse(jsonStr)
    val uid = (json \ "uid").extractOpt[String]
    val path = (json \ "path").extractOpt[String]

    path -> uid match {
      case (None, _) | (_, None) | (_, Some(NoUID)) =>
        savePath = None
        None
      case (Some(p), Some(stageUid)) =>
        savePath = Option(p)
        val stagePath = new Path(p, stageUid).toString
        val className = (json \ "className").extract[String]
        val cls = SparkUtils.classForName(className)
        val stage = cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath)
        Option(stage).map(_.asInstanceOf[S])
    }
  }
}

object SparkStageParam {
  implicit val formats: Formats = DefaultFormats
  val NoClass = ""
  val NoUID = ""

  def updateParamsMetadataWithPath(jValue: JValue, path: String): JValue = jValue match {
    case JObject(pairs) => JObject(
      pairs.map {
        case (SparkWrapperParams.SparkStageParamName, j) =>
          SparkWrapperParams.SparkStageParamName -> j.merge(JObject("path" -> JString(path)))
        case param => param
      }
    )
    case j => throw new IllegalArgumentException(s"Cannot recognize JSON Spark params metadata: $j")
  }

} 
Example 10
Source File: RWrappers.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.r

import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkException
import org.apache.spark.ml.util.MLReader


private[r] object RWrappers extends MLReader[Object] {

  override def load(path: String): Object = {
    implicit val format = DefaultFormats
    val rMetadataPath = new Path(path, "rMetadata").toString
    val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
    val rMetadata = parse(rMetadataStr)
    val className = (rMetadata \ "class").extract[String]
    className match {
      case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path)
      case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" =>
        AFTSurvivalRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" =>
        GeneralizedLinearRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.KMeansWrapper" =>
        KMeansWrapper.load(path)
      case "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper" =>
        MultilayerPerceptronClassifierWrapper.load(path)
      case "org.apache.spark.ml.r.LDAWrapper" =>
        LDAWrapper.load(path)
      case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
        IsotonicRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
        GaussianMixtureWrapper.load(path)
      case "org.apache.spark.ml.r.ALSWrapper" =>
        ALSWrapper.load(path)
      case "org.apache.spark.ml.r.LogisticRegressionWrapper" =>
        LogisticRegressionWrapper.load(path)
      case "org.apache.spark.ml.r.RandomForestRegressorWrapper" =>
        RandomForestRegressorWrapper.load(path)
      case "org.apache.spark.ml.r.RandomForestClassifierWrapper" =>
        RandomForestClassifierWrapper.load(path)
      case "org.apache.spark.ml.r.DecisionTreeRegressorWrapper" =>
        DecisionTreeRegressorWrapper.load(path)
      case "org.apache.spark.ml.r.DecisionTreeClassifierWrapper" =>
        DecisionTreeClassifierWrapper.load(path)
      case "org.apache.spark.ml.r.GBTRegressorWrapper" =>
        GBTRegressorWrapper.load(path)
      case "org.apache.spark.ml.r.GBTClassifierWrapper" =>
        GBTClassifierWrapper.load(path)
      case "org.apache.spark.ml.r.BisectingKMeansWrapper" =>
        BisectingKMeansWrapper.load(path)
      case "org.apache.spark.ml.r.LinearSVCWrapper" =>
        LinearSVCWrapper.load(path)
      case "org.apache.spark.ml.r.FPGrowthWrapper" =>
        FPGrowthWrapper.load(path)
      case _ =>
        throw new SparkException(s"SparkR read.ml does not support load $className")
    }
  }
} 
Example 11
Source File: MultilayerPerceptronClassifierWrapper.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.r

import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier}
import org.apache.spark.ml.feature.{IndexToString, RFormula}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.r.RWrapperUtils._
import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.{DataFrame, Dataset}

private[r] class MultilayerPerceptronClassifierWrapper private (
    val pipeline: PipelineModel
  ) extends MLWritable {

  import MultilayerPerceptronClassifierWrapper._

  private val mlpModel: MultilayerPerceptronClassificationModel =
    pipeline.stages(1).asInstanceOf[MultilayerPerceptronClassificationModel]

  lazy val weights: Array[Double] = mlpModel.weights.toArray
  lazy val layers: Array[Int] = mlpModel.layers

  def transform(dataset: Dataset[_]): DataFrame = {
    pipeline.transform(dataset)
      .drop(mlpModel.getFeaturesCol)
      .drop(mlpModel.getLabelCol)
      .drop(PREDICTED_LABEL_INDEX_COL)
  }

  
  override def read: MLReader[MultilayerPerceptronClassifierWrapper] =
    new MultilayerPerceptronClassifierWrapperReader

  override def load(path: String): MultilayerPerceptronClassifierWrapper = super.load(path)

  class MultilayerPerceptronClassifierWrapperReader
    extends MLReader[MultilayerPerceptronClassifierWrapper]{

    override def load(path: String): MultilayerPerceptronClassifierWrapper = {
      implicit val format = DefaultFormats
      val pipelinePath = new Path(path, "pipeline").toString

      val pipeline = PipelineModel.load(pipelinePath)
      new MultilayerPerceptronClassifierWrapper(pipeline)
    }
  }

  class MultilayerPerceptronClassifierWrapperWriter(instance: MultilayerPerceptronClassifierWrapper)
    extends MLWriter {

    override protected def saveImpl(path: String): Unit = {
      val rMetadataPath = new Path(path, "rMetadata").toString
      val pipelinePath = new Path(path, "pipeline").toString

      val rMetadata = "class" -> instance.getClass.getName
      val rMetadataJson: String = compact(render(rMetadata))
      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)

      instance.pipeline.save(pipelinePath)
    }
  }
} 
Example 12
Source File: SerializableSparkModel.scala    From seahorse-workflow-executor   with Apache License 2.0 5 votes vote down vote up
package io.deepsense.deeplang.doperables.serialization

import org.apache.spark.ml.Model
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType

import io.deepsense.sparkutils.ML

class SerializableSparkModel[M <: Model[M]](val sparkModel: M)
  extends ML.Model[SerializableSparkModel[M]]
  with MLWritable {

  override def copy(extra: ParamMap): SerializableSparkModel[M] =
    new SerializableSparkModel(sparkModel.copy(extra))

  override def write: MLWriter = {
    sparkModel match {
      case w: MLWritable => w.write
      case _ => new DefaultMLWriter(this)
    }
  }

  override def transformDF(dataset: DataFrame): DataFrame = sparkModel.transform(dataset)

  override def transformSchema(schema: StructType): StructType = sparkModel.transformSchema(schema)

  override val uid: String = "dc7178fe-b209-44f5-8a74-d3c4dafa0fae"
}

// This class may seem unused, but it is used reflectively by spark deserialization mechanism
object SerializableSparkModel extends MLReadable[SerializableSparkModel[_]] {
  override def read: MLReader[SerializableSparkModel[_]] = {
    new DefaultMLReader[SerializableSparkModel[_]]()
  }
}