org.apache.spark.ml.attribute.BinaryAttribute Scala Examples
The following examples show how to use org.apache.spark.ml.attribute.BinaryAttribute.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: Binarizer.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) val schema = dataset.schema val inputType = schema($(inputCol)).dataType val td = $(threshold) val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } val binarizerVector = udf { (data: Vector) => val indices = ArrayBuilder.make[Int] val values = ArrayBuilder.make[Double] data.foreachActive { (index, value) => if (value > td) { indices += index values += 1.0 } } Vectors.sparse(data.size, indices.result(), values.result()).compressed } val metadata = outputSchema($(outputCol)).metadata inputType match { case DoubleType => dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) case _: VectorUDT => dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata)) } } @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType val outputColName = $(outputCol) val outCol: StructField = inputType match { case DoubleType => BinaryAttribute.defaultAttr.withName(outputColName).toStructField() case _: VectorUDT => StructField(outputColName, new VectorUDT) case _ => throw new IllegalArgumentException(s"Data type $inputType is not supported.") } if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") } StructType(schema.fields :+ outCol) } @Since("1.4.1") override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) } @Since("1.6.0") object Binarizer extends DefaultParamsReadable[Binarizer] { @Since("1.6.0") override def load(path: String): Binarizer = super.load(path) }
Example 2
Source File: OneHotEncoderOp.scala From mleap with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.bundle.ops.feature import ml.combust.bundle.BundleContext import ml.combust.bundle.dsl._ import ml.combust.bundle.op.OpModel import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute, NumericAttribute} import org.apache.spark.ml.bundle._ import org.apache.spark.ml.feature.OneHotEncoderModel import org.apache.spark.sql.types.StructField import scala.util.{Failure, Try} object OneHotEncoderOp { def sizeForField(field: StructField): Int = { val attr = Attribute.fromStructField(field) (attr match { case nominal: NominalAttribute => if (nominal.values.isDefined) { Try(nominal.values.get.length) } else if (nominal.numValues.isDefined) { Try(nominal.numValues.get) } else { Failure(new RuntimeException(s"invalid nominal value for field ${field.name}")) } case binary: BinaryAttribute => Try(2) case _: NumericAttribute => Failure(new RuntimeException(s"invalid numeric attribute for field ${field.name}")) case _ => Failure(new RuntimeException(s"unsupported attribute for field ${field.name}")) // optimistic about unknown attributes }).get } } class OneHotEncoderOp extends SimpleSparkOp[OneHotEncoderModel] { override val Model: OpModel[SparkBundleContext, OneHotEncoderModel] = new OpModel[SparkBundleContext, OneHotEncoderModel] { override val klazz: Class[OneHotEncoderModel] = classOf[OneHotEncoderModel] override def opName: String = Bundle.BuiltinOps.feature.one_hot_encoder override def store(model: Model, obj: OneHotEncoderModel) (implicit context: BundleContext[SparkBundleContext]): Model = { assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz)) val df = context.context.dataset.get val categorySizes = obj.getInputCols.map { f ⇒ OneHotEncoderOp.sizeForField(df.schema(f)) } model.withValue("category_sizes", Value.intList(categorySizes)) .withValue("drop_last", Value.boolean(obj.getDropLast)) .withValue("handle_invalid", Value.string(obj.getHandleInvalid)) } override def load(model: Model) (implicit context: BundleContext[SparkBundleContext]): OneHotEncoderModel = { new OneHotEncoderModel(uid = "", categorySizes = model.value("category_sizes").getIntList.toArray) .setDropLast(model.value("drop_last").getBoolean) .setHandleInvalid(model.value("handle_invalid").getString) } } override def sparkLoad(uid: String, shape: NodeShape, model: OneHotEncoderModel): OneHotEncoderModel = { new OneHotEncoderModel(uid = uid, categorySizes = model.categorySizes) .setDropLast(model.getDropLast) .setHandleInvalid(model.getHandleInvalid) } override def sparkInputs(obj: OneHotEncoderModel): Seq[ParamSpec] = Seq(ParamSpec("input", obj.inputCols)) override def sparkOutputs(obj: OneHotEncoderModel): Seq[ParamSpec] = Seq(ParamSpec("output", obj.outputCols)) }
Example 3
Source File: ReverseStringIndexerOp.scala From mleap with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.bundle.ops.feature import ml.combust.bundle.BundleContext import ml.combust.bundle.op.OpModel import ml.combust.bundle.dsl._ import ml.combust.mleap.core.types.{DataShape, ScalarShape} import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute, NumericAttribute} import org.apache.spark.ml.bundle._ import org.apache.spark.ml.feature.IndexToString import org.apache.spark.sql.types.StructField import ml.combust.mleap.runtime.types.BundleTypeConverters._ import scala.util.{Failure, Try} object ReverseStringIndexerOp { def labelsForField(field: StructField): Array[String] = { val attr = Attribute.fromStructField(field) (attr match { case nominal: NominalAttribute => if (nominal.values.isDefined) { Try(nominal.values.get) } else { Failure(new RuntimeException(s"invalid nominal value for field ${field.name}")) } case _: BinaryAttribute => Failure(new RuntimeException(s"invalid binary attribute for field ${field.name}")) case _: NumericAttribute => Failure(new RuntimeException(s"invalid numeric attribute for field ${field.name}")) case _ => Failure(new RuntimeException(s"unsupported attribute for field ${field.name}")) // optimistic about unknown attributes }).get } } class ReverseStringIndexerOp extends SimpleSparkOp[IndexToString] { override val Model: OpModel[SparkBundleContext, IndexToString] = new OpModel[SparkBundleContext, IndexToString] { override val klazz: Class[IndexToString] = classOf[IndexToString] override def opName: String = Bundle.BuiltinOps.feature.reverse_string_indexer override def store(model: Model, obj: IndexToString) (implicit context: BundleContext[SparkBundleContext]): Model = { val labels = obj.get(obj.labels).getOrElse { assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz)) val df = context.context.dataset.get ReverseStringIndexerOp.labelsForField(df.schema(obj.getInputCol)) } model.withValue("labels", Value.stringList(labels)). withValue("input_shape", Value.dataShape(ScalarShape(false))) } override def load(model: Model) (implicit context: BundleContext[SparkBundleContext]): IndexToString = { model.getValue("input_shape").map(_.getDataShape: DataShape).foreach { shape => require(shape.isScalar, "cannot deserialize non-scalar input to Spark IndexToString model") } new IndexToString(uid = "").setLabels(model.value("labels").getStringList.toArray) } } override def sparkLoad(uid: String, shape: NodeShape, model: IndexToString): IndexToString = { new IndexToString(uid = uid).setLabels(model.getLabels) } override def sparkInputs(obj: IndexToString): Seq[ParamSpec] = { Seq("input" -> obj.inputCol) } override def sparkOutputs(obj: IndexToString): Seq[SimpleParamSpec] = { Seq("output" -> obj.outputCol) } }
Example 4
Source File: Binarizer.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) val schema = dataset.schema val inputType = schema($(inputCol)).dataType val td = $(threshold) val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } val binarizerVector = udf { (data: Vector) => val indices = ArrayBuilder.make[Int] val values = ArrayBuilder.make[Double] data.foreachActive { (index, value) => if (value > td) { indices += index values += 1.0 } } Vectors.sparse(data.size, indices.result(), values.result()).compressed } val metadata = outputSchema($(outputCol)).metadata inputType match { case DoubleType => dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) case _: VectorUDT => dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata)) } } @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType val outputColName = $(outputCol) val outCol: StructField = inputType match { case DoubleType => BinaryAttribute.defaultAttr.withName(outputColName).toStructField() case _: VectorUDT => StructField(outputColName, new VectorUDT) case _ => throw new IllegalArgumentException(s"Data type $inputType is not supported.") } if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") } StructType(schema.fields :+ outCol) } @Since("1.4.1") override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) } @Since("1.6.0") object Binarizer extends DefaultParamsReadable[Binarizer] { @Since("1.6.0") override def load(path: String): Binarizer = super.load(path) }
Example 5
Source File: Binarizer.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) val schema = dataset.schema val inputType = schema($(inputCol)).dataType val td = $(threshold) val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } val binarizerVector = udf { (data: Vector) => val indices = ArrayBuilder.make[Int] val values = ArrayBuilder.make[Double] data.foreachActive { (index, value) => if (value > td) { indices += index values += 1.0 } } Vectors.sparse(data.size, indices.result(), values.result()).compressed } val metadata = outputSchema($(outputCol)).metadata inputType match { case DoubleType => dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) case _: VectorUDT => dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata)) } } @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType val outputColName = $(outputCol) val outCol: StructField = inputType match { case DoubleType => BinaryAttribute.defaultAttr.withName(outputColName).toStructField() case _: VectorUDT => StructField(outputColName, new VectorUDT) case _ => throw new IllegalArgumentException(s"Data type $inputType is not supported.") } if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") } StructType(schema.fields :+ outCol) } @Since("1.4.1") override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) } @Since("1.6.0") object Binarizer extends DefaultParamsReadable[Binarizer] { @Since("1.6.0") override def load(path: String): Binarizer = super.load(path) }
Example 6
Source File: Binarizer.scala From iolap with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) val td = $(threshold) val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 } val outputColName = $(outputCol) val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata() dataset.select(col("*"), binarizer(col($(inputCol))).as(outputColName, metadata)) } override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields val outputColName = $(outputCol) require(inputFields.forall(_.name != outputColName), s"Output column $outputColName already exists.") val attr = BinaryAttribute.defaultAttr.withName(outputColName) val outputFields = inputFields :+ attr.toStructField() StructType(outputFields) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) }
Example 7
Source File: OneHotEncoderSuite.scala From iolap with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { def stringIndexed(): DataFrame = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") .fit(df) indexer.transform(df) } test("params") { ParamsSuite.checkParams(new OneHotEncoder) } test("OneHotEncoder dropLast = false") { val transformed = stringIndexed() val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") .setDropLast(false) val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1), vec(2)) }.collect().toSet // a -> 0, b -> 2, c -> 1 val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) assert(output === expected) } test("OneHotEncoder dropLast = true") { val transformed = stringIndexed() val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1)) }.collect().toSet // a -> 0, b -> 2, c -> 1 val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) assert(output === expected) } test("input column with ML attribute") { val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") .select(col("size").as("size", attr.toMetadata())) val encoder = new OneHotEncoder() .setInputCol("size") .setOutputCol("encoded") val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0)) assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1)) } test("input column without ML attribute") { val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") val encoder = new OneHotEncoder() .setInputCol("index") .setOutputCol("encoded") val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0)) assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1)) } }
Example 8
Source File: Binarizer.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) val td = $(threshold) val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 } val outputColName = $(outputCol) val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata() dataset.select(col("*"), binarizer(col($(inputCol))).as(outputColName, metadata)) } override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields val outputColName = $(outputCol) require(inputFields.forall(_.name != outputColName), s"Output column $outputColName already exists.") val attr = BinaryAttribute.defaultAttr.withName(outputColName) val outputFields = inputFields :+ attr.toStructField() StructType(outputFields) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) }
Example 9
Source File: Binarizer.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) val schema = dataset.schema val inputType = schema($(inputCol)).dataType val td = $(threshold) val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } val binarizerVector = udf { (data: Vector) => val indices = ArrayBuilder.make[Int] val values = ArrayBuilder.make[Double] data.foreachActive { (index, value) => if (value > td) { indices += index values += 1.0 } } Vectors.sparse(data.size, indices.result(), values.result()).compressed } val metadata = outputSchema($(outputCol)).metadata inputType match { case DoubleType => dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) case _: VectorUDT => dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata)) } } @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType val outputColName = $(outputCol) val outCol: StructField = inputType match { case DoubleType => BinaryAttribute.defaultAttr.withName(outputColName).toStructField() case _: VectorUDT => StructField(outputColName, new VectorUDT) case _ => throw new IllegalArgumentException(s"Data type $inputType is not supported.") } if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") } StructType(schema.fields :+ outCol) } @Since("1.4.1") override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) } @Since("1.6.0") object Binarizer extends DefaultParamsReadable[Binarizer] { @Since("1.6.0") override def load(path: String): Binarizer = super.load(path) }
Example 10
Source File: Binarizer.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.feature import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) val td = $(threshold) val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 } val outputColName = $(outputCol) val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata() dataset.select(col("*"), binarizer(col($(inputCol))).as(outputColName, metadata)) } override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields val outputColName = $(outputCol) require(inputFields.forall(_.name != outputColName), s"Output column $outputColName already exists.") val attr = BinaryAttribute.defaultAttr.withName(outputColName) val outputFields = inputFields :+ attr.toStructField() StructType(outputFields) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) } @Since("1.6.0") object Binarizer extends DefaultParamsReadable[Binarizer] { @Since("1.6.0") override def load(path: String): Binarizer = super.load(path) }
Example 11
Source File: OneHotEncoderSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { def stringIndexed(): DataFrame = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") .fit(df) indexer.transform(df) } test("params") { ParamsSuite.checkParams(new OneHotEncoder) } test("OneHotEncoder dropLast = false") { val transformed = stringIndexed() val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") .setDropLast(false) val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1), vec(2)) }.collect().toSet // a -> 0, b -> 2, c -> 1 val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) assert(output === expected) } test("OneHotEncoder dropLast = true") { val transformed = stringIndexed() val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1)) }.collect().toSet // a -> 0, b -> 2, c -> 1 val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) assert(output === expected) } test("input column with ML attribute") { val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") .select(col("size").as("size", attr.toMetadata())) val encoder = new OneHotEncoder() .setInputCol("size") .setOutputCol("encoded") val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) } test("input column without ML attribute") { val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") val encoder = new OneHotEncoder() .setInputCol("index") .setOutputCol("encoded") val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) } test("read/write") { val t = new OneHotEncoder() .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setDropLast(false) testDefaultReadWrite(t) } }