org.apache.spark.ml.feature.Normalizer Scala Examples

The following examples show how to use org.apache.spark.ml.feature.Normalizer. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: NormalizerExample.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.Normalizer
import org.apache.spark.ml.linalg.Vectors
// $example off$
import org.apache.spark.sql.SparkSession

object NormalizerExample {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder
      .appName("NormalizerExample")
      .getOrCreate()

    // $example on$
    val dataFrame = spark.createDataFrame(Seq(
      (0, Vectors.dense(1.0, 0.5, -1.0)),
      (1, Vectors.dense(2.0, 1.0, 1.0)),
      (2, Vectors.dense(4.0, 10.0, 2.0))
    )).toDF("id", "features")

    // Normalize each Vector using $L^1$ norm.
    val normalizer = new Normalizer()
      .setInputCol("features")
      .setOutputCol("normFeatures")
      .setP(1.0)

    val l1NormData = normalizer.transform(dataFrame)
    println("Normalized using L^1 norm")
    l1NormData.show()

    // Normalize each Vector using $L^\infty$ norm.
    val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity)
    println("Normalized using L^inf norm")
    lInfNormData.show()
    // $example off$

    spark.stop()
  }
}
// scalastyle:on println 
Example 2
Source File: LocalNormalizer.scala    From spark-ml-serving   with Apache License 2.0 5 votes vote down vote up
package io.hydrosphere.spark_ml_serving.preprocessors

import io.hydrosphere.spark_ml_serving.TypedTransformerConverter
import io.hydrosphere.spark_ml_serving.common.utils.DataUtils._
import io.hydrosphere.spark_ml_serving.common._
import org.apache.spark.ml.feature.Normalizer
import org.apache.spark.ml.linalg.Vector

class LocalNormalizer(override val sparkTransformer: Normalizer)
  extends LocalTransformer[Normalizer] {
  override def transform(localData: LocalData): LocalData = {
    localData.column(sparkTransformer.getInputCol) match {
      case Some(column) =>
        val method = classOf[Normalizer].getMethod("createTransformFunc")
        val newData = column.data.mapToMlVectors.map { vector =>
          method.invoke(sparkTransformer).asInstanceOf[Vector => Vector](vector).toList
        }
        localData.withColumn(LocalDataColumn(sparkTransformer.getOutputCol, newData))
      case None => localData
    }
  }
}

object LocalNormalizer
  extends SimpleModelLoader[Normalizer]
  with TypedTransformerConverter[Normalizer] {

  override def build(metadata: Metadata, data: LocalData): Normalizer = {
    new Normalizer(metadata.uid)
      .setInputCol(metadata.paramMap("inputCol").asInstanceOf[String])
      .setOutputCol(metadata.paramMap("outputCol").asInstanceOf[String])
      .setP(metadata.paramMap("p").toString.toDouble)
  }

  override implicit def toLocal(transformer: Normalizer) =
    new LocalNormalizer(transformer)
} 
Example 3
Source File: NormalizerOp.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.bundle.ops.feature

import ml.combust.bundle.BundleContext
import ml.combust.bundle.dsl._
import ml.combust.bundle.op.{OpModel, OpNode}
import ml.combust.mleap.core.types.TensorShape
import org.apache.spark.ml.bundle.{ParamSpec, SimpleParamSpec, SimpleSparkOp, SparkBundleContext}
import org.apache.spark.ml.feature.Normalizer
import org.apache.spark.sql.mleap.TypeConverters.sparkToMleapDataShape


class NormalizerOp extends SimpleSparkOp[Normalizer] {
  override val Model: OpModel[SparkBundleContext, Normalizer] = new OpModel[SparkBundleContext, Normalizer] {
    override val klazz: Class[Normalizer] = classOf[Normalizer]

    override def opName: String = Bundle.BuiltinOps.feature.normalizer

    override def store(model: Model, obj: Normalizer)
                      (implicit context: BundleContext[SparkBundleContext]): Model = {
      val dataset = context.context.dataset.get
      val inputShape = sparkToMleapDataShape(dataset.schema(obj.getInputCol), dataset).asInstanceOf[TensorShape]

      model.withValue("p_norm", Value.double(obj.getP))
      .withValue("input_size", Value.int(inputShape.dimensions.get.head))
    }

    override def load(model: Model)
                     (implicit context: BundleContext[SparkBundleContext]): Normalizer = {
      new Normalizer(uid = "").setP(model.value("p_norm").getDouble)
    }
  }

  override def sparkLoad(uid: String, shape: NodeShape, model: Normalizer): Normalizer = {
    new Normalizer(uid = uid).setP(model.getP)
  }

  override def sparkInputs(obj: Normalizer): Seq[ParamSpec] = {
    Seq("input" -> obj.inputCol)
  }

  override def sparkOutputs(obj: Normalizer): Seq[SimpleParamSpec] = {
    Seq("output" -> obj.outputCol)
  }
} 
Example 4
Source File: NormalizerParitySpec.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.parity.feature

import org.apache.spark.ml.parity.SparkParityBase
import org.apache.spark.ml.feature.{Normalizer, VectorAssembler}
import org.apache.spark.ml.{Pipeline, Transformer}
import org.apache.spark.sql.DataFrame


class NormalizerParitySpec extends SparkParityBase {
  override val dataset: DataFrame = baseDataset.select("dti", "loan_amount")
  override val sparkTransformer: Transformer = new Pipeline().setStages(Array(new VectorAssembler().
    setInputCols(Array("dti", "loan_amount")).
    setOutputCol("features"),
    new Normalizer().
      setP(3d).
      setInputCol("features").
      setOutputCol("scaled_features"))).fit(dataset)
} 
Example 5
Source File: L9-17MLCrossValidation.scala    From prosparkstreaming   with Apache License 2.0 5 votes vote down vote up
package org.apress.prospark

import scala.reflect.runtime.universe

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.Normalizer
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.regression.RandomForestRegressor
import org.apache.spark.ml.tuning.CrossValidator
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.Seconds
import org.apache.spark.streaming.StreamingContext

object MLCrossValidationApp {

  case class Activity(label: Double,
    accelXHand: Double, accelYHand: Double, accelZHand: Double,
    accelXChest: Double, accelYChest: Double, accelZChest: Double,
    accelXAnkle: Double, accelYAnkle: Double, accelZAnkle: Double)

  def main(args: Array[String]) {
    if (args.length != 4) {
      System.err.println(
        "Usage: MLCrossValidationApp <appname> <batchInterval> <hostname> <port>")
      System.exit(1)
    }
    val Seq(appName, batchInterval, hostname, port) = args.toSeq

    val conf = new SparkConf()
      .setAppName(appName)
      .setJars(SparkContext.jarOfClass(this.getClass).toSeq)

    val ssc = new StreamingContext(conf, Seconds(batchInterval.toInt))

    val sqlC = new SQLContext(ssc.sparkContext)
    import sqlC.implicits._

    val substream = ssc.socketTextStream(hostname, port.toInt)
      .filter(!_.contains("NaN"))
      .map(_.split(" "))
      .filter(f => f(1) == "4" || f(1) == "5")
      .map(f => Array(f(1), f(4), f(5), f(6), f(20), f(21), f(22), f(36), f(37), f(38)))
      .map(f => f.map(v => v.toDouble))
      .foreachRDD(rdd => {
        if (!rdd.isEmpty) {
          val accelerometer = rdd.map(x => Activity(x(0), x(1), x(2), x(3), x(4), x(5), x(6), x(7), x(8), x(9))).toDF()
          val split = accelerometer.randomSplit(Array(0.3, 0.7))
          val test = split(0)
          val train = split(1)

          val assembler = new VectorAssembler()
            .setInputCols(Array(
              "accelXHand", "accelYHand", "accelZHand",
              "accelXChest", "accelYChest", "accelZChest",
              "accelXAnkle", "accelYAnkle", "accelZAnkle"))
            .setOutputCol("vectors")
          val normalizer = new Normalizer()
            .setInputCol(assembler.getOutputCol)
            .setOutputCol("features")
          val regressor = new RandomForestRegressor()

          val pipeline = new Pipeline()
            .setStages(Array(assembler, normalizer, regressor))

          val validator = new CrossValidator()
            .setEstimator(pipeline)
            .setEvaluator(new RegressionEvaluator)
          val pGrid = new ParamGridBuilder()
            .addGrid(normalizer.p, Array(1.0, 5.0, 10.0))
            .addGrid(regressor.numTrees, Array(10, 50, 100))
            .build()
          validator.setEstimatorParamMaps(pGrid)
          validator.setNumFolds(5)

          val bestModel = validator.fit(train)
          val prediction = bestModel.transform(test)
          prediction.show()
        }
      })

    ssc.start()
    ssc.awaitTermination()
  }

} 
Example 6
Source File: L9-15MLPipeline.scala    From prosparkstreaming   with Apache License 2.0 5 votes vote down vote up
package org.apress.prospark

import scala.reflect.runtime.universe
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.Normalizer
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.regression.RandomForestRegressor
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.Seconds
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.ml.param.ParamMap

object MLPipelineApp {

  case class Activity(label: Double,
    accelXHand: Double, accelYHand: Double, accelZHand: Double,
    accelXChest: Double, accelYChest: Double, accelZChest: Double,
    accelXAnkle: Double, accelYAnkle: Double, accelZAnkle: Double)

  def main(args: Array[String]) {
    if (args.length != 4) {
      System.err.println(
        "Usage: MLPipelineApp <appname> <batchInterval> <hostname> <port>")
      System.exit(1)
    }
    val Seq(appName, batchInterval, hostname, port) = args.toSeq

    val conf = new SparkConf()
      .setAppName(appName)
      .setJars(SparkContext.jarOfClass(this.getClass).toSeq)

    val ssc = new StreamingContext(conf, Seconds(batchInterval.toInt))

    val sqlC = new SQLContext(ssc.sparkContext)
    import sqlC.implicits._

    val substream = ssc.socketTextStream(hostname, port.toInt)
      .filter(!_.contains("NaN"))
      .map(_.split(" "))
      .filter(f => f(1) == "4" || f(1) == "5")
      .map(f => Array(f(1), f(4), f(5), f(6), f(20), f(21), f(22), f(36), f(37), f(38)))
      .map(f => f.map(v => v.toDouble))
      .foreachRDD(rdd => {
        if (!rdd.isEmpty) {
          val accelerometer = rdd.map(x => Activity(x(0), x(1), x(2), x(3), x(4), x(5), x(6), x(7), x(8), x(9))).toDF()
          val split = accelerometer.randomSplit(Array(0.3, 0.7))
          val test = split(0)
          val train = split(1)

          val assembler = new VectorAssembler()
            .setInputCols(Array(
              "accelXHand", "accelYHand", "accelZHand",
              "accelXChest", "accelYChest", "accelZChest",
              "accelXAnkle", "accelYAnkle", "accelZAnkle"))
            .setOutputCol("vectors")
          val normalizer = new Normalizer()
            .setInputCol(assembler.getOutputCol)
            .setOutputCol("features")
          val regressor = new RandomForestRegressor()

          val pipeline = new Pipeline()
            .setStages(Array(assembler, normalizer, regressor))
          val pMap =  ParamMap(normalizer.p -> 1.0)
          val model = pipeline.fit(train, pMap)
          val prediction = model.transform(test)
          prediction.show()
        }
      })

    ssc.start()
    ssc.awaitTermination()
  }

} 
Example 7
Source File: NormalizerExample.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.Normalizer
import org.apache.spark.ml.linalg.Vectors
// $example off$
import org.apache.spark.sql.SparkSession

object NormalizerExample {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder
      .appName("NormalizerExample")
      .getOrCreate()

    // $example on$
    val dataFrame = spark.createDataFrame(Seq(
      (0, Vectors.dense(1.0, 0.5, -1.0)),
      (1, Vectors.dense(2.0, 1.0, 1.0)),
      (2, Vectors.dense(4.0, 10.0, 2.0))
    )).toDF("id", "features")

    // Normalize each Vector using $L^1$ norm.
    val normalizer = new Normalizer()
      .setInputCol("features")
      .setOutputCol("normFeatures")
      .setP(1.0)

    val l1NormData = normalizer.transform(dataFrame)
    println("Normalized using L^1 norm")
    l1NormData.show()

    // Normalize each Vector using $L^\infty$ norm.
    val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity)
    println("Normalized using L^inf norm")
    lInfNormData.show()
    // $example off$

    spark.stop()
  }
}
// scalastyle:on println 
Example 8
Source File: NormalizerExample.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.Normalizer
import org.apache.spark.ml.linalg.Vectors
// $example off$
import org.apache.spark.sql.SparkSession

object NormalizerExample {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder
      .appName("NormalizerExample")
      .getOrCreate()

    // $example on$
    val dataFrame = spark.createDataFrame(Seq(
      (0, Vectors.dense(1.0, 0.5, -1.0)),
      (1, Vectors.dense(2.0, 1.0, 1.0)),
      (2, Vectors.dense(4.0, 10.0, 2.0))
    )).toDF("id", "features")

    // Normalize each Vector using $L^1$ norm.
    val normalizer = new Normalizer()
      .setInputCol("features")
      .setOutputCol("normFeatures")
      .setP(1.0)

    val l1NormData = normalizer.transform(dataFrame)
    println("Normalized using L^1 norm")
    l1NormData.show()

    // Normalize each Vector using $L^\infty$ norm.
    val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity)
    println("Normalized using L^inf norm")
    lInfNormData.show()
    // $example off$

    spark.stop()
  }
}
// scalastyle:on println 
Example 9
Source File: OpTransformerWrapperTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.sparkwrappers.specific

import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.feature.{Normalizer, StopWordsRemover}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpTransformerWrapperTest extends FlatSpec with TestSparkContext {

  val (testData, featureVector) = TestFeatureBuilder(
    Seq[MultiPickList](
      Set("I", "saw", "the", "red", "balloon").toMultiPickList,
      Set("Mary", "had", "a", "little", "lamb").toMultiPickList
    )
  )

  val (testDataNorm, _, _) = TestFeatureBuilder("label", "features",
    Seq[(Real, OPVector)](
      0.0.toReal -> Vectors.dense(1.0, 0.5, -1.0).toOPVector,
      1.0.toReal -> Vectors.dense(2.0, 1.0, 1.0).toOPVector,
      2.0.toReal -> Vectors.dense(4.0, 10.0, 2.0).toOPVector
    )
  )
  val (targetDataNorm, targetLabelNorm, featureVectorNorm) = TestFeatureBuilder("label", "features",
    Seq[(Real, OPVector)](
      0.0.toReal -> Vectors.dense(0.4, 0.2, -0.4).toOPVector,
      1.0.toReal -> Vectors.dense(0.5, 0.25, 0.25).toOPVector,
      2.0.toReal -> Vectors.dense(0.25, 0.625, 0.125).toOPVector
    )
  )

  Spec[OpTransformerWrapper[_, _, _]] should "remove stop words with caseSensitivity=true" in {
    val remover = new StopWordsRemover().setCaseSensitive(true)
    val swFilter =
      new OpTransformerWrapper[MultiPickList, MultiPickList, StopWordsRemover](remover).setInput(featureVector)
    val output = swFilter.transform(testData)

    output.collect(swFilter.getOutput()) shouldBe Array(
      Seq("I", "saw", "red", "balloon").toMultiPickList,
      Seq("Mary", "little", "lamb").toMultiPickList
    )
  }

  it should "should properly normalize each feature vector instance with non-default norm of 1" in {
    val baseNormalizer = new Normalizer().setP(1.0)
    val normalizer =
      new OpTransformerWrapper[OPVector, OPVector, Normalizer](baseNormalizer).setInput(featureVectorNorm)
    val output = normalizer.transform(testDataNorm)

    val sumSqDist = validateDataframeDoubleColumn(output, normalizer.getOutput().name, targetDataNorm, "features")
    assert(sumSqDist <= 1E-6, "==> the sum of squared distances between actual and expected should be below tolerance.")
  }

  def validateDataframeDoubleColumn(
    normalizedFeatureDF: DataFrame, normalizedFeatureName: String, targetFeatureDF: DataFrame, targetColumnName: String
  ): Double = {
    val sqDistUdf = udf { (leftColVec: Vector, rightColVec: Vector) => Vectors.sqdist(leftColVec, rightColVec) }

    val targetColRename = "targetFeatures"
    val renamedTargedDF = targetFeatureDF.withColumnRenamed(targetColumnName, targetColRename)
    val joinedDF = normalizedFeatureDF.join(renamedTargedDF, Seq("label"))

    // compute sum of squared distances between expected and actual
    val finalDF = joinedDF.withColumn("sqDist", sqDistUdf(joinedDF(normalizedFeatureName), joinedDF(targetColRename)))
    val sumSqDist: Double = finalDF.agg(sum(finalDF("sqDist"))).first().getDouble(0)
    sumSqDist
  }
} 
Example 10
Source File: NormalizerExample.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.Normalizer
import org.apache.spark.ml.linalg.Vectors
// $example off$
import org.apache.spark.sql.SparkSession

object NormalizerExample {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder
      .appName("NormalizerExample")
      .getOrCreate()

    // $example on$
    val dataFrame = spark.createDataFrame(Seq(
      (0, Vectors.dense(1.0, 0.5, -1.0)),
      (1, Vectors.dense(2.0, 1.0, 1.0)),
      (2, Vectors.dense(4.0, 10.0, 2.0))
    )).toDF("id", "features")

    // Normalize each Vector using $L^1$ norm.
    val normalizer = new Normalizer()
      .setInputCol("features")
      .setOutputCol("normFeatures")
      .setP(1.0)

    val l1NormData = normalizer.transform(dataFrame)
    println("Normalized using L^1 norm")
    l1NormData.show()

    // Normalize each Vector using $L^\infty$ norm.
    val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity)
    println("Normalized using L^inf norm")
    lInfNormData.show()
    // $example off$

    spark.stop()
  }
}
// scalastyle:on println 
Example 11
Source File: NormalizerExample.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.Normalizer
// $example off$
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}

object NormalizerExample {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("NormalizerExample")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)

    // $example on$
    val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")

    // Normalize each Vector using $L^1$ norm.
    val normalizer = new Normalizer()
      .setInputCol("features")
      .setOutputCol("normFeatures")
      .setP(1.0)

    val l1NormData = normalizer.transform(dataFrame)
    l1NormData.show()

    // Normalize each Vector using $L^\infty$ norm.
    val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity)
    lInfNormData.show()
    // $example off$
    sc.stop()
  }
}
// scalastyle:on println 
Example 12
Source File: Normalizer.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression._
import com.ibm.aardpfark.spark.ml.PFATransformer
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.feature.Normalizer


class PFANormalizer(override val sparkTransformer: Normalizer) extends PFATransformer {
  import com.ibm.aardpfark.pfa.dsl._

  private val inputCol = sparkTransformer.getInputCol
  private val outputCol = sparkTransformer.getOutputCol
  private val inputExpr = StringExpr(s"input.${inputCol}")

  private val p = sparkTransformer.getP

  override def inputSchema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  override def outputSchema = {
    SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  private def absPow(p: Double) = FunctionDef[Double, Double](
    Seq("x"),
    Seq(core.pow(m.abs("x"), p))
  )

  private val sq = FunctionDef[Double, Double](
    Seq("x"),
    Seq(core.pow("x", 2.0))
  )

  private val absVal = FunctionDef[Double, Double](
    Seq("x"),
    Seq(m.abs("x"))
  )

  override def action: PFAExpression = {
    val fn = p match {
      case 1.0 =>
        a.sum(a.map(inputExpr, absVal))
      case 2.0 =>
        m.sqrt(a.sum(a.map(inputExpr, sq)))
      case Double.PositiveInfinity =>
        a.max(a.map(inputExpr, absVal))
      case _ =>
        core.pow(a.sum(a.map(inputExpr, absPow(p))), 1.0 / p)

    }
    val norm = Let("norm", fn)
    val invNorm = core.div(1.0, norm.ref)
    val scale = la.scale(inputExpr, invNorm)
    Action(
      norm,
      NewRecord(outputSchema, Map(outputCol -> scale))
    )
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      //.withFunction(pow(p))
      .withAction(action)
      .pfa
  }
} 
Example 13
Source File: NormalizerSuite.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.{ScalerResult, Result, SparkFeaturePFASuiteBase}
import org.apache.spark.ml.feature.Normalizer
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

class NormalizerSuite extends SparkFeaturePFASuiteBase[ScalerResult] {

  implicit val enc = ExpressionEncoder[Vector]()

  val inputPath = "data/sample_lda_libsvm_data.txt"
  val dataset = spark.read.format("libsvm").load(inputPath)

  val scaler = new Normalizer()
    .setInputCol("features")
    .setOutputCol("scaled")

  override val sparkTransformer = scaler

  val result = scaler.transform(dataset)
  override val input = withColumnAsArray(result, scaler.getInputCol).toJSON.collect()
  override val expectedOutput = withColumnAsArray(result, scaler.getOutputCol).toJSON.collect()

  test("Normalizer with P = 1") {
    val sparkTransformer = scaler.setP(1.0)
    val result = sparkTransformer.transform(dataset)
    val expectedOutput = withColumnAsArray(result, scaler.getOutputCol).toJSON.collect()
    parityTest(sparkTransformer, input, expectedOutput)
  }

  test("Normalizer with P = positive infinity"){
    val sparkTransformer = scaler.setP(Double.PositiveInfinity)
    val result = sparkTransformer.transform(dataset)
    val expectedOutput = withColumnAsArray(result, scaler.getOutputCol).toJSON.collect()
    parityTest(sparkTransformer, input, expectedOutput)
  }

  test("Normalizer with P = 3") {
    val sparkTransformer = scaler.setP(3.0)
    val result = sparkTransformer.transform(dataset)
    val expectedOutput = withColumnAsArray(result, scaler.getOutputCol).toJSON.collect()
    parityTest(sparkTransformer, input, expectedOutput)
  }

}