org.apache.spark.ml.param.shared.HasOutputCol Scala Examples

The following examples show how to use org.apache.spark.ml.param.shared.HasOutputCol. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: Binarizer.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}


  def setOutputCol(value: String): this.type = set(outputCol, value)

  override def transform(dataset: DataFrame): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    val td = $(threshold)
    val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 }
    val outputColName = $(outputCol)
    val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata()
    dataset.select(col("*"),
      binarizer(col($(inputCol))).as(outputColName, metadata))
  }

  override def transformSchema(schema: StructType): StructType = {
    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)

    val inputFields = schema.fields
    val outputColName = $(outputCol)

    require(inputFields.forall(_.name != outputColName),
      s"Output column $outputColName already exists.")

    val attr = BinaryAttribute.defaultAttr.withName(outputColName)
    val outputFields = inputFields :+ attr.toStructField()
    StructType(outputFields)
  }

  override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
} 
Example 2
Source File: Binarizer.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}


  def setOutputCol(value: String): this.type = set(outputCol, value)

  override def transform(dataset: DataFrame): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    val td = $(threshold)
    val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 }
    val outputColName = $(outputCol)
    val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata()
    dataset.select(col("*"),
      binarizer(col($(inputCol))).as(outputColName, metadata))
  }

  override def transformSchema(schema: StructType): StructType = {
    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)

    val inputFields = schema.fields
    val outputColName = $(outputCol)

    require(inputFields.forall(_.name != outputColName),
      s"Output column $outputColName already exists.")

    val attr = BinaryAttribute.defaultAttr.withName(outputColName)
    val outputFields = inputFields :+ attr.toStructField()
    StructType(outputFields)
  }

  override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
}

@Since("1.6.0")
object Binarizer extends DefaultParamsReadable[Binarizer] {

  @Since("1.6.0")
  override def load(path: String): Binarizer = super.load(path)
} 
Example 3
Source File: HashingTF.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StructType}


  def setNumFeatures(value: Int): this.type = set(numFeatures, value)

  override def transform(dataset: DataFrame): DataFrame = {
    val outputSchema = transformSchema(dataset.schema)
    val hashingTF = new feature.HashingTF($(numFeatures))
    val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
    val metadata = outputSchema($(outputCol)).metadata
    dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
  }

  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    require(inputType.isInstanceOf[ArrayType],
      s"The input column must be ArrayType, but got $inputType.")
    val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
    SchemaUtils.appendColumn(schema, attrGroup.toStructField())
  }

  override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
}

@Since("1.6.0")
object HashingTF extends DefaultParamsReadable[HashingTF] {

  @Since("1.6.0")
  override def load(path: String): HashingTF = super.load(path)
} 
Example 4
Source File: Binarizer.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import scala.collection.mutable.ArrayBuilder

import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._


  @Since("1.4.0")
  def setOutputCol(value: String): this.type = set(outputCol, value)

  @Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataset.schema, logging = true)
    val schema = dataset.schema
    val inputType = schema($(inputCol)).dataType
    val td = $(threshold)

    val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 }
    val binarizerVector = udf { (data: Vector) =>
      val indices = ArrayBuilder.make[Int]
      val values = ArrayBuilder.make[Double]

      data.foreachActive { (index, value) =>
        if (value > td) {
          indices += index
          values +=  1.0
        }
      }

      Vectors.sparse(data.size, indices.result(), values.result()).compressed
    }

    val metadata = outputSchema($(outputCol)).metadata

    inputType match {
      case DoubleType =>
        dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata))
      case _: VectorUDT =>
        dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata))
    }
  }

  @Since("1.4.0")
  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    val outputColName = $(outputCol)

    val outCol: StructField = inputType match {
      case DoubleType =>
        BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
      case _: VectorUDT =>
        StructField(outputColName, new VectorUDT)
      case _ =>
        throw new IllegalArgumentException(s"Data type $inputType is not supported.")
    }

    if (schema.fieldNames.contains(outputColName)) {
      throw new IllegalArgumentException(s"Output column $outputColName already exists.")
    }
    StructType(schema.fields :+ outCol)
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
}

@Since("1.6.0")
object Binarizer extends DefaultParamsReadable[Binarizer] {

  @Since("1.6.0")
  override def load(path: String): Binarizer = super.load(path)
} 
Example 5
Source File: HashingTF.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StructType}


  @Since("2.0.0")
  def setBinary(value: Boolean): this.type = set(binary, value)

  @Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataset.schema)
    val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
    // TODO: Make the hashingTF.transform natively in ml framework to avoid extra conversion.
    val t = udf { terms: Seq[_] => hashingTF.transform(terms).asML }
    val metadata = outputSchema($(outputCol)).metadata
    dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
  }

  @Since("1.4.0")
  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    require(inputType.isInstanceOf[ArrayType],
      s"The input column must be ArrayType, but got $inputType.")
    val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
    SchemaUtils.appendColumn(schema, attrGroup.toStructField())
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
}

@Since("1.6.0")
object HashingTF extends DefaultParamsReadable[HashingTF] {

  @Since("1.6.0")
  override def load(path: String): HashingTF = super.load(path)
} 
Example 6
Source File: URLElimminator.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl.texts

import org.apache.lucene.analysis.standard.UAX29URLEmailTokenizer
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{StringType, StructType}


  def setInputCol(value: String): this.type = set(inputCol, value)

  def this() = this(Identifiable.randomUID("URLEliminator"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.withColumn($(outputCol), filterTextUDF(dataset.col($(inputCol))))
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    if ($(inputCol) != $(outputCol)) {
      schema.add($(outputCol), StringType)
    } else {
      schema
    }
  }
}

object URLElimminator extends DefaultParamsReadable[URLElimminator] {
  override def load(path: String): URLElimminator = super.load(path)
} 
Example 7
Source File: RandomProjectionsHasher.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl.texts

import java.util.Random

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.linalg.{Matrices, SparseMatrix, Vector}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{LongType, StructType}


  def setDim(value: Long): this.type = set(dim, value)


  def this() = this(Identifiable.randomUID("randomProjectionsHasher"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    val dimensity = {
      if (!isSet(dim)) {//If dimensions is not set - will search  AttributeGroup in metadata as it comes from OdklCountVectorizer
        val vectorsIndex = dataset.schema.fieldIndex($(inputCol))
        AttributeGroup.fromStructField(dataset.schema.fields(vectorsIndex)).size
      } else {
        $(dim).toInt
      }
    }
    val projectionMatrix = dataset.sqlContext.sparkContext.broadcast(
      Matrices.sprandn($(basisSize).toInt, dimensity, $(sparsity), new Random($(seed))).asInstanceOf[SparseMatrix])
  //the matrix of random vectors to costruct hash

    val binHashSparseVectorColumn = udf((vector: Vector) => {
      projectionMatrix.value.multiply(vector).values
        .map(f =>  if (f>0) 1L else 0L)
        .view.zipWithIndex
        .foldLeft(0L) {case  (acc,(v, i)) => acc | (v << i) }

    })
    dataset.withColumn($(outputCol), binHashSparseVectorColumn(dataset.col($(inputCol))))
  }

  override def copy(extra: ParamMap): Transformer = {
    defaultCopy(extra)
  }

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    SchemaUtils.appendColumn(schema, $(outputCol), LongType)
  }

} 
Example 8
Source File: RegexpReplaceTransformer.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl.texts

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructType}


  def setInputCol(value: String): this.type = set(inputCol, value)

  def this() = this(Identifiable.randomUID("RegexpReplaceTransformer"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.withColumn($(outputCol), regexp_replace(dataset.col($(inputCol)), $(regexpPattern), $(regexpReplacement)))
  }
  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    if ($(inputCol) equals $(outputCol)) {
      val schemaWithoutInput = new StructType(schema.fields.filterNot(_.name equals $(inputCol)))
      SchemaUtils.appendColumn(schemaWithoutInput, $(outputCol), StringType)
    } else {
      SchemaUtils.appendColumn(schema, $(outputCol), StringType)
    }
  }

}

object RegexpReplaceTransformer extends DefaultParamsReadable[RegexpReplaceTransformer] {
  override def load(path: String): RegexpReplaceTransformer = super.load(path)
} 
Example 9
Source File: NGramExtractor.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl.texts

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamPair, ParamValidators, Params}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}


  def setOutputCol(value: String): this.type = set(outputCol, value)

  setDefault(new ParamPair[Int](upperN, 2), new ParamPair[Int](lowerN, 1))

  override def transform(dataset: Dataset[_]): DataFrame = {
    val lowerBound = $(lowerN)
    val upperBound = $(upperN)
    val nGramUDF = udf[Seq[String], Seq[String]](NGramUtils.nGramFun(_,lowerBound,upperBound))
    dataset.withColumn($(outputCol), nGramUDF(dataset.col($(inputCol))))
  }


  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    if ($(inputCol) != $(outputCol)) {
      schema.add($(outputCol), new ArrayType(StringType, true))
    } else {
      schema
    }
  }
}
object NGramExtractor extends DefaultParamsReadable[NGramExtractor] {
  override def load(path: String): NGramExtractor = super.load(path)
} 
Example 10
Source File: LanguageDetectorTransformer.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl.texts

import com.google.common.base.Optional
import com.optimaize.langdetect.LanguageDetector
import com.optimaize.langdetect.i18n.LdLocale
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{StringType, StructType}

import scala.collection.Map


  def setOutputCol(value: String): this.type = set(outputCol, value)

  def this() = this(Identifiable.randomUID("languageDetector"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.withColumn($(outputCol), languageDetection(dataset.col($(inputCol))))
  }

  override def copy(extra: ParamMap): Transformer = {
    defaultCopy(extra)
  }

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    SchemaUtils.appendColumn(schema, $(outputCol), StringType)
  }

  @transient object languageDetectorWrapped extends Serializable {
    val languageDetector: LanguageDetector =
      LanguageDetectorUtils.buildLanguageDetector(
        LanguageDetectorUtils.readListLangsBuiltIn(),
        $(minimalConfidence),
        $(languagePriors).toMap)
  }

} 
Example 11
Source File: LanguageAwareAnalyzer.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl.texts

import org.apache.lucene.analysis.util.StopwordAnalyzerBase
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared.HasOutputCol
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}


  def setOutputCol(value: String): this.type = set(outputCol, value)

  override def copy(extra: ParamMap): Transformer = {
    defaultCopy(extra)
  }

  def this() = this(Identifiable.randomUID("languageAnalyzer"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.withColumn($(outputCol), stemmTextUDF(dataset.col($(inputColLang)), dataset.col($(inputColText)))).toDF
  }

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    if ($(inputColText) equals $(outputCol)) {
      val schemaWithoutInput = new StructType(schema.fields.filterNot(_.name equals $(inputColText)))
      SchemaUtils.appendColumn(schemaWithoutInput, $(outputCol), ArrayType(StringType, true))
    } else {
      SchemaUtils.appendColumn(schema, $(outputCol), ArrayType(StringType, true))
    }
  }

}

object LanguageAwareAnalyzer extends DefaultParamsReadable[LanguageAwareAnalyzer] {
  override def load(path: String): LanguageAwareAnalyzer = super.load(path)
} 
Example 12
Source File: Binarizer.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}


  def setOutputCol(value: String): this.type = set(outputCol, value)

  override def transform(dataset: DataFrame): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    val td = $(threshold)
    val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 }
    val outputColName = $(outputCol)
    val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata()
    dataset.select(col("*"),
      binarizer(col($(inputCol))).as(outputColName, metadata))
  }

  override def transformSchema(schema: StructType): StructType = {
    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)

    val inputFields = schema.fields
    val outputColName = $(outputCol)

    require(inputFields.forall(_.name != outputColName),
      s"Output column $outputColName already exists.")

    val attr = BinaryAttribute.defaultAttr.withName(outputColName)
    val outputFields = inputFields :+ attr.toStructField()
    StructType(outputFields)
  }

  override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
} 
Example 13
Source File: HashingTF.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StructType}


  def setNumFeatures(value: Int): this.type = set(numFeatures, value)

  override def transform(dataset: DataFrame): DataFrame = {
    val outputSchema = transformSchema(dataset.schema)
    val hashingTF = new feature.HashingTF($(numFeatures))
    val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
    val metadata = outputSchema($(outputCol)).metadata
    dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
  }

  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    require(inputType.isInstanceOf[ArrayType],
      s"The input column must be ArrayType, but got $inputType.")
    val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
    SchemaUtils.appendColumn(schema, attrGroup.toStructField())
  }

  override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
} 
Example 14
Source File: HashingTF.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StructType}


  @Since("2.0.0")
  def setBinary(value: Boolean): this.type = set(binary, value)

  @Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataset.schema)
    val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
    // TODO: Make the hashingTF.transform natively in ml framework to avoid extra conversion.
    val t = udf { terms: Seq[_] => hashingTF.transform(terms).asML }
    val metadata = outputSchema($(outputCol)).metadata
    dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
  }

  @Since("1.4.0")
  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    require(inputType.isInstanceOf[ArrayType],
      s"The input column must be ArrayType, but got $inputType.")
    val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
    SchemaUtils.appendColumn(schema, attrGroup.toStructField())
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
}

@Since("1.6.0")
object HashingTF extends DefaultParamsReadable[HashingTF] {

  @Since("1.6.0")
  override def load(path: String): HashingTF = super.load(path)
} 
Example 15
Source File: HashingTF.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StructType}


  def setNumFeatures(value: Int): this.type = set(numFeatures, value)

  override def transform(dataset: DataFrame): DataFrame = {
    val outputSchema = transformSchema(dataset.schema)
    val hashingTF = new feature.HashingTF($(numFeatures))
    val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
    val metadata = outputSchema($(outputCol)).metadata
    dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
  }

  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    require(inputType.isInstanceOf[ArrayType],
      s"The input column must be ArrayType, but got $inputType.")
    val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
    SchemaUtils.appendColumn(schema, attrGroup.toStructField())
  }

  override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
} 
Example 16
Source File: Binarizer.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import scala.collection.mutable.ArrayBuilder

import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._


  @Since("1.4.0")
  def setOutputCol(value: String): this.type = set(outputCol, value)

  @Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataset.schema, logging = true)
    val schema = dataset.schema
    val inputType = schema($(inputCol)).dataType
    val td = $(threshold)

    val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 }
    val binarizerVector = udf { (data: Vector) =>
      val indices = ArrayBuilder.make[Int]
      val values = ArrayBuilder.make[Double]

      data.foreachActive { (index, value) =>
        if (value > td) {
          indices += index
          values +=  1.0
        }
      }

      Vectors.sparse(data.size, indices.result(), values.result()).compressed
    }

    val metadata = outputSchema($(outputCol)).metadata

    inputType match {
      case DoubleType =>
        dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata))
      case _: VectorUDT =>
        dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata))
    }
  }

  @Since("1.4.0")
  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    val outputColName = $(outputCol)

    val outCol: StructField = inputType match {
      case DoubleType =>
        BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
      case _: VectorUDT =>
        StructField(outputColName, new VectorUDT)
      case _ =>
        throw new IllegalArgumentException(s"Data type $inputType is not supported.")
    }

    if (schema.fieldNames.contains(outputColName)) {
      throw new IllegalArgumentException(s"Output column $outputColName already exists.")
    }
    StructType(schema.fields :+ outCol)
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
}

@Since("1.6.0")
object Binarizer extends DefaultParamsReadable[Binarizer] {

  @Since("1.6.0")
  override def load(path: String): Binarizer = super.load(path)
} 
Example 17
Source File: HashingTF.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StructType}


  @Since("2.0.0")
  def setBinary(value: Boolean): this.type = set(binary, value)

  @Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataset.schema)
    val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
    // TODO: Make the hashingTF.transform natively in ml framework to avoid extra conversion.
    val t = udf { terms: Seq[_] => hashingTF.transform(terms).asML }
    val metadata = outputSchema($(outputCol)).metadata
    dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
  }

  @Since("1.4.0")
  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    require(inputType.isInstanceOf[ArrayType],
      s"The input column must be ArrayType, but got $inputType.")
    val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
    SchemaUtils.appendColumn(schema, attrGroup.toStructField())
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
}

@Since("1.6.0")
object HashingTF extends DefaultParamsReadable[HashingTF] {

  @Since("1.6.0")
  override def load(path: String): HashingTF = super.load(path)
} 
Example 18
Source File: Gather.scala    From spark-ext   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.HasOutputCol
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.ext.functions._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

private[feature] trait GatherParams extends Params with HasKeyCol with HasValueCol with HasOutputCol {

  val primaryKeyCols: Param[Array[String]] = new StringArrayParam(this, "primaryKeyCols",
    "Primary key column names",
    ParamValidators.arrayLengthGt(0))

  val valueAgg: Param[String] = new Param[String](this, "valueAgg",
    "Aggregate function applied to valueCol: 'sum' or 'count'",
    ParamValidators.inArray(Array("sum", "count")))

  def getPrimaryKeyCols: Array[String] = $(primaryKeyCols)

  def getValueAgg: String = $(valueAgg)
}


class Gather(override val uid: String) extends Transformer with GatherParams {

  def this() = this(Identifiable.randomUID("gather"))

  def setPrimaryKeyCols(value: String*): this.type = set(primaryKeyCols, value.toArray)

  def setKeyCol(value: String): this.type = set(keyCol, value)

  def setValueCol(value: String): this.type = set(valueCol, value)

  def setValueAgg(value: String): this.type = set(valueAgg, value)

  def setOutputCol(value: String): this.type = set(outputCol, value)

  setDefault(
    valueAgg -> "sum"
  )

  override def transform(dataset: DataFrame): DataFrame = {
    val outputSchema = transformSchema(dataset.schema)

    val pkCols = $(primaryKeyCols).map(col)

    val grouped = dataset.groupBy(pkCols :+ col($(keyCol)) : _*)
    val aggregateCol = s"${uid}_value_aggregate"
    val aggregated = $(valueAgg) match {
      case "sum"   => grouped.agg(sum($(valueCol))   as aggregateCol)
      case "count" => grouped.agg(count($(valueCol)) as aggregateCol)
    }

    val metadata = outputSchema($(outputCol)).metadata

    aggregated
      .groupBy(pkCols: _*)
      .agg(collectArray(struct(
          col($(keyCol)),
          col(aggregateCol).cast(DoubleType).as($(valueCol))
      )).as($(outputCol), metadata))
  }

  override def transformSchema(schema: StructType): StructType = {
    val valueFunName = $(valueAgg)

    val keyColName = $(keyCol)
    val keyColDataType = schema(keyColName).dataType
    keyColDataType match {
      case _: NumericType =>
      case _: StringType =>
      case other =>
        throw new IllegalArgumentException(s"Key column data type $other is not supported.")
    }

    val valueColName = $(valueCol)
    val valueColDataType = schema(valueColName).dataType
    valueColDataType match {
      case _: NumericType =>
      case _: StringType if valueFunName == "count" =>
      case other =>
        throw new IllegalArgumentException(s"Value data type $other is not supported with value aggregate $valueAgg.")
    }

    val pkFields = $(primaryKeyCols).map(schema.apply)
    val rollupType = StructType(Array(
      StructField($(keyCol), keyColDataType),
      StructField($(valueCol), DoubleType)
    ))
    val rollupField = StructField($(outputCol), ArrayType(rollupType), nullable = false)

    StructType(pkFields :+ rollupField)
  }

  override def copy(extra: ParamMap): S2CellTransformer = defaultCopy(extra)

} 
Example 19
Source File: MovingAverage.scala    From uberdata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml

import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.ml.linalg.{VectorUDT, Vectors}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types._


  def setOutputCol(value: String): this.type = set(outputCol, value)

  setDefault(windowSize -> 3)

  override def transform(dataSet: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataSet.schema)
    val sparkContext = dataSet.sqlContext.sparkContext
    val inputType = outputSchema($(inputCol)).dataType
    val inputTypeBr = sparkContext.broadcast(inputType)
    val dataSetRdd = dataSet.rdd
    val inputColName = sparkContext.broadcast($(inputCol))
    val inputColIndex = dataSet.columns.indexOf($(inputCol))
    val inputColIndexBr = sparkContext.broadcast(inputColIndex)
    val windowSizeBr = sparkContext.broadcast($(windowSize))
    val maRdd = dataSetRdd.map { case (row: Row) =>
      val (array, rawValue) = if (inputTypeBr.value.isInstanceOf[VectorUDT]) {
        val vector =
          row.getAs[org.apache.spark.ml.linalg.Vector](inputColName.value)
        (vector.toArray, Vectors.dense(vector.toArray.drop(windowSizeBr.value - 1)))
      } else {
        val iterable = row.getAs[Iterable[Double]](inputColName.value)
        (iterable.toArray, Vectors.dense(iterable.toArray.drop(windowSizeBr.value - 1)))
      }
      val (before, after) = row.toSeq.splitAt(inputColIndexBr.value)
      Row(
        (before :+ rawValue) ++ after.tail :+ MovingAverageCalc
          .simpleMovingAverageArray(array, windowSizeBr.value): _*
      )
    }
    dataSet.sqlContext.createDataFrame(maRdd, outputSchema)
  }

  override def transformSchema(schema: StructType): StructType = {
    schema.add(StructField($(outputCol), ArrayType(DoubleType)))
  }

  override def copy(extra: ParamMap): MovingAverage[T] = defaultCopy(extra)
}

object MovingAverageCalc {
  private[ml] def simpleMovingAverageArray(values: Array[Double], period: Int): Array[Double] = {
    (for (i <- 1 to values.length)
      yield
      //TODO rollback this comment with the right size of features to make the meanaverage return
      // the features values for the first values of the calc
      if (i < period) 0d //values(i)
      else values.slice(i - period, i).sum / period).toArray.dropWhile(_ == 0d)
  }
}

object MovingAverage extends DefaultParamsReadable[MovingAverage[_]] {

  override def load(path: String): MovingAverage[_] = super.load(path)
} 
Example 20
Source File: Binarizer.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import scala.collection.mutable.ArrayBuilder

import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._


  @Since("1.4.0")
  def setOutputCol(value: String): this.type = set(outputCol, value)

  @Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataset.schema, logging = true)
    val schema = dataset.schema
    val inputType = schema($(inputCol)).dataType
    val td = $(threshold)

    val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 }
    val binarizerVector = udf { (data: Vector) =>
      val indices = ArrayBuilder.make[Int]
      val values = ArrayBuilder.make[Double]

      data.foreachActive { (index, value) =>
        if (value > td) {
          indices += index
          values +=  1.0
        }
      }

      Vectors.sparse(data.size, indices.result(), values.result()).compressed
    }

    val metadata = outputSchema($(outputCol)).metadata

    inputType match {
      case DoubleType =>
        dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata))
      case _: VectorUDT =>
        dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata))
    }
  }

  @Since("1.4.0")
  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    val outputColName = $(outputCol)

    val outCol: StructField = inputType match {
      case DoubleType =>
        BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
      case _: VectorUDT =>
        StructField(outputColName, new VectorUDT)
      case _ =>
        throw new IllegalArgumentException(s"Data type $inputType is not supported.")
    }

    if (schema.fieldNames.contains(outputColName)) {
      throw new IllegalArgumentException(s"Output column $outputColName already exists.")
    }
    StructType(schema.fields :+ outCol)
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
}

@Since("1.6.0")
object Binarizer extends DefaultParamsReadable[Binarizer] {

  @Since("1.6.0")
  override def load(path: String): Binarizer = super.load(path)
} 
Example 21
Source File: HashingTF.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StructType}


  @Since("2.0.0")
  def setBinary(value: Boolean): this.type = set(binary, value)

  @Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataset.schema)
    val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
    // TODO: Make the hashingTF.transform natively in ml framework to avoid extra conversion.
    val t = udf { terms: Seq[_] => hashingTF.transform(terms).asML }
    val metadata = outputSchema($(outputCol)).metadata
    dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
  }

  @Since("1.4.0")
  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    require(inputType.isInstanceOf[ArrayType],
      s"The input column must be ArrayType, but got $inputType.")
    val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
    SchemaUtils.appendColumn(schema, attrGroup.toStructField())
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
}

@Since("1.6.0")
object HashingTF extends DefaultParamsReadable[HashingTF] {

  @Since("1.6.0")
  override def load(path: String): HashingTF = super.load(path)
} 
Example 22
Source File: WordLengthFilter.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.mleap.feature

import ml.combust.mleap.core.feature.WordLengthFilterModel
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params}
import org.apache.spark.ml.util._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset}



  final def getWordLength: Int = $(wordLength)
}

class WordLengthFilter(override val uid: String) extends Transformer
  with WordLengthFilterParams
  with DefaultParamsWritable {

  val defaultLength = 3
  var model: WordLengthFilterModel = new WordLengthFilterModel(defaultLength) //Initialize with default filter length 3

  def this(model: WordLengthFilterModel) = this(uid = Identifiable.randomUID("filter_words"))
  def this() = this(new WordLengthFilterModel)

  def setInputCol(value: String): this.type = set(inputCol, value)
  def setOutputCol(value: String): this.type = set(outputCol, value)
  def setWordLength(value: Int = defaultLength): this.type = set(wordLength, value)

  override def transform(dataset: Dataset[_]): DataFrame = {
    if(defaultLength != getWordLength) model = new WordLengthFilterModel(getWordLength)
    val filterWordsUdf = udf {
      (words: Seq[String]) => model(words)
    }

    dataset.withColumn($(outputCol), filterWordsUdf(dataset($(inputCol))))
  }

  override def copy(extra: ParamMap): Transformer =  defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    require(schema($(inputCol)).dataType.isInstanceOf[ArrayType],
      s"Input column must be of type ArrayType(StringType,true) but got ${schema($(inputCol)).dataType}")
    val inputFields = schema.fields

    require(!inputFields.exists(_.name == $(outputCol)),
      s"Output column ${$(outputCol)} already exists.")

    StructType(schema.fields :+ StructField($(outputCol), ArrayType(StringType, true)))

  }
}

object WordLengthFilter extends  DefaultParamsReadable[WordLengthFilter] {
  override def load(path: String): WordLengthFilter = super.load(path)
} 
Example 23
Source File: MathUnary.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.mleap.feature

import ml.combust.mleap.core.feature.{MathUnaryModel, UnaryOperation}
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, Identifiable, MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}
import org.apache.spark.sql.functions.udf


    private val className = classOf[MathUnary].getName

    override def load(path: String): MathUnary = {
      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

      val dataPath = new Path(path, "data").toString

      val data = sparkSession.read.parquet(dataPath).select("operation").head()
      val operation = data.getAs[String](0)

      val model = MathUnaryModel(UnaryOperation.forName(operation))
      val transformer = new MathUnary(metadata.uid, model)

      metadata.getAndSetParams(transformer)
      transformer
    }
  }

} 
Example 24
Source File: StringMap.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.mleap.feature

import ml.combust.mleap.core.feature.{HandleInvalid, StringMapModel}
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types._


    private val className = classOf[StringMap].getName

    override def load(path: String): StringMap = {
      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

      val dataPath = new Path(path, "data").toString

      val data = sparkSession.read.parquet(dataPath).select("labels", "handleInvalid", "defaultValue").head()
      val labels = data.getAs[Map[String, Double]](0)
      val handleInvalid = HandleInvalid.fromString(data.getAs[String](1))
      val defaultValue = data.getAs[Double](2)

      val model = new StringMapModel(labels, handleInvalid = handleInvalid, defaultValue = defaultValue)
      val transformer = new StringMap(metadata.uid, model)

      metadata.getAndSetParams(transformer)
      transformer
    }
  }

} 
Example 25
Source File: Binarizer.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import scala.collection.mutable.ArrayBuilder

import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._


  @Since("1.4.0")
  def setOutputCol(value: String): this.type = set(outputCol, value)

  @Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataset.schema, logging = true)
    val schema = dataset.schema
    val inputType = schema($(inputCol)).dataType
    val td = $(threshold)

    val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 }
    val binarizerVector = udf { (data: Vector) =>
      val indices = ArrayBuilder.make[Int]
      val values = ArrayBuilder.make[Double]

      data.foreachActive { (index, value) =>
        if (value > td) {
          indices += index
          values +=  1.0
        }
      }

      Vectors.sparse(data.size, indices.result(), values.result()).compressed
    }

    val metadata = outputSchema($(outputCol)).metadata

    inputType match {
      case DoubleType =>
        dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata))
      case _: VectorUDT =>
        dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata))
    }
  }

  @Since("1.4.0")
  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    val outputColName = $(outputCol)

    val outCol: StructField = inputType match {
      case DoubleType =>
        BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
      case _: VectorUDT =>
        StructField(outputColName, new VectorUDT)
      case _ =>
        throw new IllegalArgumentException(s"Data type $inputType is not supported.")
    }

    if (schema.fieldNames.contains(outputColName)) {
      throw new IllegalArgumentException(s"Output column $outputColName already exists.")
    }
    StructType(schema.fields :+ outCol)
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
}

@Since("1.6.0")
object Binarizer extends DefaultParamsReadable[Binarizer] {

  @Since("1.6.0")
  override def load(path: String): Binarizer = super.load(path)
}