org.apache.spark.ml.attribute.NominalAttribute Scala Examples

The following examples show how to use org.apache.spark.ml.attribute.NominalAttribute. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: RWrapperUtils.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.r

import org.apache.spark.internal.Logging
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.feature.{RFormula, RFormulaModel}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.Dataset

private[r] object RWrapperUtils extends Logging {

  
  def getFeaturesAndLabels(
      rFormulaModel: RFormulaModel,
      data: Dataset[_]): (Array[String], Array[String]) = {
    val schema = rFormulaModel.transform(data).schema
    val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
      .attributes.get
    val features = featureAttrs.map(_.name.get)
    val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
      .asInstanceOf[NominalAttribute]
    val labels = labelAttr.values.get
    (features, labels)
  }
} 
Example 2
Source File: QuantileDiscretizerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkContext, SparkFunSuite}

class QuantileDiscretizerSuite
  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  import org.apache.spark.ml.feature.QuantileDiscretizerSuite._

  test("Test quantile discretizer") {
    checkDiscretizedData(sc,
      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
      10,
      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
      Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))

    checkDiscretizedData(sc,
      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
      4,
      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
      Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))

    checkDiscretizedData(sc,
      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
      3,
      Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2),
      Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity"))

    checkDiscretizedData(sc,
      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
      2,
      Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1),
      Array("-Infinity, 2.0", "2.0, Infinity"))

  }

  test("Test getting splits") {
    val splitTestPoints = Array(
      Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
      Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
      Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
      Array(Double.NegativeInfinity, Double.PositiveInfinity)
        -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
      Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
      Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity),
      Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity)
    )
    for ((ori, res) <- splitTestPoints) {
      assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.")
    }
  }

  test("read/write") {
    val t = new QuantileDiscretizer()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setNumBuckets(6)
    testDefaultReadWrite(t)
  }
}

private object QuantileDiscretizerSuite extends SparkFunSuite {

  def checkDiscretizedData(
      sc: SparkContext,
      data: Array[Double],
      numBucket: Int,
      expectedResult: Array[Double],
      expectedAttrs: Array[String]): Unit = {
    val sqlCtx = SQLContext.getOrCreate(sc)
    import sqlCtx.implicits._

    val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
    val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
      .setNumBuckets(numBucket)
    val result = discretizer.fit(df).transform(df)

    val transformedFeatures = result.select("result").collect()
      .map { case Row(transformedFeature: Double) => transformedFeature }
    val transformedAttrs = Attribute.fromStructField(result.schema("result"))
      .asInstanceOf[NominalAttribute].values.get

    assert(transformedFeatures === expectedResult,
      "Transformed features do not equal expected features.")
    assert(transformedAttrs === expectedAttrs,
      "Transformed attributes do not equal expected attributes.")
  }
} 
Example 3
Source File: OneHotEncoderSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col

class OneHotEncoderSuite
  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  def stringIndexed(): DataFrame = {
    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
    val df = sqlContext.createDataFrame(data).toDF("id", "label")
    val indexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("labelIndex")
      .fit(df)
    indexer.transform(df)
  }

  test("params") {
    ParamsSuite.checkParams(new OneHotEncoder)
  }

  test("OneHotEncoder dropLast = false") {
    val transformed = stringIndexed()
    val encoder = new OneHotEncoder()
      .setInputCol("labelIndex")
      .setOutputCol("labelVec")
      .setDropLast(false)
    val encoded = encoder.transform(transformed)

    val output = encoded.select("id", "labelVec").map { r =>
      val vec = r.getAs[Vector](1)
      (r.getInt(0), vec(0), vec(1), vec(2))
    }.collect().toSet
    // a -> 0, b -> 2, c -> 1
    val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
      (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
    assert(output === expected)
  }

  test("OneHotEncoder dropLast = true") {
    val transformed = stringIndexed()
    val encoder = new OneHotEncoder()
      .setInputCol("labelIndex")
      .setOutputCol("labelVec")
    val encoded = encoder.transform(transformed)

    val output = encoded.select("id", "labelVec").map { r =>
      val vec = r.getAs[Vector](1)
      (r.getInt(0), vec(0), vec(1))
    }.collect().toSet
    // a -> 0, b -> 2, c -> 1
    val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0),
      (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0))
    assert(output === expected)
  }

  test("input column with ML attribute") {
    val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
    val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size")
      .select(col("size").as("size", attr.toMetadata()))
    val encoder = new OneHotEncoder()
      .setInputCol("size")
      .setOutputCol("encoded")
    val output = encoder.transform(df)
    val group = AttributeGroup.fromStructField(output.schema("encoded"))
    assert(group.size === 2)
    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
  }

  test("input column without ML attribute") {
    val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index")
    val encoder = new OneHotEncoder()
      .setInputCol("index")
      .setOutputCol("encoded")
    val output = encoder.transform(df)
    val group = AttributeGroup.fromStructField(output.schema("encoded"))
    assert(group.size === 2)
    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
  }

  test("read/write") {
    val t = new OneHotEncoder()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setDropLast(false)
    testDefaultReadWrite(t)
  }
} 
Example 4
Source File: RWrapperUtils.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.r

import org.apache.spark.internal.Logging
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.feature.{RFormula, RFormulaModel}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.Dataset

private[r] object RWrapperUtils extends Logging {

  
  def getFeaturesAndLabels(
      rFormulaModel: RFormulaModel,
      data: Dataset[_]): (Array[String], Array[String]) = {
    val schema = rFormulaModel.transform(data).schema
    val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
      .attributes.get
    val features = featureAttrs.map(_.name.get)
    val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
      .asInstanceOf[NominalAttribute]
    val labels = labelAttr.values.get
    (features, labels)
  }
} 
Example 5
Source File: PredictionDeIndexer.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.preparators

import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.binary.{BinaryEstimator, BinaryModel}
import com.salesforce.op.stages.impl.feature.{OpIndexToStringNoFilter, SaveOthersParams}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.sql.Dataset

import scala.util.{Failure, Success, Try}


  override def fitFn(dataset: Dataset[(Option[Double], Option[Double])]): BinaryModel[RealNN, RealNN, Text] = {
    val colSchema = getInputSchema()(in1.name)
    val labels: Array[String] = Try(Attribute.fromStructField(colSchema).asInstanceOf[NominalAttribute].values.get)
    match {
      case Success(l) => l
      case Failure(l) => throw new Error(s"The feature ${in1.name} does not contain" +
        s" any label/index mapping in its metadata")
    }

    new PredictionDeIndexerModel(labels, $(unseenName), operationName, uid)
  }
}

final class PredictionDeIndexerModel private[op]
(
  val labels: Array[String],
  val unseen: String,
  operationName: String,
  uid: String
) extends BinaryModel[RealNN, RealNN, Text](operationName = operationName, uid = uid) {

  def transformFn: (RealNN, RealNN) => Text = (response: RealNN, pred: RealNN) => {
    val idx = pred.value.get.toInt
    if (0 <= idx && idx < labels.length) labels(idx).toText
    else unseen.toText
  }

} 
Example 6
Source File: OpIndexToStringNoFilter.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.UnaryTransformer
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.StringArrayParam


  override def transformFn: (RealNN) => Text = {
    (input: RealNN) => {
      val inputColSchema = getInputSchema()(in1.name)
      // If the labels array is empty use column metadata
      val lbls = $(labels)
      val unseen = $(unseenName)
      val values = if (!isDefined(labels) || lbls.isEmpty) {
        Attribute.fromStructField(inputColSchema)
          .asInstanceOf[NominalAttribute].values.get
      } else {
        lbls
      }
      val idx = input.value.get.toInt
      if (0 <= idx && idx < values.length) {
        values(idx).toText
      } else {
        unseen.toText
      }
    }
  }
}

object OpIndexToStringNoFilter {
  val unseenDefault: String = "UnseenIndex"
} 
Example 7
Source File: OpStringIndexerNoFilter.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.{UnaryEstimator, UnaryModel}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.sql.Dataset

import scala.reflect.runtime.universe.TypeTag


class OpStringIndexerNoFilter[I <: Text]
(
  uid: String = UID[OpStringIndexerNoFilter[I]]
)(implicit tti: TypeTag[I], ttiv: TypeTag[I#Value])
  extends UnaryEstimator[I, RealNN](operationName = "str2idx", uid = uid) with SaveOthersParams {

  setDefault(unseenName, OpStringIndexerNoFilter.UnseenNameDefault)

  def fitFn(data: Dataset[I#Value]): UnaryModel[I, RealNN] = {
    val unseen = $(unseenName)
    val counts = data.rdd.countByValue()
    val labels = counts.toSeq
      .sortBy { case (label, count) => (-count, label) }
      .map { case (label, _) => label }
      .toArray

    val otherPos = labels.length

    val cleanedLabels = labels.map(_.getOrElse("null")) :+ unseen
    val metadata = NominalAttribute.defaultAttr.withName(getOutputFeatureName).withValues(cleanedLabels).toMetadata()
    setMetadata(metadata)

    new OpStringIndexerNoFilterModel[I](labels, otherPos, operationName = operationName, uid = uid)
  }
}

final class OpStringIndexerNoFilterModel[I <: Text] private[op]
(
  val labels: Seq[Option[String]],
  val otherPos: Int,
  operationName: String,
  uid: String
)(implicit tti: TypeTag[I]) extends UnaryModel[I, RealNN](operationName = operationName, uid = uid) {

  private val labelsMap = labels.zipWithIndex.toMap
  def transformFn: I => RealNN = in => labelsMap.getOrElse(in.value, otherPos).toRealNN
}

object OpStringIndexerNoFilter {
  val UnseenNameDefault = "UnseenLabel"
} 
Example 8
Source File: StringIndexerSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext

class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("params") {
    ParamsSuite.checkParams(new StringIndexer)
    val model = new StringIndexerModel("indexer", Array("a", "b"))
    ParamsSuite.checkParams(model)
  }

  test("StringIndexer") {
    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
    val df = sqlContext.createDataFrame(data).toDF("id", "label")
    val indexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("labelIndex")
      .fit(df)
    val transformed = indexer.transform(df)
    val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
      .asInstanceOf[NominalAttribute]
    assert(attr.values.get === Array("a", "c", "b"))
    val output = transformed.select("id", "labelIndex").map { r =>
      (r.getInt(0), r.getDouble(1))
    }.collect().toSet
    // a -> 0, b -> 2, c -> 1
    val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
    assert(output === expected)
  }

  test("StringIndexer with a numeric input column") {
    val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
    val df = sqlContext.createDataFrame(data).toDF("id", "label")
    val indexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("labelIndex")
      .fit(df)
    val transformed = indexer.transform(df)
    val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
      .asInstanceOf[NominalAttribute]
    assert(attr.values.get === Array("100", "300", "200"))
    val output = transformed.select("id", "labelIndex").map { r =>
      (r.getInt(0), r.getDouble(1))
    }.collect().toSet
    // 100 -> 0, 200 -> 2, 300 -> 1
    val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
    assert(output === expected)
  }

  test("StringIndexerModel should keep silent if the input column does not exist.") {
    val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
      .setInputCol("label")
      .setOutputCol("labelIndex")
    val df = sqlContext.range(0L, 10L)
    assert(indexerModel.transform(df).eq(df))
  }
} 
Example 9
Source File: VectorAssemblerSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col

class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("params") {
    ParamsSuite.checkParams(new VectorAssembler)
  }

  test("assemble") {
    import org.apache.spark.ml.feature.VectorAssembler.assemble
    assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
    assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0)))
    val dv = Vectors.dense(2.0, 0.0)
    assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
    val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0))
    assert(assemble(0.0, dv, 1.0, sv) ===
      Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0)))
    for (v <- Seq(1, "a", null)) {
      intercept[SparkException](assemble(v))
      intercept[SparkException](assemble(1.0, v))
    }
  }

  test("assemble should compress vectors") {
    import org.apache.spark.ml.feature.VectorAssembler.assemble
    val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0))
    assert(v1.isInstanceOf[SparseVector])
    val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0)))
    assert(v2.isInstanceOf[DenseVector])
  }

  test("VectorAssembler") {
    val df = sqlContext.createDataFrame(Seq(
      (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
    )).toDF("id", "x", "y", "name", "z", "n")
    val assembler = new VectorAssembler()
      .setInputCols(Array("x", "y", "z", "n"))
      .setOutputCol("features")
    assembler.transform(df).select("features").collect().foreach {
      case Row(v: Vector) =>
        assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0)))
    }
  }

  test("ML attributes") {
    val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari")
    val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0)
    val user = new AttributeGroup("user", Array(
      NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"),
      NumericAttribute.defaultAttr.withName("salary")))
    val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0)))
    val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
      .select(
        col("browser").as("browser", browser.toMetadata()),
        col("hour").as("hour", hour.toMetadata()),
        col("count"), // "count" is an integer column without ML attribute
        col("user").as("user", user.toMetadata()),
        col("ad")) // "ad" is a vector column without ML attribute
    val assembler = new VectorAssembler()
      .setInputCols(Array("browser", "hour", "count", "user", "ad"))
      .setOutputCol("features")
    val output = assembler.transform(df)
    val schema = output.schema
    val features = AttributeGroup.fromStructField(schema("features"))
    assert(features.size === 7)
    val browserOut = features.getAttr(0)
    assert(browserOut === browser.withIndex(0).withName("browser"))
    val hourOut = features.getAttr(1)
    assert(hourOut === hour.withIndex(1).withName("hour"))
    val countOut = features.getAttr(2)
    assert(countOut === NumericAttribute.defaultAttr.withName("count").withIndex(2))
    val userGenderOut = features.getAttr(3)
    assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3))
    val userSalaryOut = features.getAttr(4)
    assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4))
    assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
    assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
  }
} 
Example 10
Source File: OneHotEncoderSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col

class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {

  def stringIndexed(): DataFrame = {
    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
    val df = sqlContext.createDataFrame(data).toDF("id", "label")
    val indexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("labelIndex")
      .fit(df)
    indexer.transform(df)
  }

  test("params") {
    ParamsSuite.checkParams(new OneHotEncoder)
  }

  test("OneHotEncoder dropLast = false") {
    val transformed = stringIndexed()
    val encoder = new OneHotEncoder()
      .setInputCol("labelIndex")
      .setOutputCol("labelVec")
      .setDropLast(false)
    val encoded = encoder.transform(transformed)

    val output = encoded.select("id", "labelVec").map { r =>
      val vec = r.getAs[Vector](1)
      (r.getInt(0), vec(0), vec(1), vec(2))
    }.collect().toSet
    // a -> 0, b -> 2, c -> 1
    val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
      (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
    assert(output === expected)
  }

  test("OneHotEncoder dropLast = true") {
    val transformed = stringIndexed()
    val encoder = new OneHotEncoder()
      .setInputCol("labelIndex")
      .setOutputCol("labelVec")
    val encoded = encoder.transform(transformed)

    val output = encoded.select("id", "labelVec").map { r =>
      val vec = r.getAs[Vector](1)
      (r.getInt(0), vec(0), vec(1))
    }.collect().toSet
    // a -> 0, b -> 2, c -> 1
    val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0),
      (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0))
    assert(output === expected)
  }

  test("input column with ML attribute") {
    val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
    val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size")
      .select(col("size").as("size", attr.toMetadata()))
    val encoder = new OneHotEncoder()
      .setInputCol("size")
      .setOutputCol("encoded")
    val output = encoder.transform(df)
    val group = AttributeGroup.fromStructField(output.schema("encoded"))
    assert(group.size === 2)
    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
  }

  test("input column without ML attribute") {
    val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index")
    val encoder = new OneHotEncoder()
      .setInputCol("index")
      .setOutputCol("encoded")
    val output = encoder.transform(df)
    val group = AttributeGroup.fromStructField(output.schema("encoded"))
    assert(group.size === 2)
    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
  }
} 
Example 11
Source File: OneHotEncoderOp.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.bundle.ops.feature

import ml.combust.bundle.BundleContext
import ml.combust.bundle.dsl._
import ml.combust.bundle.op.OpModel
import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.bundle._
import org.apache.spark.ml.feature.OneHotEncoderModel
import org.apache.spark.sql.types.StructField

import scala.util.{Failure, Try}

object OneHotEncoderOp {
  def sizeForField(field: StructField): Int = {
    val attr = Attribute.fromStructField(field)

    (attr match {
      case nominal: NominalAttribute =>
        if (nominal.values.isDefined) {
          Try(nominal.values.get.length)
        } else if (nominal.numValues.isDefined) {
          Try(nominal.numValues.get)
        } else {
          Failure(new RuntimeException(s"invalid nominal value for field ${field.name}"))
        }
      case binary: BinaryAttribute =>
        Try(2)
      case _: NumericAttribute =>
        Failure(new RuntimeException(s"invalid numeric attribute for field ${field.name}"))
      case _ =>
        Failure(new RuntimeException(s"unsupported attribute for field ${field.name}")) // optimistic about unknown attributes
    }).get
  }
}

class OneHotEncoderOp extends SimpleSparkOp[OneHotEncoderModel] {
  override val Model: OpModel[SparkBundleContext, OneHotEncoderModel] = new OpModel[SparkBundleContext, OneHotEncoderModel] {
    override val klazz: Class[OneHotEncoderModel] = classOf[OneHotEncoderModel]

    override def opName: String = Bundle.BuiltinOps.feature.one_hot_encoder

    override def store(model: Model, obj: OneHotEncoderModel)
                      (implicit context: BundleContext[SparkBundleContext]): Model = {
      assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz))

      val df = context.context.dataset.get
      val categorySizes = obj.getInputCols.map { f ⇒ OneHotEncoderOp.sizeForField(df.schema(f)) }

      model.withValue("category_sizes", Value.intList(categorySizes))
        .withValue("drop_last", Value.boolean(obj.getDropLast))
        .withValue("handle_invalid", Value.string(obj.getHandleInvalid))

    }

    override def load(model: Model)
                     (implicit context: BundleContext[SparkBundleContext]): OneHotEncoderModel = {
      new OneHotEncoderModel(uid = "", categorySizes = model.value("category_sizes").getIntList.toArray)
          .setDropLast(model.value("drop_last").getBoolean)
          .setHandleInvalid(model.value("handle_invalid").getString)
    }
  }

  override def sparkLoad(uid: String, shape: NodeShape, model: OneHotEncoderModel): OneHotEncoderModel = {
    new OneHotEncoderModel(uid = uid, categorySizes = model.categorySizes)
      .setDropLast(model.getDropLast)
      .setHandleInvalid(model.getHandleInvalid)
  }

  override def sparkInputs(obj: OneHotEncoderModel): Seq[ParamSpec] = Seq(ParamSpec("input", obj.inputCols))

  override def sparkOutputs(obj: OneHotEncoderModel): Seq[ParamSpec] = Seq(ParamSpec("output", obj.outputCols))

} 
Example 12
Source File: StringToShortIndexerSpec.scala    From spark-ext   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import com.collective.TestSparkContext
import org.apache.spark.ml.attribute.{NominalAttribute, Attribute}
import org.scalatest.FlatSpec

class StringToShortIndexerSpec extends FlatSpec with TestSparkContext {

  "StringToShortIndexer" should "assign correct index for columns" in {
    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
    val df = sqlContext.createDataFrame(data).toDF("id", "label")
    val indexer = new StringToShortIndexer()
      .setInputCol("label")
      .setOutputCol("labelIndex")
      .fit(df)

    val transformed = indexer.transform(df)
    val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
      .asInstanceOf[NominalAttribute]
    assert(attr.values.get === Array("a", "c", "b"))
    val output = transformed.select("id", "labelIndex").map { r =>
      (r.getInt(0), r.getShort(1))
    }.collect().toSet
    // a -> 0, b -> 2, c -> 1
    val expected = Set((0, 0), (1, 2), (2, 1), (3, 0), (4, 0), (5, 1))
    assert(output === expected)
  }

} 
Example 13
Source File: S2CellTransformer.scala    From spark-ext   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import com.google.common.geometry.{S2LatLng, S2CellId}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}


class S2CellTransformer(override val uid: String) extends Transformer {

  def this() = this(Identifiable.randomUID("S2CellTransformer"))

  // Input/Output column names

  val latCol: Param[String] = new Param[String](this, "latCol", "latitude column")

  val lonCol: Param[String] = new Param[String](this, "lonCol", "longitude column")

  val cellCol: Param[String] = new Param[String](this, "cellCol", "S2 Cell Id column")

  val level: Param[Int] = new IntParam(this, "level", "S2 Level [0, 30]",
    (i: Int) => ParamValidators.gtEq(0)(i) && ParamValidators.ltEq(30)(i))

  // Default parameters

  setDefault(
    latCol  -> "lat",
    lonCol  -> "lon",
    cellCol -> "cell",
    level   -> 10
  )

  def getLatCol: String = $(latCol)

  def getLonCol: String = $(lonCol)

  def getCellCol: String = $(cellCol)

  def getLevel: Int = $(level)

  def setLatCol(value: String): this.type = set(latCol, value)

  def setLonCol(value: String): this.type = set(lonCol, value)

  def setCellCol(value: String): this.type = set(cellCol, value)

  def setLevel(value: Int): this.type = set(level, value)

  override def transform(dataset: DataFrame): DataFrame = {
    val outputSchema = transformSchema(dataset.schema)
    val currentLevel = $(level)
    val t = udf { (lat: Double, lon: Double) =>
      val cellId = S2CellId.fromLatLng(S2LatLng.fromDegrees(lat, lon))
      cellId.parent(currentLevel).toToken
    }
    val metadata = outputSchema($(cellCol)).metadata
    dataset.select(col("*"), t(col($(latCol)), col($(lonCol))).as($(cellCol), metadata))
  }

  override def transformSchema(schema: StructType): StructType = {
    val latColumnName = $(latCol)
    val latDataType = schema(latColumnName).dataType
    require(latDataType == DoubleType,
      s"The latitude column $latColumnName must be Double type, " +
        s"but got $latDataType.")

    val lonColumnName = $(lonCol)
    val lonDataType = schema(lonColumnName).dataType
    require(lonDataType == DoubleType,
      s"The longitude column $lonColumnName must be Double type, " +
        s"but got $lonDataType.")

    val inputFields = schema.fields
    val outputColName = $(cellCol)
    require(inputFields.forall(_.name != outputColName),
      s"Output column $outputColName already exists.")

    val attr = NominalAttribute.defaultAttr.withName($(cellCol))
    val outputFields = inputFields :+ attr.toStructField()
    StructType(outputFields)
  }

  override def copy(extra: ParamMap): S2CellTransformer = defaultCopy(extra)
} 
Example 14
Source File: StringToShortIndexer.scala    From spark-ext   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkException
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap


class StringToShortIndexer(override val uid: String) extends Estimator[StringToShortIndexerModel]
with StringIndexerBase {

  def this() = this(Identifiable.randomUID("strShortIdx"))

  def setInputCol(value: String): this.type = set(inputCol, value)

  def setOutputCol(value: String): this.type = set(outputCol, value)

  override def fit(dataset: DataFrame): StringToShortIndexerModel = {
    val counts = dataset.select(col($(inputCol)).cast(StringType))
      .map(_.getString(0))
      .countByValue()
    val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
    require(labels.length <= Short.MaxValue,
      s"Unique labels count (${labels.length}) should be less then Short.MaxValue (${Short.MaxValue})")
    copyValues(new StringToShortIndexerModel(uid, labels).setParent(this))
  }

  override def transformSchema(schema: StructType): StructType = {
    validateAndTransformSchema(schema)
  }

  override def copy(extra: ParamMap): StringToShortIndexer = defaultCopy(extra)
}

class StringToShortIndexerModel (
  override val uid: String,
  val labels: Array[String]) extends Model[StringToShortIndexerModel] with StringIndexerBase {

  def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels)

  require(labels.length <= Short.MaxValue,
    s"Unique labels count (${labels.length}) should be less then Short.MaxValue (${Short.MaxValue})")

  private val labelToIndex: OpenHashMap[String, Short] = {
    val n = labels.length.toShort
    val map = new OpenHashMap[String, Short](n)
    var i: Short = 0
    while (i < n) {
      map.update(labels(i), i)
      i = (i + 1).toShort
    }
    map
  }

  def setInputCol(value: String): this.type = set(inputCol, value)

  def setOutputCol(value: String): this.type = set(outputCol, value)

  override def transform(dataset: DataFrame): DataFrame = {
    if (!dataset.schema.fieldNames.contains($(inputCol))) {
      logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
        "Skip StringToShortIndexerModel.")
      return dataset
    }

    val indexer = udf { label: String =>
      if (labelToIndex.contains(label)) {
        labelToIndex(label)
      } else {
        // TODO: handle unseen labels
        throw new SparkException(s"Unseen label: $label.")
      }
    }
    val outputColName = $(outputCol)
    val metadata = NominalAttribute.defaultAttr
      .withName(outputColName).withValues(labels).toMetadata()
    dataset.select(col("*"),
      indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata))
  }

  override def transformSchema(schema: StructType): StructType = {
    if (schema.fieldNames.contains($(inputCol))) {
      validateAndTransformSchema(schema)
    } else {
      // If the input column does not exist during transformation, we skip StringToShortIndexerModel.
      schema
    }
  }

  override def copy(extra: ParamMap): StringToShortIndexerModel = {
    val copied = new StringToShortIndexerModel(uid, labels)
    copyValues(copied, extra).setParent(parent)
  }
} 
Example 15
Source File: TreeUtils.scala    From spark-sql-perf   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml

import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.sql.DataFrame

object TreeUtils {
  
  def setMetadata(
      data: DataFrame,
      featuresColName: String,
      featureArity: Array[Int]): DataFrame = {
    val featuresAttributes = featureArity.zipWithIndex.map { case (arity: Int, feature: Int) =>
      if (arity > 0) {
        NominalAttribute.defaultAttr.withIndex(feature).withNumValues(arity)
      } else {
        NumericAttribute.defaultAttr.withIndex(feature)
      }
    }
    val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata()
    data.select(data(featuresColName).as(featuresColName, featuresMetadata))
  }
} 
Example 16
Source File: RWrapperUtils.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.r

import org.apache.spark.internal.Logging
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.feature.{RFormula, RFormulaModel}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.Dataset

private[r] object RWrapperUtils extends Logging {

  
  def getFeaturesAndLabels(
      rFormulaModel: RFormulaModel,
      data: Dataset[_]): (Array[String], Array[String]) = {
    val schema = rFormulaModel.transform(data).schema
    val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
      .attributes.get
    val features = featureAttrs.map(_.name.get)
    val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
      .asInstanceOf[NominalAttribute]
    val labels = labelAttr.values.get
    (features, labels)
  }
} 
Example 17
Source File: OneVsRestOp.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.bundle.extension.ops.classification

import ml.combust.bundle.BundleContext
import ml.combust.bundle.dsl._
import ml.combust.bundle.op.OpModel
import ml.combust.bundle.serializer.ModelSerializer
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.bundle.{ParamSpec, SimpleParamSpec, SimpleSparkOp, SparkBundleContext}
import org.apache.spark.ml.classification.ClassificationModel
import org.apache.spark.ml.mleap.classification.OneVsRestModel


class OneVsRestOp extends SimpleSparkOp[OneVsRestModel] {
  override val Model: OpModel[SparkBundleContext, OneVsRestModel] = new OpModel[SparkBundleContext, OneVsRestModel] {
    override val klazz: Class[OneVsRestModel] = classOf[OneVsRestModel]

    override def opName: String = Bundle.BuiltinOps.classification.one_vs_rest

    override def store(model: Model, obj: OneVsRestModel)
                      (implicit context: BundleContext[SparkBundleContext]): Model = {
      var i = 0
      for(cModel <- obj.models) {
        val name = s"model$i"
        ModelSerializer(context.bundleContext(name)).write(cModel)
        i = i + 1
        name
      }

      model.withValue("num_classes", Value.long(obj.models.length))
        .withValue("num_features", Value.long(obj.models.head.numFeatures))
    }

    override def load(model: Model)
                     (implicit context: BundleContext[SparkBundleContext]): OneVsRestModel = {
      val numClasses = model.value("num_classes").getLong.toInt

      val models = (0 until numClasses).toArray.map {
        i => ModelSerializer(context.bundleContext(s"model$i")).read().get.asInstanceOf[ClassificationModel[_, _]]
      }

      val labelMetadata = NominalAttribute.defaultAttr.
        withName("prediction").
        withNumValues(models.length).
        toMetadata
      new OneVsRestModel(uid = "", models = models, labelMetadata = labelMetadata)
    }
  }

  override def sparkLoad(uid: String, shape: NodeShape, model: OneVsRestModel): OneVsRestModel = {
    val labelMetadata = NominalAttribute.defaultAttr.
      withName(shape.output("prediction").name).
      withNumValues(model.models.length).
      toMetadata
    new OneVsRestModel(uid = uid, models = model.models, labelMetadata = labelMetadata)
  }

  override def sparkInputs(obj: OneVsRestModel): Seq[ParamSpec] = {
    Seq("features" -> obj.featuresCol)
  }

  override def sparkOutputs(obj: OneVsRestModel): Seq[SimpleParamSpec] = {
    Seq("probability" -> obj.probabilityCol,
      "prediction" -> obj.predictionCol)
  }
} 
Example 18
Source File: OneVsRestOp.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.bundle.ops.classification

import ml.combust.bundle.BundleContext
import ml.combust.bundle.op.{OpModel, OpNode}
import ml.combust.bundle.serializer.ModelSerializer
import ml.combust.bundle.dsl._
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.bundle.{ParamSpec, SimpleParamSpec, SimpleSparkOp, SparkBundleContext}
import org.apache.spark.ml.classification.{ClassificationModel, OneVsRestModel}


class OneVsRestOp extends SimpleSparkOp[OneVsRestModel] {
  override val Model: OpModel[SparkBundleContext, OneVsRestModel] = new OpModel[SparkBundleContext, OneVsRestModel] {
    override val klazz: Class[OneVsRestModel] = classOf[OneVsRestModel]

    override def opName: String = Bundle.BuiltinOps.classification.one_vs_rest

    override def store(model: Model, obj: OneVsRestModel)
                      (implicit context: BundleContext[SparkBundleContext]): Model = {
      var i = 0
      for(cModel <- obj.models) {
        val name = s"model$i"
        ModelSerializer(context.bundleContext(name)).write(cModel).get
        i = i + 1
        name
      }

      model.withValue("num_classes", Value.long(obj.models.length)).
        withValue("num_features", Value.long(obj.models.head.numFeatures))
    }

    override def load(model: Model)
                     (implicit context: BundleContext[SparkBundleContext]): OneVsRestModel = {
      val numClasses = model.value("num_classes").getLong.toInt

      val models = (0 until numClasses).toArray.map {
        i => ModelSerializer(context.bundleContext(s"model$i")).read().get.asInstanceOf[ClassificationModel[_, _]]
      }

      val labelMetadata = NominalAttribute.defaultAttr.
        withName("prediction").
        withNumValues(models.length).
        toMetadata
      new OneVsRestModel(uid = "", models = models, labelMetadata = labelMetadata)
    }
  }

  override def sparkLoad(uid: String, shape: NodeShape, model: OneVsRestModel): OneVsRestModel = {
    val labelMetadata = NominalAttribute.defaultAttr.
      withName(shape.output("prediction").name).
      withNumValues(model.models.length).
      toMetadata

    new OneVsRestModel(uid = uid,
      labelMetadata = labelMetadata,
      models = model.models)
  }

  override def sparkInputs(obj: OneVsRestModel): Seq[ParamSpec] = {
    Seq("features" -> obj.featuresCol)
  }

  override def sparkOutputs(obj: OneVsRestModel): Seq[SimpleParamSpec] = {
    Seq("raw_prediction" -> obj.rawPredictionCol, "prediction" -> obj.predictionCol)
  }
} 
Example 19
Source File: ReverseStringIndexerOp.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.bundle.ops.feature

import ml.combust.bundle.BundleContext
import ml.combust.bundle.op.OpModel
import ml.combust.bundle.dsl._
import ml.combust.mleap.core.types.{DataShape, ScalarShape}
import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.bundle._
import org.apache.spark.ml.feature.IndexToString
import org.apache.spark.sql.types.StructField
import ml.combust.mleap.runtime.types.BundleTypeConverters._

import scala.util.{Failure, Try}


object ReverseStringIndexerOp {
  def labelsForField(field: StructField): Array[String] = {
    val attr = Attribute.fromStructField(field)

    (attr match {
      case nominal: NominalAttribute =>
        if (nominal.values.isDefined) {
          Try(nominal.values.get)
        } else {
          Failure(new RuntimeException(s"invalid nominal value for field ${field.name}"))
        }
      case _: BinaryAttribute =>
        Failure(new RuntimeException(s"invalid binary attribute for field ${field.name}"))
      case _: NumericAttribute =>
        Failure(new RuntimeException(s"invalid numeric attribute for field ${field.name}"))
      case _ =>
        Failure(new RuntimeException(s"unsupported attribute for field ${field.name}")) // optimistic about unknown attributes
    }).get
  }
}

class ReverseStringIndexerOp extends SimpleSparkOp[IndexToString] {
  override val Model: OpModel[SparkBundleContext, IndexToString] = new OpModel[SparkBundleContext, IndexToString] {
    override val klazz: Class[IndexToString] = classOf[IndexToString]

    override def opName: String = Bundle.BuiltinOps.feature.reverse_string_indexer

    override def store(model: Model, obj: IndexToString)
                      (implicit context: BundleContext[SparkBundleContext]): Model = {
      val labels = obj.get(obj.labels).getOrElse {
        assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz))
        val df = context.context.dataset.get
        ReverseStringIndexerOp.labelsForField(df.schema(obj.getInputCol))
      }

      model.withValue("labels", Value.stringList(labels)).
        withValue("input_shape", Value.dataShape(ScalarShape(false)))
    }

    override def load(model: Model)
                     (implicit context: BundleContext[SparkBundleContext]): IndexToString = {
      model.getValue("input_shape").map(_.getDataShape: DataShape).foreach {
        shape =>
          require(shape.isScalar, "cannot deserialize non-scalar input to Spark IndexToString model")
      }
      
      new IndexToString(uid = "").setLabels(model.value("labels").getStringList.toArray)
    }
  }

  override def sparkLoad(uid: String, shape: NodeShape, model: IndexToString): IndexToString = {
    new IndexToString(uid = uid).setLabels(model.getLabels)
  }

  override def sparkInputs(obj: IndexToString): Seq[ParamSpec] = {
    Seq("input" -> obj.inputCol)
  }

  override def sparkOutputs(obj: IndexToString): Seq[SimpleParamSpec] = {
    Seq("output" -> obj.outputCol)
  }
} 
Example 20
Source File: InteractionOp.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.bundle.ops.feature

import ml.bundle.DataShape
import ml.combust.bundle.BundleContext
import ml.combust.bundle.dsl._
import ml.combust.bundle.op.{OpModel, OpNode}
import ml.combust.mleap.core.annotation.SparkCode
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.bundle._
import org.apache.spark.ml.feature.Interaction
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.mleap.TypeConverters._
import ml.combust.mleap.runtime.types.BundleTypeConverters._
import org.apache.spark.sql.types.{BooleanType, NumericType}


class InteractionOp extends SimpleSparkOp[Interaction] {
  override val Model: OpModel[SparkBundleContext, Interaction] = new OpModel[SparkBundleContext, Interaction] {
    override val klazz: Class[Interaction] = classOf[Interaction]

    override def opName: String = Bundle.BuiltinOps.feature.interaction

    override def store(model: Model, obj: Interaction)
                      (implicit context: BundleContext[SparkBundleContext]): Model = {
      assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz))

      val dataset = context.context.dataset.get
      val spec = buildSpec(obj.getInputCols, dataset)
      val inputShapes = obj.getInputCols.map(v => sparkToMleapDataShape(dataset.schema(v), dataset): DataShape)

      val m = model.withValue("num_inputs", Value.int(spec.length)).
        withValue("input_shapes", Value.dataShapeList(inputShapes))

      spec.zipWithIndex.foldLeft(m) {
        case (m2, (numFeatures, index)) => m2.withValue(s"num_features$index", Value.intList(numFeatures))
      }
    }

    override def load(model: Model)
                     (implicit context: BundleContext[SparkBundleContext]): Interaction = {
      // No need to do anything here, everything is handled through Spark meta data
      new Interaction()
    }

    @SparkCode(uri = "https://github.com/apache/spark/blob/branch-2.1/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala")
    private def buildSpec(inputCols: Array[String], dataset: DataFrame): Array[Array[Int]] = {
      def getNumFeatures(attr: Attribute): Int = {
        attr match {
          case nominal: NominalAttribute =>
            math.max(1, nominal.getNumValues.getOrElse(
              throw new IllegalArgumentException("Nominal features must have attr numValues defined.")))
          case _ =>
            1  // numeric feature
        }
      }

      inputCols.map(dataset.schema.apply).map { f =>
        f.dataType match {
          case _: NumericType | BooleanType =>
            Array(getNumFeatures(Attribute.fromStructField(f)))
          case _: VectorUDT =>
            val attrs = AttributeGroup.fromStructField(f).attributes.getOrElse(
              throw new IllegalArgumentException("Vector attributes must be defined for interaction."))
            attrs.map(getNumFeatures)
        }
      }
    }
  }

  override def sparkLoad(uid: String, shape: NodeShape, model: Interaction): Interaction = {
    new Interaction(uid = uid)
  }

  override def sparkInputs(obj: Interaction): Seq[ParamSpec] = {
    Seq("input" -> obj.inputCols)
  }

  override def sparkOutputs(obj: Interaction): Seq[SimpleParamSpec] = {
    Seq("output" -> obj.outputCol)
  }
}