org.apache.spark.SparkFunSuite Scala Examples

The following examples show how to use org.apache.spark.SparkFunSuite. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: MulticlassMetricsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.evaluation

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Matrices
import org.apache.spark.mllib.util.MLlibTestSparkContext

class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
  test("Multiclass evaluation metrics") {
    
    val confusionMatrix = Matrices.dense(3, 3, Array(2, 1, 0, 1, 3, 0, 1, 0, 1))
    val labels = Array(0.0, 1.0, 2.0)
    val predictionAndLabels = sc.parallelize(
      Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
        (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
    val metrics = new MulticlassMetrics(predictionAndLabels)
    val delta = 0.0000001
    val fpRate0 = 1.0 / (9 - 4)
    val fpRate1 = 1.0 / (9 - 4)
    val fpRate2 = 1.0 / (9 - 1)
    val precision0 = 2.0 / (2 + 1)
    val precision1 = 3.0 / (3 + 1)
    val precision2 = 1.0 / (1 + 1)
    val recall0 = 2.0 / (2 + 2)
    val recall1 = 3.0 / (3 + 1)
    val recall2 = 1.0 / (1 + 0)
    val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
    val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
    val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
    val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0)
    val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
    val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)

    assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray))
    assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta)
    assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta)
    assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta)
    assert(math.abs(metrics.precision(0.0) - precision0) < delta)
    assert(math.abs(metrics.precision(1.0) - precision1) < delta)
    assert(math.abs(metrics.precision(2.0) - precision2) < delta)
    assert(math.abs(metrics.recall(0.0) - recall0) < delta)
    assert(math.abs(metrics.recall(1.0) - recall1) < delta)
    assert(math.abs(metrics.recall(2.0) - recall2) < delta)
    assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta)
    assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta)
    assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta)
    assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta)
    assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta)
    assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta)

    assert(math.abs(metrics.recall -
      (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta)
    assert(math.abs(metrics.recall - metrics.precision) < delta)
    assert(math.abs(metrics.recall - metrics.fMeasure) < delta)
    assert(math.abs(metrics.recall - metrics.weightedRecall) < delta)
    assert(math.abs(metrics.weightedFalsePositiveRate -
      ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta)
    assert(math.abs(metrics.weightedPrecision -
      ((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta)
    assert(math.abs(metrics.weightedRecall -
      ((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta)
    assert(math.abs(metrics.weightedFMeasure -
      ((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta)
    assert(math.abs(metrics.weightedFMeasure(2.0) -
      ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta)
    assert(metrics.labels.sameElements(labels))
  }
} 
Example 2
Source File: PythonMLLibAPISuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.api.python

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.recommendation.Rating

class PythonMLLibAPISuite extends SparkFunSuite {

  SerDe.initialize()

  test("pickle vector") {
    val vectors = Seq(
      Vectors.dense(Array.empty[Double]),
      Vectors.dense(0.0),
      Vectors.dense(0.0, -2.0),
      Vectors.sparse(0, Array.empty[Int], Array.empty[Double]),
      Vectors.sparse(1, Array.empty[Int], Array.empty[Double]),
      Vectors.sparse(2, Array(1), Array(-2.0)))
    vectors.foreach { v =>
      val u = SerDe.loads(SerDe.dumps(v))
      assert(u.getClass === v.getClass)
      assert(u === v)
    }
  }

  test("pickle labeled point") {
    val points = Seq(
      LabeledPoint(0.0, Vectors.dense(Array.empty[Double])),
      LabeledPoint(1.0, Vectors.dense(0.0)),
      LabeledPoint(-0.5, Vectors.dense(0.0, -2.0)),
      LabeledPoint(0.0, Vectors.sparse(0, Array.empty[Int], Array.empty[Double])),
      LabeledPoint(1.0, Vectors.sparse(1, Array.empty[Int], Array.empty[Double])),
      LabeledPoint(-0.5, Vectors.sparse(2, Array(1), Array(-2.0))))
    points.foreach { p =>
      val q = SerDe.loads(SerDe.dumps(p)).asInstanceOf[LabeledPoint]
      assert(q.label === p.label)
      assert(q.features.getClass === p.features.getClass)
      assert(q.features === p.features)
    }
  }

  test("pickle double") {
    for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) {
      val deser = SerDe.loads(SerDe.dumps(x.asInstanceOf[AnyRef])).asInstanceOf[Double]
      // We use `equals` here for comparison because we cannot use `==` for NaN
      assert(x.equals(deser))
    }
  }

  test("pickle matrix") {
    val values = Array[Double](0, 1.2, 3, 4.56, 7, 8)
    val matrix = Matrices.dense(2, 3, values)
    val nm = SerDe.loads(SerDe.dumps(matrix)).asInstanceOf[DenseMatrix]
    assert(matrix === nm)

    // Test conversion for empty matrix
    val empty = Array[Double]()
    val emptyMatrix = Matrices.dense(0, 0, empty)
    val ne = SerDe.loads(SerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix]
    assert(emptyMatrix == ne)

    val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4))
    val nsm = SerDe.loads(SerDe.dumps(sm)).asInstanceOf[SparseMatrix]
    assert(sm.toArray === nsm.toArray)

    val smt = new SparseMatrix(
      3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9),
      isTransposed = true)
    val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix]
    assert(smt.toArray === nsmt.toArray)
  }

  test("pickle rating") {
    val rat = new Rating(1, 2, 3.0)
    val rat2 = SerDe.loads(SerDe.dumps(rat)).asInstanceOf[Rating]
    assert(rat == rat2)

    // Test name of class only occur once
    val rats = (1 to 10).map(x => new Rating(x, x + 1, x + 3.0)).toArray
    val bytes = SerDe.dumps(rats)
    assert(bytes.toString.split("Rating").length == 1)
    assert(bytes.length / 10 < 25) //  25 bytes per rating

  }
} 
Example 3
Source File: FPTreeSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.fpm

import scala.language.existentials

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext

class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("add transaction") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    assert(tree.root.children.size == 2)
    assert(tree.root.children.contains("a"))
    assert(tree.root.children("a").item.equals("a"))
    assert(tree.root.children("a").count == 2)
    assert(tree.root.children.contains("b"))
    assert(tree.root.children("b").item.equals("b"))
    assert(tree.root.children("b").count == 1)
    var child = tree.root.children("a")
    assert(child.children.size == 1)
    assert(child.children.contains("b"))
    assert(child.children("b").item.equals("b"))
    assert(child.children("b").count == 2)
    child = child.children("b")
    assert(child.children.size == 2)
    assert(child.children.contains("c"))
    assert(child.children.contains("y"))
    assert(child.children("c").item.equals("c"))
    assert(child.children("y").item.equals("y"))
    assert(child.children("c").count == 1)
    assert(child.children("y").count == 1)
  }

  test("merge tree") {
    val tree1 = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    val tree2 = new FPTree[String]
      .add(Seq("a", "b"))
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "c", "d"))
      .add(Seq("a", "x"))
      .add(Seq("a", "x", "y"))
      .add(Seq("c", "n"))
      .add(Seq("c", "m"))

    val tree3 = tree1.merge(tree2)

    assert(tree3.root.children.size == 3)
    assert(tree3.root.children("a").count == 7)
    assert(tree3.root.children("b").count == 1)
    assert(tree3.root.children("c").count == 2)
    val child1 = tree3.root.children("a")
    assert(child1.children.size == 2)
    assert(child1.children("b").count == 5)
    assert(child1.children("x").count == 2)
    val child2 = child1.children("b")
    assert(child2.children.size == 2)
    assert(child2.children("y").count == 1)
    assert(child2.children("c").count == 3)
    val child3 = child2.children("c")
    assert(child3.children.size == 1)
    assert(child3.children("d").count == 1)
    val child4 = child1.children("x")
    assert(child4.children.size == 1)
    assert(child4.children("y").count == 1)
    val child5 = tree3.root.children("c")
    assert(child5.children.size == 2)
    assert(child5.children("n").count == 1)
    assert(child5.children("m").count == 1)
  }

  test("extract freq itemsets") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("a", "b"))
      .add(Seq("a"))
      .add(Seq("b"))
      .add(Seq("b", "n"))

    val freqItemsets = tree.extract(3L).map { case (items, count) =>
      (items.toSet, count)
    }.toSet
    val expected = Set(
      (Set("a"), 4L),
      (Set("b"), 5L),
      (Set("a", "b"), 3L))
    assert(freqItemsets === expected)
  }
} 
Example 4
Source File: AssociationRulesSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.fpm

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext

class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("association rules using String type") {
    val freqItemsets = sc.parallelize(Seq(
      (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
      (Set("r"), 3L),
      (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
      (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
      (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
      (Set("t", "y", "x"), 3L),
      (Set("t", "y", "x", "z"), 3L)
    ).map {
      case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq)
    })

    val ar = new AssociationRules()

    val results1 = ar
      .setMinConfidence(0.9)
      .run(freqItemsets)
      .collect()

    
    assert(results2.size === 30)
    assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
  }
} 
Example 5
Source File: KernelDensitySuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.stat

import org.apache.commons.math3.distribution.NormalDistribution

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext

class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext {
  test("kernel density single sample") {
    val rdd = sc.parallelize(Array(5.0))
    val evaluationPoints = Array(5.0, 6.0)
    val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
    val normal = new NormalDistribution(5.0, 3.0)
    val acceptableErr = 1e-6
    assert(math.abs(densities(0) - normal.density(5.0)) < acceptableErr)
    assert(math.abs(densities(1) - normal.density(6.0)) < acceptableErr)
  }

  test("kernel density multiple samples") {
    val rdd = sc.parallelize(Array(5.0, 10.0))
    val evaluationPoints = Array(5.0, 6.0)
    val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
    val normal1 = new NormalDistribution(5.0, 3.0)
    val normal2 = new NormalDistribution(10.0, 3.0)
    val acceptableErr = 1e-6
    assert(math.abs(
      densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2) < acceptableErr)
    assert(math.abs(
      densities(1) - (normal1.density(6.0) + normal2.density(6.0)) / 2) < acceptableErr)
  }
} 
Example 6
Source File: MultivariateGaussianSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.stat.distribution

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{ Vectors, Matrices }
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext {
  test("univariate") {
    val x1 = Vectors.dense(0.0)
    val x2 = Vectors.dense(1.5)

    val mu = Vectors.dense(0.0)
    val sigma1 = Matrices.dense(1, 1, Array(1.0))
    val dist1 = new MultivariateGaussian(mu, sigma1)
    assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
    assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)

    val sigma2 = Matrices.dense(1, 1, Array(4.0))
    val dist2 = new MultivariateGaussian(mu, sigma2)
    assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
    assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
  }

  test("multivariate") {
    val x1 = Vectors.dense(0.0, 0.0)
    val x2 = Vectors.dense(1.0, 1.0)

    val mu = Vectors.dense(0.0, 0.0)
    val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
    val dist1 = new MultivariateGaussian(mu, sigma1)
    assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
    assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)

    val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
    val dist2 = new MultivariateGaussian(mu, sigma2)
    assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
    assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
  }

  test("multivariate degenerate") {
    val x1 = Vectors.dense(0.0, 0.0)
    val x2 = Vectors.dense(1.0, 1.0)

    val mu = Vectors.dense(0.0, 0.0)
    val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
    val dist = new MultivariateGaussian(mu, sigma)
    assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5)
    assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5)
  }

  test("SPARK-11302") {
    val x = Vectors.dense(629, 640, 1.7188, 618.19)
    val mu = Vectors.dense(
      1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697)
    val sigma = Matrices.dense(4, 4, Array(
      166769.00466698944, 169336.6705268059, 12.820670788921873, 164243.93314092053,
      169336.6705268059, 172041.5670061245, 21.62590020524533, 166678.01075856484,
      12.820670788921873, 21.62590020524533, 0.872524191943962, 4.283255814732373,
      164243.93314092053, 166678.01075856484, 4.283255814732373, 161848.9196719207))
    val dist = new MultivariateGaussian(mu, sigma)
    // Agrees with R's dmvnorm: 7.154782e-05
    assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9)
  }

} 
Example 7
Source File: KMeansPMMLModelExportSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.pmml.export

import org.dmg.pmml.ClusteringModel

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors

class KMeansPMMLModelExportSuite extends SparkFunSuite {

  test("KMeansPMMLModelExport generate PMML format") {
    val clusterCenters = Array(
      Vectors.dense(1.0, 2.0, 6.0),
      Vectors.dense(1.0, 3.0, 0.0),
      Vectors.dense(1.0, 4.0, 6.0))
    val kmeansModel = new KMeansModel(clusterCenters)

    val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel)

    // assert that the PMML format is as expected
    assert(modelExport.isInstanceOf[PMMLModelExport])
    val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml
    assert(pmml.getHeader.getDescription === "k-means clustering")
    // check that the number of fields match the single vector size
    assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size)
    // This verify that there is a model attached to the pmml object and the model is a clustering
    // one. It also verifies that the pmml model has the same number of clusters of the spark model.
    val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
    assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
  }

} 
Example 8
Source File: PMMLModelExportFactorySuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.pmml.export

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel}
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
import org.apache.spark.mllib.util.LinearDataGenerator

class PMMLModelExportFactorySuite extends SparkFunSuite {

  test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") {
    val clusterCenters = Array(
      Vectors.dense(1.0, 2.0, 6.0),
      Vectors.dense(1.0, 3.0, 0.0),
      Vectors.dense(1.0, 4.0, 6.0))
    val kmeansModel = new KMeansModel(clusterCenters)

    val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel)

    assert(modelExport.isInstanceOf[KMeansPMMLModelExport])
  }

  test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a "
    + "LinearRegressionModel, RidgeRegressionModel or LassoModel") {
    val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)

    val linearRegressionModel =
      new LinearRegressionModel(linearInput(0).features, linearInput(0).label)
    val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel)
    assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])

    val ridgeRegressionModel =
      new RidgeRegressionModel(linearInput(0).features, linearInput(0).label)
    val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel)
    assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])

    val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
    val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
    assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
  }

  test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport "
    + "when passing a LogisticRegressionModel or SVMModel") {
    val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)

    val logisticRegressionModel =
      new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)
    val logisticRegressionModelExport =
      PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)
    assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])

    val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
    val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
    assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
  }

  test("PMMLModelExportFactory throw IllegalArgumentException "
    + "when passing a Multinomial Logistic Regression") {
    
    val multiclassLogisticRegressionModel = new LogisticRegressionModel(
      weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0,
      numFeatures = 2, numClasses = 3)

    intercept[IllegalArgumentException] {
      PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel)
    }
  }

  test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") {
    val invalidModel = new Object

    intercept[IllegalArgumentException] {
      PMMLModelExportFactory.createPMMLModelExport(invalidModel)
    }
  }
} 
Example 9
Source File: NumericParserSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.util

import org.apache.spark.{SparkException, SparkFunSuite}

class NumericParserSuite extends SparkFunSuite {

  test("parser") {
    val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)"
    val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]]
    assert(parsed(0).asInstanceOf[Seq[_]] === Seq(1.0, 2.0e3))
    assert(parsed(1).asInstanceOf[Double] === -4.0)
    assert(parsed(2).asInstanceOf[Array[Double]] === Array(5.0e-6, 7.0e8))
    assert(parsed(3).asInstanceOf[Double] === 9.0)

    val malformatted = Seq("a", "[1,,]", "0.123.4", "1 2", "3+4")
    malformatted.foreach { s =>
      intercept[SparkException] {
        NumericParser.parse(s)
        throw new RuntimeException(s"Didn't detect malformatted string $s.")
      }
    }
  }

  test("parser with whitespaces") {
    val s = "(0.0, [1.0, 2.0])"
    val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]]
    assert(parsed(0).asInstanceOf[Double] === 0.0)
    assert(parsed(1).asInstanceOf[Array[Double]] === Array(1.0, 2.0))
  }
} 
Example 10
Source File: BreezeMatrixConversionSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.linalg

import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM}

import org.apache.spark.SparkFunSuite

class BreezeMatrixConversionSuite extends SparkFunSuite {
  test("dense matrix to breeze") {
    val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))
    val breeze = mat.toBreeze.asInstanceOf[BDM[Double]]
    assert(breeze.rows === mat.numRows)
    assert(breeze.cols === mat.numCols)
    assert(breeze.data.eq(mat.asInstanceOf[DenseMatrix].values), "should not copy data")
  }

  test("dense breeze matrix to matrix") {
    val breeze = new BDM[Double](3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))
    val mat = Matrices.fromBreeze(breeze).asInstanceOf[DenseMatrix]
    assert(mat.numRows === breeze.rows)
    assert(mat.numCols === breeze.cols)
    assert(mat.values.eq(breeze.data), "should not copy data")
    // transposed matrix
    val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[DenseMatrix]
    assert(matTransposed.numRows === breeze.cols)
    assert(matTransposed.numCols === breeze.rows)
    assert(matTransposed.values.eq(breeze.data), "should not copy data")
  }

  test("sparse matrix to breeze") {
    val values = Array(1.0, 2.0, 4.0, 5.0)
    val colPtrs = Array(0, 2, 4)
    val rowIndices = Array(1, 2, 1, 2)
    val mat = Matrices.sparse(3, 2, colPtrs, rowIndices, values)
    val breeze = mat.toBreeze.asInstanceOf[BSM[Double]]
    assert(breeze.rows === mat.numRows)
    assert(breeze.cols === mat.numCols)
    assert(breeze.data.eq(mat.asInstanceOf[SparseMatrix].values), "should not copy data")
  }

  test("sparse breeze matrix to sparse matrix") {
    val values = Array(1.0, 2.0, 4.0, 5.0)
    val colPtrs = Array(0, 2, 4)
    val rowIndices = Array(1, 2, 1, 2)
    val breeze = new BSM[Double](values, 3, 2, colPtrs, rowIndices)
    val mat = Matrices.fromBreeze(breeze).asInstanceOf[SparseMatrix]
    assert(mat.numRows === breeze.rows)
    assert(mat.numCols === breeze.cols)
    assert(mat.values.eq(breeze.data), "should not copy data")
    val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[SparseMatrix]
    assert(matTransposed.numRows === breeze.cols)
    assert(matTransposed.numCols === breeze.rows)
    assert(!matTransposed.values.eq(breeze.data), "has to copy data")
  }
} 
Example 11
Source File: CoordinateMatrixSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.linalg.distributed

import breeze.linalg.{DenseMatrix => BDM}

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.linalg.Vectors

class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {

  val m = 5
  val n = 4
  var mat: CoordinateMatrix = _

  override def beforeAll() {
    super.beforeAll()
    val entries = sc.parallelize(Seq(
      (0, 0, 1.0),
      (0, 1, 2.0),
      (1, 1, 3.0),
      (1, 2, 4.0),
      (2, 2, 5.0),
      (2, 3, 6.0),
      (3, 0, 7.0),
      (3, 3, 8.0),
      (4, 1, 9.0)), 3).map { case (i, j, value) =>
      MatrixEntry(i, j, value)
    }
    mat = new CoordinateMatrix(entries)
  }

  test("size") {
    assert(mat.numRows() === m)
    assert(mat.numCols() === n)
  }

  test("empty entries") {
    val entries = sc.parallelize(Seq[MatrixEntry](), 1)
    val emptyMat = new CoordinateMatrix(entries)
    intercept[RuntimeException] {
      emptyMat.numCols()
    }
    intercept[RuntimeException] {
      emptyMat.numRows()
    }
  }

  test("toBreeze") {
    val expected = BDM(
      (1.0, 2.0, 0.0, 0.0),
      (0.0, 3.0, 4.0, 0.0),
      (0.0, 0.0, 5.0, 6.0),
      (7.0, 0.0, 0.0, 8.0),
      (0.0, 9.0, 0.0, 0.0))
    assert(mat.toBreeze() === expected)
  }

  test("transpose") {
    val transposed = mat.transpose()
    assert(mat.toBreeze().t === transposed.toBreeze())
  }

  test("toIndexedRowMatrix") {
    val indexedRowMatrix = mat.toIndexedRowMatrix()
    val expected = BDM(
      (1.0, 2.0, 0.0, 0.0),
      (0.0, 3.0, 4.0, 0.0),
      (0.0, 0.0, 5.0, 6.0),
      (7.0, 0.0, 0.0, 8.0),
      (0.0, 9.0, 0.0, 0.0))
    assert(indexedRowMatrix.toBreeze() === expected)
  }

  test("toRowMatrix") {
    val rowMatrix = mat.toRowMatrix()
    val rows = rowMatrix.rows.collect().toSet
    val expected = Set(
      Vectors.dense(1.0, 2.0, 0.0, 0.0),
      Vectors.dense(0.0, 3.0, 4.0, 0.0),
      Vectors.dense(0.0, 0.0, 5.0, 6.0),
      Vectors.dense(7.0, 0.0, 0.0, 8.0),
      Vectors.dense(0.0, 9.0, 0.0, 0.0))
    assert(rows === expected)
  }

  test("toBlockMatrix") {
    val blockMat = mat.toBlockMatrix(2, 2)
    assert(blockMat.numRows() === m)
    assert(blockMat.numCols() === n)
    assert(blockMat.toBreeze() === mat.toBreeze())

    intercept[IllegalArgumentException] {
      mat.toBlockMatrix(-1, 2)
    }
    intercept[IllegalArgumentException] {
      mat.toBlockMatrix(2, 0)
    }
  }
} 
Example 12
Source File: BreezeVectorConversionSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.linalg

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}

import org.apache.spark.SparkFunSuite


class BreezeVectorConversionSuite extends SparkFunSuite {

  val arr = Array(0.1, 0.2, 0.3, 0.4)
  val n = 20
  val indices = Array(0, 3, 5, 10, 13)
  val values = Array(0.1, 0.5, 0.3, -0.8, -1.0)

  test("dense to breeze") {
    val vec = Vectors.dense(arr)
    assert(vec.toBreeze === new BDV[Double](arr))
  }

  test("sparse to breeze") {
    val vec = Vectors.sparse(n, indices, values)
    assert(vec.toBreeze === new BSV[Double](indices, values, n))
  }

  test("dense breeze to vector") {
    val breeze = new BDV[Double](arr)
    val vec = Vectors.fromBreeze(breeze).asInstanceOf[DenseVector]
    assert(vec.size === arr.length)
    assert(vec.values.eq(arr), "should not copy data")
  }

  test("sparse breeze to vector") {
    val breeze = new BSV[Double](indices, values, n)
    val vec = Vectors.fromBreeze(breeze).asInstanceOf[SparseVector]
    assert(vec.size === n)
    assert(vec.indices.eq(indices), "should not copy data")
    assert(vec.values.eq(values), "should not copy data")
  }

  test("sparse breeze with partially-used arrays to vector") {
    val activeSize = 3
    val breeze = new BSV[Double](indices, values, activeSize, n)
    val vec = Vectors.fromBreeze(breeze).asInstanceOf[SparseVector]
    assert(vec.size === n)
    assert(vec.indices === indices.slice(0, activeSize))
    assert(vec.values === values.slice(0, activeSize))
  }
} 
Example 13
Source File: LabeledPointSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.regression

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors

class LabeledPointSuite extends SparkFunSuite {

  test("parse labeled points") {
    val points = Seq(
      LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
      LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0))))
    points.foreach { p =>
      assert(p === LabeledPoint.parse(p.toString))
    }
  }

  test("parse labeled points with whitespaces") {
    val point = LabeledPoint.parse("(0.0, [1.0, 2.0])")
    assert(point === LabeledPoint(0.0, Vectors.dense(1.0, 2.0)))
  }

  test("parse labeled points with v0.9 format") {
    val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0")
    assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0)))
  }
} 
Example 14
Source File: RidgeRegressionSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.regression

import scala.util.Random

import org.jblas.DoubleMatrix

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
  MLlibTestSparkContext}
import org.apache.spark.util.Utils

private object RidgeRegressionSuite {

  
  val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
}

class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {

  def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = {
    predictions.zip(input).map { case (prediction, expected) =>
      (prediction - expected.label) * (prediction - expected.label)
    }.reduceLeft(_ + _) / predictions.size
  }

  test("ridge regression can help avoid overfitting") {

    // For small number of examples and large variance of error distribution,
    // ridge regression should give smaller generalization error that linear regression.

    val numExamples = 50
    val numFeatures = 20

    org.jblas.util.Random.seed(42)
    // Pick weights as random values distributed uniformly in [-0.5, 0.5]
    val w = DoubleMatrix.rand(numFeatures, 1).subi(0.5)

    // Use half of data for training and other half for validation
    val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2 * numExamples, 42, 10.0)
    val testData = data.take(numExamples)
    val validationData = data.takeRight(numExamples)

    val testRDD = sc.parallelize(testData, 2).cache()
    val validationRDD = sc.parallelize(validationData, 2).cache()

    // First run without regularization.
    val linearReg = new LinearRegressionWithSGD()
    linearReg.optimizer.setNumIterations(200)
                       .setStepSize(1.0)

    val linearModel = linearReg.run(testRDD)
    val linearErr = predictionError(
        linearModel.predict(validationRDD.map(_.features)).collect(), validationData)

    val ridgeReg = new RidgeRegressionWithSGD()
    ridgeReg.optimizer.setNumIterations(200)
                      .setRegParam(0.1)
                      .setStepSize(1.0)
    val ridgeModel = ridgeReg.run(testRDD)
    val ridgeErr = predictionError(
        ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData)

    // Ridge validation error should be lower than linear regression.
    assert(ridgeErr < linearErr,
      "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
  }

  test("model save/load") {
    val model = RidgeRegressionSuite.model

    val tempDir = Utils.createTempDir()
    val path = tempDir.toURI.toString

    // Save model, load it back, and compare.
    try {
      model.save(sc, path)
      val sameModel = RidgeRegressionModel.load(sc, path)
      assert(model.weights == sameModel.weights)
      assert(model.intercept == sameModel.intercept)
    } finally {
      Utils.deleteRecursively(tempDir)
    }
  }
}

class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {

  test("task size should be small in both training and prediction") {
    val m = 4
    val n = 200000
    val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
      val random = new Random(idx)
      iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
    }.cache()
    // If we serialize data directly in the task closure, the size of the serialized task would be
    // greater than 1MB and hence Spark would throw an error.
    val model = RidgeRegressionWithSGD.train(points, 2)
    val predictions = model.predict(points.map(_.features))
  }
} 
Example 15
Source File: MLPairRDDFunctionsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.rdd

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.rdd.MLPairRDDFunctions._

class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {
  test("topByKey") {
    val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5,
      1), (3, 5)), 2)
      .topByKey(5)
      .collectAsMap()

    assert(topMap.size === 3)
    assert(topMap(1) === Array(7, 6, 3, 2, 1))
    assert(topMap(3) === Array(7, 5, 2))
    assert(topMap(5) === Array(1))
  }
} 
Example 16
Source File: RDDFunctionsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.rdd

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.rdd.RDDFunctions._

class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("sliding") {
    val data = 0 until 6
    for (numPartitions <- 1 to 8) {
      val rdd = sc.parallelize(data, numPartitions)
      for (windowSize <- 1 to 6) {
        for (step <- 1 to 3) {
          val sliding = rdd.sliding(windowSize, step).collect().map(_.toList).toList
          val expected = data.sliding(windowSize, step)
            .map(_.toList).toList.filter(l => l.size == windowSize)
          assert(sliding === expected)
        }
      }
      assert(rdd.sliding(7).collect().isEmpty,
        "Should return an empty RDD if the window size is greater than the number of items.")
    }
  }

  test("sliding with empty partitions") {
    val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7))
    val rdd = sc.parallelize(data, data.length).flatMap(s => s)
    assert(rdd.partitions.length === data.length)
    val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq)
    val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq)
    assert(sliding === expected)
  }
} 
Example 17
Source File: TwitterStreamSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.twitter


import org.scalatest.BeforeAndAfter
import twitter4j.Status
import twitter4j.auth.{NullAuthorization, Authorization}

import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream

class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging {

  val batchDuration = Seconds(1)

  private val master: String = "local[2]"

  private val framework: String = this.getClass.getSimpleName

  test("twitter input stream") {
    val ssc = new StreamingContext(master, framework, batchDuration)
    val filters = Seq("filter1", "filter2")
    val authorization: Authorization = NullAuthorization.getInstance()

    // tests the API, does not actually test data receiving
    val test1: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, None)
    val test2: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, None, filters)
    val test3: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_AND_DISK_SER_2)
    val test4: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, Some(authorization))
    val test5: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, Some(authorization), filters)
    val test6: ReceiverInputDStream[Status] = TwitterUtils.createStream(
      ssc, Some(authorization), filters, StorageLevel.MEMORY_AND_DISK_SER_2)

    // Note that actually testing the data receiving is hard as authentication keys are
    // necessary for accessing Twitter live stream
    ssc.stop()
  }
} 
Example 18
Source File: FlumeStreamSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.flume

import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.concurrent.duration._
import scala.language.postfixOps

import com.google.common.base.Charsets
import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression._
import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._

import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}

class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging {
  val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite")
  var ssc: StreamingContext = null

  test("flume input stream") {
    testFlumeStream(testCompression = false)
  }

  test("flume input compressed stream") {
    testFlumeStream(testCompression = true)
  }

  
  private class CompressionChannelFactory(compressionLevel: Int)
    extends NioClientSocketChannelFactory {

    override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
      val encoder = new ZlibEncoder(compressionLevel)
      pipeline.addFirst("deflater", encoder)
      pipeline.addFirst("inflater", new ZlibDecoder())
      super.newChannel(pipeline)
    }
  }
} 
Example 19
Source File: ZeroMQStreamSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.zeromq

import akka.actor.SupervisorStrategy
import akka.util.ByteString
import akka.zeromq.Subscribe

import org.apache.spark.SparkFunSuite
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.dstream.ReceiverInputDStream

class ZeroMQStreamSuite extends SparkFunSuite {

  val batchDuration = Seconds(1)

  private val master: String = "local[2]"

  private val framework: String = this.getClass.getSimpleName

  test("zeromq input stream") {
    val ssc = new StreamingContext(master, framework, batchDuration)
    val publishUrl = "abc"
    val subscribe = new Subscribe(null.asInstanceOf[ByteString])
    val bytesToObjects = (bytes: Seq[ByteString]) => null.asInstanceOf[Iterator[String]]

    // tests the API, does not actually test data receiving
    val test1: ReceiverInputDStream[String] =
      ZeroMQUtils.createStream(ssc, publishUrl, subscribe, bytesToObjects)
    val test2: ReceiverInputDStream[String] = ZeroMQUtils.createStream(
      ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2)
    val test3: ReceiverInputDStream[String] = ZeroMQUtils.createStream(
      ssc, publishUrl, subscribe, bytesToObjects,
      StorageLevel.MEMORY_AND_DISK_SER_2, SupervisorStrategy.defaultStrategy)

    // TODO: Actually test data receiving
    ssc.stop()
  }
} 
Example 20
Source File: MQTTStreamSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.mqtt

import scala.concurrent.duration._
import scala.language.postfixOps

import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}

class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {

  private val batchDuration = Milliseconds(500)
  private val master = "local[2]"
  private val framework = this.getClass.getSimpleName
  private val topic = "def"

  private var ssc: StreamingContext = _
  private var mqttTestUtils: MQTTTestUtils = _

  before {
    ssc = new StreamingContext(master, framework, batchDuration)
    mqttTestUtils = new MQTTTestUtils
    mqttTestUtils.setup()
  }

  after {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }
    if (mqttTestUtils != null) {
      mqttTestUtils.teardown()
      mqttTestUtils = null
    }
  }

  test("mqtt input stream") {
    val sendMessage = "MQTT demo for spark streaming"
    val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic,
      StorageLevel.MEMORY_ONLY)

    @volatile var receiveMessage: List[String] = List()
    receiveStream.foreachRDD { rdd =>
      if (rdd.collect.length > 0) {
        receiveMessage = receiveMessage ::: List(rdd.first)
        receiveMessage
      }
    }

    ssc.start()

    // Retry it because we don't know when the receiver will start.
    eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
      mqttTestUtils.publishData(topic, sendMessage)
      assert(sendMessage.equals(receiveMessage(0)))
    }
    ssc.stop()
  }
} 
Example 21
Source File: KafkaStreamSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.kafka

import scala.collection.mutable
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Random

import kafka.serializer.StringDecoder
import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Eventually

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}

class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll {
  private var ssc: StreamingContext = _
  private var kafkaTestUtils: KafkaTestUtils = _

  override def beforeAll(): Unit = {
    kafkaTestUtils = new KafkaTestUtils
    kafkaTestUtils.setup()
  }

  override def afterAll(): Unit = {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }

    if (kafkaTestUtils != null) {
      kafkaTestUtils.teardown()
      kafkaTestUtils = null
    }
  }

  test("Kafka input stream") {
    val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
    ssc = new StreamingContext(sparkConf, Milliseconds(500))
    val topic = "topic1"
    val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
    kafkaTestUtils.createTopic(topic)
    kafkaTestUtils.sendMessages(topic, sent)

    val kafkaParams = Map("zookeeper.connect" -> kafkaTestUtils.zkAddress,
      "group.id" -> s"test-consumer-${Random.nextInt(10000)}",
      "auto.offset.reset" -> "smallest")

    val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
      ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY)
    val result = new mutable.HashMap[String, Long]() with mutable.SynchronizedMap[String, Long]
    stream.map(_._2).countByValue().foreachRDD { r =>
      val ret = r.collect()
      ret.toMap.foreach { kv =>
        val count = result.getOrElseUpdate(kv._1, 0) + kv._2
        result.put(kv._1, count)
      }
    }

    ssc.start()

    eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
      assert(sent === result)
    }
  }
} 
Example 22
Source File: KafkaClusterSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.kafka

import scala.util.Random

import kafka.common.TopicAndPartition
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.SparkFunSuite

class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll {
  private val topic = "kcsuitetopic" + Random.nextInt(10000)
  private val topicAndPartition = TopicAndPartition(topic, 0)
  private var kc: KafkaCluster = null

  private var kafkaTestUtils: KafkaTestUtils = _

  override def beforeAll() {
    kafkaTestUtils = new KafkaTestUtils
    kafkaTestUtils.setup()

    kafkaTestUtils.createTopic(topic)
    kafkaTestUtils.sendMessages(topic, Map("a" -> 1))
    kc = new KafkaCluster(Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress))
  }

  override def afterAll() {
    if (kafkaTestUtils != null) {
      kafkaTestUtils.teardown()
      kafkaTestUtils = null
    }
  }

  test("metadata apis") {
    val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition)
    val leaderAddress = s"${leader._1}:${leader._2}"
    assert(leaderAddress === kafkaTestUtils.brokerAddress, "didn't get leader")

    val parts = kc.getPartitions(Set(topic)).right.get
    assert(parts(topicAndPartition), "didn't get partitions")

    val err = kc.getPartitions(Set(topic + "BAD"))
    assert(err.isLeft, "getPartitions for a nonexistant topic should be an error")
  }

  test("leader offset apis") {
    val earliest = kc.getEarliestLeaderOffsets(Set(topicAndPartition)).right.get
    assert(earliest(topicAndPartition).offset === 0, "didn't get earliest")

    val latest = kc.getLatestLeaderOffsets(Set(topicAndPartition)).right.get
    assert(latest(topicAndPartition).offset === 1, "didn't get latest")
  }

  test("consumer offset apis") {
    val group = "kcsuitegroup" + Random.nextInt(10000)

    val offset = Random.nextInt(10000)

    val set = kc.setConsumerOffsets(group, Map(topicAndPartition -> offset))
    assert(set.isRight, "didn't set consumer offsets")

    val get = kc.getConsumerOffsets(group, Set(topicAndPartition)).right.get
    assert(get(topicAndPartition) === offset, "didn't get consumer offsets")
  }
} 
Example 23
Source File: ConcurrentHiveSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.hive.test.TestHiveContext
import org.scalatest.BeforeAndAfterAll

class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll {
  ignore("multiple instances not supported") {
    test("Multiple Hive Instances") {
      (1 to 10).map { i =>
        val conf = new SparkConf()
        conf.set("spark.ui.enabled", "false")
        val ts =
          new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", conf))
        ts.executeSql("SHOW TABLES").toRdd.collect()
        ts.executeSql("SELECT * FROM src").toRdd.collect()
        ts.executeSql("SHOW TABLES").toRdd.collect()
      }
    }
  }
} 
Example 24
Source File: FiltersSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.client

import java.util.Collections

import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.serde.serdeConstants

import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._


class FiltersSuite extends SparkFunSuite with Logging {
  private val shim = new Shim_v0_13

  private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test")
  private val varCharCol = new FieldSchema()
  varCharCol.setName("varchar")
  varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME)
  testTable.setPartCols(Collections.singletonList(varCharCol))

  filterTest("string filter",
    (a("stringcol", StringType) > Literal("test")) :: Nil,
    "stringcol > \"test\"")

  filterTest("string filter backwards",
    (Literal("test") > a("stringcol", StringType)) :: Nil,
    "\"test\" > stringcol")

  filterTest("int filter",
    (a("intcol", IntegerType) === Literal(1)) :: Nil,
    "intcol = 1")

  filterTest("int filter backwards",
    (Literal(1) === a("intcol", IntegerType)) :: Nil,
    "1 = intcol")

  filterTest("int and string filter",
    (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil,
    "1 = intcol and \"a\" = strcol")

  filterTest("skip varchar",
    (Literal("") === a("varchar", StringType)) :: Nil,
    "")

  private def filterTest(name: String, filters: Seq[Expression], result: String) = {
    test(name){
      val converted = shim.convertFilters(testTable, filters)
      if (converted != result) {
        fail(
          s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'")
      }
    }
  }

  private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)()
} 
Example 25
Source File: ClasspathDependenciesSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import java.net.URL

import org.apache.spark.SparkFunSuite


class ClasspathDependenciesSuite extends SparkFunSuite {
  private val classloader = this.getClass.getClassLoader

  private def assertLoads(classname: String): Unit = {
    val resourceURL: URL = Option(findResource(classname)).getOrElse {
      fail(s"Class $classname not found as ${resourceName(classname)}")
    }

    logInfo(s"Class $classname at $resourceURL")
    classloader.loadClass(classname)
  }

  private def assertLoads(classes: String*): Unit = {
    classes.foreach(assertLoads)
  }

  private def findResource(classname: String): URL = {
    val resource = resourceName(classname)
    classloader.getResource(resource)
  }

  private def resourceName(classname: String): String = {
    classname.replace(".", "/") + ".class"
  }

  private def assertClassNotFound(classname: String): Unit = {
    Option(findResource(classname)).foreach { resourceURL =>
      fail(s"Class $classname found at $resourceURL")
    }

    intercept[ClassNotFoundException] {
      classloader.loadClass(classname)
    }
  }

  private def assertClassNotFound(classes: String*): Unit = {
    classes.foreach(assertClassNotFound)
  }

  private val KRYO = "com.esotericsoftware.kryo.Kryo"

  private val SPARK_HIVE = "org.apache.hive."
  private val SPARK_SHADED = "org.spark-project.hive.shaded."

  test("shaded Protobuf") {
    assertLoads(SPARK_SHADED + "com.google.protobuf.ServiceException")
  }

  test("hive-common") {
    assertLoads("org.apache.hadoop.hive.conf.HiveConf")
  }

  test("hive-exec") {
    assertLoads("org.apache.hadoop.hive.ql.CommandNeedRetryException")
  }

  private val STD_INSTANTIATOR = "org.objenesis.strategy.StdInstantiatorStrategy"

  test("unshaded kryo") {
    assertLoads(KRYO, STD_INSTANTIATOR)
  }

  test("Forbidden Dependencies") {
    assertClassNotFound(
      SPARK_HIVE + KRYO,
      SPARK_SHADED + KRYO,
      "org.apache.hive." + KRYO,
      "com.esotericsoftware.shaded." + STD_INSTANTIATOR,
      SPARK_HIVE + "com.esotericsoftware.shaded." + STD_INSTANTIATOR,
      "org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR
    )
  }

  test("parquet-hadoop-bundle") {
    assertLoads(
      "parquet.hadoop.ParquetOutputFormat",
      "parquet.hadoop.ParquetInputFormat"
    )
  }
} 
Example 26
Source File: RandomDataGeneratorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.types._


  def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = {
    val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType)
    val generator = RandomDataGenerator.forType(dataType, nullable, Some(33)).getOrElse {
      fail(s"Random data generator was not defined for $dataType")
    }
    if (nullable) {
      assert(Iterator.fill(100)(generator()).contains(null))
    } else {
      assert(Iterator.fill(100)(generator()).forall(_ != null))
    }
    for (_ <- 1 to 10) {
      val generatedValue = generator()
      toCatalyst(generatedValue)
    }
  }

  // Basic types:
  for (
    dataType <- DataTypeTestUtils.atomicTypes;
    nullable <- Seq(true, false)
    if !dataType.isInstanceOf[DecimalType]) {
    test(s"$dataType (nullable=$nullable)") {
      testRandomDataGeneration(dataType)
    }
  }

  for (
    arrayType <- DataTypeTestUtils.atomicArrayTypes
    if RandomDataGenerator.forType(arrayType.elementType, arrayType.containsNull).isDefined
  ) {
    test(s"$arrayType") {
      testRandomDataGeneration(arrayType)
    }
  }

  val atomicTypesWithDataGenerators =
    DataTypeTestUtils.atomicTypes.filter(RandomDataGenerator.forType(_).isDefined)

  // Complex types:
  for (
    keyType <- atomicTypesWithDataGenerators;
    valueType <- atomicTypesWithDataGenerators
    // Scala's BigDecimal.hashCode can lead to OutOfMemoryError on Scala 2.10 (see SI-6173) and
    // Spark can hit NumberFormatException errors when converting certain BigDecimals (SPARK-8802).
    // For these reasons, we don't support generation of maps with decimal keys.
    if !keyType.isInstanceOf[DecimalType]
  ) {
    val mapType = MapType(keyType, valueType)
    test(s"$mapType") {
      testRandomDataGeneration(mapType)
    }
  }

  for (
    colOneType <- atomicTypesWithDataGenerators;
    colTwoType <- atomicTypesWithDataGenerators
  ) {
    val structType = StructType(StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil)
    test(s"$structType") {
      testRandomDataGeneration(structType)
    }
  }

} 
Example 27
Source File: LogicalPlanSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._


class LogicalPlanSuite extends SparkFunSuite {
  private var invocationCount = 0
  private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
    case p: Project =>
      invocationCount += 1
      p
  }

  private val testRelation = LocalRelation()

  test("resolveOperator runs on operators") {
    invocationCount = 0
    val plan = Project(Nil, testRelation)
    plan resolveOperators function

    assert(invocationCount === 1)
  }

  test("resolveOperator runs on operators recursively") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan resolveOperators function

    assert(invocationCount === 2)
  }

  test("resolveOperator skips all ready resolved plans") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan.foreach(_.setAnalyzed())
    plan resolveOperators function

    assert(invocationCount === 0)
  }

  test("resolveOperator skips partially resolved plans") {
    invocationCount = 0
    val plan1 = Project(Nil, testRelation)
    val plan2 = Project(Nil, plan1)
    plan1.foreach(_.setAnalyzed())
    plan2 resolveOperators function

    assert(invocationCount === 1)
  }
} 
Example 28
Source File: SameResultSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.util._


class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")
    }
  }

  test("relations") {
    assertSameResult(testRelation, testRelation2)
  }

  test("projections") {
    assertSameResult(testRelation.select('a), testRelation2.select('a))
    assertSameResult(testRelation.select('b), testRelation2.select('b))
    assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b))
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a))

    assertSameResult(testRelation, testRelation2.select('a), result = false)
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), result = false)
  }

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))
  }

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
  }
} 
Example 29
Source File: NondeterministicSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite

class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper {
  test("MonotonicallyIncreasingID") {
    checkEvaluation(MonotonicallyIncreasingID(), 0L)
  }

  test("SparkPartitionID") {
    checkEvaluation(SparkPartitionID(), 0)
  }

  test("InputFileName") {
    checkEvaluation(InputFileName(), "")
  }
} 
Example 30
Source File: NullFunctionsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._

class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = {
    testFunc(false, BooleanType)
    testFunc(1.toByte, ByteType)
    testFunc(1.toShort, ShortType)
    testFunc(1, IntegerType)
    testFunc(1L, LongType)
    testFunc(1.0F, FloatType)
    testFunc(1.0, DoubleType)
    testFunc(Decimal(1.5), DecimalType(2, 1))
    testFunc(new java.sql.Date(10), DateType)
    testFunc(new java.sql.Timestamp(10), TimestampType)
    testFunc("abcd", StringType)
  }

  test("isnull and isnotnull") {
    testAllTypes { (value: Any, tpe: DataType) =>
      checkEvaluation(IsNull(Literal.create(value, tpe)), false)
      checkEvaluation(IsNotNull(Literal.create(value, tpe)), true)
      checkEvaluation(IsNull(Literal.create(null, tpe)), true)
      checkEvaluation(IsNotNull(Literal.create(null, tpe)), false)
    }
  }

  test("IsNaN") {
    checkEvaluation(IsNaN(Literal(Double.NaN)), true)
    checkEvaluation(IsNaN(Literal(Float.NaN)), true)
    checkEvaluation(IsNaN(Literal(math.log(-3))), true)
    checkEvaluation(IsNaN(Literal.create(null, DoubleType)), false)
    checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false)
    checkEvaluation(IsNaN(Literal(Float.MaxValue)), false)
    checkEvaluation(IsNaN(Literal(5.5f)), false)
  }

  test("nanvl") {
    checkEvaluation(NaNvl(Literal(5.0), Literal.create(null, DoubleType)), 5.0)
    checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(5.0)), null)
    checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(Double.NaN)), null)
    checkEvaluation(NaNvl(Literal(Double.NaN), Literal(5.0)), 5.0)
    checkEvaluation(NaNvl(Literal(Double.NaN), Literal.create(null, DoubleType)), null)
    assert(NaNvl(Literal(Double.NaN), Literal(Double.NaN)).
      eval(EmptyRow).asInstanceOf[Double].isNaN)
  }

  test("coalesce") {
    testAllTypes { (value: Any, tpe: DataType) =>
      val lit = Literal.create(value, tpe)
      val nullLit = Literal.create(null, tpe)
      checkEvaluation(Coalesce(Seq(nullLit)), null)
      checkEvaluation(Coalesce(Seq(lit)), value)
      checkEvaluation(Coalesce(Seq(nullLit, lit)), value)
      checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value)
      checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value)
    }
  }

  test("AtLeastNNonNulls") {
    val mix = Seq(Literal("x"),
      Literal.create(null, StringType),
      Literal.create(null, DoubleType),
      Literal(Double.NaN),
      Literal(5f))

    val nanOnly = Seq(Literal("x"),
      Literal(10.0),
      Literal(Float.NaN),
      Literal(math.log(-2)),
      Literal(Double.MaxValue))

    val nullOnly = Seq(Literal("x"),
      Literal.create(null, DoubleType),
      Literal.create(null, DecimalType.USER_DEFAULT),
      Literal(Float.MaxValue),
      Literal(false))

    checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow)
    checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow)
    checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow)
    checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow)
    checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow)
    checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow)
  }
} 
Example 31
Source File: DecimalExpressionSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{LongType, DecimalType, Decimal}


class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("UnscaledValue") {
    val d1 = Decimal("10.1")
    checkEvaluation(UnscaledValue(Literal(d1)), 101L)
    val d2 = Decimal(101, 3, 1)
    checkEvaluation(UnscaledValue(Literal(d2)), 101L)
    checkEvaluation(UnscaledValue(Literal.create(null, DecimalType(2, 1))), null)
  }

  test("MakeDecimal") {
    checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
    checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
  }

  test("PromotePrecision") {
    val d1 = Decimal("10.1")
    checkEvaluation(PromotePrecision(Literal(d1)), d1)
    val d2 = Decimal(101, 3, 1)
    checkEvaluation(PromotePrecision(Literal(d2)), d2)
    checkEvaluation(PromotePrecision(Literal.create(null, DecimalType(2, 1))), null)
  }

  test("CheckOverflow") {
    val d1 = Decimal("10.1")
    checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10"))
    checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1)
    checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1)
    checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null)

    val d2 = Decimal(101, 3, 1)
    checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10"))
    checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2)
    checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2)
    checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null)

    checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null)
  }

} 
Example 32
Source File: CollectionFunctionsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._


class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("Array and Map Size") {
    val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
    val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
    val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType))

    checkEvaluation(Size(a0), 3)
    checkEvaluation(Size(a1), 0)
    checkEvaluation(Size(a2), 2)

    val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType))
    val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
    val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType))

    checkEvaluation(Size(m0), 2)
    checkEvaluation(Size(m1), 0)
    checkEvaluation(Size(m2), 1)

    checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null)
    checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
  }

  test("Sort Array") {
    val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
    val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
    val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
    val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))
    val a4 = Literal.create(Seq(null, null), ArrayType(NullType))

    checkEvaluation(new SortArray(a0), Seq(1, 2, 3))
    checkEvaluation(new SortArray(a1), Seq[Integer]())
    checkEvaluation(new SortArray(a2), Seq("a", "b"))
    checkEvaluation(new SortArray(a3), Seq(null, "a", "b"))
    checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3))
    checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]())
    checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b"))
    checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b"))
    checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1))
    checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]())
    checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a"))
    checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null))

    checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
    checkEvaluation(new SortArray(a4), Seq(null, null))

    val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
    val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS)

    checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2)))
  }

  test("Array contains") {
    val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
    val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
    val a2 = Literal.create(Seq(null), ArrayType(LongType))
    val a3 = Literal.create(null, ArrayType(StringType))

    checkEvaluation(ArrayContains(a0, Literal(1)), true)
    checkEvaluation(ArrayContains(a0, Literal(0)), false)
    checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null)

    checkEvaluation(ArrayContains(a1, Literal("")), true)
    checkEvaluation(ArrayContains(a1, Literal("a")), null)
    checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null)

    checkEvaluation(ArrayContains(a2, Literal(1L)), null)
    checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null)

    checkEvaluation(ArrayContains(a3, Literal("")), null)
    checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
  }
} 
Example 33
Source File: RandomSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite

class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("random") {
    checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001)
    checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001)
  }

  test("SPARK-9127 codegen with long seed") {
    checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001)
    checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001)
  }
} 
Example 34
Source File: MiscFunctionsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.codec.digest.DigestUtils

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}

class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("md5") {
    checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932")
    checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      "6ac1e56bc78f031059be7be854522c4c")
    checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType)
  }

  test("sha1") {
    checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")
    checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      "5d211bad8f4ee70e16c7d343a838fc344a1ed961")
    checkEvaluation(Sha1(Literal.create(null, BinaryType)), null)
    checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709")
    checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType)
  }

  test("sha2") {
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
    checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
      DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
    // unsupported bit length
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null)
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null)
  }

  test("crc32") {
    checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L)
    checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      2180413220L)
    checkEvaluation(Crc32(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
  }
} 
Example 35
Source File: AttributeSetSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)
  }

  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bUpper) === false)
    assert(aSet.contains(bLower) === false)
  }

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)
  }

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)
    assert(addSet.contains(aUpper))
    assert(addSet.contains(aLower))
    assert(addSet.contains(bUpper))
    assert(addSet.contains(bLower))
  }

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)
  }

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)
  }

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
  }
} 
Example 36
Source File: CodeFormatterSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.SparkFunSuite


class CodeFormatterSuite extends SparkFunSuite {

  def testCase(name: String)(input: String)(expected: String): Unit = {
    test(name) {
      assert(CodeFormatter.format(input).trim === expected.trim)
    }
  }

  testCase("basic example") {
    """class A {
      |blahblah;
      |}""".stripMargin
  }{
    """
      |   c)
    """.stripMargin
  }
} 
Example 37
Source File: GenerateUnsafeRowJoinerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.codegen

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.types._


class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {

  private val fixed = Seq(IntegerType)
  private val variable = Seq(IntegerType, StringType)

  test("simple fixed width types") {
    testConcat(0, 0, fixed)
    testConcat(0, 1, fixed)
    testConcat(1, 0, fixed)
    testConcat(64, 0, fixed)
    testConcat(0, 64, fixed)
    testConcat(64, 64, fixed)
  }

  test("randomized fix width types") {
    for (i <- 0 until 20) {
      testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed)
    }
  }

  test("simple variable width types") {
    testConcat(0, 0, variable)
    testConcat(0, 1, variable)
    testConcat(1, 0, variable)
    testConcat(64, 0, variable)
    testConcat(0, 64, variable)
    testConcat(64, 64, variable)
  }

  test("randomized variable width types") {
    for (i <- 0 until 10) {
      testConcatOnce(Random.nextInt(100), Random.nextInt(100), variable)
    }
  }

  private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]): Unit = {
    for (i <- 0 until 10) {
      testConcatOnce(numFields1, numFields2, candidateTypes)
    }
  }

  private def testConcatOnce(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]) {
    info(s"schema size $numFields1, $numFields2")
    val schema1 = RandomDataGenerator.randomSchema(numFields1, candidateTypes)
    val schema2 = RandomDataGenerator.randomSchema(numFields2, candidateTypes)

    // Create the converters needed to convert from external row to internal row and to UnsafeRows.
    val internalConverter1 = CatalystTypeConverters.createToCatalystConverter(schema1)
    val internalConverter2 = CatalystTypeConverters.createToCatalystConverter(schema2)
    val converter1 = UnsafeProjection.create(schema1)
    val converter2 = UnsafeProjection.create(schema2)

    // Create the input rows, convert them into UnsafeRows.
    val extRow1 = RandomDataGenerator.forType(schema1, nullable = false).get.apply()
    val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply()
    val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow])
    val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow])

    // Run the joiner.
    val mergedSchema = StructType(schema1 ++ schema2)
    val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
    val output = concater.join(row1, row2)

    // Test everything equals ...
    for (i <- mergedSchema.indices) {
      if (i < schema1.size) {
        assert(output.isNullAt(i) === row1.isNullAt(i))
        if (!output.isNullAt(i)) {
          assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType))
        }
      } else {
        assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size))
        if (!output.isNullAt(i)) {
          assert(output.get(i, mergedSchema(i).dataType) ===
            row2.get(i - schema1.size, mergedSchema(i).dataType))
        }
      }
    }
  }

} 
Example 38
Source File: EncoderErrorMessageSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.encoders

import scala.reflect.ClassTag

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Encoders

class NonEncodable(i: Int)

case class ComplexNonEncodable1(name1: NonEncodable)

case class ComplexNonEncodable2(name2: ComplexNonEncodable1)

case class ComplexNonEncodable3(name3: Option[NonEncodable])

case class ComplexNonEncodable4(name4: Array[NonEncodable])

case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]])

class EncoderErrorMessageSuite extends SparkFunSuite {

  // Note: we also test error messages for encoders for private classes in JavaDatasetSuite.
  // That is done in Java because Scala cannot create truly private classes.

  test("primitive types in encoders using Kryo serialization") {
    intercept[UnsupportedOperationException] { Encoders.kryo[Int] }
    intercept[UnsupportedOperationException] { Encoders.kryo[Long] }
    intercept[UnsupportedOperationException] { Encoders.kryo[Char] }
  }

  test("primitive types in encoders using Java serialization") {
    intercept[UnsupportedOperationException] { Encoders.javaSerialization[Int] }
    intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] }
    intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] }
  }

  test("nice error message for missing encoder") {
    val errorMsg1 =
      intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage
    assert(errorMsg1.contains(
      s"""root class: "${clsName[ComplexNonEncodable1]}""""))
    assert(errorMsg1.contains(
      s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))

    val errorMsg2 =
      intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage
    assert(errorMsg2.contains(
      s"""root class: "${clsName[ComplexNonEncodable2]}""""))
    assert(errorMsg2.contains(
      s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")"""))
    assert(errorMsg1.contains(
      s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))

    val errorMsg3 =
      intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage
    assert(errorMsg3.contains(
      s"""root class: "${clsName[ComplexNonEncodable3]}""""))
    assert(errorMsg3.contains(
      s"""field (class: "scala.Option", name: "name3")"""))
    assert(errorMsg3.contains(
      s"""option value class: "${clsName[NonEncodable]}""""))

    val errorMsg4 =
      intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage
    assert(errorMsg4.contains(
      s"""root class: "${clsName[ComplexNonEncodable4]}""""))
    assert(errorMsg4.contains(
      s"""field (class: "scala.Array", name: "name4")"""))
    assert(errorMsg4.contains(
      s"""array element class: "${clsName[NonEncodable]}""""))

    val errorMsg5 =
      intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage
    assert(errorMsg5.contains(
      s"""root class: "${clsName[ComplexNonEncodable5]}""""))
    assert(errorMsg5.contains(
      s"""field (class: "scala.Option", name: "name5")"""))
    assert(errorMsg5.contains(
      s"""option value class: "scala.Array""""))
    assert(errorMsg5.contains(
      s"""array element class: "${clsName[NonEncodable]}""""))
  }

  private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName
} 
Example 39
Source File: RuleExecutorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.trees

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}

class RuleExecutorSuite extends SparkFunSuite {
  object DecrementLiterals extends Rule[Expression] {
    def apply(e: Expression): Expression = e transform {
      case IntegerLiteral(i) if i > 0 => Literal(i - 1)
    }
  }

  test("only once") {
    object ApplyOnce extends RuleExecutor[Expression] {
      val batches = Batch("once", Once, DecrementLiterals) :: Nil
    }

    assert(ApplyOnce.execute(Literal(10)) === Literal(9))
  }

  test("to fixed point") {
    object ToFixedPoint extends RuleExecutor[Expression] {
      val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil
    }

    assert(ToFixedPoint.execute(Literal(10)) === Literal(0))
  }

  test("to maxIterations") {
    object ToFixedPoint extends RuleExecutor[Expression] {
      val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil
    }

    assert(ToFixedPoint.execute(Literal(100)) === Literal(90))
  }
} 
Example 40
Source File: PartitioningSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning}

class PartitioningSuite extends SparkFunSuite {
  test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") {
    val expressions = Seq(Literal(2), Literal(3))
    // Consider two HashPartitionings that have the same _set_ of hash expressions but which are
    // created with different orderings of those expressions:
    val partitioningA = HashPartitioning(expressions, 100)
    val partitioningB = HashPartitioning(expressions.reverse, 100)
    // These partitionings are not considered equal:
    assert(partitioningA != partitioningB)
    // However, they both satisfy the same clustered distribution:
    val distribution = ClusteredDistribution(expressions)
    assert(partitioningA.satisfies(distribution))
    assert(partitioningB.satisfies(distribution))
    // These partitionings compute different hashcodes for the same input row:
    def computeHashCode(partitioning: HashPartitioning): Int = {
      val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty)
      hashExprProj.apply(InternalRow.empty).hashCode()
    }
    assert(computeHashCode(partitioningA) != computeHashCode(partitioningB))
    // Thus, these partitionings are incompatible:
    assert(!partitioningA.compatibleWith(partitioningB))
    assert(!partitioningB.compatibleWith(partitioningA))
    assert(!partitioningA.guarantees(partitioningB))
    assert(!partitioningB.guarantees(partitioningA))

    // Just to be sure that we haven't cheated by having these methods always return false,
    // check that identical partitionings are still compatible with and guarantee each other:
    assert(partitioningA === partitioningA)
    assert(partitioningA.guarantees(partitioningA))
    assert(partitioningA.compatibleWith(partitioningA))
  }
} 
Example 41
Source File: NumberConverterSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.NumberConverter.convert
import org.apache.spark.unsafe.types.UTF8String

class NumberConverterSuite extends SparkFunSuite {

  private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = {
    assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) ===
      UTF8String.fromString(expected))
  }

  test("convert") {
    checkConv("3", 10, 2, "11")
    checkConv("-15", 10, -16, "-F")
    checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
    checkConv("big", 36, 16, "3A48")
    checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
    checkConv("11abc", 10, 16, "B")
  }

} 
Example 42
Source File: MetadataSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.json4s.jackson.JsonMethods.parse

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{MetadataBuilder, Metadata}

class MetadataSuite extends SparkFunSuite {

  val baseMetadata = new MetadataBuilder()
    .putString("purpose", "ml")
    .putBoolean("isBase", true)
    .build()

  val summary = new MetadataBuilder()
    .putLong("numFeatures", 10L)
    .build()

  val age = new MetadataBuilder()
    .putString("name", "age")
    .putLong("index", 1L)
    .putBoolean("categorical", false)
    .putDouble("average", 45.0)
    .build()

  val gender = new MetadataBuilder()
    .putString("name", "gender")
    .putLong("index", 5)
    .putBoolean("categorical", true)
    .putStringArray("categories", Array("male", "female"))
    .build()

  val metadata = new MetadataBuilder()
    .withMetadata(baseMetadata)
    .putBoolean("isBase", false) // overwrite an existing key
    .putMetadata("summary", summary)
    .putLongArray("long[]", Array(0L, 1L))
    .putDoubleArray("double[]", Array(3.0, 4.0))
    .putBooleanArray("boolean[]", Array(true, false))
    .putMetadataArray("features", Array(age, gender))
    .build()

  test("metadata builder and getters") {
    assert(age.contains("summary") === false)
    assert(age.contains("index") === true)
    assert(age.getLong("index") === 1L)
    assert(age.contains("average") === true)
    assert(age.getDouble("average") === 45.0)
    assert(age.contains("categorical") === true)
    assert(age.getBoolean("categorical") === false)
    assert(age.contains("name") === true)
    assert(age.getString("name") === "age")
    assert(metadata.contains("purpose") === true)
    assert(metadata.getString("purpose") === "ml")
    assert(metadata.contains("isBase") === true)
    assert(metadata.getBoolean("isBase") === false)
    assert(metadata.contains("summary") === true)
    assert(metadata.getMetadata("summary") === summary)
    assert(metadata.contains("long[]") === true)
    assert(metadata.getLongArray("long[]").toSeq === Seq(0L, 1L))
    assert(metadata.contains("double[]") === true)
    assert(metadata.getDoubleArray("double[]").toSeq === Seq(3.0, 4.0))
    assert(metadata.contains("boolean[]") === true)
    assert(metadata.getBooleanArray("boolean[]").toSeq === Seq(true, false))
    assert(gender.contains("categories") === true)
    assert(gender.getStringArray("categories").toSeq === Seq("male", "female"))
    assert(metadata.contains("features") === true)
    assert(metadata.getMetadataArray("features").toSeq === Seq(age, gender))
  }

  test("metadata json conversion") {
    val json = metadata.json
    withClue("toJson must produce a valid JSON string") {
      parse(json)
    }
    val parsed = Metadata.fromJson(json)
    assert(parsed === metadata)
    assert(parsed.## === metadata.##)
  }
} 
Example 43
Source File: StringUtilsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.StringUtils._

class StringUtilsSuite extends SparkFunSuite {

  test("escapeLikeRegex") {
    assert(escapeLikeRegex("abdef") === "(?s)\\Qa\\E\\Qb\\E\\Qd\\E\\Qe\\E\\Qf\\E")
    assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E_.\\Qb\\E")
    assert(escapeLikeRegex("a_%b") === "(?s)\\Qa\\E..*\\Qb\\E")
    assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*%\\Qb\\E")
    assert(escapeLikeRegex("a%") === "(?s)\\Qa\\E.*")
    assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E")
    assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E")
  }
} 
Example 44
Source File: CatalystTypeConvertersSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class CatalystTypeConvertersSuite extends SparkFunSuite {

  private val simpleTypes: Seq[DataType] = Seq(
    StringType,
    DateType,
    BooleanType,
    ByteType,
    ShortType,
    IntegerType,
    LongType,
    FloatType,
    DoubleType,
    DecimalType.SYSTEM_DEFAULT,
    DecimalType.USER_DEFAULT)

  test("null handling in rows") {
    val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t)))
    val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema)
    val convertToScala = CatalystTypeConverters.createToScalaConverter(schema)

    val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null))
    assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow)
  }

  test("null handling for individual values") {
    for (dataType <- simpleTypes) {
      assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null)
    }
  }

  test("option handling in convertToCatalyst") {
    // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with
    // createToCatalystConverter but it may not actually matter as this is only called internally
    // in a handful of places where we don't expect to receive Options.
    assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123))
  }

  test("option handling in createToCatalystConverter") {
    assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)
  }
} 
Example 45
Source File: SQLContextSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.{SharedSparkContext, SparkFunSuite}

class SQLContextSuite extends SparkFunSuite with SharedSparkContext{

  test("getOrCreate instantiates SQLContext") {
    val sqlContext = SQLContext.getOrCreate(sc)
    assert(sqlContext != null, "SQLContext.getOrCreate returned null")
    assert(SQLContext.getOrCreate(sc).eq(sqlContext),
      "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate")
  }

  test("getOrCreate return the original SQLContext") {
    val sqlContext = SQLContext.getOrCreate(sc)
    val newSession = sqlContext.newSession()
    assert(SQLContext.getOrCreate(sc).eq(sqlContext),
      "SQLContext.getOrCreate after explicitly created SQLContext did not return the context")
    SQLContext.setActive(newSession)
    assert(SQLContext.getOrCreate(sc).eq(newSession),
      "SQLContext.getOrCreate after explicitly setActive() did not return the active context")
  }

  test("Sessions of SQLContext") {
    val sqlContext = SQLContext.getOrCreate(sc)
    val session1 = sqlContext.newSession()
    val session2 = sqlContext.newSession()

    // all have the default configurations
    val key = SQLConf.SHUFFLE_PARTITIONS.key
    assert(session1.getConf(key) === session2.getConf(key))
    session1.setConf(key, "1")
    session2.setConf(key, "2")
    assert(session1.getConf(key) === "1")
    assert(session2.getConf(key) === "2")

    // temporary table should not be shared
    val df = session1.range(10)
    df.registerTempTable("test1")
    assert(session1.tableNames().contains("test1"))
    assert(!session2.tableNames().contains("test1"))

    // UDF should not be shared
    def myadd(a: Int, b: Int): Int = a + b
    session1.udf.register[Int, Int, Int]("myadd", myadd)
    session1.sql("select myadd(1, 2)").explain()
    intercept[AnalysisException] {
      session2.sql("select myadd(1, 2)").explain()
    }
  }

  test("SPARK-13390: createDataFrame(java.util.List[_],Class[_]) NotSerializableException") {
    val rows = new java.util.ArrayList[IntJavaBean]()
    rows.add(new IntJavaBean(1))
    val sqlContext = SQLContext.getOrCreate(sc)
    // Without the fix for SPARK-13390, this will throw NotSerializableException
    sqlContext.createDataFrame(rows, classOf[IntJavaBean]).groupBy("int").count().collect()
  }
}

class IntJavaBean(private var i: Int) extends Serializable {

  def getInt(): Int = i

  def setInt(i: Int): Unit = {
    this.i = i
  }
} 
Example 46
Source File: LocalNodeTest.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.local

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType}


class LocalNodeTest extends SparkFunSuite {

  protected val conf: SQLConf = new SQLConf
  protected val kvIntAttributes = Seq(
    AttributeReference("k", IntegerType)(),
    AttributeReference("v", IntegerType)())
  protected val joinNameAttributes = Seq(
    AttributeReference("id1", IntegerType)(),
    AttributeReference("name", StringType)())
  protected val joinNicknameAttributes = Seq(
    AttributeReference("id2", IntegerType)(),
    AttributeReference("nickname", StringType)())

  
  protected def resolveExpressions(
      expressions: Seq[Expression],
      localNode: LocalNode): Seq[Expression] = {
    require(localNode.expressions.forall(_.resolved))
    val inputMap = localNode.output.map { a => (a.name, a) }.toMap
    expressions.map { expression =>
      expression.transformUp {
        case UnresolvedAttribute(Seq(u)) =>
          inputMap.getOrElse(u,
            sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
      }
    }
  }

} 
Example 47
Source File: CoGroupedIteratorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper

class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator
    val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator
    val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
    val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
    val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))

    val result = cogrouped.map {
      case (key, leftData, rightData) =>
        assert(key.numFields == 1)
        (key.getInt(0), leftData.toSeq, rightData.toSeq)
    }.toSeq
    assert(result ==
      (1,
        Seq(create_row(1, "a"), create_row(1, "b")),
        Seq(create_row(1, 2L))) ::
      (2,
        Seq(create_row(2, "c")),
        Seq(create_row(2, 3L))) ::
      (3,
        Seq.empty,
        Seq(create_row(3, 4L))) ::
      Nil
    )
  }

  test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") {
    val leftInput = Seq(create_row(2, "a")).iterator
    val rightInput = Seq(create_row(1, 2L)).iterator
    val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
    val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
    val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))

    val result = cogrouped.map {
      case (key, leftData, rightData) =>
        assert(key.numFields == 1)
        (key.getInt(0), leftData.toSeq, rightData.toSeq)
    }.toSeq

    assert(result ==
      (1,
        Seq.empty,
        Seq(create_row(1, 2L))) ::
      (2,
        Seq(create_row(2, "a")),
        Seq.empty) ::
      Nil
    )
  }
} 
Example 48
Source File: SQLExecutionSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import java.util.Properties

import scala.collection.parallel.CompositeThrowable

import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.SQLContext

class SQLExecutionSuite extends SparkFunSuite {

  test("concurrent query execution (SPARK-10548)") {
    // Try to reproduce the issue with the old SparkContext
    val conf = new SparkConf()
      .setMaster("local[*]")
      .setAppName("test")
    val badSparkContext = new BadSparkContext(conf)
    try {
      testConcurrentQueryExecution(badSparkContext)
      fail("unable to reproduce SPARK-10548")
    } catch {
      case e: IllegalArgumentException =>
        assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY))
    } finally {
      badSparkContext.stop()
    }

    // Verify that the issue is fixed with the latest SparkContext
    val goodSparkContext = new SparkContext(conf)
    try {
      testConcurrentQueryExecution(goodSparkContext)
    } finally {
      goodSparkContext.stop()
    }
  }

  
private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) {
  protected[spark] override val localProperties = new InheritableThreadLocal[Properties] {
    override protected def childValue(parent: Properties): Properties = new Properties(parent)
    override protected def initialValue(): Properties = new Properties()
  }
} 
Example 49
Source File: NullableColumnBuilderSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.columnar

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow}
import org.apache.spark.sql.types._

class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType])
  extends BasicColumnBuilder[JvmType](new NoopColumnStats, columnType)
  with NullableColumnBuilder

object TestNullableColumnBuilder {
  def apply[JvmType](columnType: ColumnType[JvmType], initialSize: Int = 0)
    : TestNullableColumnBuilder[JvmType] = {
    val builder = new TestNullableColumnBuilder(columnType)
    builder.initialize(initialSize)
    builder
  }
}

class NullableColumnBuilderSuite extends SparkFunSuite {
  import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._

  Seq(
    BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
    STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
    STRUCT(StructType(StructField("a", StringType) :: Nil)),
    ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType)))
    .foreach {
    testNullableColumnBuilder(_)
  }

  def testNullableColumnBuilder[JvmType](
      columnType: ColumnType[JvmType]): Unit = {

    val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
    val dataType = columnType.dataType
    val proj = UnsafeProjection.create(Array[DataType](dataType))
    val converter = CatalystTypeConverters.createToScalaConverter(dataType)

    test(s"$typeName column builder: empty column") {
      val columnBuilder = TestNullableColumnBuilder(columnType)
      val buffer = columnBuilder.build()

      assertResult(0, "Wrong null count")(buffer.getInt())
      assert(!buffer.hasRemaining)
    }

    test(s"$typeName column builder: buffer size auto growth") {
      val columnBuilder = TestNullableColumnBuilder(columnType)
      val randomRow = makeRandomRow(columnType)

      (0 until 4).foreach { _ =>
        columnBuilder.appendFrom(proj(randomRow), 0)
      }

      val buffer = columnBuilder.build()

      assertResult(0, "Wrong null count")(buffer.getInt())
    }

    test(s"$typeName column builder: null values") {
      val columnBuilder = TestNullableColumnBuilder(columnType)
      val randomRow = makeRandomRow(columnType)
      val nullRow = makeNullRow(1)

      (0 until 4).foreach { _ =>
        columnBuilder.appendFrom(proj(randomRow), 0)
        columnBuilder.appendFrom(proj(nullRow), 0)
      }

      val buffer = columnBuilder.build()

      assertResult(4, "Wrong null count")(buffer.getInt())

      // For null positions
      (1 to 7 by 2).foreach(assertResult(_, "Wrong null position")(buffer.getInt()))

      // For non-null values
      val actual = new GenericMutableRow(new Array[Any](1))
      (0 until 4).foreach { _ =>
        columnType.extract(buffer, actual, 0)
        assert(converter(actual.get(0, dataType)) === converter(randomRow.get(0, dataType)),
          "Extracted value didn't equal to the original one")
      }

      assert(!buffer.hasRemaining)
    }
  }
} 
Example 50
Source File: ColumnStatsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.columnar

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types._

class ColumnStatsSuite extends SparkFunSuite {
  testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0))
  testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0))
  testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0))
  testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0))
  testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0))
  testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0))
  testColumnStats(classOf[DoubleColumnStats], DOUBLE,
    createRow(Double.MaxValue, Double.MinValue, 0))
  testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0))
  testDecimalColumnStats(createRow(null, null, 0))

  def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray)

  def testColumnStats[T <: AtomicType, U <: ColumnStats](
      columnStatsClass: Class[U],
      columnType: NativeColumnType[T],
      initialStatistics: GenericInternalRow): Unit = {

    val columnStatsName = columnStatsClass.getSimpleName

    test(s"$columnStatsName: empty") {
      val columnStats = columnStatsClass.newInstance()
      columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
        case (actual, expected) => assert(actual === expected)
      }
    }

    test(s"$columnStatsName: non-empty") {
      import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._

      val columnStats = columnStatsClass.newInstance()
      val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
      rows.foreach(columnStats.gatherStats(_, 0))

      val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType])
      val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
      val stats = columnStats.collectedStatistics

      assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
      assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
      assertResult(10, "Wrong null count")(stats.values(2))
      assertResult(20, "Wrong row count")(stats.values(3))
      assertResult(stats.values(4), "Wrong size in bytes") {
        rows.map { row =>
          if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
        }.sum
      }
    }
  }

  def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](
      initialStatistics: GenericInternalRow): Unit = {

    val columnStatsName = classOf[DecimalColumnStats].getSimpleName
    val columnType = COMPACT_DECIMAL(15, 10)

    test(s"$columnStatsName: empty") {
      val columnStats = new DecimalColumnStats(15, 10)
      columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
        case (actual, expected) => assert(actual === expected)
      }
    }

    test(s"$columnStatsName: non-empty") {
      import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._

      val columnStats = new DecimalColumnStats(15, 10)
      val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
      rows.foreach(columnStats.gatherStats(_, 0))

      val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType])
      val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
      val stats = columnStats.collectedStatistics

      assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
      assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
      assertResult(10, "Wrong null count")(stats.values(2))
      assertResult(20, "Wrong row count")(stats.values(3))
      assertResult(stats.values(4), "Wrong size in bytes") {
        rows.map { row =>
          if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
        }.sum
      }
    }
  }
} 
Example 51
Source File: NullableColumnAccessorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.columnar

import java.nio.ByteBuffer

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow}
import org.apache.spark.sql.types._

class TestNullableColumnAccessor[JvmType](
    buffer: ByteBuffer,
    columnType: ColumnType[JvmType])
  extends BasicColumnAccessor(buffer, columnType)
  with NullableColumnAccessor

object TestNullableColumnAccessor {
  def apply[JvmType](buffer: ByteBuffer, columnType: ColumnType[JvmType])
    : TestNullableColumnAccessor[JvmType] = {
    new TestNullableColumnAccessor(buffer, columnType)
  }
}

class NullableColumnAccessorSuite extends SparkFunSuite {
  import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._

  Seq(
    NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
    STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
    STRUCT(StructType(StructField("a", StringType) :: Nil)),
    ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType)))
    .foreach {
    testNullableColumnAccessor(_)
  }

  def testNullableColumnAccessor[JvmType](
      columnType: ColumnType[JvmType]): Unit = {

    val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
    val nullRow = makeNullRow(1)

    test(s"Nullable $typeName column accessor: empty column") {
      val builder = TestNullableColumnBuilder(columnType)
      val accessor = TestNullableColumnAccessor(builder.build(), columnType)
      assert(!accessor.hasNext)
    }

    test(s"Nullable $typeName column accessor: access null values") {
      val builder = TestNullableColumnBuilder(columnType)
      val randomRow = makeRandomRow(columnType)
      val proj = UnsafeProjection.create(Array[DataType](columnType.dataType))

      (0 until 4).foreach { _ =>
        builder.appendFrom(proj(randomRow), 0)
        builder.appendFrom(proj(nullRow), 0)
      }

      val accessor = TestNullableColumnAccessor(builder.build(), columnType)
      val row = new GenericMutableRow(1)
      val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType)

      (0 until 4).foreach { _ =>
        assert(accessor.hasNext)
        accessor.extractTo(row, 0)
        assert(converter(row.get(0, columnType.dataType))
          === converter(randomRow.get(0, columnType.dataType)))

        assert(accessor.hasNext)
        accessor.extractTo(row, 0)
        assert(row.isNullAt(0))
      }

      assert(!accessor.hasNext)
    }
  }
} 
Example 52
Source File: GroupedIteratorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema)
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema)

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema)
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 53
Source File: RowSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class RowSuite extends SparkFunSuite with SharedSQLContext {
  import testImplicits._

  test("create row") {
    val expected = new GenericMutableRow(4)
    expected.setInt(0, 2147483647)
    expected.update(1, UTF8String.fromString("this is a string"))
    expected.setBoolean(2, false)
    expected.setNullAt(3)

    val actual1 = Row(2147483647, "this is a string", false, null)
    assert(expected.numFields === actual1.size)
    assert(expected.getInt(0) === actual1.getInt(0))
    assert(expected.getString(1) === actual1.getString(1))
    assert(expected.getBoolean(2) === actual1.getBoolean(2))
    assert(expected.isNullAt(3) === actual1.isNullAt(3))

    val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
    assert(expected.numFields === actual2.size)
    assert(expected.getInt(0) === actual2.getInt(0))
    assert(expected.getString(1) === actual2.getString(1))
    assert(expected.getBoolean(2) === actual2.getBoolean(2))
    assert(expected.isNullAt(3) === actual2.isNullAt(3))
  }

  test("SpecificMutableRow.update with null") {
    val row = new SpecificMutableRow(Seq(IntegerType))
    row(0) = null
    assert(row.isNullAt(0))
  }

  test("serialize w/ kryo") {
    val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first()
    val serializer = new SparkSqlSerializer(sparkContext.getConf)
    val instance = serializer.newInstance()
    val ser = instance.serialize(row)
    val de = instance.deserialize(ser).asInstanceOf[Row]
    assert(de === row)
  }

  test("get values by field name on Row created via .toDF") {
    val row = Seq((1, Seq(1))).toDF("a", "b").first()
    assert(row.getAs[Int]("a") === 1)
    assert(row.getAs[Seq[Int]]("b") === Seq(1))

    intercept[IllegalArgumentException]{
      row.getAs[Int]("c")
    }
  }

  test("float NaN == NaN") {
    val r1 = Row(Float.NaN)
    val r2 = Row(Float.NaN)
    assert(r1 === r2)
  }

  test("double NaN == NaN") {
    val r1 = Row(Double.NaN)
    val r2 = Row(Double.NaN)
    assert(r1 === r2)
  }

  test("equals and hashCode") {
    val r1 = Row("Hello")
    val r2 = Row("Hello")
    assert(r1 === r2)
    assert(r1.hashCode() === r2.hashCode())
    val r3 = Row("World")
    assert(r3.hashCode() != r1.hashCode())
  }
} 
Example 54
Source File: ResolvedDataSourceSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.datasources.ResolvedDataSource

class ResolvedDataSourceSuite extends SparkFunSuite {

  test("jdbc") {
    assert(
      ResolvedDataSource.lookupDataSource("jdbc") ===
      classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource])
    assert(
      ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.jdbc") ===
      classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource])
    assert(
      ResolvedDataSource.lookupDataSource("org.apache.spark.sql.jdbc") ===
        classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource])
  }

  test("json") {
    assert(
      ResolvedDataSource.lookupDataSource("json") ===
      classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource])
    assert(
      ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.json") ===
        classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource])
    assert(
      ResolvedDataSource.lookupDataSource("org.apache.spark.sql.json") ===
        classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource])
  }

  test("parquet") {
    assert(
      ResolvedDataSource.lookupDataSource("parquet") ===
      classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource])
    assert(
      ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.parquet") ===
        classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource])
    assert(
      ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") ===
        classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource])
  }

  test("error message for unknown data sources") {
    val error1 = intercept[ClassNotFoundException] {
      ResolvedDataSource.lookupDataSource("avro")
    }
    assert(error1.getMessage.contains("spark-packages"))

    val error2 = intercept[ClassNotFoundException] {
      ResolvedDataSource.lookupDataSource("com.databricks.spark.avro")
    }
    assert(error2.getMessage.contains("spark-packages"))

    val error3 = intercept[ClassNotFoundException] {
      ResolvedDataSource.lookupDataSource("asfdwefasdfasdf")
    }
    assert(error3.getMessage.contains("spark-packages"))
  }
} 
Example 55
Source File: RateLimiterSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.receiver

import org.apache.spark.SparkConf
import org.apache.spark.SparkFunSuite


class RateLimiterSuite extends SparkFunSuite {

  test("rate limiter initializes even without a maxRate set") {
    val conf = new SparkConf()
    val rateLimiter = new RateLimiter(conf){}
    rateLimiter.updateRate(105)
    assert(rateLimiter.getCurrentLimit == 105)
  }

  test("rate limiter updates when below maxRate") {
    val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "110")
    val rateLimiter = new RateLimiter(conf){}
    rateLimiter.updateRate(105)
    assert(rateLimiter.getCurrentLimit == 105)
  }

  test("rate limiter stays below maxRate despite large updates") {
    val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "100")
    val rateLimiter = new RateLimiter(conf){}
    rateLimiter.updateRate(105)
    assert(rateLimiter.getCurrentLimit === 100)
  }
} 
Example 56
Source File: FailureSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming

import java.io.File

import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkFunSuite, Logging}
import org.apache.spark.util.Utils


class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging {

  private val batchDuration: Duration = Milliseconds(1000)
  private val numBatches = 30
  private var directory: File = null

  before {
    directory = Utils.createTempDir()
  }

  after {
    if (directory != null) {
      Utils.deleteRecursively(directory)
    }
    StreamingContext.getActive().foreach { _.stop() }
  }

  test("multiple failures with map") {
    MasterFailureTest.testMap(directory.getAbsolutePath, numBatches, batchDuration)
  }

  test("multiple failures with updateStateByKey") {
    MasterFailureTest.testUpdateStateByKey(directory.getAbsolutePath, numBatches, batchDuration)
  }
} 
Example 57
Source File: UIUtilsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.ui

import java.util.TimeZone
import java.util.concurrent.TimeUnit

import org.scalatest.Matchers

import org.apache.spark.SparkFunSuite

class UIUtilsSuite extends SparkFunSuite with Matchers{

  test("shortTimeUnitString") {
    assert("ns" === UIUtils.shortTimeUnitString(TimeUnit.NANOSECONDS))
    assert("us" === UIUtils.shortTimeUnitString(TimeUnit.MICROSECONDS))
    assert("ms" === UIUtils.shortTimeUnitString(TimeUnit.MILLISECONDS))
    assert("sec" === UIUtils.shortTimeUnitString(TimeUnit.SECONDS))
    assert("min" === UIUtils.shortTimeUnitString(TimeUnit.MINUTES))
    assert("hrs" === UIUtils.shortTimeUnitString(TimeUnit.HOURS))
    assert("days" === UIUtils.shortTimeUnitString(TimeUnit.DAYS))
  }

  test("normalizeDuration") {
    verifyNormalizedTime(900, TimeUnit.MILLISECONDS, 900)
    verifyNormalizedTime(1.0, TimeUnit.SECONDS, 1000)
    verifyNormalizedTime(1.0, TimeUnit.MINUTES, 60 * 1000)
    verifyNormalizedTime(1.0, TimeUnit.HOURS, 60 * 60 * 1000)
    verifyNormalizedTime(1.0, TimeUnit.DAYS, 24 * 60 * 60 * 1000)
  }

  private def verifyNormalizedTime(
      expectedTime: Double, expectedUnit: TimeUnit, input: Long): Unit = {
    val (time, unit) = UIUtils.normalizeDuration(input)
    time should be (expectedTime +- 1E-6)
    unit should be (expectedUnit)
  }

  test("convertToTimeUnit") {
    verifyConvertToTimeUnit(60.0 * 1000 * 1000 * 1000, 60 * 1000, TimeUnit.NANOSECONDS)
    verifyConvertToTimeUnit(60.0 * 1000 * 1000, 60 * 1000, TimeUnit.MICROSECONDS)
    verifyConvertToTimeUnit(60 * 1000, 60 * 1000, TimeUnit.MILLISECONDS)
    verifyConvertToTimeUnit(60, 60 * 1000, TimeUnit.SECONDS)
    verifyConvertToTimeUnit(1, 60 * 1000, TimeUnit.MINUTES)
    verifyConvertToTimeUnit(1.0 / 60, 60 * 1000, TimeUnit.HOURS)
    verifyConvertToTimeUnit(1.0 / 60 / 24, 60 * 1000, TimeUnit.DAYS)
  }

  private def verifyConvertToTimeUnit(
      expectedTime: Double, milliseconds: Long, unit: TimeUnit): Unit = {
    val convertedTime = UIUtils.convertToTimeUnit(milliseconds, unit)
    convertedTime should be (expectedTime +- 1E-6)
  }

  test("formatBatchTime") {
    val tzForTest = TimeZone.getTimeZone("America/Los_Angeles")
    val batchTime = 1431637480452L // Thu May 14 14:04:40 PDT 2015
    assert("2015/05/14 14:04:40" === UIUtils.formatBatchTime(batchTime, 1000, timezone = tzForTest))
    assert("2015/05/14 14:04:40.452" ===
      UIUtils.formatBatchTime(batchTime, 999, timezone = tzForTest))
    assert("14:04:40" === UIUtils.formatBatchTime(batchTime, 1000, false, timezone = tzForTest))
    assert("14:04:40.452" === UIUtils.formatBatchTime(batchTime, 999, false, timezone = tzForTest))
  }
} 
Example 58
Source File: InputInfoTrackerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.streaming.{Time, Duration, StreamingContext}

class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter {

  private var ssc: StreamingContext = _

  before {
    val conf = new SparkConf().setMaster("local[2]").setAppName("DirectStreamTacker")
    if (ssc == null) {
      ssc = new StreamingContext(conf, Duration(1000))
    }
  }

  after {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }
  }

  test("test report and get InputInfo from InputInfoTracker") {
    val inputInfoTracker = new InputInfoTracker(ssc)

    val streamId1 = 0
    val streamId2 = 1
    val time = Time(0L)
    val inputInfo1 = StreamInputInfo(streamId1, 100L)
    val inputInfo2 = StreamInputInfo(streamId2, 300L)
    inputInfoTracker.reportInfo(time, inputInfo1)
    inputInfoTracker.reportInfo(time, inputInfo2)

    val batchTimeToInputInfos = inputInfoTracker.getInfo(time)
    assert(batchTimeToInputInfos.size == 2)
    assert(batchTimeToInputInfos.keys === Set(streamId1, streamId2))
    assert(batchTimeToInputInfos(streamId1) === inputInfo1)
    assert(batchTimeToInputInfos(streamId2) === inputInfo2)
    assert(inputInfoTracker.getInfo(time)(streamId1) === inputInfo1)
  }

  test("test cleanup InputInfo from InputInfoTracker") {
    val inputInfoTracker = new InputInfoTracker(ssc)

    val streamId1 = 0
    val inputInfo1 = StreamInputInfo(streamId1, 100L)
    val inputInfo2 = StreamInputInfo(streamId1, 300L)
    inputInfoTracker.reportInfo(Time(0), inputInfo1)
    inputInfoTracker.reportInfo(Time(1), inputInfo2)

    inputInfoTracker.cleanup(Time(0))
    assert(inputInfoTracker.getInfo(Time(0))(streamId1) === inputInfo1)
    assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2)

    inputInfoTracker.cleanup(Time(1))
    assert(inputInfoTracker.getInfo(Time(0)).get(streamId1) === None)
    assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2)
  }
} 
Example 59
Source File: RecurringTimerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.util

import scala.collection.mutable
import scala.concurrent.duration._

import org.scalatest.PrivateMethodTester
import org.scalatest.concurrent.Eventually._

import org.apache.spark.SparkFunSuite
import org.apache.spark.util.ManualClock

class RecurringTimerSuite extends SparkFunSuite with PrivateMethodTester {

  test("basic") {
    val clock = new ManualClock()
    val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long]
    val timer = new RecurringTimer(clock, 100, time => {
      results += time
    }, "RecurringTimerSuite-basic")
    timer.start(0)
    eventually(timeout(10.seconds), interval(10.millis)) {
      assert(results === Seq(0L))
    }
    clock.advance(100)
    eventually(timeout(10.seconds), interval(10.millis)) {
      assert(results === Seq(0L, 100L))
    }
    clock.advance(200)
    eventually(timeout(10.seconds), interval(10.millis)) {
      assert(results === Seq(0L, 100L, 200L, 300L))
    }
    assert(timer.stop(interruptTimer = true) === 300L)
  }

  test("SPARK-10224: call 'callback' after stopping") {
    val clock = new ManualClock()
    val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long]
    val timer = new RecurringTimer(clock, 100, time => {
      results += time
    }, "RecurringTimerSuite-SPARK-10224")
    timer.start(0)
    eventually(timeout(10.seconds), interval(10.millis)) {
      assert(results === Seq(0L))
    }
    @volatile var lastTime = -1L
    // Now RecurringTimer is waiting for the next interval
    val thread = new Thread {
      override def run(): Unit = {
        lastTime = timer.stop(interruptTimer = false)
      }
    }
    thread.start()
    val stopped = PrivateMethod[RecurringTimer]('stopped)
    // Make sure the `stopped` field has been changed
    eventually(timeout(10.seconds), interval(10.millis)) {
      assert(timer.invokePrivate(stopped()) === true)
    }
    clock.advance(200)
    // When RecurringTimer is awake from clock.waitTillTime, it will call `callback` once.
    // Then it will find `stopped` is true and exit the loop, but it should call `callback` again
    // before exiting its internal thread.
    thread.join()
    assert(results === Seq(0L, 100L, 200L))
    assert(lastTime === 200L)
  }
} 
Example 60
Source File: RateLimitedOutputStreamSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.util

import java.io.ByteArrayOutputStream
import java.util.concurrent.TimeUnit._

import org.apache.spark.SparkFunSuite

class RateLimitedOutputStreamSuite extends SparkFunSuite {

  private def benchmark[U](f: => U): Long = {
    val start = System.nanoTime
    f
    System.nanoTime - start
  }

  test("write") {
    val underlying = new ByteArrayOutputStream
    val data = "X" * 41000
    val stream = new RateLimitedOutputStream(underlying, desiredBytesPerSec = 10000)
    val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) }

    val seconds = SECONDS.convert(elapsedNs, NANOSECONDS)
    assert(seconds >= 4, s"Seconds value ($seconds) is less than 4.")
    assert(seconds <= 30, s"Took more than 30 seconds ($seconds) to write data.")
    assert(underlying.toString("UTF-8") === data)
  }
} 
Example 61
Source File: NettyRpcHandlerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rpc.netty

import java.net.InetSocketAddress
import java.nio.ByteBuffer

import io.netty.channel.Channel
import org.mockito.Mockito._
import org.mockito.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.network.client.{TransportResponseHandler, TransportClient}
import org.apache.spark.network.server.StreamManager
import org.apache.spark.rpc._

class NettyRpcHandlerSuite extends SparkFunSuite {

  val env = mock(classOf[NettyRpcEnv])
  val sm = mock(classOf[StreamManager])
  when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any()))
    .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null))

  test("receive") {
    val dispatcher = mock(classOf[Dispatcher])
    val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)

    val channel = mock(classOf[Channel])
    val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
    when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
    nettyRpcHandler.receive(client, null, null)

    verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
  }

  test("connectionTerminated") {
    val dispatcher = mock(classOf[Dispatcher])
    val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)

    val channel = mock(classOf[Channel])
    val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
    when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
    nettyRpcHandler.receive(client, null, null)

    when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
    nettyRpcHandler.connectionTerminated(client)

    verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
    verify(dispatcher, times(1)).postToAll(
      RemoteProcessDisconnected(RpcAddress("localhost", 40000)))
  }

} 
Example 62
Source File: RpcAddressSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rpc

import org.apache.spark.{SparkException, SparkFunSuite}

class RpcAddressSuite extends SparkFunSuite {

  test("hostPort") {
    val address = RpcAddress("1.2.3.4", 1234)
    assert(address.host == "1.2.3.4")
    assert(address.port == 1234)
    assert(address.hostPort == "1.2.3.4:1234")
  }

  test("fromSparkURL") {
    val address = RpcAddress.fromSparkURL("spark://1.2.3.4:1234")
    assert(address.host == "1.2.3.4")
    assert(address.port == 1234)
  }

  test("fromSparkURL: a typo url") {
    val e = intercept[SparkException] {
      RpcAddress.fromSparkURL("spark://1.2. 3.4:1234")
    }
    assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage)
  }

  test("fromSparkURL: invalid scheme") {
    val e = intercept[SparkException] {
      RpcAddress.fromSparkURL("invalid://1.2.3.4:1234")
    }
    assert("Invalid master URL: invalid://1.2.3.4:1234" === e.getMessage)
  }

  test("toSparkURL") {
    val address = RpcAddress("1.2.3.4", 1234)
    assert(address.toSparkURL == "spark://1.2.3.4:1234")
  }
} 
Example 63
Source File: NettyBlockTransferServiceSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import org.apache.spark.network.BlockDataManager
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.mockito.Mockito.mock
import org.scalatest._

class NettyBlockTransferServiceSuite
  extends SparkFunSuite
  with BeforeAndAfterEach
  with ShouldMatchers {

  private var service0: NettyBlockTransferService = _
  private var service1: NettyBlockTransferService = _

  override def afterEach() {
    if (service0 != null) {
      service0.close()
      service0 = null
    }

    if (service1 != null) {
      service1.close()
      service1 = null
    }
  }

  test("can bind to a random port") {
    service0 = createService(port = 0)
    service0.port should not be 0
  }

  test("can bind to two random ports") {
    service0 = createService(port = 0)
    service1 = createService(port = 0)
    service0.port should not be service1.port
  }

  test("can bind to a specific port") {
    val port = 17634
    service0 = createService(port)
    service0.port should be >= port
    service0.port should be <= (port + 10) // avoid testing equality in case of simultaneous tests
  }

  test("can bind to a specific port twice and the second increments") {
    val port = 17634
    service0 = createService(port)
    service1 = createService(port)
    service0.port should be >= port
    service0.port should be <= (port + 10)
    service1.port should be (service0.port + 1)
  }

  private def createService(port: Int): NettyBlockTransferService = {
    val conf = new SparkConf()
      .set("spark.app.id", s"test-${getClass.getName}")
      .set("spark.blockManager.port", port.toString)
    val securityManager = new SecurityManager(conf)
    val blockDataManager = mock(classOf[BlockDataManager])
    val service = new NettyBlockTransferService(conf, securityManager, numCores = 1)
    service.init(blockDataManager)
    service
  }
} 
Example 64
Source File: PythonBroadcastSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import scala.io.Source

import java.io.{PrintWriter, File}

import org.scalatest.Matchers

import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.Utils

// This test suite uses SharedSparkContext because we need a SparkEnv in order to deserialize
// a PythonBroadcast:
class PythonBroadcastSuite extends SparkFunSuite with Matchers with SharedSparkContext {
  test("PythonBroadcast can be serialized with Kryo (SPARK-4882)") {
    val tempDir = Utils.createTempDir()
    val broadcastedString = "Hello, world!"
    def assertBroadcastIsValid(broadcast: PythonBroadcast): Unit = {
      val source = Source.fromFile(broadcast.path)
      val contents = source.mkString
      source.close()
      contents should be (broadcastedString)
    }
    try {
      val broadcastDataFile: File = {
        val file = new File(tempDir, "broadcastData")
        val printWriter = new PrintWriter(file)
        printWriter.write(broadcastedString)
        printWriter.close()
        file
      }
      val broadcast = new PythonBroadcast(broadcastDataFile.getAbsolutePath)
      assertBroadcastIsValid(broadcast)
      val conf = new SparkConf().set("spark.kryo.registrationRequired", "true")
      val deserializedBroadcast =
        Utils.clone[PythonBroadcast](broadcast, new KryoSerializer(conf).newInstance())
      assertBroadcastIsValid(deserializedBroadcast)
    } finally {
      Utils.deleteRecursively(tempDir)
    }
  }
} 
Example 65
Source File: PythonRDDSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.{ByteArrayOutputStream, DataOutputStream}

import org.apache.spark.SparkFunSuite

class PythonRDDSuite extends SparkFunSuite {

  test("Writing large strings to the worker") {
    val input: List[String] = List("a"*100000)
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    PythonRDD.writeIteratorToStream(input.iterator, buffer)
  }

  test("Handle nulls gracefully") {
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    // Should not have NPE when write an Iterator with null in it
    // The correctness will be tested in Python
    PythonRDD.writeIteratorToStream(Iterator("a", null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer)
    PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer)
    PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer)
    PythonRDD.writeIteratorToStream(
      Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer)
  }
} 
Example 66
Source File: SerDeUtilSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import org.apache.spark.{SharedSparkContext, SparkFunSuite}

class SerDeUtilSuite extends SparkFunSuite with SharedSparkContext {

  test("Converting an empty pair RDD to python does not throw an exception (SPARK-5441)") {
    val emptyRdd = sc.makeRDD(Seq[(Any, Any)]())
    SerDeUtil.pairRDDToPython(emptyRdd, 10)
  }

  test("Converting an empty python RDD to pair RDD does not throw an exception (SPARK-5441)") {
    val emptyRdd = sc.makeRDD(Seq[(Any, Any)]())
    val javaRdd = emptyRdd.toJavaRDD()
    val pythonRdd = SerDeUtil.javaToPython(javaRdd)
    SerDeUtil.pythonToPairRDD(pythonRdd, false)
  }
} 
Example 67
Source File: PythonRunnerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy

import org.apache.spark.SparkFunSuite
import org.apache.spark.util.Utils

class PythonRunnerSuite extends SparkFunSuite {

  // Test formatting a single path to be added to the PYTHONPATH
  test("format path") {
    assert(PythonRunner.formatPath("spark.py") === "spark.py")
    assert(PythonRunner.formatPath("file:/spark.py") === "/spark.py")
    assert(PythonRunner.formatPath("file:///spark.py") === "/spark.py")
    assert(PythonRunner.formatPath("local:/spark.py") === "/spark.py")
    assert(PythonRunner.formatPath("local:///spark.py") === "/spark.py")
    if (Utils.isWindows) {
      assert(PythonRunner.formatPath("file:/C:/a/b/spark.py", testWindows = true) ===
        "C:/a/b/spark.py")
      assert(PythonRunner.formatPath("C:\\a\\b\\spark.py", testWindows = true) ===
        "C:/a/b/spark.py")
      assert(PythonRunner.formatPath("C:\\a b\\spark.py", testWindows = true) ===
        "C:/a b/spark.py")
    }
    intercept[IllegalArgumentException] { PythonRunner.formatPath("one:two") }
    intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:s3:xtremeFS") }
    intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:/path/to/some.py") }
  }

  // Test formatting multiple comma-separated paths to be added to the PYTHONPATH
  test("format paths") {
    assert(PythonRunner.formatPaths("spark.py") === Array("spark.py"))
    assert(PythonRunner.formatPaths("file:/spark.py") === Array("/spark.py"))
    assert(PythonRunner.formatPaths("file:/app.py,local:/spark.py") ===
      Array("/app.py", "/spark.py"))
    assert(PythonRunner.formatPaths("me.py,file:/you.py,local:/we.py") ===
      Array("me.py", "/you.py", "/we.py"))
    if (Utils.isWindows) {
      assert(PythonRunner.formatPaths("C:\\a\\b\\spark.py", testWindows = true) ===
        Array("C:/a/b/spark.py"))
      assert(PythonRunner.formatPaths("C:\\free.py,pie.py", testWindows = true) ===
        Array("C:/free.py", "pie.py"))
      assert(PythonRunner.formatPaths("lovely.py,C:\\free.py,file:/d:/fry.py",
        testWindows = true) ===
        Array("lovely.py", "C:/free.py", "d:/fry.py"))
    }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("one:two,three") }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("two,three,four:five:six") }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("hdfs:/some.py,foo.py") }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("foo.py,hdfs:/some.py") }
  }
} 
Example 68
Source File: LogUrlsStandaloneSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy

import java.net.URL

import scala.collection.mutable
import scala.io.Source

import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListener}
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.util.SparkConfWithEnv

class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext {

  
  private val WAIT_TIMEOUT_MILLIS = 10000

  test("verify that correct log urls get propagated from workers") {
    sc = new SparkContext("local-cluster[2,1,1024]", "test")

    val listener = new SaveExecutorInfo
    sc.addSparkListener(listener)

    // Trigger a job so that executors get added
    sc.parallelize(1 to 100, 4).map(_.toString).count()

    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
    listener.addedExecutorInfos.values.foreach { info =>
      assert(info.logUrlMap.nonEmpty)
      // Browse to each URL to check that it's valid
      info.logUrlMap.foreach { case (logType, logUrl) =>
        val html = Source.fromURL(logUrl).mkString
        assert(html.contains(s"$logType log page"))
      }
    }
  }

  test("verify that log urls reflect SPARK_PUBLIC_DNS (SPARK-6175)") {
    val SPARK_PUBLIC_DNS = "public_dns"
    val conf = new SparkConfWithEnv(Map("SPARK_PUBLIC_DNS" -> SPARK_PUBLIC_DNS)).set(
      "spark.extraListeners", classOf[SaveExecutorInfo].getName)
    sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)

    // Trigger a job so that executors get added
    sc.parallelize(1 to 100, 4).map(_.toString).count()

    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
    val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo]
    assert(listeners.size === 1)
    val listener = listeners(0)
    listener.addedExecutorInfos.values.foreach { info =>
      assert(info.logUrlMap.nonEmpty)
      info.logUrlMap.values.foreach { logUrl =>
        assert(new URL(logUrl).getHost === SPARK_PUBLIC_DNS)
      }
    }
  }
}

private[spark] class SaveExecutorInfo extends SparkListener {
  val addedExecutorInfos = mutable.Map[String, ExecutorInfo]()

  override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
    addedExecutorInfos(executor.executorId) = executor.executorInfo
  }
} 
Example 69
Source File: ClientSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy

import org.scalatest.Matchers

import org.apache.spark.SparkFunSuite

class ClientSuite extends SparkFunSuite with Matchers {
  test("correctly validates driver jar URL's") {
    ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true)
    ClientArguments.isValidJarUrl("https://someHost:8080/foo.jar") should be (true)

    // file scheme with authority and path is valid.
    ClientArguments.isValidJarUrl("file://somehost/path/to/a/jarFile.jar") should be (true)

    // file scheme without path is not valid.
    // In this case, jarFile.jar is recognized as authority.
    ClientArguments.isValidJarUrl("file://jarFile.jar") should be (false)

    // file scheme without authority but with triple slash is valid.
    ClientArguments.isValidJarUrl("file:///some/path/to/a/jarFile.jar") should be (true)
    ClientArguments.isValidJarUrl("hdfs://someHost:1234/foo.jar") should be (true)

    ClientArguments.isValidJarUrl("hdfs://someHost:1234/foo") should be (false)
    ClientArguments.isValidJarUrl("/missing/a/protocol/jarfile.jar") should be (false)
    ClientArguments.isValidJarUrl("not-even-a-path.jar") should be (false)

    // This URI doesn't have authority and path.
    ClientArguments.isValidJarUrl("hdfs:someHost:1234/jarfile.jar") should be (false)

    // Invalid syntax.
    ClientArguments.isValidJarUrl("hdfs:") should be (false)
  }
} 
Example 70
Source File: MasterWebUISuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.util.Date

import scala.io.Source
import scala.language.postfixOps

import org.json4s.jackson.JsonMethods._
import org.json4s.JsonAST.{JNothing, JString, JInt}
import org.mockito.Mockito.{mock, when}
import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SecurityManager, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.MasterStateResponse
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.RpcEnv


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfter {

  val masterPage = mock(classOf[MasterPage])
  val master = {
    val conf = new SparkConf
    val securityMgr = new SecurityManager(conf)
    val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr)
    val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf)
    master
  }
  val masterWebUI = new MasterWebUI(master, 0, customMasterPage = Some(masterPage))

  before {
    masterWebUI.bind()
  }

  after {
    masterWebUI.stop()
  }

  test("list applications") {
    val worker = createWorkerInfo()
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "id", appDesc, new Date(), null, Int.MaxValue)
    activeApp.addExecutor(worker, 2)

    val workers = Array[WorkerInfo](worker)
    val activeApps = Array(activeApp)
    val completedApps = Array[ApplicationInfo]()
    val activeDrivers = Array[DriverInfo]()
    val completedDrivers = Array[DriverInfo]()
    val stateResponse = new MasterStateResponse(
      "host", 8080, None, workers, activeApps, completedApps,
      activeDrivers, completedDrivers, RecoveryState.ALIVE)

    when(masterPage.getMasterState).thenReturn(stateResponse)

    val resultJson = Source.fromURL(
      s"http://localhost:${masterWebUI.boundPort}/api/v1/applications")
      .mkString
    val parsedJson = parse(resultJson)
    val firstApp = parsedJson(0)

    assert(firstApp \ "id" === JString(activeApp.id))
    assert(firstApp \ "name" === JString(activeApp.desc.name))
    assert(firstApp \ "coresGranted" === JInt(2))
    assert(firstApp \ "maxCores" === JInt(4))
    assert(firstApp \ "memoryPerExecutorMB" === JInt(1234))
    assert(firstApp \ "coresPerExecutor" === JNothing)
  }

} 
Example 71
Source File: WorkerWatcherSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.SecurityManager
import org.apache.spark.rpc.{RpcAddress, RpcEnv}

class WorkerWatcherSuite extends SparkFunSuite {
  test("WorkerWatcher shuts down on valid disassociation") {
    val conf = new SparkConf()
    val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
    val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker")
    val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true)
    rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
    workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234))
    assert(workerWatcher.isShutDown)
    rpcEnv.shutdown()
  }

  test("WorkerWatcher stays alive on invalid disassociation") {
    val conf = new SparkConf()
    val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
    val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker")
    val otherRpcAddress = RpcAddress("4.3.2.1", 1234)
    val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true)
    rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
    workerWatcher.onDisconnected(otherRpcAddress)
    assert(!workerWatcher.isShutDown)
    rpcEnv.shutdown()
  }
} 
Example 72
Source File: WorkerArgumentsTest.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.SparkConfWithEnv

class WorkerArgumentsTest extends SparkFunSuite {

  test("Memory can't be set to 0 when cmd line args leave off M or G") {
    val conf = new SparkConf
    val args = Array("-m", "10000", "spark://localhost:0000  ")
    intercept[IllegalStateException] {
      new WorkerArguments(args, conf)
    }
  }


  test("Memory can't be set to 0 when SPARK_WORKER_MEMORY env property leaves off M or G") {
    val args = Array("spark://localhost:0000  ")
    val conf = new SparkConfWithEnv(Map("SPARK_WORKER_MEMORY" -> "50000"))
    intercept[IllegalStateException] {
      new WorkerArguments(args, conf)
    }
  }

  test("Memory correctly set when SPARK_WORKER_MEMORY env property appends G") {
    val args = Array("spark://localhost:0000  ")
    val conf = new SparkConfWithEnv(Map("SPARK_WORKER_MEMORY" -> "5G"))
    val workerArgs = new WorkerArguments(args, conf)
    assert(workerArgs.memory === 5120)
  }

  test("Memory correctly set from args with M appended to memory value") {
    val conf = new SparkConf
    val args = Array("-m", "10000M", "spark://localhost:0000  ")

    val workerArgs = new WorkerArguments(args, conf)
    assert(workerArgs.memory === 10000)

  }

} 
Example 73
Source File: ExecutorRunnerTest.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker

import java.io.File

import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState}
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}

class ExecutorRunnerTest extends SparkFunSuite {
  test("command includes appId") {
    val appId = "12345-worker321-9876"
    val conf = new SparkConf
    val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
    val appDesc = new ApplicationDescription("app name", Some(8), 500,
      Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl")
    val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", 123,
      "publicAddr", new File(sparkHome), new File("ooga"), "blah", conf, Seq("localDir"),
      ExecutorState.RUNNING)
    val builder = CommandUtils.buildProcessBuilder(
      appDesc.command, new SecurityManager(conf), 512, sparkHome, er.substituteVariables)
    val builderCommand = builder.command()
    assert(builderCommand.get(builderCommand.size() - 1) === appId)
  }
} 
Example 74
Source File: CommandUtilsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.Command
import org.apache.spark.util.Utils
import org.scalatest.{Matchers, PrivateMethodTester}

class CommandUtilsSuite extends SparkFunSuite with Matchers with PrivateMethodTester {

  test("set libraryPath correctly") {
    val appId = "12345-worker321-9876"
    val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
    val cmd = new Command("mainClass", Seq(), Map(), Seq(), Seq("libraryPathToB"), Seq())
    val builder = CommandUtils.buildProcessBuilder(
      cmd, new SecurityManager(new SparkConf), 512, sparkHome, t => t)
    val libraryPath = Utils.libraryPathEnvName
    val env = builder.environment
    env.keySet should contain(libraryPath)
    assert(env.get(libraryPath).startsWith("libraryPathToB"))
  }

  test("auth secret shouldn't appear in java opts") {
    val buildLocalCommand = PrivateMethod[Command]('buildLocalCommand)
    val conf = new SparkConf
    val secret = "This is the secret sauce"
    // set auth secret
    conf.set(SecurityManager.SPARK_AUTH_SECRET_CONF, secret)
    val command = new Command("mainClass", Seq(), Map(), Seq(), Seq("lib"),
      Seq("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF + "=" + secret))

    // auth is not set
    var cmd = CommandUtils invokePrivate buildLocalCommand(
      command, new SecurityManager(conf), (t: String) => t, Seq(), Map())
    assert(!cmd.javaOpts.exists(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF)))
    assert(!cmd.environment.contains(SecurityManager.ENV_AUTH_SECRET))

    // auth is set to false
    conf.set(SecurityManager.SPARK_AUTH_CONF, "false")
    cmd = CommandUtils invokePrivate buildLocalCommand(
      command, new SecurityManager(conf), (t: String) => t, Seq(), Map())
    assert(!cmd.javaOpts.exists(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF)))
    assert(!cmd.environment.contains(SecurityManager.ENV_AUTH_SECRET))

    // auth is set to true
    conf.set(SecurityManager.SPARK_AUTH_CONF, "true")
    cmd = CommandUtils invokePrivate buildLocalCommand(
      command, new SecurityManager(conf), (t: String) => t, Seq(), Map())
    assert(!cmd.javaOpts.exists(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF)))
    assert(cmd.environment(SecurityManager.ENV_AUTH_SECRET) === secret)
  }
} 
Example 75
Source File: LogPageSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker.ui

import java.io.{File, FileWriter}

import org.mockito.Mockito.{mock, when}
import org.scalatest.PrivateMethodTester

import org.apache.spark.SparkFunSuite

class LogPageSuite extends SparkFunSuite with PrivateMethodTester {

  test("get logs simple") {
    val webui = mock(classOf[WorkerWebUI])
    val tmpDir = new File(sys.props("java.io.tmpdir"))
    val workDir = new File(tmpDir, "work-dir")
    workDir.mkdir()
    when(webui.workDir).thenReturn(workDir)
    val logPage = new LogPage(webui)

    // Prepare some fake log files to read later
    val out = "some stdout here"
    val err = "some stderr here"
    val tmpOut = new File(workDir, "stdout")
    val tmpErr = new File(workDir, "stderr")
    val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory
    val tmpOutBad = new File(tmpDir, "stdout")
    val tmpRand = new File(workDir, "random")
    write(tmpOut, out)
    write(tmpErr, err)
    write(tmpOutBad, out)
    write(tmpErrBad, err)
    write(tmpRand, "1 6 4 5 2 7 8")

    // Get the logs. All log types other than "stderr" or "stdout" will be rejected
    val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog)
    val (stdout, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100)
    val (stderr, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100)
    val (error1, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100)
    val (error2, _, _, _) =
      logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100)
    // These files exist, but live outside the working directory
    val (error3, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100)
    val (error4, _, _, _) =
      logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100)
    assert(stdout === out)
    assert(stderr === err)
    assert(error1.startsWith("Error: Log type must be one of "))
    assert(error2.startsWith("Error: Log type must be one of "))
    assert(error3.startsWith("Error: invalid log directory"))
    assert(error4.startsWith("Error: invalid log directory"))
  }

  
  private def write(f: File, s: String): Unit = {
    val writer = new FileWriter(f)
    try {
      writer.write(s)
    } finally {
      writer.close()
    }
  }

} 
Example 76
Source File: PagedTableSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import scala.xml.Node

import org.apache.spark.SparkFunSuite

class PagedDataSourceSuite extends SparkFunSuite {

  test("basic") {
    val dataSource1 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
    assert(dataSource1.pageData(1) === PageData(3, (1 to 2)))

    val dataSource2 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
    assert(dataSource2.pageData(2) === PageData(3, (3 to 4)))

    val dataSource3 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
    assert(dataSource3.pageData(3) === PageData(3, Seq(5)))

    val dataSource4 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
    val e1 = intercept[IndexOutOfBoundsException] {
      dataSource4.pageData(4)
    }
    assert(e1.getMessage === "Page 4 is out of range. Please select a page number between 1 and 3.")

    val dataSource5 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
    val e2 = intercept[IndexOutOfBoundsException] {
      dataSource5.pageData(0)
    }
    assert(e2.getMessage === "Page 0 is out of range. Please select a page number between 1 and 3.")

  }
}

class PagedTableSuite extends SparkFunSuite {
  test("pageNavigation") {
    // Create a fake PagedTable to test pageNavigation
    val pagedTable = new PagedTable[Int] {
      override def tableId: String = ""

      override def tableCssClass: String = ""

      override def dataSource: PagedDataSource[Int] = null

      override def pageLink(page: Int): String = page.toString

      override def headers: Seq[Node] = Nil

      override def row(t: Int): Seq[Node] = Nil

      override def goButtonJavascriptFunction: (String, String) = ("", "")
    }

    assert(pagedTable.pageNavigation(1, 10, 1) === Nil)
    assert(
      (pagedTable.pageNavigation(1, 10, 2).head \\ "li").map(_.text.trim) === Seq("1", "2", ">"))
    assert(
      (pagedTable.pageNavigation(2, 10, 2).head \\ "li").map(_.text.trim) === Seq("<", "1", "2"))

    assert((pagedTable.pageNavigation(1, 10, 100).head \\ "li").map(_.text.trim) ===
      (1 to 10).map(_.toString) ++ Seq(">", ">>"))
    assert((pagedTable.pageNavigation(2, 10, 100).head \\ "li").map(_.text.trim) ===
      Seq("<") ++ (1 to 10).map(_.toString) ++ Seq(">", ">>"))

    assert((pagedTable.pageNavigation(100, 10, 100).head \\ "li").map(_.text.trim) ===
      Seq("<<", "<") ++ (91 to 100).map(_.toString))
    assert((pagedTable.pageNavigation(99, 10, 100).head \\ "li").map(_.text.trim) ===
      Seq("<<", "<") ++ (91 to 100).map(_.toString) ++ Seq(">"))

    assert((pagedTable.pageNavigation(11, 10, 100).head \\ "li").map(_.text.trim) ===
      Seq("<<", "<") ++ (11 to 20).map(_.toString) ++ Seq(">", ">>"))
    assert((pagedTable.pageNavigation(93, 10, 97).head \\ "li").map(_.text.trim) ===
      Seq("<<", "<") ++ (91 to 97).map(_.toString) ++ Seq(">"))
  }
}

private[spark] class SeqPagedDataSource[T](seq: Seq[T], pageSize: Int)
  extends PagedDataSource[T](pageSize) {

  override protected def dataSize: Int = seq.size

  override protected def sliceData(from: Int, to: Int): Seq[T] = seq.slice(from, to)
} 
Example 77
Source File: UISuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import java.net.ServerSocket

import scala.io.Source
import scala.util.{Failure, Success, Try}

import org.eclipse.jetty.servlet.ServletContextHandler
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._

import org.apache.spark.LocalSparkContext._
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}

class UISuite extends SparkFunSuite {

  
  private def newSparkContext(): SparkContext = {
    val conf = new SparkConf()
      .setMaster("local")
      .setAppName("test")
      .set("spark.ui.enabled", "true")
    val sc = new SparkContext(conf)
    assert(sc.ui.isDefined)
    sc
  }

  ignore("basic ui visibility") {
    withSpark(newSparkContext()) { sc =>
      // test if the ui is visible, and all the expected tabs are visible
      eventually(timeout(10 seconds), interval(50 milliseconds)) {
        val html = Source.fromURL(sc.ui.get.appUIAddress).mkString
        assert(!html.contains("random data that should not be present"))
        assert(html.toLowerCase.contains("stages"))
        assert(html.toLowerCase.contains("storage"))
        assert(html.toLowerCase.contains("environment"))
        assert(html.toLowerCase.contains("executors"))
      }
    }
  }

  ignore("visibility at localhost:4040") {
    withSpark(newSparkContext()) { sc =>
      // test if visible from http://localhost:4040
      eventually(timeout(10 seconds), interval(50 milliseconds)) {
        val html = Source.fromURL("http://localhost:4040").mkString
        assert(html.toLowerCase.contains("stages"))
      }
    }
  }

  test("jetty selects different port under contention") {
    val server = new ServerSocket(0)
    val startPort = server.getLocalPort
    val serverInfo1 = JettyUtils.startJettyServer(
      "0.0.0.0", startPort, Seq[ServletContextHandler](), new SparkConf)
    val serverInfo2 = JettyUtils.startJettyServer(
      "0.0.0.0", startPort, Seq[ServletContextHandler](), new SparkConf)
    // Allow some wiggle room in case ports on the machine are under contention
    val boundPort1 = serverInfo1.boundPort
    val boundPort2 = serverInfo2.boundPort
    assert(boundPort1 != startPort)
    assert(boundPort2 != startPort)
    assert(boundPort1 != boundPort2)
    serverInfo1.server.stop()
    serverInfo2.server.stop()
    server.close()
  }

  test("jetty binds to port 0 correctly") {
    val serverInfo = JettyUtils.startJettyServer(
      "0.0.0.0", 0, Seq[ServletContextHandler](), new SparkConf)
    val server = serverInfo.server
    val boundPort = serverInfo.boundPort
    assert(server.getState === "STARTED")
    assert(boundPort != 0)
    Try { new ServerSocket(boundPort) } match {
      case Success(s) => fail("Port %s doesn't seem used by jetty server".format(boundPort))
      case Failure(e) =>
    }
  }

  test("verify appUIAddress contains the scheme") {
    withSpark(newSparkContext()) { sc =>
      val ui = sc.ui.get
      val uiAddress = ui.appUIAddress
      val uiHostPort = ui.appUIHostPort
      assert(uiAddress.equals("http://" + uiHostPort))
    }
  }

  test("verify appUIAddress contains the port") {
    withSpark(newSparkContext()) { sc =>
      val ui = sc.ui.get
      val splitUIAddress = ui.appUIAddress.split(':')
      val boundPort = ui.boundPort
      assert(splitUIAddress(2).toInt == boundPort)
    }
  }
} 
Example 78
Source File: UIUtilsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import scala.xml.Elem

import org.apache.spark.SparkFunSuite

class UIUtilsSuite extends SparkFunSuite {
  import UIUtils._

  test("makeDescription") {
    verify(
      """test <a href="/link"> text </a>""",
      <span class="description-input">test <a href="/link"> text </a></span>,
      "Correctly formatted text with only anchors and relative links should generate HTML"
    )

    verify(
      """test <a href="/link" text </a>""",
      <span class="description-input">{"""test <a href="/link" text </a>"""}</span>,
      "Badly formatted text should make the description be treated as a streaming instead of HTML"
    )

    verify(
      """test <a href="link"> text </a>""",
      <span class="description-input">{"""test <a href="link"> text </a>"""}</span>,
      "Non-relative links should make the description be treated as a string instead of HTML"
    )

    verify(
      """test<a><img></img></a>""",
      <span class="description-input">{"""test<a><img></img></a>"""}</span>,
      "Non-anchor elements should make the description be treated as a string instead of HTML"
    )

    verify(
      """test <a href="/link"> text </a>""",
      <span class="description-input">test <a href="base/link"> text </a></span>,
      baseUrl = "base",
      errorMsg = "Base URL should be prepended to html links"
    )
  }

  test("SPARK-11906: Progress bar should not overflow because of speculative tasks") {
    val generated = makeProgressBar(2, 3, 0, 0, 4).head.child.filter(_.label == "div")
    val expected = Seq(
      <div class="bar bar-completed" style="width: 75.0%"></div>,
      <div class="bar bar-running" style="width: 25.0%"></div>
    )
    assert(generated.sameElements(expected),
      s"\nRunning progress bar should round down\n\nExpected:\n$expected\nGenerated:\n$generated")
  }

  test("decodeURLParameter (SPARK-12708: Sorting task error in Stages Page when yarn mode.)") {
    val encoded1 = "%252F"
    val decoded1 = "/"
    val encoded2 = "%253Cdriver%253E"
    val decoded2 = "<driver>"

    assert(decoded1 === decodeURLParameter(encoded1))
    assert(decoded2 === decodeURLParameter(encoded2))

    // verify that no affect to decoded URL.
    assert(decoded1 === decodeURLParameter(decoded1))
    assert(decoded2 === decodeURLParameter(decoded2))
  }

  private def verify(
      desc: String, expected: Elem, errorMsg: String = "", baseUrl: String = ""): Unit = {
    val generated = makeDescription(desc, baseUrl)
    assert(generated.sameElements(expected),
      s"\n$errorMsg\n\nExpected:\n$expected\nGenerated:\n$generated")
  }
} 
Example 79
Source File: GenericAvroSerializerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.ByteBuffer

import com.esotericsoftware.kryo.io.{Output, Input}
import org.apache.avro.{SchemaBuilder, Schema}
import org.apache.avro.generic.GenericData.Record

import org.apache.spark.{SparkFunSuite, SharedSparkContext}

class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext {
  conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

  val schema : Schema = SchemaBuilder
    .record("testRecord").fields()
    .requiredString("data")
    .endRecord()
  val record = new Record(schema)
  record.put("data", "test data")

  test("schema compression and decompression") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema))))
  }

  test("record serialization and deserialization") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)

    val outputStream = new ByteArrayOutputStream()
    val output = new Output(outputStream)
    genericSer.serializeDatum(record, output)
    output.flush()
    output.close()

    val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
    assert(genericSer.deserializeDatum(input) === record)
  }

  test("uses schema fingerprint to decrease message size") {
    val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema)

    val output = new Output(new ByteArrayOutputStream())

    val beginningNormalPosition = output.total()
    genericSerFull.serializeDatum(record, output)
    output.flush()
    val normalLength = output.total - beginningNormalPosition

    conf.registerAvroSchemas(schema)
    val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema)
    val beginningFingerprintPosition = output.total()
    genericSerFinger.serializeDatum(record, output)
    val fingerprintLength = output.total - beginningFingerprintPosition

    assert(fingerprintLength < normalLength)
  }

  test("caches previously seen schemas") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    val compressedSchema = genericSer.compress(schema)
    val decompressedSchema = genericSer.decompress(ByteBuffer.wrap(compressedSchema))

    assert(compressedSchema.eq(genericSer.compress(schema)))
    assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema))))
  }
} 
Example 80
Source File: KryoSerializerResizableOutputSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.SparkContext
import org.apache.spark.LocalSparkContext
import org.apache.spark.SparkException


class KryoSerializerResizableOutputSuite extends SparkFunSuite {

  // trial and error showed this will not serialize with 1mb buffer
  val x = (1 to 400000).toArray

  test("kryo without resizable output buffer should fail on large array") {
    val conf = new SparkConf(false)
    conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    conf.set("spark.kryoserializer.buffer", "1m")
    conf.set("spark.kryoserializer.buffer.max", "1m")
    val sc = new SparkContext("local", "test", conf)
    intercept[SparkException](sc.parallelize(x).collect())
    LocalSparkContext.stop(sc)
  }

  test("kryo with resizable output buffer should succeed on large array") {
    val conf = new SparkConf(false)
    conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    conf.set("spark.kryoserializer.buffer", "1m")
    conf.set("spark.kryoserializer.buffer.max", "2m")
    val sc = new SparkContext("local", "test", conf)
    assert(sc.parallelize(x).collect() === x)
    LocalSparkContext.stop(sc)
  }
} 
Example 81
Source File: JavaSerializerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import org.apache.spark.{SparkConf, SparkFunSuite}

class JavaSerializerSuite extends SparkFunSuite {
  test("JavaSerializer instances are serializable") {
    val serializer = new JavaSerializer(new SparkConf())
    val instance = serializer.newInstance()
    instance.deserialize[JavaSerializer](instance.serialize(serializer))
  }

  test("Deserialize object containing a primitive Class as attribute") {
    val serializer = new JavaSerializer(new SparkConf())
    val instance = serializer.newInstance()
    instance.deserialize[JavaSerializer](instance.serialize(new ContainsPrimitiveClass()))
  }
}

private class ContainsPrimitiveClass extends Serializable {
  val intClass = classOf[Int]
  val longClass = classOf[Long]
  val shortClass = classOf[Short]
  val charClass = classOf[Char]
  val doubleClass = classOf[Double]
  val floatClass = classOf[Float]
  val booleanClass = classOf[Boolean]
  val byteClass = classOf[Byte]
  val voidClass = classOf[Void]
} 
Example 82
Source File: ProactiveClosureSerializationSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import org.apache.spark.{SharedSparkContext, SparkException, SparkFunSuite}
import org.apache.spark.rdd.RDD


class UnserializableClass {
  def op[T](x: T): String = x.toString

  def pred[T](x: T): Boolean = x.toString.length % 2 == 0
}

class ProactiveClosureSerializationSuite extends SparkFunSuite with SharedSparkContext {

  def fixture: (RDD[String], UnserializableClass) = {
    (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass)
  }

  test("throws expected serialization exceptions on actions") {
    val (data, uc) = fixture
    val ex = intercept[SparkException] {
      data.map(uc.op(_)).count()
    }
    assert(ex.getMessage.contains("Task not serializable"))
  }

  // There is probably a cleaner way to eliminate boilerplate here, but we're
  // iterating over a map from transformation names to functions that perform that
  // transformation on a given RDD, creating one test case for each

  for (transformation <-
      Map("map" -> xmap _,
          "flatMap" -> xflatMap _,
          "filter" -> xfilter _,
          "mapPartitions" -> xmapPartitions _,
          "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _)) {
    val (name, xf) = transformation

    test(s"$name transformations throw proactive serialization exceptions") {
      val (data, uc) = fixture
      val ex = intercept[SparkException] {
        xf(data, uc)
      }
      assert(ex.getMessage.contains("Task not serializable"),
        s"RDD.$name doesn't proactively throw NotSerializableException")
    }
  }

  private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] =
    x.map(y => uc.op(y))

  private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
    x.flatMap(y => Seq(uc.op(y)))

  private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] =
    x.filter(y => uc.pred(y))

  private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
    x.mapPartitions(_.map(y => uc.op(y)))

  private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
    x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y)))

} 
Example 83
Source File: CoarseGrainedSchedulerBackendSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite}
import org.apache.spark.util.{SerializableBuffer, AkkaUtils}

class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext {

  test("serialized task larger than akka frame size") {
    val conf = new SparkConf
    conf.set("spark.akka.frameSize", "1")
    conf.set("spark.default.parallelism", "1")
    sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf)
    val frameSize = AkkaUtils.maxFrameSizeBytes(sc.conf)
    val buffer = new SerializableBuffer(java.nio.ByteBuffer.allocate(2 * frameSize))
    val larger = sc.parallelize(Seq(buffer))
    val thrown = intercept[SparkException] {
      larger.collect()
    }
    assert(thrown.getMessage.contains("using broadcast variables for large values"))
    val smaller = sc.parallelize(1 to 4).collect()
    assert(smaller.size === 4)
  }

} 
Example 84
Source File: MesosClusterSchedulerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.mesos

import java.util.Date

import org.scalatest.mock.MockitoSugar

import org.apache.spark.deploy.Command
import org.apache.spark.deploy.mesos.MesosDriverDescription
import org.apache.spark.scheduler.cluster.mesos._
import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite}


class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar {

  private val command = new Command("mainClass", Seq("arg"), null, null, null, null)

  test("can queue drivers") {
    val conf = new SparkConf()
    conf.setMaster("mesos://localhost:5050")
    conf.setAppName("spark mesos")
    val scheduler = new MesosClusterScheduler(
      new BlackHoleMesosClusterPersistenceEngineFactory, conf) {
      override def start(): Unit = { ready = true }
    }
    scheduler.start()
    val response = scheduler.submitDriver(
        new MesosDriverDescription("d1", "jar", 1000, 1, true,
          command, Map[String, String](), "s1", new Date()))
    assert(response.success)
    val response2 =
      scheduler.submitDriver(new MesosDriverDescription(
        "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date()))
    assert(response2.success)
    val state = scheduler.getSchedulerState()
    val queuedDrivers = state.queuedDrivers.toList
    assert(queuedDrivers(0).submissionId == response.submissionId)
    assert(queuedDrivers(1).submissionId == response2.submissionId)
  }

  test("can kill queued drivers") {
    val conf = new SparkConf()
    conf.setMaster("mesos://localhost:5050")
    conf.setAppName("spark mesos")
    val scheduler = new MesosClusterScheduler(
      new BlackHoleMesosClusterPersistenceEngineFactory, conf) {
      override def start(): Unit = { ready = true }
    }
    scheduler.start()
    val response = scheduler.submitDriver(
        new MesosDriverDescription("d1", "jar", 1000, 1, true,
          command, Map[String, String](), "s1", new Date()))
    assert(response.success)
    val killResponse = scheduler.killDriver(response.submissionId)
    assert(killResponse.success)
    val state = scheduler.getSchedulerState()
    assert(state.queuedDrivers.isEmpty)
  }
} 
Example 85
Source File: MesosTaskLaunchDataSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster.mesos

import java.nio.ByteBuffer

import org.apache.spark.SparkFunSuite

class MesosTaskLaunchDataSuite extends SparkFunSuite {
  test("serialize and deserialize data must be same") {
    val serializedTask = ByteBuffer.allocate(40)
    (Range(100, 110).map(serializedTask.putInt(_)))
    serializedTask.rewind
    val attemptNumber = 100
    val byteString = MesosTaskLaunchData(serializedTask, attemptNumber).toByteString
    serializedTask.rewind
    val mesosTaskLaunchData = MesosTaskLaunchData.fromByteString(byteString)
    assert(mesosTaskLaunchData.attemptNumber == attemptNumber)
    assert(mesosTaskLaunchData.serializedTask.equals(serializedTask))
  }
} 
Example 86
Source File: SparkListenerWithClusterSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import scala.collection.mutable

import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}

import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.scheduler.cluster.ExecutorInfo


  val WAIT_TIMEOUT_MILLIS = 10000

  before {
    sc = new SparkContext("local-cluster[2,1,1024]", "SparkListenerSuite")
  }

  test("SparkListener sends executor added message") {
    val listener = new SaveExecutorInfo
    sc.addSparkListener(listener)

    // This test will check if the number of executors received by "SparkListener" is same as the
    // number of all executors, so we need to wait until all executors are up
    sc.jobProgressListener.waitUntilExecutorsUp(2, 60000)

    val rdd1 = sc.parallelize(1 to 100, 4)
    val rdd2 = rdd1.map(_.toString)
    rdd2.setName("Target RDD")
    rdd2.count()

    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
    assert(listener.addedExecutorInfo.size == 2)
    assert(listener.addedExecutorInfo("0").totalCores == 1)
    assert(listener.addedExecutorInfo("1").totalCores == 1)
  }

  private class SaveExecutorInfo extends SparkListener {
    val addedExecutorInfo = mutable.Map[String, ExecutorInfo]()

    override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
      addedExecutorInfo(executor.executorId) = executor.executorInfo
    }
  }
} 
Example 87
Source File: OutputCommitCoordinatorIntegrationSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.{Span, Seconds}

import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext}
import org.apache.spark.util.Utils


class OutputCommitCoordinatorIntegrationSuite
  extends SparkFunSuite
  with LocalSparkContext
  with Timeouts {

  override def beforeAll(): Unit = {
    super.beforeAll()
    val conf = new SparkConf()
      .set("master", "local[2,4]")
      .set("spark.speculation", "true")
      .set("spark.hadoop.mapred.output.committer.class",
        classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName)
    sc = new SparkContext("local[2, 4]", "test", conf)
  }

  test("exception thrown in OutputCommitter.commitTask()") {
    // Regression test for SPARK-10381
    failAfter(Span(60, Seconds)) {
      val tempDir = Utils.createTempDir()
      try {
        sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out")
      } finally {
        Utils.deleteRecursively(tempDir)
      }
    }
  }
}

private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter {
  override def commitTask(context: TaskAttemptContext): Unit = {
    val ctx = TaskContext.get()
    if (ctx.attemptNumber < 1) {
      throw new java.io.FileNotFoundException("Intentional exception")
    }
    super.commitTask(context)
  }
} 
Example 88
Source File: CompletionIteratorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import org.apache.spark.SparkFunSuite

class CompletionIteratorSuite extends SparkFunSuite {
  test("basic test") {
    var numTimesCompleted = 0
    val iter = List(1, 2, 3).iterator
    val completionIter = CompletionIterator[Int, Iterator[Int]](iter, { numTimesCompleted += 1 })

    assert(completionIter.hasNext)
    assert(completionIter.next() === 1)
    assert(numTimesCompleted === 0)

    assert(completionIter.hasNext)
    assert(completionIter.next() === 2)
    assert(numTimesCompleted === 0)

    assert(completionIter.hasNext)
    assert(completionIter.next() === 3)
    assert(numTimesCompleted === 0)

    assert(!completionIter.hasNext)
    assert(numTimesCompleted === 1)

    // SPARK-4264: Calling hasNext should not trigger the completion callback again.
    assert(!completionIter.hasNext)
    assert(numTimesCompleted === 1)
  }
} 
Example 89
Source File: NextIteratorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import java.util.NoSuchElementException

import scala.collection.mutable.Buffer

import org.scalatest.Matchers

import org.apache.spark.SparkFunSuite

class NextIteratorSuite extends SparkFunSuite with Matchers {
  test("one iteration") {
    val i = new StubIterator(Buffer(1))
    i.hasNext should be (true)
    i.next should be (1)
    i.hasNext should be (false)
    intercept[NoSuchElementException] { i.next() }
  }

  test("two iterations") {
    val i = new StubIterator(Buffer(1, 2))
    i.hasNext should be (true)
    i.next should be (1)
    i.hasNext should be (true)
    i.next should be (2)
    i.hasNext should be (false)
    intercept[NoSuchElementException] { i.next() }
  }

  test("empty iteration") {
    val i = new StubIterator(Buffer())
    i.hasNext should be (false)
    intercept[NoSuchElementException] { i.next() }
  }

  test("close is called once for empty iterations") {
    val i = new StubIterator(Buffer())
    i.hasNext should be (false)
    i.hasNext should be (false)
    i.closeCalled should be (1)
  }

  test("close is called once for non-empty iterations") {
    val i = new StubIterator(Buffer(1, 2))
    i.next should be (1)
    i.next should be (2)
    // close isn't called until we check for the next element
    i.closeCalled should be (0)
    i.hasNext should be (false)
    i.closeCalled should be (1)
    i.hasNext should be (false)
    i.closeCalled should be (1)
  }

  class StubIterator(ints: Buffer[Int])  extends NextIterator[Int] {
    var closeCalled = 0

    override def getNext(): Int = {
      if (ints.size == 0) {
        finished = true
        0
      } else {
        ints.remove(0)
      }
    }

    override def close() {
      closeCalled += 1
    }
  }
} 
Example 90
Source File: PrimitiveVectorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util.collection

import org.apache.spark.SparkFunSuite
import org.apache.spark.util.SizeEstimator

class PrimitiveVectorSuite extends SparkFunSuite {

  test("primitive value") {
    val vector = new PrimitiveVector[Int]

    for (i <- 0 until 1000) {
      vector += i
      assert(vector(i) === i)
    }

    assert(vector.size === 1000)
    assert(vector.size == vector.length)
    intercept[IllegalArgumentException] {
      vector(1000)
    }

    for (i <- 0 until 1000) {
      assert(vector(i) == i)
    }
  }

  test("non-primitive value") {
    val vector = new PrimitiveVector[String]

    for (i <- 0 until 1000) {
      vector += i.toString
      assert(vector(i) === i.toString)
    }

    assert(vector.size === 1000)
    assert(vector.size == vector.length)
    intercept[IllegalArgumentException] {
      vector(1000)
    }

    for (i <- 0 until 1000) {
      assert(vector(i) == i.toString)
    }
  }

  test("ideal growth") {
    val vector = new PrimitiveVector[Long](initialSize = 1)
    vector += 1
    for (i <- 1 until 1024) {
      vector += i
      assert(vector.size === i + 1)
      assert(vector.capacity === Integer.highestOneBit(i) * 2)
    }
    assert(vector.capacity === 1024)
    vector += 1024
    assert(vector.capacity === 2048)
  }

  test("ideal size") {
    val vector = new PrimitiveVector[Long](8192)
    for (i <- 0 until 8192) {
      vector += i
    }
    assert(vector.size === 8192)
    assert(vector.capacity === 8192)
    val actualSize = SizeEstimator.estimate(vector)
    val expectedSize = 8192 * 8
    // Make sure we are not allocating a significant amount of memory beyond our expected.
    // Due to specialization wonkiness, we need to ensure we don't have 2 copies of the array.
    assert(actualSize < expectedSize * 1.1)
  }

  test("resizing") {
    val vector = new PrimitiveVector[Long]
    for (i <- 0 until 4097) {
      vector += i
    }
    assert(vector.size === 4097)
    assert(vector.capacity === 8192)
    vector.trim()
    assert(vector.size === 4097)
    assert(vector.capacity === 4097)
    vector.resize(5000)
    assert(vector.size === 4097)
    assert(vector.capacity === 5000)
    vector.resize(4000)
    assert(vector.size === 4000)
    assert(vector.capacity === 4000)
    vector.resize(5000)
    assert(vector.size === 4000)
    assert(vector.capacity === 5000)
    for (i <- 0 until 4000) {
      assert(vector(i) == i)
    }
    intercept[IllegalArgumentException] {
      vector(4000)
    }
  }
} 
Example 91
Source File: CompactBufferSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util.collection

import org.apache.spark.SparkFunSuite

class CompactBufferSuite extends SparkFunSuite {
  test("empty buffer") {
    val b = new CompactBuffer[Int]
    assert(b.size === 0)
    assert(b.iterator.toList === Nil)
    assert(b.size === 0)
    assert(b.iterator.toList === Nil)
    intercept[IndexOutOfBoundsException] { b(0) }
    intercept[IndexOutOfBoundsException] { b(1) }
    intercept[IndexOutOfBoundsException] { b(2) }
    intercept[IndexOutOfBoundsException] { b(-1) }
  }

  test("basic inserts") {
    val b = new CompactBuffer[Int]
    assert(b.size === 0)
    assert(b.iterator.toList === Nil)
    for (i <- 0 until 1000) {
      b += i
      assert(b.size === i + 1)
      assert(b(i) === i)
    }
    assert(b.iterator.toList === (0 until 1000).toList)
    assert(b.iterator.toList === (0 until 1000).toList)
    assert(b.size === 1000)
  }

  test("adding sequences") {
    val b = new CompactBuffer[Int]
    assert(b.size === 0)
    assert(b.iterator.toList === Nil)

    // Add some simple lists and iterators
    b ++= List(0)
    assert(b.size === 1)
    assert(b.iterator.toList === List(0))
    b ++= Iterator(1)
    assert(b.size === 2)
    assert(b.iterator.toList === List(0, 1))
    b ++= List(2)
    assert(b.size === 3)
    assert(b.iterator.toList === List(0, 1, 2))
    b ++= Iterator(3, 4, 5, 6, 7, 8, 9)
    assert(b.size === 10)
    assert(b.iterator.toList === (0 until 10).toList)

    // Add CompactBuffers
    val b2 = new CompactBuffer[Int]
    b2 ++= 0 until 10
    b ++= b2
    assert(b.iterator.toList === (1 to 2).flatMap(i => 0 until 10).toList)
    b ++= b2
    assert(b.iterator.toList === (1 to 3).flatMap(i => 0 until 10).toList)
    b ++= b2
    assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList)

    // Add some small CompactBuffers as well
    val b3 = new CompactBuffer[Int]
    b ++= b3
    assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList)
    b3 += 0
    b ++= b3
    assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0))
    b3 += 1
    b ++= b3
    assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0, 0, 1))
    b3 += 2
    b ++= b3
    assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0, 0, 1, 0, 1, 2))
  }

  test("adding the same buffer to itself") {
    val b = new CompactBuffer[Int]
    assert(b.size === 0)
    assert(b.iterator.toList === Nil)
    b += 1
    assert(b.toList === List(1))
    for (j <- 1 until 8) {
      b ++= b
      assert(b.size === (1 << j))
      assert(b.iterator.toList === (1 to (1 << j)).map(i => 1).toList)
    }
  }
} 
Example 92
Source File: PrefixComparatorsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util.collection.unsafe.sort

import com.google.common.primitives.UnsignedBytes
import org.scalatest.prop.PropertyChecks
import org.apache.spark.SparkFunSuite
import org.apache.spark.unsafe.types.UTF8String

class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {

  test("String prefix comparator") {

    def testPrefixComparison(s1: String, s2: String): Unit = {
      val utf8string1 = UTF8String.fromString(s1)
      val utf8string2 = UTF8String.fromString(s2)
      val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1)
      val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2)
      val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix)

      val cmp = UnsignedBytes.lexicographicalComparator().compare(
        utf8string1.getBytes.take(8), utf8string2.getBytes.take(8))

      assert(
        (prefixComparisonResult == 0 && cmp == 0) ||
        (prefixComparisonResult < 0 && s1.compareTo(s2) < 0) ||
        (prefixComparisonResult > 0 && s1.compareTo(s2) > 0))
    }

    // scalastyle:off
    val regressionTests = Table(
      ("s1", "s2"),
      ("abc", "世界"),
      ("你好", "世界"),
      ("你好123", "你好122")
    )
    // scalastyle:on

    forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
    forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
  }

  test("Binary prefix comparator") {

     def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
      for (i <- 0 until x.length; if i < y.length) {
        val res = x(i).compare(y(i))
        if (res != 0) return res
      }
      x.length - y.length
    }

    def testPrefixComparison(x: Array[Byte], y: Array[Byte]): Unit = {
      val s1Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(x)
      val s2Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(y)
      val prefixComparisonResult =
        PrefixComparators.BINARY.compare(s1Prefix, s2Prefix)
      assert(
        (prefixComparisonResult == 0) ||
        (prefixComparisonResult < 0 && compareBinary(x, y) < 0) ||
        (prefixComparisonResult > 0 && compareBinary(x, y) > 0))
    }

    // scalastyle:off
    val regressionTests = Table(
      ("s1", "s2"),
      ("abc", "世界"),
      ("你好", "世界"),
      ("你好123", "你好122")
    )
    // scalastyle:on

    forAll (regressionTests) { (s1: String, s2: String) =>
      testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8"))
    }
    forAll { (s1: String, s2: String) =>
      testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8"))
    }
  }

  test("double prefix comparator handles NaNs properly") {
    val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
    val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
    assert(nan1.isNaN)
    assert(nan2.isNaN)
    val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1)
    val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2)
    assert(nan1Prefix === nan2Prefix)
    val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue)
    assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
  }

} 
Example 93
Source File: ResetSystemProperties.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import java.util.Properties

import org.apache.commons.lang3.SerializationUtils
import org.scalatest.{BeforeAndAfterEach, Suite}

import org.apache.spark.SparkFunSuite


private[spark] trait ResetSystemProperties extends BeforeAndAfterEach { this: Suite =>
  var oldProperties: Properties = null

  override def beforeEach(): Unit = {
    // we need SerializationUtils.clone instead of `new Properties(System.getProperties()` because
    // the later way of creating a copy does not copy the properties but it initializes a new
    // Properties object with the given properties as defaults. They are not recognized at all
    // by standard Scala wrapper over Java Properties then.
    oldProperties = SerializationUtils.clone(System.getProperties)
    super.beforeEach()
  }

  override def afterEach(): Unit = {
    try {
      super.afterEach()
    } finally {
      System.setProperties(oldProperties)
      oldProperties = null
    }
  }
} 
Example 94
Source File: SamplingUtilsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util.random

import scala.util.Random

import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}

import org.apache.spark.SparkFunSuite

class SamplingUtilsSuite extends SparkFunSuite {

  test("reservoirSampleAndCount") {
    val input = Seq.fill(100)(Random.nextInt())

    // input size < k
    val (sample1, count1) = SamplingUtils.reservoirSampleAndCount(input.iterator, 150)
    assert(count1 === 100)
    assert(input === sample1.toSeq)

    // input size == k
    val (sample2, count2) = SamplingUtils.reservoirSampleAndCount(input.iterator, 100)
    assert(count2 === 100)
    assert(input === sample2.toSeq)

    // input size > k
    val (sample3, count3) = SamplingUtils.reservoirSampleAndCount(input.iterator, 10)
    assert(count3 === 100)
    assert(sample3.length === 10)
  }

  test("computeFraction") {
    // test that the computed fraction guarantees enough data points
    // in the sample with a failure rate <= 0.0001
    val n = 100000

    for (s <- 1 to 15) {
      val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
      val poisson = new PoissonDistribution(frac * n)
      assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
    }
    for (s <- List(20, 100, 1000)) {
      val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
      val poisson = new PoissonDistribution(frac * n)
      assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
    }
    for (s <- List(1, 10, 100, 1000)) {
      val frac = SamplingUtils.computeFractionForSampleSize(s, n, false)
      val binomial = new BinomialDistribution(n, frac)
      assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
    }
  }
} 
Example 95
Source File: XORShiftRandomSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util.random

import org.scalatest.Matchers

import org.apache.commons.math3.stat.inference.ChiSquareTest

import org.apache.spark.SparkFunSuite
import org.apache.spark.util.Utils.times

import scala.language.reflectiveCalls

class XORShiftRandomSuite extends SparkFunSuite with Matchers {

  private def fixture = new {
    val seed = 1L
    val xorRand = new XORShiftRandom(seed)
    val hundMil = 1e8.toInt
  }

  
    val chiTest = new ChiSquareTest
    assert(chiTest.chiSquareTest(bins, 0.05) === false)
  }

  test ("XORShift with zero seed") {
    val random = new XORShiftRandom(0L)
    assert(random.nextInt() != 0)
  }

  test ("hashSeed has random bits throughout") {
    val totalBitCount = (0 until 10).map { seed =>
      val hashed = XORShiftRandom.hashSeed(seed)
      val bitCount = java.lang.Long.bitCount(hashed)
      // make sure we have roughly equal numbers of 0s and 1s.  Mostly just check that we
      // don't have all 0s or 1s in the high bits
      bitCount should be > 20
      bitCount should be < 44
      bitCount
    }.sum
    // and over all the seeds, very close to equal numbers of 0s & 1s
    totalBitCount should be > (32 * 10 - 30)
    totalBitCount should be < (32 * 10 + 30)
  }
} 
Example 96
Source File: DistributionSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import org.scalatest.Matchers

import org.apache.spark.SparkFunSuite



class DistributionSuite extends SparkFunSuite with Matchers {
  test("summary") {
    val d = new Distribution((1 to 100).toArray.map{_.toDouble})
    val stats = d.statCounter
    stats.count should be (100)
    stats.mean should be (50.5)
    stats.sum should be (50 * 101)

    val quantiles = d.getQuantiles()
    quantiles(0) should be (1)
    quantiles(1) should be (26)
    quantiles(2) should be (51)
    quantiles(3) should be (76)
    quantiles(4) should be (100)
  }
} 
Example 97
Source File: VectorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import scala.util.Random

import org.apache.spark.SparkFunSuite


@deprecated("suppress compile time deprecation warning", "1.0.0")
class VectorSuite extends SparkFunSuite {

  def verifyVector(vector: Vector, expectedLength: Int): Unit = {
    assert(vector.length == expectedLength)
    assert(vector.elements.min > 0.0)
    assert(vector.elements.max < 1.0)
  }

  test("random with default random number generator") {
    val vector100 = Vector.random(100)
    verifyVector(vector100, 100)
  }

  test("random with given random number generator") {
    val vector100 = Vector.random(100, new Random(100))
    verifyVector(vector100, 100)
  }
} 
Example 98
Source File: ByteArrayChunkOutputStreamSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util.io

import scala.util.Random

import org.apache.spark.SparkFunSuite


class ByteArrayChunkOutputStreamSuite extends SparkFunSuite {

  test("empty output") {
    val o = new ByteArrayChunkOutputStream(1024)
    assert(o.toArrays.length === 0)
  }

  test("write a single byte") {
    val o = new ByteArrayChunkOutputStream(1024)
    o.write(10)
    assert(o.toArrays.length === 1)
    assert(o.toArrays.head.toSeq === Seq(10.toByte))
  }

  test("write a single near boundary") {
    val o = new ByteArrayChunkOutputStream(10)
    o.write(new Array[Byte](9))
    o.write(99)
    assert(o.toArrays.length === 1)
    assert(o.toArrays.head(9) === 99.toByte)
  }

  test("write a single at boundary") {
    val o = new ByteArrayChunkOutputStream(10)
    o.write(new Array[Byte](10))
    o.write(99)
    assert(o.toArrays.length === 2)
    assert(o.toArrays(1).length === 1)
    assert(o.toArrays(1)(0) === 99.toByte)
  }

  test("single chunk output") {
    val ref = new Array[Byte](8)
    Random.nextBytes(ref)
    val o = new ByteArrayChunkOutputStream(10)
    o.write(ref)
    val arrays = o.toArrays
    assert(arrays.length === 1)
    assert(arrays.head.length === ref.length)
    assert(arrays.head.toSeq === ref.toSeq)
  }

  test("single chunk output at boundary size") {
    val ref = new Array[Byte](10)
    Random.nextBytes(ref)
    val o = new ByteArrayChunkOutputStream(10)
    o.write(ref)
    val arrays = o.toArrays
    assert(arrays.length === 1)
    assert(arrays.head.length === ref.length)
    assert(arrays.head.toSeq === ref.toSeq)
  }

  test("multiple chunk output") {
    val ref = new Array[Byte](26)
    Random.nextBytes(ref)
    val o = new ByteArrayChunkOutputStream(10)
    o.write(ref)
    val arrays = o.toArrays
    assert(arrays.length === 3)
    assert(arrays(0).length === 10)
    assert(arrays(1).length === 10)
    assert(arrays(2).length === 6)

    assert(arrays(0).toSeq === ref.slice(0, 10))
    assert(arrays(1).toSeq === ref.slice(10, 20))
    assert(arrays(2).toSeq === ref.slice(20, 26))
  }

  test("multiple chunk output at boundary size") {
    val ref = new Array[Byte](30)
    Random.nextBytes(ref)
    val o = new ByteArrayChunkOutputStream(10)
    o.write(ref)
    val arrays = o.toArrays
    assert(arrays.length === 3)
    assert(arrays(0).length === 10)
    assert(arrays(1).length === 10)
    assert(arrays(2).length === 10)

    assert(arrays(0).toSeq === ref.slice(0, 10))
    assert(arrays(1).toSeq === ref.slice(10, 20))
    assert(arrays(2).toSeq === ref.slice(20, 30))
  }
} 
Example 99
Source File: DiskBlockManagerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileWriter}

import scala.language.reflectiveCalls

import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils

class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
  private val testConf = new SparkConf(false)
  private var rootDir0: File = _
  private var rootDir1: File = _
  private var rootDirs: String = _

  val blockManager = mock(classOf[BlockManager])
  when(blockManager.conf).thenReturn(testConf)
  var diskBlockManager: DiskBlockManager = _

  override def beforeAll() {
    super.beforeAll()
    rootDir0 = Utils.createTempDir()
    rootDir1 = Utils.createTempDir()
    rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
  }

  override def afterAll() {
    super.afterAll()
    Utils.deleteRecursively(rootDir0)
    Utils.deleteRecursively(rootDir1)
  }

  override def beforeEach() {
    val conf = testConf.clone
    conf.set("spark.local.dir", rootDirs)
    diskBlockManager = new DiskBlockManager(blockManager, conf)
  }

  override def afterEach() {
    diskBlockManager.stop()
  }

  test("basic block creation") {
    val blockId = new TestBlockId("test")
    val newFile = diskBlockManager.getFile(blockId)
    writeToFile(newFile, 10)
    assert(diskBlockManager.containsBlock(blockId))
    newFile.delete()
    assert(!diskBlockManager.containsBlock(blockId))
  }

  test("enumerating blocks") {
    val ids = (1 to 100).map(i => TestBlockId("test_" + i))
    val files = ids.map(id => diskBlockManager.getFile(id))
    files.foreach(file => writeToFile(file, 10))
    assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
  }

  def writeToFile(file: File, numBytes: Int) {
    val writer = new FileWriter(file, true)
    for (i <- 0 until numBytes) writer.write(i)
    writer.close()
  }
} 
Example 100
Source File: BlockIdSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import org.apache.spark.SparkFunSuite

class BlockIdSuite extends SparkFunSuite {
  def assertSame(id1: BlockId, id2: BlockId) {
    assert(id1.name === id2.name)
    assert(id1.hashCode === id2.hashCode)
    assert(id1 === id2)
  }

  def assertDifferent(id1: BlockId, id2: BlockId) {
    assert(id1.name != id2.name)
    assert(id1.hashCode != id2.hashCode)
    assert(id1 != id2)
  }

  test("test-bad-deserialization") {
    try {
      // Try to deserialize an invalid block id.
      BlockId("myblock")
      fail()
    } catch {
      case e: IllegalStateException => // OK
      case _: Throwable => fail()
    }
  }

  test("rdd") {
    val id = RDDBlockId(1, 2)
    assertSame(id, RDDBlockId(1, 2))
    assertDifferent(id, RDDBlockId(1, 1))
    assert(id.name === "rdd_1_2")
    assert(id.asRDDId.get.rddId === 1)
    assert(id.asRDDId.get.splitIndex === 2)
    assert(id.isRDD)
    assertSame(id, BlockId(id.toString))
  }

  test("shuffle") {
    val id = ShuffleBlockId(1, 2, 3)
    assertSame(id, ShuffleBlockId(1, 2, 3))
    assertDifferent(id, ShuffleBlockId(3, 2, 3))
    assert(id.name === "shuffle_1_2_3")
    assert(id.asRDDId === None)
    assert(id.shuffleId === 1)
    assert(id.mapId === 2)
    assert(id.reduceId === 3)
    assert(id.isShuffle)
    assertSame(id, BlockId(id.toString))
  }

  test("broadcast") {
    val id = BroadcastBlockId(42)
    assertSame(id, BroadcastBlockId(42))
    assertDifferent(id, BroadcastBlockId(123))
    assert(id.name === "broadcast_42")
    assert(id.asRDDId === None)
    assert(id.broadcastId === 42)
    assert(id.isBroadcast)
    assertSame(id, BlockId(id.toString))
  }

  test("taskresult") {
    val id = TaskResultBlockId(60)
    assertSame(id, TaskResultBlockId(60))
    assertDifferent(id, TaskResultBlockId(61))
    assert(id.name === "taskresult_60")
    assert(id.asRDDId === None)
    assert(id.taskId === 60)
    assert(!id.isRDD)
    assertSame(id, BlockId(id.toString))
  }

  test("stream") {
    val id = StreamBlockId(1, 100)
    assertSame(id, StreamBlockId(1, 100))
    assertDifferent(id, StreamBlockId(2, 101))
    assert(id.name === "input-1-100")
    assert(id.asRDDId === None)
    assert(id.streamId === 1)
    assert(id.uniqueId === 100)
    assert(!id.isBroadcast)
    assertSame(id, BlockId(id.toString))
  }

  test("test") {
    val id = TestBlockId("abc")
    assertSame(id, TestBlockId("abc"))
    assertDifferent(id, TestBlockId("ab"))
    assert(id.name === "test_abc")
    assert(id.asRDDId === None)
    assert(id.id === "abc")
    assert(!id.isShuffle)
    assertSame(id, BlockId(id.toString))
  }
} 
Example 101
Source File: LocalDirsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.File

import org.apache.spark.util.Utils
import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.SparkConfWithEnv


class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter {

  before {
    Utils.clearLocalRootDirs()
  }

  test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") {
    // Regression test for SPARK-2974
    assert(!new File("/NONEXISTENT_DIR").exists())
    val conf = new SparkConf(false)
      .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}")
    assert(new File(Utils.getLocalDir(conf)).exists())
  }

  test("SPARK_LOCAL_DIRS override also affects driver") {
    // Regression test for SPARK-2975
    assert(!new File("/NONEXISTENT_DIR").exists())
    // spark.local.dir only contains invalid directories, but that's not a problem since
    // SPARK_LOCAL_DIRS will override it on both the driver and workers:
    val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir")))
      .set("spark.local.dir", "/NONEXISTENT_PATH")
    assert(new File(Utils.getLocalDir(conf)).exists())
  }

} 
Example 102
Source File: PartitionwiseSampledRDDSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import org.apache.spark.{SharedSparkContext, SparkFunSuite}
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler}


class MockSampler extends RandomSampler[Long, Long] {

  private var s: Long = _

  override def setSeed(seed: Long) {
    s = seed
  }

  override def sample(items: Iterator[Long]): Iterator[Long] = {
    Iterator(s)
  }

  override def clone: MockSampler = new MockSampler
}

class PartitionwiseSampledRDDSuite extends SparkFunSuite with SharedSparkContext {

  test("seed distribution") {
    val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2)
    val sampler = new MockSampler
    val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, false, 0L)
    assert(sample.distinct().count == 2, "Seeds must be different.")
  }

  test("concurrency") {
    // SPARK-2251: zip with self computes each partition twice.
    // We want to make sure there are no concurrency issues.
    val rdd = sc.parallelize(0 until 111, 10)
    for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) {
      val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler, true)
      sampled.zip(sampled).count()
    }
  }
} 
Example 103
Source File: PartitionPruningRDDSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import org.apache.spark.{Partition, SharedSparkContext, SparkFunSuite, TaskContext}

class PartitionPruningRDDSuite extends SparkFunSuite with SharedSparkContext {

  test("Pruned Partitions inherit locality prefs correctly") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
        Array[Partition](
          new TestPartition(0, 1),
          new TestPartition(1, 1),
          new TestPartition(2, 1))
      }

      def compute(split: Partition, context: TaskContext) = {
        Iterator()
      }
    }
    val prunedRDD = PartitionPruningRDD.create(rdd, _ == 2)
    assert(prunedRDD.partitions.length == 1)
    val p = prunedRDD.partitions(0)
    assert(p.index == 0)
    assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)
  }


  test("Pruned Partitions can be unioned ") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
        Array[Partition](
          new TestPartition(0, 4),
          new TestPartition(1, 5),
          new TestPartition(2, 6))
      }

      def compute(split: Partition, context: TaskContext) = {
        List(split.asInstanceOf[TestPartition].testValue).iterator
      }
    }
    val prunedRDD1 = PartitionPruningRDD.create(rdd, _ == 0)


    val prunedRDD2 = PartitionPruningRDD.create(rdd, _ == 2)

    val merged = prunedRDD1 ++ prunedRDD2
    assert(merged.count() == 2)
    val take = merged.take(2)
    assert(take.apply(0) == 4)
    assert(take.apply(1) == 6)
  }
}

class TestPartition(i: Int, value: Int) extends Partition with Serializable {
  def index: Int = i
  def testValue: Int = this.value
} 
Example 104
Source File: JdbcRDDSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.sql._

import org.scalatest.BeforeAndAfter

import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.util.Utils

class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {

  before {
    Utils.classForName("org.apache.derby.jdbc.EmbeddedDriver")
    val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true")
    try {

      try {
        val create = conn.createStatement
        create.execute("""
          CREATE TABLE FOO(
            ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1),
            DATA INTEGER
          )""")
        create.close()
        val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)")
        (1 to 100).foreach { i =>
          insert.setInt(1, i * 2)
          insert.executeUpdate
        }
        insert.close()
      } catch {
        case e: SQLException if e.getSQLState == "X0Y32" =>
        // table exists
      }

      try {
        val create = conn.createStatement
        create.execute("CREATE TABLE BIGINT_TEST(ID BIGINT NOT NULL, DATA INTEGER)")
        create.close()
        val insert = conn.prepareStatement("INSERT INTO BIGINT_TEST VALUES(?,?)")
        (1 to 100).foreach { i =>
          insert.setLong(1, 100000000000000000L +  4000000000000000L * i)
          insert.setInt(2, i)
          insert.executeUpdate
        }
        insert.close()
      } catch {
        case e: SQLException if e.getSQLState == "X0Y32" =>
        // table exists
      }

    } finally {
      conn.close()
    }
  }

  test("basic functionality") {
    sc = new SparkContext("local", "test")
    val rdd = new JdbcRDD(
      sc,
      () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
      "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
      1, 100, 3,
      (r: ResultSet) => { r.getInt(1) } ).cache()

    assert(rdd.count === 100)
    assert(rdd.reduce(_ + _) === 10100)
  }

  test("large id overflow") {
    sc = new SparkContext("local", "test")
    val rdd = new JdbcRDD(
      sc,
      () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
      "SELECT DATA FROM BIGINT_TEST WHERE ? <= ID AND ID <= ?",
      1131544775L, 567279358897692673L, 20,
      (r: ResultSet) => { r.getInt(1) } ).cache()
    assert(rdd.count === 100)
    assert(rdd.reduce(_ + _) === 5050)
  }

  after {
    try {
      DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;shutdown=true")
    } catch {
      case se: SQLException if se.getSQLState == "08006" =>
        // Normal single database shutdown
        // https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html
    }
  }
} 
Example 105
Source File: ZippedPartitionsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import org.apache.spark.{SharedSparkContext, SparkFunSuite}

object ZippedPartitionsSuite {
  def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = {
    Iterator(i.toArray.size, s.toArray.size, d.toArray.size)
  }
}

class ZippedPartitionsSuite extends SparkFunSuite with SharedSparkContext {
  test("print sizes") {
    val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2)
    val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2)
    val data3 = sc.makeRDD(Array(1.0, 2.0), 2)

    val zippedRDD = data1.zipPartitions(data2, data3)(ZippedPartitionsSuite.procZippedData)

    val obtainedSizes = zippedRDD.collect()
    val expectedSizes = Array(2, 3, 1, 2, 3, 1)
    assert(obtainedSizes.size == 6)
    assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2))
  }
} 
Example 106
Source File: HiveTestTrait.scala    From cloud-integration   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import java.io.File

import com.cloudera.spark.cloud.ObjectStoreConfigurations
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.{SparkSession, SQLContext, SQLImplicits}
import org.apache.spark.sql.hive.test.TestHiveContext
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils


trait HiveTestTrait extends SparkFunSuite with BeforeAndAfterAll {
//  override protected val enableAutoThreadAudit = false
  protected var hiveContext: HiveInstanceForTests = _
  protected var spark: SparkSession = _


  protected override def beforeAll(): Unit = {
    super.beforeAll()
    // set up spark and hive context
    hiveContext = new HiveInstanceForTests()
    spark = hiveContext.sparkSession
  }

  protected override def afterAll(): Unit = {
    try {
      SparkSession.clearActiveSession()

      if (hiveContext != null) {
        hiveContext.reset()
        hiveContext = null
      }
      if (spark != null) {
        spark.close()
        spark = null
      }
    } finally {
      super.afterAll()
    }
  }

}

class HiveInstanceForTests
  extends TestHiveContext(
    new SparkContext(
      System.getProperty("spark.sql.test.master", "local[1]"),
      "TestSQLContext",
      new SparkConf()
        .setAll(ObjectStoreConfigurations.RW_TEST_OPTIONS)
        .set("spark.sql.warehouse.dir",
          TestSetup.makeWarehouseDir().toURI.getPath)
    )
  ) {

}




object TestSetup {

  def makeWarehouseDir(): File = {
    val warehouseDir = Utils.createTempDir(namePrefix = "warehouse")
    warehouseDir.delete()
    warehouseDir
  }
} 
Example 107
Source File: PulsarConfigUpdaterSuite.scala    From pulsar-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.pulsar

import org.apache.spark.SparkFunSuite
import org.scalatest.BeforeAndAfterEach

class PulsarConfigUpdaterSuite extends SparkFunSuite with BeforeAndAfterEach {
  private val testModule = "testModule"
  private val testKey = "testKey"
  private val testValue = "testValue"
  private val otherTestValue = "otherTestValue"

  test("set should always set value") {
    val params = Map.empty[String, String]

    val updatedParams = PulsarConfigUpdater(testModule, params)
      .set(testKey, testValue)
      .build()

    assert(updatedParams.size() === 1)
    assert(updatedParams.get(testKey) === testValue)
  }

  test("setIfUnset without existing key should set value") {
    val params = Map.empty[String, String]

    val updatedParams = PulsarConfigUpdater(testModule, params)
      .setIfUnset(testKey, testValue)
      .build()

    assert(updatedParams.size() === 1)
    assert(updatedParams.get(testKey) === testValue)
  }

  test("setIfUnset with existing key should not set value") {
    val params = Map[String, String](testKey -> testValue)

    val updatedParams = PulsarConfigUpdater(testModule, params)
      .setIfUnset(testKey, otherTestValue)
      .build()

    assert(updatedParams.size() === 1)
    assert(updatedParams.get(testKey) === testValue)
  }

} 
Example 108
Source File: MergeClauseSuite.scala    From spark-acid   with Apache License 2.0 5 votes vote down vote up
package com.qubole.spark.hiveacid.merge

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{AnalysisException, functions}

class MergeClauseSuite extends SparkFunSuite {
  def insertClause(addCondition : Boolean = true): MergeWhenNotInsert = {
    if (addCondition) {
      MergeWhenNotInsert(Some(functions.expr("x > 2").expr),
        Seq(functions.col("x").expr, functions.col("y").expr))
    }
    else {
      MergeWhenNotInsert(None,
        Seq(functions.col("x").expr, functions.col("y").expr))
    }
  }

  def updateClause(addCondition : Boolean = true): MergeWhenUpdateClause = {
    if (addCondition) {
      val updateCondition = Some(functions.expr("a > 2").expr)
      MergeWhenUpdateClause(updateCondition,
        Map("b" -> functions.lit(3).expr), isStar = false)
    } else {
      MergeWhenUpdateClause(None,
        Map("b" -> functions.lit(3).expr), isStar = false)
    }
  }

  def deleteClause(addCondition : Boolean = true): MergeWhenDelete = {
    if (addCondition) {
      MergeWhenDelete(Some(functions.expr("a < 1").expr))
    } else {
      MergeWhenDelete(None)
    }
  }

  test("Validate MergeClauses") {
    val clauses = Seq(insertClause(), updateClause(), deleteClause())
    MergeWhenClause.validate(clauses)
  }

  test("Invalid MergeClause cases") {
    val invalidMerge = "MERGE Validation Error: "

    //empty clauses
    checkInvalidMergeClause(invalidMerge + MergeWhenClause.atleastOneClauseError, Seq())

    // multi update or insert clauses
    val multiUpdateClauses = Seq(updateClause(), updateClause(), insertClause())
    checkInvalidMergeClause(invalidMerge + MergeWhenClause.justOneClausePerTypeError, multiUpdateClauses)

    // multi match clauses with first clause without condition
    val invalidMultiMatch = Seq(updateClause(false), deleteClause())
    checkInvalidMergeClause(invalidMerge + MergeWhenClause.matchClauseConditionError, invalidMultiMatch)

    // invalid Update Clause
    val invalidUpdateClause = MergeWhenUpdateClause(None, Map(), isStar = false)
    val thrown = intercept[IllegalArgumentException] {
      MergeWhenClause.validate(Seq(invalidUpdateClause))
    }
    assert(thrown.getMessage === "UPDATE Clause in MERGE should have one or more SET Values")
  }

  private def checkInvalidMergeClause(invalidMessage: String, multiUpdateClauses: Seq[MergeWhenClause]) = {
    val thrown = intercept[AnalysisException] {
      MergeWhenClause.validate(multiUpdateClauses)
    }
    assert(thrown.message === invalidMessage)
  }
} 
Example 109
Source File: ClickThroughRatePrediction.scala    From click-through-rate-prediction   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.examples.kaggle

import org.apache.spark.SparkFunSuite
import org.apache.spark.util.MLlibTestSparkContext

class ClickThroughRatePredictionSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("run") {
    //    Logger.getLogger("org").setLevel(Level.OFF)
    //    Logger.getLogger("akka").setLevel(Level.OFF)

    val trainPath = this.getClass.getResource("/train.part-10000").getPath
    val testPath = this.getClass.getResource("/test.part-10000").getPath
    val resultPath = "./tmp/result/"

    ClickThroughRatePrediction.run(sc, sqlContext, trainPath, testPath, resultPath)
  }
} 
Example 110
Source File: LabeledPointSuite.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.sona.ml.feature

import org.apache.spark.linalg.Vectors
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.KryoSerializer

class LabeledPointSuite extends SparkFunSuite {

  test("Kryo class register") {
    val conf = new SparkConf(false)
    conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    conf.registerKryoClasses(
      Array(classOf[scala.collection.mutable.WrappedArray.ofRef[_]],
        classOf[LabeledPoint]))
//    conf.set("spark.kryo.registrationRequired", "true")

    val ser = new KryoSerializer(conf).newInstance()

    val labeled1 = LabeledPoint(1.0, Vectors.dense(Array(1.0, 2.0)))
    val labeled2 = LabeledPoint(1.0, Vectors.sparse(10, Array(5, 7), Array(1.0, 2.0)))

    Seq(labeled1, labeled2).foreach { l =>
      val l2 = ser.deserialize[LabeledPoint](ser.serialize(l))
      assert(l === l2)
    }
  }
} 
Example 111
Source File: InstanceSuite.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.sona.ml.feature

import org.apache.spark.linalg.Vectors
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.KryoSerializer

class InstanceSuite extends SparkFunSuite{

  test("Kryo class register") {
    val conf = new SparkConf(false)
    conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    conf.registerKryoClasses(
      Array(classOf[scala.collection.mutable.WrappedArray.ofRef[_]],
        classOf[Instance]))
//    conf.set("spark.kryo.registrationRequired", "true")

    val ser = new KryoSerializer(conf).newInstance()

    val instance1 = Instance(19.0, 2.0, Vectors.dense(1.0, 7.0))
    val instance2 = Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse)
    Seq(instance1, instance2).foreach { i =>
      val i2 = ser.deserialize[Instance](ser.serialize(i))
      assert(i === i2)
    }

    val oInstance1 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0))
    val oInstance2 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0).toSparse)
    Seq(oInstance1, oInstance2).foreach { o =>
      val o2 = ser.deserialize[OffsetInstance](ser.serialize(o))
      assert(o === o2)
    }
  }
} 
Example 112
Source File: LibFFMRelationSuite.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.sona.ml.source.libffm

import java.io.File
import java.nio.charset.StandardCharsets

import com.google.common.io.Files
import org.apache.spark.SparkFunSuite
import com.tencent.angel.sona.ml.util.MLlibTestSparkContext
import org.apache.spark.util.SparkUtil

class LibFFMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
  // Path for dataset
  var path: String = _

  override def beforeAll(): Unit = {
    super.beforeAll()
    val lines0 =
      """
        |1 0:1:1.0 1:3:2.0 2:5:3.0
        |0
      """.stripMargin
    val lines1 =
      """
        |0 0:2:4.0 1:4:5.0 2:6:6.0
      """.stripMargin
    val dir = SparkUtil.createTempDir()
    val succ = new File(dir, "_SUCCESS")
    val file0 = new File(dir, "part-00000")
    val file1 = new File(dir, "part-00001")
    Files.write("", succ, StandardCharsets.UTF_8)
    Files.write(lines0, file0, StandardCharsets.UTF_8)
    Files.write(lines1, file1, StandardCharsets.UTF_8)
    path = dir.getPath
  }

  override def afterAll(): Unit = {
    try {
      val prefix = "C:\\Users\\fitzwang\\AppData\\Local\\Temp\\"
      if (path.startsWith(prefix)) {
        SparkUtil.deleteRecursively(new File(path))
      }
    } finally {
      super.afterAll()
    }
  }

  test("ffmIO"){
    val df = spark.read.format("libffm").load(path)
    val metadata = df.schema(1).metadata

    val fieldSet = MetaSummary.getFieldSet(metadata)
    println(fieldSet.mkString("[", ",", "]"))

    val keyFieldMap = MetaSummary.getKeyFieldMap(metadata)
    println(keyFieldMap.mkString("[", ",", "]"))

    df.write.format("libffm").save("temp.libffm")
  }

  test("read_ffm"){
    val df = spark.read.format("libffm").load(path)
    val metadata = df.schema(1).metadata

    val fieldSet = MetaSummary.getFieldSet(metadata)
    println(fieldSet.mkString("[", ",", "]"))

    val keyFieldMap = MetaSummary.getKeyFieldMap(metadata)
    println(keyFieldMap.mkString("[", ",", "]"))
  }

} 
Example 113
Source File: AngelTestUtils.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.sona.ml.util

import com.tencent.angel.sona.core.DriverContext
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.sql.{DataFrameReader, SparkSession}

class AngelTestUtils extends SparkFunSuite {
  protected var spark: SparkSession = _
  protected var libsvm: DataFrameReader = _
  protected var dummy: DataFrameReader = _
  protected var sparkConf: SparkConf = _
  protected var driverCtx: DriverContext = _

  protected override def beforeAll(): Unit = {
    super.beforeAll()
    spark = SparkSession.builder()
      .master("local[2]")
      .appName("AngelClassification")
      .getOrCreate()

    libsvm = spark.read.format("libsvmex")
    dummy = spark.read.format("dummy")
    sparkConf = spark.sparkContext.getConf

    driverCtx = DriverContext.get(sparkConf)
    driverCtx.startAngelAndPSAgent()
  }

  protected override def afterAll(): Unit = {
    super.afterAll()
    driverCtx.stopAngelAndPSAgent()
  }
} 
Example 114
Source File: PgWireProtocolSuite.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server.service.postgresql.protocol.v3

import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.sql.SQLException

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String

class PgWireProtocolSuite extends SparkFunSuite {

  val conf = new SQLConf()

  test("DataRow") {
    val v3Protocol = new PgWireProtocol(65536)
    val row = new GenericInternalRow(2)
    row.update(0, 8)
    row.update(1, UTF8String.fromString("abcdefghij"))
    val schema = StructType.fromDDL("a INT, b STRING")
    val rowConverters = PgRowConverters(conf, schema, Seq(true, false))
    val data = v3Protocol.DataRow(row, rowConverters)
    val bytes = ByteBuffer.wrap(data)
    assert(bytes.get() === 'D'.toByte)
    assert(bytes.getInt === 28)
    assert(bytes.getShort === 2)
    assert(bytes.getInt === 4)
    assert(bytes.getInt === 8)
    assert(bytes.getInt === 10)
    assert(data.slice(19, 30) === "abcdefghij".getBytes(StandardCharsets.UTF_8))
  }

  test("Fails when message buffer overflowed") {
    val v3Protocol = new PgWireProtocol(4)
    val row = new GenericInternalRow(1)
    row.update(0, UTF8String.fromString("abcdefghijk"))
    val schema = StructType.fromDDL("a STRING")
    val rowConverters = PgRowConverters(conf, schema, Seq(false))
    val errMsg = intercept[SQLException] {
      v3Protocol.DataRow(row, rowConverters)
    }.getMessage
    assert(errMsg.contains(
      "Cannot generate a V3 protocol message because buffer is not enough for the message. " +
        "To avoid this exception, you might set higher value at " +
        "'spark.sql.server.messageBufferSizeInBytes'")
    )
  }
} 
Example 115
Source File: ExtensionBuilderSuite.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server

import java.net.URL

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.util.Utils

class ExtensionBuilderSuite extends SparkFunSuite with BeforeAndAfterAll {

  var _sqlContext: SQLContext = _

  // TODO: This method works only in Java8
  private def addJarInClassPath(jarURLString: String): Unit = {
    // val cl = ClassLoader.getSystemClassLoader
    val cl = Utils.getSparkClassLoader
    val clazz = cl.getClass
    val method = clazz.getSuperclass.getDeclaredMethod("addURL", Seq(classOf[URL]): _*)
    method.setAccessible(true)
    method.invoke(cl, Seq[Object](new URL(jarURLString)): _*)
  }

  protected override def beforeAll(): Unit = {
    super.beforeAll()

    // Adds a jar for an extension builder
    val jarPath = "src/test/resources/extensions_2.12_3.0.0-preview2_0.1.7-spark3.0-SNAPSHOT.jar"
    val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath"
    // sqlContext.sparkContext.addJar(jarURL)
    addJarInClassPath(jarURL)

    val conf = new SparkConf(loadDefaults = true)
      .setMaster("local[1]")
      .setAppName("spark-sql-server-test")
      .set("spark.sql.server.extensions.builder", "org.apache.spark.ExtensionBuilderExample")
    _sqlContext = SQLServerEnv.newSQLContext(conf)
  }

  protected override def afterAll(): Unit = {
    try {
      super.afterAll()
    } finally {
      try {
        if (_sqlContext != null) {
          _sqlContext.sparkContext.stop()
          _sqlContext = null
        }
      } finally {
        SparkSession.clearActiveSession()
        SparkSession.clearDefaultSession()
      }
    }
  }

  test("user-defined optimizer rules") {
    val rules = Seq("org.apache.spark.catalyst.EmptyRule1", "org.apache.spark.catalyst.EmptyRule2")
    val optimizerRuleNames = _sqlContext.sessionState.optimizer
      .extendedOperatorOptimizationRules.map(_.ruleName)
    rules.foreach { expectedRuleName =>
      assert(optimizerRuleNames.contains(expectedRuleName))
    }
  }
} 
Example 116
Source File: NormalizerSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}


class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext {

  @transient var data: Array[Vector] = _
  @transient var dataFrame: DataFrame = _
  @transient var normalizer: Normalizer = _
  @transient var l1Normalized: Array[Vector] = _
  @transient var l2Normalized: Array[Vector] = _

  override def beforeAll(): Unit = {
    super.beforeAll()

    data = Array(
      Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
      Vectors.dense(0.0, 0.0, 0.0),
      Vectors.dense(0.6, -1.1, -3.0),
      Vectors.sparse(3, Seq((1, 0.91), (2, 3.2))),
      Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))),
      Vectors.sparse(3, Seq())
    )
     l1Normalized = Array(
      Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))),
      Vectors.dense(0.0, 0.0, 0.0),
      Vectors.dense(0.12765957, -0.23404255, -0.63829787),
      Vectors.sparse(3, Seq((1, 0.22141119), (2, 0.7785888))),
      Vectors.dense(0.625, 0.07894737, 0.29605263),
      Vectors.sparse(3, Seq())
    )
    l2Normalized = Array(
      Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))),
      Vectors.dense(0.0, 0.0, 0.0),
      Vectors.dense(0.184549876, -0.3383414, -0.922749378),
      Vectors.sparse(3, Seq((1, 0.27352993), (2, 0.96186349))),
      Vectors.dense(0.897906166, 0.113419726, 0.42532397),
      Vectors.sparse(3, Seq())
    )

    val sqlContext = new SQLContext(sc)
    dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData))
    normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized_features")
  }
  //收集的结果
  def collectResult(result: DataFrame): Array[Vector] = {
    result.select("normalized_features").collect().map {
      case Row(features: Vector) => features
    }
  }
  //向量的断言类型
  def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
    assert((lhs, rhs).zipped.forall {
      case (v1: DenseVector, v2: DenseVector) => true
      case (v1: SparseVector, v2: SparseVector) => true
      case _ => false
    }, "The vector type should be preserved after normalization.")
  }
  //断言值
  def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
    assert((lhs, rhs).zipped.forall { (vector1, vector2) =>
      vector1 ~== vector2 absTol 1E-5
    }, "The vector value is not correct after normalization.")
  }

  test("Normalization with default parameter") {//默认参数的正常化
  //transform()方法将DataFrame转化为另外一个DataFrame的算法
    normalizer.transform(dataFrame).show()
    val result = collectResult(normalizer.transform(dataFrame))

    assertTypeOfVector(data, result)

    assertValues(result, l2Normalized)
  }

  test("Normalization with setter") {//规范化设置
    normalizer.setP(1)
    //transform()方法将DataFrame转化为另外一个DataFrame的算法
    normalizer.transform(dataFrame).show()
    val result = collectResult(normalizer.transform(dataFrame))

    assertTypeOfVector(data, result)

    assertValues(result, l1Normalized)
  }
}

private object NormalizerSuite {
  case class FeatureData(features: Vector)
} 
Example 117
Source File: AttributeGroupSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.attribute

import org.apache.spark.SparkFunSuite

class AttributeGroupSuite extends SparkFunSuite {

  test("attribute group") {//属性分组
    val attrs = Array(
      NumericAttribute.defaultAttr,
      NominalAttribute.defaultAttr,
      BinaryAttribute.defaultAttr.withIndex(0),
      NumericAttribute.defaultAttr.withName("age").withSparsity(0.8),
      NominalAttribute.defaultAttr.withName("size").withValues("small", "medium", "large"),
      BinaryAttribute.defaultAttr.withName("clicked").withValues("no", "yes"),
      NumericAttribute.defaultAttr,
      NumericAttribute.defaultAttr)
    val group = new AttributeGroup("user", attrs)
    assert(group.size === 8)
    assert(group.name === "user")
    assert(group(0) === NumericAttribute.defaultAttr.withIndex(0))
    assert(group(2) === BinaryAttribute.defaultAttr.withIndex(2))
    assert(group.indexOf("age") === 3)
    assert(group.indexOf("size") === 4)
    assert(group.indexOf("clicked") === 5)
    assert(!group.hasAttr("abc"))
    intercept[NoSuchElementException] {
      group("abc")
    }
    assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name))
    assert(group === AttributeGroup.fromStructField(group.toStructField()))
  }

  test("attribute group without attributes") {//没有属性的属性组
    val group0 = new AttributeGroup("user", 10)
    assert(group0.name === "user")
    assert(group0.numAttributes === Some(10))
    assert(group0.size === 10)
    assert(group0.attributes.isEmpty)
    assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name))
    assert(group0 === AttributeGroup.fromStructField(group0.toStructField()))

    val group1 = new AttributeGroup("item")
    assert(group1.name === "item")
    assert(group1.numAttributes.isEmpty)
    assert(group1.attributes.isEmpty)
    assert(group1.size === -1)
  }
} 
Example 118
Source File: RegressionEvaluatorSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.evaluation

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._

class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("params") {
    ParamsSuite.checkParams(new RegressionEvaluator)
  }

  test("Regression Evaluator: default params") {//评估回归:默认参数
    
    val trainer = new LinearRegression
    //fit()方法将DataFrame转化为一个Transformer的算法
    val model = trainer.fit(dataset) //转换
    //Prediction 预测
    //transform()方法将DataFrame转化为另外一个DataFrame的算法
    val predictions = model.transform(dataset)
    predictions.collect()

    // default = rmse
    //默认rmse均方根误差说明样本的离散程度
    val evaluator = new RegressionEvaluator()
    println("==MetricName="+evaluator.getMetricName+"=LabelCol="+evaluator.getLabelCol+"=PredictionCol="+evaluator.getPredictionCol)
    //==MetricName=rmse=LabelCol=label=PredictionCol=prediction,默认rmse均方根误差说明样本的离散程度
    assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001)

    // r2 score 评分
    //R2平方系统也称判定系数,用来评估模型拟合数据的好坏
    evaluator.setMetricName("r2")
    assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001)

    //MAE平均绝对误差是所有单个观测值与算术平均值的偏差的绝对值的平均
    evaluator.setMetricName("mae")
    assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
  }
} 
Example 119
Source File: ParamGridBuilderSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.tuning

import scala.collection.mutable

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.{ParamMap, TestParams}

class ParamGridBuilderSuite extends SparkFunSuite {

  val solver = new TestParams()
  import solver.{inputCol, maxIter}

  test("param grid builder") {//参数网格生成器
    def validateGrid(maps: Array[ParamMap], expected: mutable.Set[(Int, String)]): Unit = {
      assert(maps.size === expected.size)
      maps.foreach { m =>//m:ParamMap类型
        //(10,input0)(10,input1)
        val tuple = (m(maxIter), m(inputCol))
        assert(expected.contains(tuple))
        expected.remove(tuple)
      }
      assert(expected.isEmpty)
    }
    //通过addGrid添加我们需要寻找的最佳参数
    //ParamGridBuilder构建待选参数(如:logistic regression的regParam)
    val maps0 = new ParamGridBuilder()
      .baseOn(maxIter -> 10)
      .addGrid(inputCol, Array("input0", "input1"))
      .build()
    //期望值
    val expected0 = mutable.Set(
      (10, "input0"),
      (10, "input1"))
    validateGrid(maps0, expected0)
    val maps1 = new ParamGridBuilder()
      .baseOn(ParamMap(maxIter -> 5, inputCol -> "input")) // will be overwritten 将被覆盖
      .addGrid(maxIter, Array(10, 20))//重载
      .addGrid(inputCol, Array("input0", "input1"))
      .build()
    val expected1 = mutable.Set(
      (10, "input0"),
      (20, "input0"),
      (10, "input1"),
      (20, "input1"))
    validateGrid(maps1, expected1)
  }
} 
Example 120
Source File: ProbabilisticClassifierSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.classification

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}

final class TestProbabilisticClassificationModel(
    override val uid: String,
    override val numClasses: Int)
  extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] {

  override def copy(extra: org.apache.spark.ml.param.ParamMap): this.type = defaultCopy(extra)

  override protected def predictRaw(input: Vector): Vector = {
    input
  }

  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
    rawPrediction
  }

  def friendlyPredict(input: Vector): Double = {
    predict(input)
  }
}

//概率分类器套件
class ProbabilisticClassifierSuite extends SparkFunSuite {

  test("test thresholding") {//测试阈值
    val thresholds = Array(0.5, 0.2)
     //在二进制分类中设置阈值,范围为[0,1],如果类标签1的估计概率>Threshold,则预测1,否则0
    val testModel = new TestProbabilisticClassificationModel("myuid", 2).setThresholds(thresholds)
    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
  }

  test("test thresholding not required") {//测试不需要阈值
    val testModel = new TestProbabilisticClassificationModel("myuid", 2)
    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
  }
} 
Example 121
Source File: ANNSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.ann

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._


class ANNSuite extends SparkFunSuite with MLlibTestSparkContext {

  // TODO: test for weights comparison with Weka MLP
  //人工神经网络与乙状结肠学习LBFGS优化器异或函数
  test("ANN with Sigmoid learns XOR function with LBFGS optimizer") {
    val inputs = Array(
      Array(0.0, 0.0),
      Array(0.0, 1.0),
      Array(1.0, 0.0),
      Array(1.0, 1.0)
    )
    val outputs = Array(0.0, 1.0, 1.0, 0.0)
    val data = inputs.zip(outputs).map { case (features, label) =>
      (Vectors.dense(features), Vectors.dense(label))
    }
    val rddData = sc.parallelize(data, 1)
    val hiddenLayersTopology = Array(5)
    val dataSample = rddData.first()
    val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
    val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
    val initialWeights = FeedForwardModel(topology, 23124).weights()
    val trainer = new FeedForwardTrainer(topology, 2, 1)
    //initialWeights初始取值,默认是0向量
    trainer.setWeights(initialWeights)
    trainer.LBFGSOptimizer.setNumIterations(20)
    val model = trainer.train(rddData)
    val predictionAndLabels = rddData.map { case (input, label) =>
      (model.predict(input)(0), label(0))
    }.collect()
    predictionAndLabels.foreach { case (p, l) =>
      assert(math.round(p) === l)
    }
  }
  //人工神经网络与学习两输出和批量SoftMax GD优化器异或函数
  test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") {
    val inputs = Array(
      Array(0.0, 0.0),
      Array(0.0, 1.0),
      Array(1.0, 0.0),
      Array(1.0, 1.0)
    )
    val outputs = Array(
      Array(1.0, 0.0),
      Array(0.0, 1.0),
      Array(0.0, 1.0),
      Array(1.0, 0.0)
    )
    val data = inputs.zip(outputs).map { case (features, label) =>
      (Vectors.dense(features), Vectors.dense(label))
    }
    val rddData = sc.parallelize(data, 1)
    val hiddenLayersTopology = Array(5)
    val dataSample = rddData.first()
    val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
    val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
    val initialWeights = FeedForwardModel(topology, 23124).weights()
    val trainer = new FeedForwardTrainer(topology, 2, 2)
    //(SGD随机梯度下降)
    trainer.SGDOptimizer.setNumIterations(2000)
    //initialWeights初始取值,默认是0向量
    trainer.setWeights(initialWeights)
    val model = trainer.train(rddData)
    val predictionAndLabels = rddData.map { case (input, label) =>
      (model.predict(input), label)
    }.collect()
    predictionAndLabels.foreach { case (p, l) =>
      assert(p ~== l absTol 0.5)
    }
  }
} 
Example 122
Source File: ChiSqSelectorSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext

  //特征提取和转换 卡方选择(ChiSqSelector)稀疏和稠密向量
  test("ChiSqSelector transform test (sparse & dense vector)") {
    val labeledDiscreteData = sc.parallelize(//标记的离散数据
      Seq(LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
      //LabeledPoint标记点是局部向量,向量可以是密集型或者稀疏型,每个向量会关联了一个标签(label)
        LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
        LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
        LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2)
    val preFilteredData =//预过滤数据
    //LabeledPoint标记点是局部向量,向量可以是密集型或者稀疏型,每个向量会关联了一个标签(label)
      Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
        LabeledPoint(1.0, Vectors.dense(Array(6.0))),
        LabeledPoint(1.0, Vectors.dense(Array(8.0))),
        LabeledPoint(2.0, Vectors.dense(Array(5.0))))
	//fit()方法将DataFrame转化为一个Transformer的算法
    val model = new ChiSqSelector(1).fit(labeledDiscreteData)
    val filteredData = labeledDiscreteData.map { lp =>
     //transform()方法将DataFrame转化为另外一个DataFrame的算法
      LabeledPoint(lp.label, model.transform(lp.features))
    }.collect().toSet
    assert(filteredData == preFilteredData)
  }
} 
Example 123
Source File: ElementwiseProductSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext {
  //产品应适用于数据集在一个密集的矢量
  test("elementwise (hadamard) product should properly apply vector to dense data set") {
    val denseData = Array(
      Vectors.dense(1.0, 4.0, 1.9, -9.0)
    )
    val scalingVec = Vectors.dense(2.0, 0.5, 0.0, 0.25)
    val transformer = new ElementwiseProduct(scalingVec)
    //批理变换和每个变换,得到相同的结果
     //transform()方法将DataFrame转化为另外一个DataFrame的算法
    val transformedData = transformer.transform(sc.makeRDD(denseData))
    val transformedVecs = transformedData.collect()
    val transformedVec = transformedVecs(0)
    val expectedVec = Vectors.dense(2.0, 2.0, 0.0, -2.25)
    assert(transformedVec ~== expectedVec absTol 1E-5,
      s"Expected transformed vector $expectedVec but found $transformedVec")
  }
  //元素(Hadamard)产品应正确运用向量的稀疏数据集
  test("elementwise (hadamard) product should properly apply vector to sparse data set") {
    val sparseData = Array(
      Vectors.sparse(3, Seq((1, -1.0), (2, -3.0)))
    )
    val dataRDD = sc.parallelize(sparseData, 3)
    val scalingVec = Vectors.dense(1.0, 0.0, 0.5)
    val transformer = new ElementwiseProduct(scalingVec)
    val data2 = sparseData.map(transformer.transform)
     //transform()方法将DataFrame转化为另外一个DataFrame的算法
    val data2RDD = transformer.transform(dataRDD)

    assert((sparseData, data2, data2RDD.collect()).zipped.forall {
      case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
      case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
      case _ => false
    }, "The vector type should be preserved after hadamard product")

    assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
    assert(data2(0) ~== Vectors.sparse(3, Seq((1, 0.0), (2, -1.5))) absTol 1E-5)
  }
} 
Example 124
Source File: PCASuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.util.MLlibTestSparkContext

class PCASuite extends SparkFunSuite with MLlibTestSparkContext {

  private val data = Array(
    Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
    Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
    Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
  )

  private lazy val dataRDD = sc.parallelize(data, 2)

  test("Correct computing use a PCA wrapper") {//正确的计算使用一个主成分分析包装
    val k = dataRDD.count().toInt
    //fit()方法将DataFrame转化为一个Transformer的算法
    val pca = new PCA(k).fit(dataRDD)
   //转换分布式矩阵分
    val mat = new RowMatrix(dataRDD)
    //计算主成分析,将维度降为K
    val pc = mat.computePrincipalComponents(k)
    //PCA变换
     //transform()方法将DataFrame转化为另外一个DataFrame的算法
    val pca_transform = pca.transform(dataRDD).collect()
    //Mat _相乘
    val mat_multiply = mat.multiply(pc).rows.collect()
    assert(pca_transform.toSet === mat_multiply.toSet)
  }
} 
Example 125
Source File: HashingTFSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
  //把字符转换特征哈希值,返回词的频率


class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("hashing tf on a single doc") {//散列在一个单一的文件
      
    val hashingTF = new HashingTF(1000)
    val doc = "a a b b c d".split(" ")
    val n = hashingTF.numFeatures
    //词的频率
    val termFreqs = Seq(
      (hashingTF.indexOf("a"), 2.0),
      (hashingTF.indexOf("b"), 2.0),
      (hashingTF.indexOf("c"), 1.0),
      (hashingTF.indexOf("d"), 1.0))
    //termFreqs: Seq[(Int, Double)] = List((97,2.0), (98,2.0), (99,1.0), (100,1.0))
    assert(termFreqs.map(_._1).forall(i => i >= 0 && i < n),
      "index must be in range [0, #features)")//索引必须在范围内
    assert(termFreqs.map(_._1).toSet.size === 4, "expecting perfect hashing")//期待完美的哈希
    val expected = Vectors.sparse(n, termFreqs)
    //transform 把每个输入文档映射到一个Vector对象
     //transform()方法将DataFrame转化为另外一个DataFrame的算法
    assert(hashingTF.transform(doc) === expected)
  }

  test("hashing tf on an RDD") {//散列TF在RDD
    val hashingTF = new HashingTF
    val localDocs: Seq[Seq[String]] = Seq(
      "a a b b b c d".split(" "),
      "a b c d a b c".split(" "),
      "c b a c b a a".split(" "))
    val docs = sc.parallelize(localDocs, 2)    
     //transform()方法将DataFrame转化为另外一个DataFrame的算法
    assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet)
  }
} 
Example 126
Source File: ImpuritySuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.tree

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
import org.apache.spark.mllib.util.MLlibTestSparkContext


class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext {
  test("Gini impurity does not support negative labels") {//基尼杂质不支持负标签
    val gini = new GiniAggregator(2)
    intercept[IllegalArgumentException] {
      gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
    }
  }

  test("Entropy does not support negative labels") {//熵不支持负标签
    val entropy = new EntropyAggregator(2)
    intercept[IllegalArgumentException] {
      entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
    }
  }
} 
Example 127
Source File: BaggedPointSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.tree.impl

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.tree.EnsembleTestHelper
import org.apache.spark.mllib.util.MLlibTestSparkContext


class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext  {

  test("BaggedPoint RDD: without subsampling") {
    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
    val rdd = sc.parallelize(arr)
    val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42)
    baggedRDD.collect().foreach { baggedPoint =>
      assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
    }
  }

  test("BaggedPoint RDD: with subsampling with replacement (fraction = 1.0)") {
    val numSubsamples = 100
    val (expectedMean, expectedStddev) = (1.0, 1.0)

    val seeds = Array(123, 5354, 230, 349867, 23987)
    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
    val rdd = sc.parallelize(arr)
    seeds.foreach { seed =>
      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed)
      val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
      EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
      //epsilon代收敛的阀值
        expectedStddev, epsilon = 0.01)
    }
  }

  test("BaggedPoint RDD: with subsampling with replacement (fraction = 0.5)") {
    val numSubsamples = 100
    val subsample = 0.5
     //math.abs返回数的绝对值
    val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample))

    val seeds = Array(123, 5354, 230, 349867, 23987)
    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
    val rdd = sc.parallelize(arr)
    seeds.foreach { seed =>
      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed)
      val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
      EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
        expectedStddev, epsilon = 0.01)
    }
  }

  test("BaggedPoint RDD: with subsampling without replacement (fraction = 1.0)") {
    val numSubsamples = 100
    val (expectedMean, expectedStddev) = (1.0, 0)

    val seeds = Array(123, 5354, 230, 349867, 23987)
    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
    val rdd = sc.parallelize(arr)
    seeds.foreach { seed =>
      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed)
      val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
      EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
        expectedStddev, epsilon = 0.01)
    }
  }

  test("BaggedPoint RDD: with subsampling without replacement (fraction = 0.5)") {
    val numSubsamples = 100
    val subsample = 0.5
     //math.abs返回数的绝对值
    val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample * (1 - subsample)))

    val seeds = Array(123, 5354, 230, 349867, 23987)
    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
    val rdd = sc.parallelize(arr)
    seeds.foreach { seed =>
      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed)
      val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
      EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
        expectedStddev, epsilon = 0.01)
    }
  }
} 
Example 128
Source File: MatrixFactorizationModelSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.recommendation

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

    sqlContext.createDataFrame(prodFeatures).show()
  }

  test("constructor") {//构造函数
    val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
    //预测得分,用户ID,产品ID
    println("========"+model.predict(0, 2))
    //17.0
    assert(model.predict(0, 2) ~== 17.0 relTol 1e-14)

    intercept[IllegalArgumentException] {
      new MatrixFactorizationModel(1, userFeatures, prodFeatures)
    }
    //userFeatures 用户特征
    val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0))))
    intercept[IllegalArgumentException] {
      new MatrixFactorizationModel(rank, userFeatures1, prodFeatures)
    }
   //prodFeatures 产品特征
    val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0))))
    intercept[IllegalArgumentException] {
      new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
    }
  }

  test("save/load") {//保存/加载
    val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
    val tempDir = Utils.createTempDir()
    val path = tempDir.toURI.toString
    def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = {
      features.mapValues(_.toSeq).collect().toSet
    }
    try {
      model.save(sc, path)
      val newModel = MatrixFactorizationModel.load(sc, path)
      assert(newModel.rank === rank)
      //用户特征
      assert(collect(newModel.userFeatures) === collect(userFeatures))
      //产品特征
      assert(collect(newModel.productFeatures) === collect(prodFeatures))
    } finally {
      Utils.deleteRecursively(tempDir)
    }
  }

  test("batch predict API recommendProductsForUsers") {//批量预测API recommendproductsforusers
    val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
    val topK = 10
    //为用户推荐个数为num的商品
    val recommendations = model.recommendProductsForUsers(topK).collectAsMap()

    assert(recommendations(0)(0).rating ~== 17.0 relTol 1e-14)
    assert(recommendations(1)(0).rating ~== 39.0 relTol 1e-14)
  }

  test("batch predict API recommendUsersForProducts") {
    
    //userFeatures用户因子,prodFeatures商品因子,rank因子个数,因子个数一般越多越好,普通取值10到200
    val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
    val topK = 10
    //为用户推荐个数为num的商品
    val recommendations = model.recommendUsersForProducts(topK).collectAsMap()

    assert(recommendations(2)(0).user == 1)
    assert(recommendations(2)(0).rating ~== 39.0 relTol 1e-14)
    assert(recommendations(2)(1).user == 0)
    assert(recommendations(2)(1).rating ~== 17.0 relTol 1e-14)
  }
} 
Example 129
Source File: AreaUnderCurveSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.evaluation

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

class AreaUnderCurveSuite extends SparkFunSuite with MLlibTestSparkContext {
  test("auc computation") {//AUC计算
    //曲线
    val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
    val auc = 4.0
    assert(AreaUnderCurve.of(curve) ~== auc absTol 1E-5) //1e-5的意思就是1乘以10的负5次幂.就是0.000001
    val rddCurve = sc.parallelize(curve, 2)
    assert(AreaUnderCurve.of(rddCurve) ~== auc absTol 1E-5)
  }

  test("auc of an empty curve") {//AUC空曲线
    //曲线
    val curve = Seq.empty[(Double, Double)]
    assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5)
    val rddCurve = sc.parallelize(curve, 2)
    assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5)
  }

  test("auc of a curve with a single point") {//单点与曲线的AUC
    val curve = Seq((1.0, 1.0))
    assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5)
    val rddCurve = sc.parallelize(curve, 2)
    assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5)
  }
} 
Example 130
Source File: PythonMLLibAPISuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.api.python

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.recommendation.Rating

class PythonMLLibAPISuite extends SparkFunSuite {

  SerDe.initialize()

  test("pickle vector") {
    val vectors = Seq(
      Vectors.dense(Array.empty[Double]),
      Vectors.dense(0.0),
      Vectors.dense(0.0, -2.0),
      Vectors.sparse(0, Array.empty[Int], Array.empty[Double]),
      Vectors.sparse(1, Array.empty[Int], Array.empty[Double]),
      Vectors.sparse(2, Array(1), Array(-2.0)))
    vectors.foreach { v =>
      val u = SerDe.loads(SerDe.dumps(v))
      assert(u.getClass === v.getClass)
      assert(u === v)
    }
  }

  test("pickle labeled point") {
    val points = Seq(
   
      LabeledPoint(0.0, Vectors.dense(Array.empty[Double])),
      LabeledPoint(1.0, Vectors.dense(0.0)),
      LabeledPoint(-0.5, Vectors.dense(0.0, -2.0)),
      LabeledPoint(0.0, Vectors.sparse(0, Array.empty[Int], Array.empty[Double])),
      LabeledPoint(1.0, Vectors.sparse(1, Array.empty[Int], Array.empty[Double])),
      LabeledPoint(-0.5, Vectors.sparse(2, Array(1), Array(-2.0))))
    points.foreach { p =>
      val q = SerDe.loads(SerDe.dumps(p)).asInstanceOf[LabeledPoint]
      assert(q.label === p.label)
      assert(q.features.getClass === p.features.getClass)
      assert(q.features === p.features)
    }
  }

  test("pickle double") {
    for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) {
      val deser = SerDe.loads(SerDe.dumps(x.asInstanceOf[AnyRef])).asInstanceOf[Double]
      // We use `equals` here for comparison because we cannot use `==` for NaN
      assert(x.equals(deser))
    }
  }

  test("pickle matrix") {
    val values = Array[Double](0, 1.2, 3, 4.56, 7, 8)
    val matrix = Matrices.dense(2, 3, values)
    val nm = SerDe.loads(SerDe.dumps(matrix)).asInstanceOf[DenseMatrix]
    assert(matrix === nm)

    // Test conversion for empty matrix
    val empty = Array[Double]()
    val emptyMatrix = Matrices.dense(0, 0, empty)
    val ne = SerDe.loads(SerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix]
    assert(emptyMatrix == ne)

    val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4))
    val nsm = SerDe.loads(SerDe.dumps(sm)).asInstanceOf[SparseMatrix]
    assert(sm.toArray === nsm.toArray)

    val smt = new SparseMatrix(
      3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9),
      isTransposed = true)
    val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix]
    assert(smt.toArray === nsmt.toArray)
  }

  test("pickle rating") {
    val rat = new Rating(1, 2, 3.0)
    val rat2 = SerDe.loads(SerDe.dumps(rat)).asInstanceOf[Rating]
    assert(rat == rat2)

    // Test name of class only occur once
    val rats = (1 to 10).map(x => new Rating(x, x + 1, x + 3.0)).toArray
    val bytes = SerDe.dumps(rats)
    assert(bytes.toString.split("Rating").length == 1)
    assert(bytes.length / 10 < 25) //  25 bytes per rating

  }
} 
Example 131
Source File: FPTreeSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.fpm

import scala.language.existentials

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext

class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("add transaction") {//增加转换
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    assert(tree.root.children.size == 2)
    assert(tree.root.children.contains("a"))
    assert(tree.root.children("a").item.equals("a"))
    assert(tree.root.children("a").count == 2)
    assert(tree.root.children.contains("b"))
    assert(tree.root.children("b").item.equals("b"))
    assert(tree.root.children("b").count == 1)
    var child = tree.root.children("a")
    assert(child.children.size == 1)
    assert(child.children.contains("b"))
    assert(child.children("b").item.equals("b"))
    assert(child.children("b").count == 2)
    child = child.children("b")
    assert(child.children.size == 2)
    assert(child.children.contains("c"))
    assert(child.children.contains("y"))
    assert(child.children("c").item.equals("c"))
    assert(child.children("y").item.equals("y"))
    assert(child.children("c").count == 1)
    assert(child.children("y").count == 1)
  }

  test("merge tree") {//合并树
    val tree1 = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    val tree2 = new FPTree[String]
      .add(Seq("a", "b"))
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "c", "d"))
      .add(Seq("a", "x"))
      .add(Seq("a", "x", "y"))
      .add(Seq("c", "n"))
      .add(Seq("c", "m"))

    val tree3 = tree1.merge(tree2)

    assert(tree3.root.children.size == 3)
    assert(tree3.root.children("a").count == 7)
    assert(tree3.root.children("b").count == 1)
    assert(tree3.root.children("c").count == 2)
    val child1 = tree3.root.children("a")
    assert(child1.children.size == 2)
    assert(child1.children("b").count == 5)
    assert(child1.children("x").count == 2)
    val child2 = child1.children("b")
    assert(child2.children.size == 2)
    assert(child2.children("y").count == 1)
    assert(child2.children("c").count == 3)
    val child3 = child2.children("c")
    assert(child3.children.size == 1)
    assert(child3.children("d").count == 1)
    val child4 = child1.children("x")
    assert(child4.children.size == 1)
    assert(child4.children("y").count == 1)
    val child5 = tree3.root.children("c")
    assert(child5.children.size == 2)
    assert(child5.children("n").count == 1)
    assert(child5.children("m").count == 1)
  }

  test("extract freq itemsets") {//频繁项集的提取物
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("a", "b"))
      .add(Seq("a"))
      .add(Seq("b"))
      .add(Seq("b", "n"))

    val freqItemsets = tree.extract(3L).map { case (items, count) =>
      (items.toSet, count)
    }.toSet
    val expected = Set(
      (Set("a"), 4L),
      (Set("b"), 5L),
      (Set("a", "b"), 3L))
    assert(freqItemsets === expected)
  }
} 
Example 132
Source File: AssociationRulesSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.fpm

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
//频繁模式挖掘-Association Rules
class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("association rules using String type") {//使用字符串类型的关联规则
    val freqItemsets = sc.parallelize(Seq(//频繁项集
      (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
      (Set("r"), 3L),
      (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
      (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
      (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
      (Set("t", "y", "x"), 3L),
      (Set("t", "y", "x", "z"), 3L)
    ).map {
      case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq)
    })
  //频繁模式挖掘-Association Rules
    val ar = new AssociationRules()

    val results1 = ar
      .setMinConfidence(0.9)
      .run(freqItemsets)
      .collect()

    
    assert(results2.size === 30)
     //math.abs返回数的绝对值
    assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
  }
} 
Example 133
Source File: KernelDensitySuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.stat

import org.apache.commons.math3.distribution.NormalDistribution

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext

class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext {
  test("kernel density single sample") {//核密度单样本
    val rdd = sc.parallelize(Array(5.0))
    val evaluationPoints = Array(5.0, 6.0)
    val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
    val normal = new NormalDistribution(5.0, 3.0)
    val acceptableErr = 1e-6
     //math.abs返回数的绝对值
    assert(math.abs(densities(0) - normal.density(5.0)) < acceptableErr)
    assert(math.abs(densities(1) - normal.density(6.0)) < acceptableErr)
  }

  test("kernel density multiple samples") {//核密度多样本
    val rdd = sc.parallelize(Array(5.0, 10.0))
    val evaluationPoints = Array(5.0, 6.0)
    val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
    val normal1 = new NormalDistribution(5.0, 3.0)
    val normal2 = new NormalDistribution(10.0, 3.0)
    val acceptableErr = 1e-6
     //math.abs返回数的绝对值
    assert(math.abs(
      densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2) < acceptableErr)
    assert(math.abs(
      densities(1) - (normal1.density(6.0) + normal2.density(6.0)) / 2) < acceptableErr)
  }
} 
Example 134
Source File: MultivariateGaussianSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.stat.distribution

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{ Vectors, Matrices }
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext {
  test("univariate") {//单变量
    val x1 = Vectors.dense(0.0)
    val x2 = Vectors.dense(1.5)

    val mu = Vectors.dense(0.0)
    //密集矩阵
    val sigma1 = Matrices.dense(1, 1, Array(1.0))
    //多元高斯
    val dist1 = new MultivariateGaussian(mu, sigma1)
    assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
    assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)

    val sigma2 = Matrices.dense(1, 1, Array(4.0))
    val dist2 = new MultivariateGaussian(mu, sigma2)
    assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
    assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
  }

  test("multivariate") {//多变量
    val x1 = Vectors.dense(0.0, 0.0)//创建密集向量
    val x2 = Vectors.dense(1.0, 1.0)//创建密集向量

    val mu = Vectors.dense(0.0, 0.0)//创建密集向量
    val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
    val dist1 = new MultivariateGaussian(mu, sigma1)
    assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
    assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)

    val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
    val dist2 = new MultivariateGaussian(mu, sigma2)
    assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
    assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
  }

  test("multivariate degenerate") {//多元退化
    val x1 = Vectors.dense(0.0, 0.0)
    val x2 = Vectors.dense(1.0, 1.0)

    val mu = Vectors.dense(0.0, 0.0)
    val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
    val dist = new MultivariateGaussian(mu, sigma)
    assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5)
    assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5)
  }

  test("SPARK-11302") {
    val x = Vectors.dense(629, 640, 1.7188, 618.19)
    val mu = Vectors.dense(
      1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697)
    val sigma = Matrices.dense(4, 4, Array(
      166769.00466698944, 169336.6705268059, 12.820670788921873, 164243.93314092053,
      169336.6705268059, 172041.5670061245, 21.62590020524533, 166678.01075856484,
      12.820670788921873, 21.62590020524533, 0.872524191943962, 4.283255814732373,
      164243.93314092053, 166678.01075856484, 4.283255814732373, 161848.9196719207))
    val dist = new MultivariateGaussian(mu, sigma)
    // Agrees with R's dmvnorm: 7.154782e-05
    assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9)
  }

} 
Example 135
Source File: KMeansPMMLModelExportSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.pmml.export

import org.dmg.pmml.ClusteringModel

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors

class KMeansPMMLModelExportSuite extends SparkFunSuite {

  test("KMeansPMMLModelExport generate PMML format") {
    val clusterCenters = Array(
      Vectors.dense(1.0, 2.0, 6.0),
      Vectors.dense(1.0, 3.0, 0.0),
      Vectors.dense(1.0, 4.0, 6.0))
    val kmeansModel = new KMeansModel(clusterCenters)

    val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel)

    // assert that the PMML format is as expected
    assert(modelExport.isInstanceOf[PMMLModelExport])
    val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml
    assert(pmml.getHeader.getDescription === "k-means clustering")
    // check that the number of fields match the single vector size
     //clusterCenters聚类中心点
    assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size)
    // This verify that there is a model attached to the pmml object and the model is a clustering
    // one. It also verifies that the pmml model has the same number of clusters of the spark model.
    val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
    assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
  }

} 
Example 136
Source File: PMMLModelExportFactorySuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.pmml.export

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel}
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
import org.apache.spark.mllib.util.LinearDataGenerator

    val multiclassLogisticRegressionModel = new LogisticRegressionModel(
      weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0,
      //numClasses 分类数
      numFeatures = 2, numClasses = 3)

    intercept[IllegalArgumentException] {
      PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel)
    }
  }

  test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") {
    val invalidModel = new Object

    intercept[IllegalArgumentException] {
      PMMLModelExportFactory.createPMMLModelExport(invalidModel)
    }
  }
} 
Example 137
Source File: NumericParserSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.util

import org.apache.spark.{SparkException, SparkFunSuite}

class NumericParserSuite extends SparkFunSuite {

  test("parser") {//解析
    val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)"
    val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]]
    assert(parsed(0).asInstanceOf[Seq[_]] === Seq(1.0, 2.0e3))
    assert(parsed(1).asInstanceOf[Double] === -4.0)
    assert(parsed(2).asInstanceOf[Array[Double]] === Array(5.0e-6, 7.0e8))
    assert(parsed(3).asInstanceOf[Double] === 9.0)

    val malformatted = Seq("a", "[1,,]", "0.123.4", "1 2", "3+4")
    malformatted.foreach { s =>
      intercept[SparkException] {
        NumericParser.parse(s)
        throw new RuntimeException(s"Didn't detect malformatted string $s.")
      }
    }
  }

  test("parser with whitespaces") {//空格的解析
    val s = "(0.0, [1.0, 2.0])"
    //数字解析
    val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]]
    assert(parsed(0).asInstanceOf[Double] === 0.0)
    assert(parsed(1).asInstanceOf[Array[Double]] === Array(1.0, 2.0))
  }
} 
Example 138
Source File: BreezeMatrixConversionSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.linalg

import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM}

import org.apache.spark.SparkFunSuite

class BreezeMatrixConversionSuite extends SparkFunSuite {
  test("dense matrix to breeze") {//稠密矩阵转换breeze矩阵
    val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))
    val breeze = mat.toBreeze.asInstanceOf[BDM[Double]]
    assert(breeze.rows === mat.numRows)
    assert(breeze.cols === mat.numCols)
    assert(breeze.data.eq(mat.asInstanceOf[DenseMatrix].values), "should not copy data")
  }

  test("dense breeze matrix to matrix") {//稠密breeze矩阵转换矩阵
    val breeze = new BDM[Double](3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))
    val mat = Matrices.fromBreeze(breeze).asInstanceOf[DenseMatrix]
    assert(mat.numRows === breeze.rows)
    assert(mat.numCols === breeze.cols)
    assert(mat.values.eq(breeze.data), "should not copy data")
    // transposed matrix 转置矩阵
    val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[DenseMatrix]
    assert(matTransposed.numRows === breeze.cols)
    assert(matTransposed.numCols === breeze.rows)
    assert(matTransposed.values.eq(breeze.data), "should not copy data")
  }

  test("sparse matrix to breeze") {//稀疏矩阵转换breeze矩阵
    val values = Array(1.0, 2.0, 4.0, 5.0)
    val colPtrs = Array(0, 2, 4)
    val rowIndices = Array(1, 2, 1, 2)
    val mat = Matrices.sparse(3, 2, colPtrs, rowIndices, values)
    val breeze = mat.toBreeze.asInstanceOf[BSM[Double]]
    assert(breeze.rows === mat.numRows)
    assert(breeze.cols === mat.numCols)
    assert(breeze.data.eq(mat.asInstanceOf[SparseMatrix].values), "should not copy data")
  }

  test("sparse breeze matrix to sparse matrix") {//稀疏breeze矩阵转换稀疏矩阵
    val values = Array(1.0, 2.0, 4.0, 5.0)
    val colPtrs = Array(0, 2, 4)
    val rowIndices = Array(1, 2, 1, 2)
    val breeze = new BSM[Double](values, 3, 2, colPtrs, rowIndices)
    val mat = Matrices.fromBreeze(breeze).asInstanceOf[SparseMatrix]
    assert(mat.numRows === breeze.rows)
    assert(mat.numCols === breeze.cols)
    assert(mat.values.eq(breeze.data), "should not copy data")
    val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[SparseMatrix]
    assert(matTransposed.numRows === breeze.cols)
    assert(matTransposed.numCols === breeze.rows)
    assert(!matTransposed.values.eq(breeze.data), "has to copy data")
  }
} 
Example 139
Source File: CoordinateMatrixSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.linalg.distributed

import breeze.linalg.{DenseMatrix => BDM}

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.linalg.Vectors

    val blockMat = mat.toBlockMatrix(2, 2)
    assert(blockMat.numRows() === m)
    assert(blockMat.numCols() === n)
    assert(blockMat.toBreeze() === mat.toBreeze())

    intercept[IllegalArgumentException] {
      mat.toBlockMatrix(-1, 2)
    }
    intercept[IllegalArgumentException] {
      mat.toBlockMatrix(2, 0)
    }
  }
} 
Example 140
Source File: BreezeVectorConversionSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.linalg

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}

import org.apache.spark.SparkFunSuite


class BreezeVectorConversionSuite extends SparkFunSuite {

  val arr = Array(0.1, 0.2, 0.3, 0.4)
  val n = 20
  val indices = Array(0, 3, 5, 10, 13)
  val values = Array(0.1, 0.5, 0.3, -0.8, -1.0)

  test("dense to breeze") {//密集矩阵转换breeze矩阵
    val vec = Vectors.dense(arr)
    assert(vec.toBreeze === new BDV[Double](arr))
  }

  test("sparse to breeze") {//稀疏矩阵转换breeze矩阵
    val vec = Vectors.sparse(n, indices, values)
    assert(vec.toBreeze === new BSV[Double](indices, values, n))
  }

  test("dense breeze to vector") {//密集breeze矩阵转换向量
    val breeze = new BDV[Double](arr)
    val vec = Vectors.fromBreeze(breeze).asInstanceOf[DenseVector]
    assert(vec.size === arr.length)
    assert(vec.values.eq(arr), "should not copy data")
  }

  test("sparse breeze to vector") {//稀疏breeze矩阵转换向量
    val breeze = new BSV[Double](indices, values, n)
    val vec = Vectors.fromBreeze(breeze).asInstanceOf[SparseVector]
    assert(vec.size === n)
    assert(vec.indices.eq(indices), "should not copy data")
    assert(vec.values.eq(values), "should not copy data")
  }

  test("sparse breeze with partially-used arrays to vector") {
    val activeSize = 3
    val breeze = new BSV[Double](indices, values, activeSize, n)
    val vec = Vectors.fromBreeze(breeze).asInstanceOf[SparseVector]
    assert(vec.size === n)
    assert(vec.indices === indices.slice(0, activeSize))
    assert(vec.values === values.slice(0, activeSize))
  }
} 
Example 141
Source File: LabeledPointSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.regression

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors

class LabeledPointSuite extends SparkFunSuite {
  
        println(p+"||||"+LabeledPoint.parse(p.toString))
        assert(p === LabeledPoint.parse(p.toString))}
    }
  }

  test("parse labeled points with whitespaces") {//解析标记点的空格
    //标记点字符串的解析
    //LabeledPoint标记点是局部向量,向量可以是密集型或者稀疏型,每个向量会关联了一个标签(label)
    val point = LabeledPoint.parse("(0.0, [1.0, 2.0])")
    assert(point === LabeledPoint(0.0, Vectors.dense(1.0, 2.0)))
  }

  test("parse labeled points with v0.9 format") {//解析标记点的V0.9格式
    //默认密集向量,未指定标记默认1 
    val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0")
    //LabeledPoint标记点是局部向量,向量可以是密集型或者稀疏型,每个向量会关联了一个标签(label)
    assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0)))
  }
} 
Example 142
Source File: MLPairRDDFunctionsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.rdd

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.rdd.MLPairRDDFunctions._

class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {
  test("topByKey") {
    
    val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5,
      1), (3, 5)), 2)
      //
      .topByKey(5)
      //以k转换map数组
      .collectAsMap()

    assert(topMap.size === 3)
    assert(topMap(1) === Array(7, 6, 3, 2, 1))
    assert(topMap(3) === Array(7, 5, 2))
    assert(topMap(5) === Array(1))
  }
} 
Example 143
Source File: RDDFunctionsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.rdd

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.rdd.RDDFunctions._

class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("sliding") {//滑动
    val data = 0 until 6
    for (numPartitions <- 1 to 8) {
      val rdd = sc.parallelize(data, numPartitions)
      for (windowSize <- 1 to 6) {
        val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList
        val expected = data.sliding(windowSize).map(_.toList).toList
        assert(sliding === expected)
      }
      assert(rdd.sliding(7).collect().isEmpty,
          //应该返回一个空盘如果窗口大小大于物品的数量
        "Should return an empty RDD if the window size is greater than the number of items.")
    }
  }

  test("sliding with empty partitions") {//带空分区的滑动
    val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7))
    // Array(1, 2, 3, 4, 5, 6, 7)
    val rdd = sc.parallelize(data, data.length).flatMap(s => s)
    //data.length = 5
    assert(rdd.partitions.size === data.length)
    
    //设置数据平滑窗口
    val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq)
    //expected: Seq[Seq[Int]] = Stream(List(1, 2, 3), ?)
    val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq)
    assert(sliding === expected)
  }
} 
Example 144
Source File: TwitterStreamSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.twitter


import org.scalatest.BeforeAndAfter
import twitter4j.Status
import twitter4j.auth.{NullAuthorization, Authorization}

import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream

class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging {

  val batchDuration = Seconds(1)

  private val master: String = "local[2]"

  private val framework: String = this.getClass.getSimpleName

  test("twitter input stream") {
    val ssc = new StreamingContext(master, framework, batchDuration)
    val filters = Seq("filter1", "filter2")
    val authorization: Authorization = NullAuthorization.getInstance()

    // tests the API, does not actually test data receiving
    val test1: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, None)
    val test2: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, None, filters)
    val test3: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_AND_DISK_SER_2)
    val test4: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, Some(authorization))
    val test5: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, Some(authorization), filters)
    val test6: ReceiverInputDStream[Status] = TwitterUtils.createStream(
      ssc, Some(authorization), filters, StorageLevel.MEMORY_AND_DISK_SER_2)

    // Note that actually testing the data receiving is hard as authentication keys are
    // necessary for accessing Twitter live stream
    ssc.stop()
  }
} 
Example 145
Source File: FlumePollingStreamSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.flume

import java.net.InetSocketAddress

import scala.collection.JavaConversions._
import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
import scala.concurrent.duration._
import scala.language.postfixOps

import com.google.common.base.Charsets.UTF_8
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._

import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext}
import org.apache.spark.util.{ManualClock, Utils}

  private def testMultipleTimes(test: () => Unit): Unit = {
    var testPassed = false
    var attempt = 0
    while (!testPassed && attempt < maxAttempts) {
      try {
        test()
        testPassed = true
      } catch {
        case e: Exception if Utils.isBindCollision(e) =>
          logWarning("Exception when running flume polling test: " + e)
          attempt += 1
      }
    }
    assert(testPassed, s"Test failed after $attempt attempts!")
  }

  private def testFlumePolling(): Unit = {
    try {
      val port = utils.startSingleSink()

      writeAndVerify(Seq(port))
      utils.assertChannelsAreEmpty()
    } finally {
      utils.close()
    }
  }

  private def testFlumePollingMultipleHost(): Unit = {
    try {
      val ports = utils.startMultipleSinks()
      writeAndVerify(ports)
      utils.assertChannelsAreEmpty()
    } finally {
      utils.close()
    }
  }

  def writeAndVerify(sinkPorts: Seq[Int]): Unit = {
    // Set up the streaming context and input streams
    //设置流上下文和输入流
    val ssc = new StreamingContext(conf, batchDuration)
    val addresses = sinkPorts.map(port => new InetSocketAddress("localhost", port))
    val flumeStream: ReceiverInputDStream[SparkFlumeEvent] =
      FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK,
        utils.eventsPerBatch, 5)
    val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]]
      with SynchronizedBuffer[Seq[SparkFlumeEvent]]
    val outputStream = new TestOutputStream(flumeStream, outputBuffer)
    outputStream.register()

    ssc.start()
    try {
      utils.sendDatAndEnsureAllDataHasBeenReceived()
      val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
      clock.advance(batchDuration.milliseconds)

      // The eventually is required to ensure that all data in the batch has been processed.
      //最终需要确保批处理中的所有数据已被处理
      eventually(timeout(10 seconds), interval(100 milliseconds)) {
        val flattenOutputBuffer = outputBuffer.flatten
        val headers = flattenOutputBuffer.map(_.event.getHeaders.map {
          case kv => (kv._1.toString, kv._2.toString)
        }).map(mapAsJavaMap)
        val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8))
        utils.assertOutput(headers, bodies)
      }
    } finally {
      ssc.stop()
    }
  }

} 
Example 146
Source File: FlumeStreamSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.flume

import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.concurrent.duration._
import scala.language.postfixOps

import com.google.common.base.Charsets
import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression._
import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._

import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}

  private class CompressionChannelFactory(compressionLevel: Int)
    extends NioClientSocketChannelFactory {

    override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
      val encoder = new ZlibEncoder(compressionLevel)
      pipeline.addFirst("deflater", encoder)
      pipeline.addFirst("inflater", new ZlibDecoder())
      super.newChannel(pipeline)
    }
  }
} 
Example 147
Source File: ZeroMQStreamSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.zeromq

import akka.actor.SupervisorStrategy
import akka.util.ByteString
import akka.zeromq.Subscribe

import org.apache.spark.SparkFunSuite
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.dstream.ReceiverInputDStream

class ZeroMQStreamSuite extends SparkFunSuite {

  val batchDuration = Seconds(1)

  private val master: String = "local[2]"

  private val framework: String = this.getClass.getSimpleName

  test("zeromq input stream") {
    val ssc = new StreamingContext(master, framework, batchDuration)
    val publishUrl = "abc"
    val subscribe = new Subscribe(null.asInstanceOf[ByteString])
    val bytesToObjects = (bytes: Seq[ByteString]) => null.asInstanceOf[Iterator[String]]

    // tests the API, does not actually test data receiving
    val test1: ReceiverInputDStream[String] =
      ZeroMQUtils.createStream(ssc, publishUrl, subscribe, bytesToObjects)
    val test2: ReceiverInputDStream[String] = ZeroMQUtils.createStream(
      ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2)
    val test3: ReceiverInputDStream[String] = ZeroMQUtils.createStream(
      ssc, publishUrl, subscribe, bytesToObjects,
      StorageLevel.MEMORY_AND_DISK_SER_2, SupervisorStrategy.defaultStrategy)

    // TODO: Actually test data receiving
    ssc.stop()
  }
} 
Example 148
Source File: MQTTStreamSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.mqtt

import scala.concurrent.duration._
import scala.language.postfixOps

import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}

class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {

  private val batchDuration = Milliseconds(500)
  private val master = "local[2]"
  private val framework = this.getClass.getSimpleName
  private val topic = "def"

  private var ssc: StreamingContext = _
  private var mqttTestUtils: MQTTTestUtils = _

  before {
    ssc = new StreamingContext(master, framework, batchDuration)
    mqttTestUtils = new MQTTTestUtils
    mqttTestUtils.setup()
  }

  after {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }
    if (mqttTestUtils != null) {
      mqttTestUtils.teardown()
      mqttTestUtils = null
    }
  }

  test("mqtt input stream") {
    val sendMessage = "MQTT demo for spark streaming"
    val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic,
      StorageLevel.MEMORY_ONLY)

    @volatile var receiveMessage: List[String] = List()
    receiveStream.foreachRDD { rdd =>
      if (rdd.collect.length > 0) {
        receiveMessage = receiveMessage ::: List(rdd.first)
        receiveMessage
      }
    }

    ssc.start()

    // Retry it because we don't know when the receiver will start.
    eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
      mqttTestUtils.publishData(topic, sendMessage)
      assert(sendMessage.equals(receiveMessage(0)))
    }
    ssc.stop()
  }
} 
Example 149
Source File: KafkaStreamSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.kafka

import scala.collection.mutable
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Random

import kafka.serializer.StringDecoder
import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Eventually

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}

class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll {
  private var ssc: StreamingContext = _
  private var kafkaTestUtils: KafkaTestUtils = _

  override def beforeAll(): Unit = {
    kafkaTestUtils = new KafkaTestUtils
    kafkaTestUtils.setup()
  }

  override def afterAll(): Unit = {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }

    if (kafkaTestUtils != null) {
      kafkaTestUtils.teardown()
      kafkaTestUtils = null
    }
  }

  test("Kafka input stream") {//Kafka输入流
    val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
    ssc = new StreamingContext(sparkConf, Milliseconds(500))
    val topic = "topic1"
    val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
    kafkaTestUtils.createTopic(topic)
    kafkaTestUtils.sendMessages(topic, sent)

    val kafkaParams = Map("zookeeper.connect" -> kafkaTestUtils.zkAddress,
      "group.id" -> s"test-consumer-${Random.nextInt(10000)}",
      "auto.offset.reset" -> "smallest")

    val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
      ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY)
    val result = new mutable.HashMap[String, Long]() with mutable.SynchronizedMap[String, Long]
    stream.map(_._2).countByValue().foreachRDD { r =>
      val ret = r.collect()
      ret.toMap.foreach { kv =>
        val count = result.getOrElseUpdate(kv._1, 0) + kv._2
        result.put(kv._1, count)
      }
    }

    ssc.start()

    eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
      assert(sent === result)
    }
  }
} 
Example 150
Source File: KafkaClusterSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.kafka

import scala.util.Random

import kafka.common.TopicAndPartition
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.SparkFunSuite

class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll {
  private val topic = "kcsuitetopic" + Random.nextInt(10000)
  private val topicAndPartition = TopicAndPartition(topic, 0)
  //
  private var kc: KafkaCluster = null

  private var kafkaTestUtils: KafkaTestUtils = _

  override def beforeAll() {
    kafkaTestUtils = new KafkaTestUtils
    kafkaTestUtils.setup()

    kafkaTestUtils.createTopic(topic)
    kafkaTestUtils.sendMessages(topic, Map("a" -> 1))
    kc = new KafkaCluster(Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress))
  }

  override def afterAll() {
    if (kafkaTestUtils != null) {
      kafkaTestUtils.teardown()
      kafkaTestUtils = null
    }
  }

  test("metadata apis") {//元数据API
    val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition)
    val leaderAddress = s"${leader._1}:${leader._2}"
    assert(leaderAddress === kafkaTestUtils.brokerAddress, "didn't get leader")

    val parts = kc.getPartitions(Set(topic)).right.get
    assert(parts(topicAndPartition), "didn't get partitions")

    val err = kc.getPartitions(Set(topic + "BAD"))
    assert(err.isLeft, "getPartitions for a nonexistant topic should be an error")
  }

  test("leader offset apis") {//指挥者偏移API
    val earliest = kc.getEarliestLeaderOffsets(Set(topicAndPartition)).right.get
    assert(earliest(topicAndPartition).offset === 0, "didn't get earliest")

    val latest = kc.getLatestLeaderOffsets(Set(topicAndPartition)).right.get
    assert(latest(topicAndPartition).offset === 1, "didn't get latest")
  }

  test("consumer offset apis") {//消费者偏移API
    val group = "kcsuitegroup" + Random.nextInt(10000)

    val offset = Random.nextInt(10000)

    val set = kc.setConsumerOffsets(group, Map(topicAndPartition -> offset))
    assert(set.isRight, "didn't set consumer offsets")

    val get = kc.getConsumerOffsets(group, Set(topicAndPartition)).right.get
    assert(get(topicAndPartition) === offset, "didn't get consumer offsets")
  }
} 
Example 151
Source File: OrcTest.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.orc

import java.io.File

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.test.SQLTestUtils

private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite =>
  protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive
  protected val sqlContext = _sqlContext
  import sqlContext.implicits._
  import sqlContext.sparkContext

  
  protected def withOrcTable[T <: Product: ClassTag: TypeTag]
      (data: Seq[T], tableName: String)
      (f: => Unit): Unit = {
    withOrcDataFrame(data) { df =>
      sqlContext.registerDataFrameAsTable(df, tableName)
      withTempTable(tableName)(f)
    }
  }

  protected def makeOrcFile[T <: Product: ClassTag: TypeTag](
      data: Seq[T], path: File): Unit = {
    data.toDF().write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath)
  }

  protected def makeOrcFile[T <: Product: ClassTag: TypeTag](
      df: DataFrame, path: File): Unit = {
    df.write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath)
  }
} 
Example 152
Source File: ConcurrentHiveSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.hive.test.TestHiveContext
import org.scalatest.BeforeAndAfterAll

class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll {
  ignore("multiple instances not supported") {//不支持多个实例
    test("Multiple Hive Instances") {
      (1 to 10).map { i =>
        val ts =
          new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", new SparkConf()))
        ts.executeSql("SHOW TABLES").toRdd.collect()
        ts.executeSql("SELECT * FROM src").toRdd.collect()
        ts.executeSql("SHOW TABLES").toRdd.collect()
      }
    }
  }
} 
Example 153
Source File: FiltersSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.client

import scala.collection.JavaConversions._

import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.serde.serdeConstants

import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._


class FiltersSuite extends SparkFunSuite with Logging {
  private val shim = new Shim_v0_13

  private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test")
  private val varCharCol = new FieldSchema()
  varCharCol.setName("varchar")
  varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME)
  testTable.setPartCols(varCharCol :: Nil)
  //字符串过滤器
  filterTest("string filter",
    (a("stringcol", StringType) > Literal("test")) :: Nil,
    "stringcol > \"test\"")
  //字符串过滤器向后
  filterTest("string filter backwards",
    (Literal("test") > a("stringcol", StringType)) :: Nil,
    "\"test\" > stringcol")
  //int过滤器
  filterTest("int filter",
    (a("intcol", IntegerType) === Literal(1)) :: Nil,
    "intcol = 1")
  //int向后过滤
  filterTest("int filter backwards",
    (Literal(1) === a("intcol", IntegerType)) :: Nil,
    "1 = intcol")

  filterTest("int and string filter",
    (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil,
    "1 = intcol and \"a\" = strcol")

  filterTest("skip varchar",
    (Literal("") === a("varchar", StringType)) :: Nil,
    "")

  private def filterTest(name: String, filters: Seq[Expression], result: String) = {
    test(name){
      val converted = shim.convertFilters(testTable, filters)
      if (converted != result) {
        fail(
          s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'")
      }
    }
  }

  private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)()
} 
Example 154
Source File: SerializationSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer

class SerializationSuite extends SparkFunSuite {
  //HiveContext应该是可序列化的
  test("[SPARK-5840] HiveContext should be serializable") {
    val hiveContext = org.apache.spark.sql.hive.test.TestHive
    hiveContext.hiveconf
    val serializer = new JavaSerializer(new SparkConf()).newInstance()
    val bytes = serializer.serialize(hiveContext)
    val deSer = serializer.deserialize[AnyRef](bytes)
  }
} 
Example 155
Source File: ClasspathDependenciesSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import java.net.URL

import org.apache.spark.SparkFunSuite


class ClasspathDependenciesSuite extends SparkFunSuite {
  //ClassLoader就是用来动态加载class文件到内存当中用
  private val classloader = this.getClass.getClassLoader

  private def assertLoads(classname: String): Unit = {
    val resourceURL: URL = Option(findResource(classname)).getOrElse {
      fail(s"Class $classname not found as ${resourceName(classname)}")
    }

    logInfo(s"Class $classname at $resourceURL")
    classloader.loadClass(classname)
  }

  private def assertLoads(classes: String*): Unit = {
    classes.foreach(assertLoads)
  }

  private def findResource(classname: String): URL = {
    val resource = resourceName(classname)
    classloader.getResource(resource)
  }

  private def resourceName(classname: String): String = {
    classname.replace(".", "/") + ".class"
  }

  private def assertClassNotFound(classname: String): Unit = {
    Option(findResource(classname)).foreach { resourceURL =>
      fail(s"Class $classname found at $resourceURL")
    }

    intercept[ClassNotFoundException] {
      classloader.loadClass(classname)
    }
  }

  private def assertClassNotFound(classes: String*): Unit = {
    classes.foreach(assertClassNotFound)
  }

  private val KRYO = "com.esotericsoftware.kryo.Kryo"

  private val SPARK_HIVE = "org.apache.hive."
  private val SPARK_SHADED = "org.spark-project.hive.shaded."

  test("shaded Protobuf") {
    assertLoads(SPARK_SHADED + "com.google.protobuf.ServiceException")
  }

  test("hive-common") {
    assertLoads("org.apache.hadoop.hive.conf.HiveConf")
  }

  test("hive-exec") {
    assertLoads("org.apache.hadoop.hive.ql.CommandNeedRetryException")
  }

  private val STD_INSTANTIATOR = "org.objenesis.strategy.StdInstantiatorStrategy"

  test("unshaded kryo") {
    assertLoads(KRYO, STD_INSTANTIATOR)
  }

  test("Forbidden Dependencies") {
    assertClassNotFound(
      SPARK_HIVE + KRYO,
      SPARK_SHADED + KRYO,
      "org.apache.hive." + KRYO,
      "com.esotericsoftware.shaded." + STD_INSTANTIATOR,
      SPARK_HIVE + "com.esotericsoftware.shaded." + STD_INSTANTIATOR,
      "org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR
    )
  }

  test("parquet-hadoop-bundle") {
    assertLoads(
      "parquet.hadoop.ParquetOutputFormat",
      "parquet.hadoop.ParquetInputFormat"
    )
  }
} 
Example 156
Source File: CommitFailureTestRelationSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils


class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils {
  override def _sqlContext: SQLContext = TestHive
  private val sqlContext = _sqlContext

  // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose.
  //提交任务时,“CommitFailureTestSource”会为测试目的引发异常
  val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName
  //commitTask()失败应该回退到abortTask()
  test("SPARK-7684: commitTask() failure should fallback to abortTask()") {
    withTempPath { file =>
      // Here we coalesce partition number to 1 to ensure that only a single task is issued.  This
      // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary`
      // directory while committing/aborting the job.  See SPARK-8513 for more details.
      //这里我们将分区号合并为1,以确保只发出一个任务, 这个防止当FileOutputCommitter尝试删除`_temporary`时发生竞争条件
      //目录提交/中止作业, 有关详细信息,请参阅SPARK-8513
      val df = sqlContext.range(0, 10).coalesce(1)
      intercept[SparkException] {
        df.write.format(dataSourceName).save(file.getCanonicalPath)
      }

      val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
      assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
    }
  }
} 
Example 157
Source File: RandomDataGeneratorSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.types._


  def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = {
    val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType)
    val generator = RandomDataGenerator.forType(dataType, nullable, Some(33)).getOrElse {
      fail(s"Random data generator was not defined for $dataType")
    }
    if (nullable) {
      assert(Iterator.fill(100)(generator()).contains(null))
    } else {
      assert(Iterator.fill(100)(generator()).forall(_ != null))
    }
    for (_ <- 1 to 10) {
      val generatedValue = generator()
      toCatalyst(generatedValue)
    }
  }

  // Basic types:
  for (
    dataType <- DataTypeTestUtils.atomicTypes;
    nullable <- Seq(true, false)
    if !dataType.isInstanceOf[DecimalType]) {
    test(s"$dataType (nullable=$nullable)") {
      testRandomDataGeneration(dataType)
    }
  }

  for (
    arrayType <- DataTypeTestUtils.atomicArrayTypes
    if RandomDataGenerator.forType(arrayType.elementType, arrayType.containsNull).isDefined
  ) {
    test(s"$arrayType") {
      testRandomDataGeneration(arrayType)
    }
  }

  val atomicTypesWithDataGenerators =
    DataTypeTestUtils.atomicTypes.filter(RandomDataGenerator.forType(_).isDefined)

  // Complex types:
  for (
    keyType <- atomicTypesWithDataGenerators;
    valueType <- atomicTypesWithDataGenerators
    // Scala's BigDecimal.hashCode can lead to OutOfMemoryError on Scala 2.10 (see SI-6173) and
    // Spark can hit NumberFormatException errors when converting certain BigDecimals (SPARK-8802).
    // For these reasons, we don't support generation of maps with decimal keys.
    if !keyType.isInstanceOf[DecimalType]
  ) {
    val mapType = MapType(keyType, valueType)
    test(s"$mapType") {
      testRandomDataGeneration(mapType)
    }
  }

  for (
    colOneType <- atomicTypesWithDataGenerators;
    colTwoType <- atomicTypesWithDataGenerators
  ) {
    val structType = StructType(StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil)
    test(s"$structType") {
      testRandomDataGeneration(structType)
    }
  }

} 
Example 158
Source File: SqlParserSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.Command

private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command {
  override def output: Seq[Attribute] = Seq.empty
  override def children: Seq[LogicalPlan] = Seq.empty
}

private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser {
  protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST")

  override protected lazy val start: Parser[LogicalPlan] = set

  private lazy val set: Parser[LogicalPlan] =
    EXECUTE ~> ident ^^ {
      case fileName => TestCommand(fileName)
    }
}

private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser {
  protected val EXECUTE = Keyword("EXECUTE")

  override protected lazy val start: Parser[LogicalPlan] = set

  private lazy val set: Parser[LogicalPlan] =
    EXECUTE ~> ident ^^ {
      case fileName => TestCommand(fileName)
    }
}

class SqlParserSuite extends SparkFunSuite {

  test("test long keyword") {
    val parser = new SuperLongKeywordTestParser
    assert(TestCommand("NotRealCommand") ===
      parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand"))
  }

  test("test case insensitive") {
    val parser = new CaseInsensitiveTestParser
    assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand"))
    assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand"))
    assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand"))
  }
} 
Example 159
Source File: LogicalPlanSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._


class LogicalPlanSuite extends SparkFunSuite {
  private var invocationCount = 0
  private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
    case p: Project =>
      invocationCount += 1
      p
  }

  private val testRelation = LocalRelation()

  test("resolveOperator runs on operators") {
    invocationCount = 0
    val plan = Project(Nil, testRelation)
    plan resolveOperators function

    assert(invocationCount === 1)
  }

  test("resolveOperator runs on operators recursively") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan resolveOperators function

    assert(invocationCount === 2)
  }

  test("resolveOperator skips all ready resolved plans") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan.foreach(_.setAnalyzed())
    plan resolveOperators function

    assert(invocationCount === 0)
  }

  test("resolveOperator skips partially resolved plans") {
    invocationCount = 0
    val plan1 = Project(Nil, testRelation)
    val plan2 = Project(Nil, plan1)
    plan1.foreach(_.setAnalyzed())
    plan2 resolveOperators function

    assert(invocationCount === 1)
  }
} 
Example 160
Source File: SameResultSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.util._


class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")
    }
  }

  test("relations") {
    assertSameResult(testRelation, testRelation2)
  }

  test("projections") {
    assertSameResult(testRelation.select('a), testRelation2.select('a))
    assertSameResult(testRelation.select('b), testRelation2.select('b))
    assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b))
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a))

    assertSameResult(testRelation, testRelation2.select('a), result = false)
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), result = false)
  }

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))
  }

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
  }
} 
Example 161
Source File: NondeterministicSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite

class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper {
  test("MonotonicallyIncreasingID") {
    checkEvaluation(MonotonicallyIncreasingID(), 0L)
  }

  test("SparkPartitionID") {
    checkEvaluation(SparkPartitionID(), 0)
  }

  test("InputFileName") {
    checkEvaluation(InputFileName(), "")
  }
} 
Example 162
Source File: NullFunctionsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._


  test("coalesce") {
    testAllTypes { (value: Any, tpe: DataType) =>
      val lit = Literal.create(value, tpe)
      val nullLit = Literal.create(null, tpe)
      checkEvaluation(Coalesce(Seq(nullLit)), null)
      checkEvaluation(Coalesce(Seq(lit)), value)
      checkEvaluation(Coalesce(Seq(nullLit, lit)), value)
      checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value)
      checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value)
    }
  }

  test("AtLeastNNonNulls") {
    val mix = Seq(Literal("x"),
      Literal.create(null, StringType),
      Literal.create(null, DoubleType),
      Literal(Double.NaN),
      Literal(5f))

    val nanOnly = Seq(Literal("x"),
      Literal(10.0),
      Literal(Float.NaN),
      Literal(math.log(-2)),
      Literal(Double.MaxValue))

    val nullOnly = Seq(Literal("x"),
      Literal.create(null, DoubleType),
      Literal.create(null, DecimalType.USER_DEFAULT),
      Literal(Float.MaxValue),
      Literal(false))

    checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow)
    checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow)
    checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow)
    checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow)
    checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow)
    checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow)
  }
} 
Example 163
Source File: DecimalExpressionSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{LongType, DecimalType, Decimal}


class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
  //非标准的值
  test("UnscaledValue") {
    val d1 = Decimal("10.1")
    checkEvaluation(UnscaledValue(Literal(d1)), 101L)
    val d2 = Decimal(101, 3, 1)
    checkEvaluation(UnscaledValue(Literal(d2)), 101L)
    checkEvaluation(UnscaledValue(Literal.create(null, DecimalType(2, 1))), null)
  }
  //十进制
  test("MakeDecimal") {
    checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
    checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
  }
  //提高精度
  test("PromotePrecision") {
    val d1 = Decimal("10.1")
    checkEvaluation(PromotePrecision(Literal(d1)), d1)
    val d2 = Decimal(101, 3, 1)
    checkEvaluation(PromotePrecision(Literal(d2)), d2)
    checkEvaluation(PromotePrecision(Literal.create(null, DecimalType(2, 1))), null)
  }
  //检查溢出
  test("CheckOverflow") {
    val d1 = Decimal("10.1")
    checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10"))
    checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1)
    checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1)
    checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null)

    val d2 = Decimal(101, 3, 1)
    checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10"))
    checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2)
    checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2)
    checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null)

    checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null)
  }

} 
Example 164
Source File: LiteralExpressionSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._


class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("null") {
    checkEvaluation(Literal.create(null, BooleanType), null)
    checkEvaluation(Literal.create(null, ByteType), null)
    checkEvaluation(Literal.create(null, ShortType), null)
    checkEvaluation(Literal.create(null, IntegerType), null)
    checkEvaluation(Literal.create(null, LongType), null)
    checkEvaluation(Literal.create(null, FloatType), null)
    checkEvaluation(Literal.create(null, LongType), null)
    checkEvaluation(Literal.create(null, StringType), null)
    checkEvaluation(Literal.create(null, BinaryType), null)
    checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null)
    checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null)
    checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null)
    checkEvaluation(Literal.create(null, StructType(Seq.empty)), null)
  }

  test("boolean literals") {
    checkEvaluation(Literal(true), true)
    checkEvaluation(Literal(false), false)
  }

  test("int literals") {
    List(0, 1, Int.MinValue, Int.MaxValue).foreach { d =>
      checkEvaluation(Literal(d), d)
      checkEvaluation(Literal(d.toLong), d.toLong)
      checkEvaluation(Literal(d.toShort), d.toShort)
      checkEvaluation(Literal(d.toByte), d.toByte)
    }
    checkEvaluation(Literal(Long.MinValue), Long.MinValue)
    checkEvaluation(Literal(Long.MaxValue), Long.MaxValue)
  }

  test("double literals") {
    List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d =>
      checkEvaluation(Literal(d), d)
      checkEvaluation(Literal(d.toFloat), d.toFloat)
    }
    checkEvaluation(Literal(Double.MinValue), Double.MinValue)
    checkEvaluation(Literal(Double.MaxValue), Double.MaxValue)
    checkEvaluation(Literal(Float.MinValue), Float.MinValue)
    checkEvaluation(Literal(Float.MaxValue), Float.MaxValue)

  }

  test("string literals") {
    checkEvaluation(Literal(""), "")
    checkEvaluation(Literal("test"), "test")
    checkEvaluation(Literal("\0"), "\0")
  }

  test("sum two literals") {
    checkEvaluation(Add(Literal(1), Literal(1)), 2)
  }

  test("binary literals") {
    checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0))
    checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2))
  }

  test("decimal") {
    List(-0.0001, 0.0, 0.001, 1.2, 1.1111, 5).foreach { d =>
      checkEvaluation(Literal(Decimal(d)), Decimal(d))
      checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt))
      checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong))
      checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 3)),
        Decimal((d * 1000L).toLong, 10, 3))
      checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d))
      checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d))
    }
  }

  // TODO(davies): add tests for ArrayType, MapType and StructType
} 
Example 165
Source File: CollectionFunctionsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._


class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
  //数组和映射大小
  test("Array and Map Size") {
    val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
    val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
    val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType))

    checkEvaluation(Size(a0), 3)
    checkEvaluation(Size(a1), 0)
    checkEvaluation(Size(a2), 2)

    val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType))
    val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
    val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType))

    checkEvaluation(Size(m0), 2)
    checkEvaluation(Size(m1), 0)
    checkEvaluation(Size(m2), 1)

    checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null)
    checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
  }
  //数组排序
  test("Sort Array") {
    val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
    val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
    val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
    val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))

    checkEvaluation(new SortArray(a0), Seq(1, 2, 3))
    checkEvaluation(new SortArray(a1), Seq[Integer]())
    checkEvaluation(new SortArray(a2), Seq("a", "b"))
    checkEvaluation(new SortArray(a3), Seq(null, "a", "b"))
    checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3))
    checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]())
    checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b"))
    checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b"))
    checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1))
    checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]())
    checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a"))
    checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null))

    checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
  }
  //数组包含
  test("Array contains") {
    val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
    val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
    val a2 = Literal.create(Seq(null), ArrayType(LongType))
    val a3 = Literal.create(null, ArrayType(StringType))

    checkEvaluation(ArrayContains(a0, Literal(1)), true)
    checkEvaluation(ArrayContains(a0, Literal(0)), false)
    checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null)

    checkEvaluation(ArrayContains(a1, Literal("")), true)
    checkEvaluation(ArrayContains(a1, Literal("a")), null)
    checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null)

    checkEvaluation(ArrayContains(a2, Literal(1L)), null)
    checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null)

    checkEvaluation(ArrayContains(a3, Literal("")), null)
    checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
  }
} 
Example 166
Source File: RandomSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
//随机测试套件
class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("random") {
    checkDoubleEvaluation(Rand(30), 0.7363714192755834 +- 0.001)
    checkDoubleEvaluation(Randn(30), 0.5181478766595276 +- 0.001)
  }
  //代码生成长种子
  test("SPARK-9127 codegen with long seed") {
    checkDoubleEvaluation(Rand(5419823303878592871L), 0.4061913198963727 +- 0.001)
    checkDoubleEvaluation(Randn(5419823303878592871L), -0.24417152005343168 +- 0.001)
  }
} 
Example 167
Source File: MiscFunctionsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.codec.digest.DigestUtils

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}

class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("md5") {
    checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932")
    checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      "6ac1e56bc78f031059be7be854522c4c")
    checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType)
  }

  test("sha1") {
    checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")
    checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      "5d211bad8f4ee70e16c7d343a838fc344a1ed961")
    checkEvaluation(Sha1(Literal.create(null, BinaryType)), null)
    checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709")
    checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType)
  }

  test("sha2") {
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
    checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
      DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
    // unsupported bit length
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null)
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null)
  }

  test("crc32") {
    checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L)
    checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      2180413220L)
    checkEvaluation(Crc32(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
  }
} 
Example 168
Source File: AttributeSetSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)
  }
  //按ID检查而不是名称
  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bUpper) === false)
    assert(aSet.contains(bLower) === false)
  }

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)
  }

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)
    assert(addSet.contains(aUpper))
    assert(addSet.contains(aLower))
    assert(addSet.contains(bUpper))
    assert(addSet.contains(bLower))
  }

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)
  }

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)
  }

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
  }
} 
Example 169
Source File: GenerateUnsafeRowJoinerSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.codegen

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.types._


class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {

  private val fixed = Seq(IntegerType)
  private val variable = Seq(IntegerType, StringType)
  //简单的固定宽度类型
  test("simple fixed width types") {
    testConcat(0, 0, fixed)
    testConcat(0, 1, fixed)
    testConcat(1, 0, fixed)
    testConcat(64, 0, fixed)
    testConcat(0, 64, fixed)
    testConcat(64, 64, fixed)
  }
  //随机化的固定宽度类型
  test("randomized fix width types") {
    for (i <- 0 until 20) {
      testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed)
    }
  }
  //简单变量宽度类型
  test("simple variable width types") {
    testConcat(0, 0, variable)
    testConcat(0, 1, variable)
    testConcat(1, 0, variable)
    testConcat(64, 0, variable)
    testConcat(0, 64, variable)
    testConcat(64, 64, variable)
  }
  //随机变量宽度类型
  test("randomized variable width types") {
    for (i <- 0 until 10) {
      testConcatOnce(Random.nextInt(100), Random.nextInt(100), variable)
    }
  }

  private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]): Unit = {
    for (i <- 0 until 10) {
      testConcatOnce(numFields1, numFields2, candidateTypes)
    }
  }

  private def testConcatOnce(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]) {
    info(s"schema size $numFields1, $numFields2")
    val schema1 = RandomDataGenerator.randomSchema(numFields1, candidateTypes)
    val schema2 = RandomDataGenerator.randomSchema(numFields2, candidateTypes)

    // Create the converters needed to convert from external row to internal row and to UnsafeRows.
    //创建从外部行转换为内部行和UnsafeRows所需的转换器
    val internalConverter1 = CatalystTypeConverters.createToCatalystConverter(schema1)
    val internalConverter2 = CatalystTypeConverters.createToCatalystConverter(schema2)
    val converter1 = UnsafeProjection.create(schema1)
    val converter2 = UnsafeProjection.create(schema2)

    // Create the input rows, convert them into UnsafeRows.
    //创建输入行,将它们转换成UnsafeRows
    val extRow1 = RandomDataGenerator.forType(schema1, nullable = false).get.apply()
    val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply()
    val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow])
    val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow])

    // Run the joiner.
    val mergedSchema = StructType(schema1 ++ schema2)
    val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
    val output = concater.join(row1, row2)

    // Test everything equals ...
    for (i <- mergedSchema.indices) {
      if (i < schema1.size) {
        assert(output.isNullAt(i) === row1.isNullAt(i))
        if (!output.isNullAt(i)) {
          assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType))
        }
      } else {
        assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size))
        if (!output.isNullAt(i)) {
          assert(output.get(i, mergedSchema(i).dataType) ===
            row2.get(i - schema1.size, mergedSchema(i).dataType))
        }
      }
    }
  }

} 
Example 170
Source File: GeneratedProjectionSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


class GeneratedProjectionSuite extends SparkFunSuite {
  //在更宽的桌子上产生预测
  test("generated projections on wider table") {
    val N = 1000
    val wideRow1 = new GenericInternalRow((1 to N).toArray[Any])
    val schema1 = StructType((1 to N).map(i => StructField("", IntegerType)))
    val wideRow2 = new GenericInternalRow(
      (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any])
    val schema2 = StructType((1 to N).map(i => StructField("", StringType)))
    val joined = new JoinedRow(wideRow1, wideRow2)
    val joinedSchema = StructType(schema1 ++ schema2)
    val nested = new JoinedRow(InternalRow(joined, joined), joined)
    val nestedSchema = StructType(
      Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema)

    // test generated UnsafeProjection
    val unsafeProj = UnsafeProjection.create(nestedSchema)
    val unsafe: UnsafeRow = unsafeProj(nested)
    (0 until N).foreach { i =>
      val s = UTF8String.fromString((i + 1).toString)
      assert(i + 1 === unsafe.getInt(i + 2))
      assert(s === unsafe.getUTF8String(i + 2 + N))
      assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i))
      assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N))
      assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i))
      assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N))
    }

    // test generated SafeProjection
    val safeProj = FromUnsafeProjection(nestedSchema)
    val result = safeProj(unsafe)
    // Can't compare GenericInternalRow with JoinedRow directly
    (0 until N).foreach { i =>
      val r = i + 1
      val s = UTF8String.fromString((i + 1).toString)
      assert(r === result.getInt(i + 2))
      assert(s === result.getUTF8String(i + 2 + N))
      assert(r === result.getStruct(0, N * 2).getInt(i))
      assert(s === result.getStruct(0, N * 2).getUTF8String(i + N))
      assert(r === result.getStruct(1, N * 2).getInt(i))
      assert(s === result.getStruct(1, N * 2).getUTF8String(i + N))
    }

    // test generated MutableProjection
    val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) =>
      BoundReference(i, f.dataType, true)
    }
    val mutableProj = GenerateMutableProjection.generate(exprs)()
    val row1 = mutableProj(result)
    assert(result === row1)
    val row2 = mutableProj(result)
    assert(result === row2)
  }

  test("generated unsafe projection with array of binary") {
    val row = InternalRow(
      Array[Byte](1, 2),
      new GenericArrayData(Array(Array[Byte](1, 2), null, Array[Byte](3, 4))))
    val fields = (BinaryType :: ArrayType(BinaryType) :: Nil).toArray[DataType]

    val unsafeProj = UnsafeProjection.create(fields)
    val unsafeRow: UnsafeRow = unsafeProj(row)
    assert(java.util.Arrays.equals(unsafeRow.getBinary(0), Array[Byte](1, 2)))
    assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(0), Array[Byte](1, 2)))
    assert(unsafeRow.getArray(1).isNullAt(1))
    assert(unsafeRow.getArray(1).getBinary(1) === null)
    assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(2), Array[Byte](3, 4)))

    val safeProj = FromUnsafeProjection(fields)
    val row2 = safeProj(unsafeRow)
    assert(row2 === row)
  }
} 
Example 171
Source File: RuleExecutorSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.trees

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}

class RuleExecutorSuite extends SparkFunSuite {
  object DecrementLiterals extends Rule[Expression] {
    def apply(e: Expression): Expression = e transform {
      case IntegerLiteral(i) if i > 0 => Literal(i - 1)
    }
  }

  test("only once") {
    object ApplyOnce extends RuleExecutor[Expression] {
      val batches = Batch("once", Once, DecrementLiterals) :: Nil
    }

    assert(ApplyOnce.execute(Literal(10)) === Literal(9))
  }

  test("to fixed point") {
    object ToFixedPoint extends RuleExecutor[Expression] {
      val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil
    }

    assert(ToFixedPoint.execute(Literal(10)) === Literal(0))
  }

  test("to maxIterations") {
    object ToFixedPoint extends RuleExecutor[Expression] {
      val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil
    }

    assert(ToFixedPoint.execute(Literal(100)) === Literal(90))
  }
} 
Example 172
Source File: NumberConverterSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.NumberConverter.convert
import org.apache.spark.unsafe.types.UTF8String


class NumberConverterSuite extends SparkFunSuite {

  private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = {
    assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) ===
      UTF8String.fromString(expected))
  }
  //转换
  test("convert") {
    checkConv("3", 10, 2, "11")
    checkConv("-15", 10, -16, "-F")
    checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
    checkConv("big", 36, 16, "3A48")
    checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
    checkConv("11abc", 10, 16, "B")
  }

} 
Example 173
Source File: MetadataSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.json4s.jackson.JsonMethods.parse

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{MetadataBuilder, Metadata}

class MetadataSuite extends SparkFunSuite {

  val baseMetadata = new MetadataBuilder()
    .putString("purpose", "ml")
    .putBoolean("isBase", true)
    .build()

  val summary = new MetadataBuilder()
    .putLong("numFeatures", 10L)
    .build()

  val age = new MetadataBuilder()
    .putString("name", "age")
    .putLong("index", 1L)
    .putBoolean("categorical", false)
    .putDouble("average", 45.0)
    .build()

  val gender = new MetadataBuilder()
    .putString("name", "gender")
    .putLong("index", 5)
    .putBoolean("categorical", true)
    .putStringArray("categories", Array("male", "female"))
    .build()

  val metadata = new MetadataBuilder()
    .withMetadata(baseMetadata)
    .putBoolean("isBase", false) // overwrite an existing key
    .putMetadata("summary", summary)
    .putLongArray("long[]", Array(0L, 1L))
    .putDoubleArray("double[]", Array(3.0, 4.0))
    .putBooleanArray("boolean[]", Array(true, false))
    .putMetadataArray("features", Array(age, gender))
    .build()
  //元数据构建器和getter
  test("metadata builder and getters") {
    assert(age.contains("summary") === false)
    assert(age.contains("index") === true)
    assert(age.getLong("index") === 1L)
    assert(age.contains("average") === true)
    assert(age.getDouble("average") === 45.0)
    assert(age.contains("categorical") === true)
    assert(age.getBoolean("categorical") === false)
    assert(age.contains("name") === true)
    assert(age.getString("name") === "age")
    assert(metadata.contains("purpose") === true)
    assert(metadata.getString("purpose") === "ml")
    assert(metadata.contains("isBase") === true)
    assert(metadata.getBoolean("isBase") === false)
    assert(metadata.contains("summary") === true)
    assert(metadata.getMetadata("summary") === summary)
    assert(metadata.contains("long[]") === true)
    assert(metadata.getLongArray("long[]").toSeq === Seq(0L, 1L))
    assert(metadata.contains("double[]") === true)
    assert(metadata.getDoubleArray("double[]").toSeq === Seq(3.0, 4.0))
    assert(metadata.contains("boolean[]") === true)
    assert(metadata.getBooleanArray("boolean[]").toSeq === Seq(true, false))
    assert(gender.contains("categories") === true)
    assert(gender.getStringArray("categories").toSeq === Seq("male", "female"))
    assert(metadata.contains("features") === true)
    assert(metadata.getMetadataArray("features").toSeq === Seq(age, gender))
  }
  //元数据的JSON转换
  test("metadata json conversion") {
    val json = metadata.json
    withClue("toJson must produce a valid JSON string") {
      parse(json)
    }
    val parsed = Metadata.fromJson(json)
    assert(parsed === metadata)
    assert(parsed.## === metadata.##)
  }
} 
Example 174
Source File: StringUtilsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.StringUtils._

class StringUtilsSuite extends SparkFunSuite {
  //像正则表达式一样转义
  test("escapeLikeRegex") {
    assert(escapeLikeRegex("abdef") === "(?s)\\Qa\\E\\Qb\\E\\Qd\\E\\Qe\\E\\Qf\\E")
    assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E_.\\Qb\\E")
    assert(escapeLikeRegex("a_%b") === "(?s)\\Qa\\E..*\\Qb\\E")
    assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*%\\Qb\\E")
    assert(escapeLikeRegex("a%") === "(?s)\\Qa\\E.*")
    assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E")
    assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E")
  }
} 
Example 175
Source File: CatalystTypeConvertersSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class CatalystTypeConvertersSuite extends SparkFunSuite {

  private val simpleTypes: Seq[DataType] = Seq(
    StringType,
    DateType,
    BooleanType,
    ByteType,
    ShortType,
    IntegerType,
    LongType,
    FloatType,
    DoubleType,
    DecimalType.SYSTEM_DEFAULT,
    DecimalType.USER_DEFAULT)
  //行空处理
  test("null handling in rows") {
    val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t)))
    val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema)
    val convertToScala = CatalystTypeConverters.createToScalaConverter(schema)

    val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null))
    assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow)
  }
  //个别值的空处理
  test("null handling for individual values") {
    for (dataType <- simpleTypes) {
      assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null)
    }
  }
  //convertToCatalyst中的选项处理
  test("option handling in convertToCatalyst") {
    // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with
    // createToCatalystConverter but it may not actually matter as this is only called internally
    // in a handful of places where we don't expect to receive Options.
    assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123))
  }

  test("option handling in createToCatalystConverter") {
    assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)
  }
} 
Example 176
Source File: SQLContextSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
//SQL上下文测试套件
class SQLContextSuite extends SparkFunSuite with SharedSQLContext {

  override def afterAll(): Unit = {
    try {
      SQLContext.setLastInstantiatedContext(ctx)
    } finally {
      super.afterAll()
    }
  }

  test("getOrCreate instantiates SQLContext") {//获取或创建实例化SQL上下文
    SQLContext.clearLastInstantiatedContext()
    val sqlContext = SQLContext.getOrCreate(ctx.sparkContext)
    assert(sqlContext != null, "SQLContext.getOrCreate returned null")
    assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext),
      "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate")
  }

  test("getOrCreate gets last explicitly instantiated SQLContext") {//获得或创造获取最后的显式实例化SQL上下文
    SQLContext.clearLastInstantiatedContext()
    val sqlContext = new SQLContext(ctx.sparkContext)
    assert(SQLContext.getOrCreate(ctx.sparkContext) != null,
      "SQLContext.getOrCreate after explicitly created SQLContext returned null")
    assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext),
      "SQLContext.getOrCreate after explicitly created SQLContext did not return the context")
  }
} 
Example 177
Source File: SQLExecutionSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import java.util.Properties

import scala.collection.parallel.CompositeThrowable

import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.SQLContext
//SQL执行测试套件
class SQLExecutionSuite extends SparkFunSuite {
  //并发查询执行
  test("concurrent query execution (SPARK-10548)") {
    // Try to reproduce the issue with the old SparkContext
    //尝试旧sparkcontext再现问题
    val conf = new SparkConf()
      .setMaster("local[*]")
      .setAppName("test")
    val badSparkContext = new BadSparkContext(conf)
    try {
      testConcurrentQueryExecution(badSparkContext)
      fail("unable to reproduce SPARK-10548")
    } catch {
      case e: IllegalArgumentException =>
        assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY))
    } finally {
      badSparkContext.stop()
    }

    // Verify that the issue is fixed with the latest SparkContext
    //验证问题是固定的最新sparkcontext
    val goodSparkContext = new SparkContext(conf)
    try {
      testConcurrentQueryExecution(goodSparkContext)
    } finally {
      goodSparkContext.stop()
    }
  }

  
private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) {
  protected[spark] override val localProperties = new InheritableThreadLocal[Properties] {
    override protected def childValue(parent: Properties): Properties = new Properties(parent)
    override protected def initialValue(): Properties = new Properties()
  }
} 
Example 178
Source File: NullableColumnBuilderSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.columnar

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types._
//测试空列生成器
class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType])
  extends BasicColumnBuilder[JvmType](new NoopColumnStats, columnType)
  with NullableColumnBuilder

object TestNullableColumnBuilder {
  def apply[JvmType](columnType: ColumnType[JvmType], initialSize: Int = 0)
    : TestNullableColumnBuilder[JvmType] = {
    val builder = new TestNullableColumnBuilder(columnType)
    builder.initialize(initialSize)
    builder
  }
}
//空列生成器套件
class NullableColumnBuilderSuite extends SparkFunSuite {
  import ColumnarTestUtils._

  Seq(
    BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE,
    STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType)))
    .foreach {
    testNullableColumnBuilder(_)
  }

  def testNullableColumnBuilder[JvmType](
      columnType: ColumnType[JvmType]): Unit = {
    //stripSuffix去掉<string>字串中结尾的字符
    val typeName = columnType.getClass.getSimpleName.stripSuffix("$")

    test(s"$typeName column builder: empty column") {//空列
      val columnBuilder = TestNullableColumnBuilder(columnType)
      val buffer = columnBuilder.build()

      assertResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
      assertResult(0, "Wrong null count")(buffer.getInt())
      assert(!buffer.hasRemaining)
    }

    test(s"$typeName column builder: buffer size auto growth") {//缓存大小自动增长
      val columnBuilder = TestNullableColumnBuilder(columnType)
      val randomRow = makeRandomRow(columnType)

      (0 until 4).foreach { _ =>
        columnBuilder.appendFrom(randomRow, 0)
      }

      val buffer = columnBuilder.build()

      assertResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
      assertResult(0, "Wrong null count")(buffer.getInt())
    }
  //列生成器:空值
    test(s"$typeName column builder: null values") {
      val columnBuilder = TestNullableColumnBuilder(columnType)
      val randomRow = makeRandomRow(columnType)
      val nullRow = makeNullRow(1)

      (0 until 4).foreach { _ =>
        columnBuilder.appendFrom(randomRow, 0)
        columnBuilder.appendFrom(nullRow, 0)
      }

      val buffer = columnBuilder.build()

      assertResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
      assertResult(4, "Wrong null count")(buffer.getInt())

      // For null positions 空位置
      (1 to 7 by 2).foreach(assertResult(_, "Wrong null position")(buffer.getInt()))

      // For non-null values 对于非空值
      (0 until 4).foreach { _ =>
        val actual = if (columnType.isInstanceOf[GENERIC]) {
          SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]])
        } else {
          columnType.extract(buffer)
        }

        assert(actual === randomRow.get(0, columnType.dataType),
          "Extracted value didn't equal to the original one")
      }

      assert(!buffer.hasRemaining)
    }
  }
} 
Example 179
Source File: ColumnStatsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.columnar

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types._
//列统计测试套件
class ColumnStatsSuite extends SparkFunSuite {
  testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0))
  testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0))
  testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0))
  testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0))
  testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0))
  testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0))
  testColumnStats(classOf[TimestampColumnStats], TIMESTAMP,
    createRow(Long.MaxValue, Long.MinValue, 0))
  testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0))
  testColumnStats(classOf[DoubleColumnStats], DOUBLE,
    createRow(Double.MaxValue, Double.MinValue, 0))
  testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0))
  testDecimalColumnStats(createRow(null, null, 0))

  def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray)
//测试列统计
  def testColumnStats[T <: AtomicType, U <: ColumnStats](
      columnStatsClass: Class[U],
      columnType: NativeColumnType[T],
      initialStatistics: GenericInternalRow): Unit = {

    val columnStatsName = columnStatsClass.getSimpleName

    test(s"$columnStatsName: empty") {
      val columnStats = columnStatsClass.newInstance()
      columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
        case (actual, expected) => assert(actual === expected)
      }
    }

    test(s"$columnStatsName: non-empty") {//非空
      import org.apache.spark.sql.columnar.ColumnarTestUtils._

      val columnStats = columnStatsClass.newInstance()
      val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
      rows.foreach(columnStats.gatherStats(_, 0))

      val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType])
      val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
      val stats = columnStats.collectedStatistics

      assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
      assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
      assertResult(10, "Wrong null count")(stats.values(2))
      assertResult(20, "Wrong row count")(stats.values(3))
      assertResult(stats.values(4), "Wrong size in bytes") {
        rows.map { row =>
          if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
        }.sum
      }
    }
  }
  //测试十进制列统计
  def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](
      initialStatistics: GenericInternalRow): Unit = {

    val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName
    val columnType = FIXED_DECIMAL(15, 10)

    test(s"$columnStatsName: empty") {
      val columnStats = new FixedDecimalColumnStats(15, 10)
      columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
        case (actual, expected) => assert(actual === expected)
      }
    }

    test(s"$columnStatsName: non-empty") {//非空
      import org.apache.spark.sql.columnar.ColumnarTestUtils._

      val columnStats = new FixedDecimalColumnStats(15, 10)
      val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
      rows.foreach(columnStats.gatherStats(_, 0))

      val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType])
      val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
      val stats = columnStats.collectedStatistics

      assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
      assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
      assertResult(10, "Wrong null count")(stats.values(2))
      assertResult(20, "Wrong row count")(stats.values(3))
      assertResult(stats.values(4), "Wrong size in bytes") {
        rows.map { row =>
          if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
        }.sum
      }
    }
  }
} 
Example 180
Source File: NullableColumnAccessorSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.columnar

import java.nio.ByteBuffer

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types.{StringType, ArrayType, DataType}
//试验可为空的列的访问
class TestNullableColumnAccessor[JvmType](
    buffer: ByteBuffer,
    columnType: ColumnType[JvmType])
  extends BasicColumnAccessor(buffer, columnType)
  with NullableColumnAccessor
//试验可为空的列的访问
object TestNullableColumnAccessor {
  def apply[JvmType](buffer: ByteBuffer, columnType: ColumnType[JvmType])
    : TestNullableColumnAccessor[JvmType] = {
    // Skips the column type ID
    buffer.getInt()
    new TestNullableColumnAccessor(buffer, columnType)
  }
}
//空列存取器套件
class NullableColumnAccessorSuite extends SparkFunSuite {
  import ColumnarTestUtils._

  Seq(
    BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE,
    STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType)))
    .foreach {
    testNullableColumnAccessor(_)
  }
  //试验可为空的列的访问
  def testNullableColumnAccessor[JvmType](
      columnType: ColumnType[JvmType]): Unit = {
    //stripSuffix去掉<string>字串中结尾的字符
    val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
    val nullRow = makeNullRow(1)
    //空值
    test(s"Nullable $typeName column accessor: empty column") {
      val builder = TestNullableColumnBuilder(columnType)
      val accessor = TestNullableColumnAccessor(builder.build(), columnType)
      assert(!accessor.hasNext)
    }
    //访问空值
    test(s"Nullable $typeName column accessor: access null values") {
      val builder = TestNullableColumnBuilder(columnType)
      val randomRow = makeRandomRow(columnType)

      (0 until 4).foreach { _ =>
        builder.appendFrom(randomRow, 0)
        builder.appendFrom(nullRow, 0)
      }

      val accessor = TestNullableColumnAccessor(builder.build(), columnType)
      val row = new GenericMutableRow(1)

      (0 until 4).foreach { _ =>
        assert(accessor.hasNext)
        accessor.extractTo(row, 0)
        assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType))

        assert(accessor.hasNext)
        accessor.extractTo(row, 0)
        assert(row.isNullAt(0))
      }

      assert(!accessor.hasNext)
    }
  }
} 
Example 181
Source File: ResolvedDataSourceSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.datasources.ResolvedDataSource
  //解析数据源测试套件
class ResolvedDataSourceSuite extends SparkFunSuite {

  test("jdbc") {
    assert(
        //解析JDBC数据源
      ResolvedDataSource.lookupDataSource("jdbc") ===
        //默认使用jdbc.DefaultSource类型
      classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource])
    assert(
      ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.jdbc") ===
      classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource])
    assert(
      ResolvedDataSource.lookupDataSource("org.apache.spark.sql.jdbc") ===
        classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource])
  }

  test("json") {
    assert(
        //解析json数据源
      ResolvedDataSource.lookupDataSource("json") ===
         //默认使用json.DefaultSource类型
      classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource])
    assert(
      ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.json") ===
        classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource])
    assert(
      ResolvedDataSource.lookupDataSource("org.apache.spark.sql.json") ===
        classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource])
  }

  test("parquet") {
    assert(
        //解析parquet数据源
      ResolvedDataSource.lookupDataSource("parquet") ===
        //默认使用parquet.DefaultSource类型
      classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource])
    assert(
      ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.parquet") ===
        classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource])
    assert(
      ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") ===
        classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource])
  }
} 
Example 182
Source File: FailureSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming

import java.io.File

import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkFunSuite, Logging}
import org.apache.spark.util.Utils


class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging {

  private val batchDuration: Duration = Milliseconds(1000)
  private val numBatches = 30
  private var directory: File = null

  before {
    directory = Utils.createTempDir()
  }

  after {
    if (directory != null) {
     //删除临时目录
      Utils.deleteRecursively(directory)
    }
    //停止所有活动实时流
    StreamingContext.getActive().foreach { _.stop() }
  }
  //多次失败map
  test("multiple failures with map") {
    MasterFailureTest.testMap(directory.getAbsolutePath, numBatches, batchDuration)
  }
  //多次失败updateStateByKey
  test("multiple failures with updateStateByKey") {
    MasterFailureTest.testUpdateStateByKey(directory.getAbsolutePath, numBatches, batchDuration)
  }
} 
Example 183
Source File: RateLimitedOutputStreamSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.util

import java.io.ByteArrayOutputStream
import java.util.concurrent.TimeUnit._

import org.apache.spark.SparkFunSuite

class RateLimitedOutputStreamSuite extends SparkFunSuite {

  private def benchmark[U](f: => U): Long = {
    val start = System.nanoTime
    f
    System.nanoTime - start
  }

  test("write") {//写
    val underlying = new ByteArrayOutputStream
    val data = "X" * 41000
    //desiredBytesPerSec 每秒所需的字节数
    val stream = new RateLimitedOutputStream(underlying, desiredBytesPerSec = 10000)
    val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) }

    val seconds = SECONDS.convert(elapsedNs, NANOSECONDS)
    assert(seconds >= 4, s"Seconds value ($seconds) is less than 4.")
    assert(seconds <= 30, s"Took more than 30 seconds ($seconds) to write data.")
    assert(underlying.toString("UTF-8") === data)
  }
} 
Example 184
Source File: RpcAddressSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rpc

import org.apache.spark.{SparkException, SparkFunSuite}

class RpcAddressSuite extends SparkFunSuite {

  test("hostPort") {//主机端口
    val address = RpcAddress("1.2.3.4", 1234)
    assert(address.host == "1.2.3.4")
    assert(address.port == 1234)
    assert(address.hostPort == "1.2.3.4:1234")
  }

  test("fromSparkURL") {//来自Spark URL
    val address = RpcAddress.fromSparkURL("spark://1.2.3.4:1234")
    assert(address.host == "1.2.3.4")
    assert(address.port == 1234)
  }

  test("fromSparkURL: a typo url") {//来自一个错误Spark URL
    val e = intercept[SparkException] {
      RpcAddress.fromSparkURL("spark://1.2. 3.4:1234")//中间有空格
    }
    assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage)
  }

  test("fromSparkURL: invalid scheme") {//来自一个Spark URL无效模式
    val e = intercept[SparkException] {
      RpcAddress.fromSparkURL("invalid://1.2.3.4:1234")
    }
    assert("Invalid master URL: invalid://1.2.3.4:1234" === e.getMessage)
  }

  test("toSparkURL") {//转换SparkURL格式
    val address = RpcAddress("1.2.3.4", 1234)
    assert(address.toSparkURL == "spark://1.2.3.4:1234")
  }
} 
Example 185
Source File: NettyBlockTransferServiceSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import org.apache.spark.network.BlockDataManager
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.mockito.Mockito.mock
import org.scalatest._

class NettyBlockTransferServiceSuite
  extends SparkFunSuite
  with BeforeAndAfterEach
  with ShouldMatchers {

  private var service0: NettyBlockTransferService = _
  private var service1: NettyBlockTransferService = _

  override def afterEach() {
    if (service0 != null) {
      service0.close()
      service0 = null
    }

    if (service1 != null) {
      service1.close()
      service1 = null
    }
  }

  test("can bind to a random port") {//可以绑定到一个随机端口
    service0 = createService(port = 0)
    service0.port should not be 0
  }

  test("can bind to two random ports") {//可以绑定到两个随机端口
    service0 = createService(port = 0)
    service1 = createService(port = 0)
    service0.port should not be service1.port
  }

  test("can bind to a specific port") {//可以绑定到一个特定的端口
    val port = 17634
    service0 = createService(port)
    service0.port should be >= port
    //在同时测试的情况下避免测试平等
    service0.port should be <= (port + 10) // avoid testing equality in case of simultaneous tests
  }
  //可以绑定到一个特定的端口两次和第二个增量
  test("can bind to a specific port twice and the second increments") {
    val port = 17634
    service0 = createService(port)
    service1 = createService(port)
    service0.port should be >= port
    service0.port should be <= (port + 10)
    service1.port should be (service0.port + 1)
  }

  private def createService(port: Int): NettyBlockTransferService = {
    val conf = new SparkConf()
      .set("spark.app.id", s"test-${getClass.getName}")
      .set("spark.blockManager.port", port.toString)
    val securityManager = new SecurityManager(conf)
    val blockDataManager = mock(classOf[BlockDataManager])
    val service = new NettyBlockTransferService(conf, securityManager, numCores = 1)
    service.init(blockDataManager)
    service
  }
} 
Example 186
Source File: SortShuffleWriterSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.shuffle.sort

import org.mockito.Mockito._

import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite}

class SortShuffleWriterSuite extends SparkFunSuite {

  import SortShuffleWriter._

  test("conditions for bypassing merge-sort") {
    val conf = new SparkConf(loadDefaults = false)
    val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS)
    val ord = implicitly[Ordering[Int]]

    // Numbers of partitions that are above and below the default bypassMergeThreshold    /
    //那上面和下面的默认bypassmergethreshold分区数
    val FEW_PARTITIONS = 50
    val MANY_PARTITIONS = 10000

    // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high
    //没有排序或聚合Shuffle
    assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None))
    assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None))

    // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions
    //Shuffle的排序或聚合,
    assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord)))
    assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None))
  }
} 
Example 187
Source File: PythonBroadcastSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import scala.io.Source

import java.io.{PrintWriter, File}

import org.scalatest.Matchers

import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.Utils

// This test suite uses SharedSparkContext because we need a SparkEnv in order to deserialize
// a PythonBroadcast:
class PythonBroadcastSuite extends SparkFunSuite with Matchers with SharedSparkContext {
  test("PythonBroadcast can be serialized with Kryo (SPARK-4882)") {
    val tempDir = Utils.createTempDir()
    val broadcastedString = "Hello, world!"
    def assertBroadcastIsValid(broadcast: PythonBroadcast): Unit = {
      val source = Source.fromFile(broadcast.path)
      val contents = source.mkString
      source.close()
      contents should be (broadcastedString)
    }
    try {
      val broadcastDataFile: File = {
        val file = new File(tempDir, "broadcastData")
        val printWriter = new PrintWriter(file)
        printWriter.write(broadcastedString)
        printWriter.close()
        file
      }
      val broadcast = new PythonBroadcast(broadcastDataFile.getAbsolutePath)
      assertBroadcastIsValid(broadcast)
      val conf = new SparkConf().set("spark.kryo.registrationRequired", "true")
      val deserializedBroadcast =
        Utils.clone[PythonBroadcast](broadcast, new KryoSerializer(conf).newInstance())
      assertBroadcastIsValid(deserializedBroadcast)
    } finally {
      Utils.deleteRecursively(tempDir)
    }
  }
} 
Example 188
Source File: PythonRDDSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.{ByteArrayOutputStream, DataOutputStream}

import org.apache.spark.SparkFunSuite

class PythonRDDSuite extends SparkFunSuite {
  //写大串给worker
  test("Writing large strings to the worker") {
    val input: List[String] = List("a"*100000)
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    PythonRDD.writeIteratorToStream(input.iterator, buffer)
  }
  //很好的处理null
  test("Handle nulls gracefully") {
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    // Should not have NPE when write an Iterator with null in it
    // The correctness will be tested in Python
    PythonRDD.writeIteratorToStream(Iterator("a", null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer)
    PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer)
    PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer)
    PythonRDD.writeIteratorToStream(
      Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer)
  }
} 
Example 189
Source File: SerDeUtilSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import org.apache.spark.{SharedSparkContext, SparkFunSuite}

class SerDeUtilSuite extends SparkFunSuite with SharedSparkContext {
  //将空对RDD转换为python不会引发异常
  test("Converting an empty pair RDD to python does not throw an exception (SPARK-5441)") {
    val emptyRdd = sc.makeRDD(Seq[(Any, Any)]())
    SerDeUtil.pairRDDToPython(emptyRdd, 10)
  }
  //将空python RDD转换为RDD不会引发异常
  test("Converting an empty python RDD to pair RDD does not throw an exception (SPARK-5441)") {
    val emptyRdd = sc.makeRDD(Seq[(Any, Any)]())
    val javaRdd = emptyRdd.toJavaRDD()
    val pythonRdd = SerDeUtil.javaToPython(javaRdd)
    SerDeUtil.pythonToPairRDD(pythonRdd, false)
  }
} 
Example 190
Source File: PythonRunnerSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy

import org.apache.spark.SparkFunSuite
import org.apache.spark.util.Utils

class PythonRunnerSuite extends SparkFunSuite {

  // Test formatting a single path to be added to the PYTHONPATH
  //测试格式化要添加到PYTHONPATH的单个路径
    test("format path") {
      assert(PythonRunner.formatPath("spark.py") === "spark.py")
    assert(PythonRunner.formatPath("file:/spark.py") === "/spark.py")
    assert(PythonRunner.formatPath("file:///spark.py") === "/spark.py")
    assert(PythonRunner.formatPath("local:/spark.py") === "/spark.py")
    assert(PythonRunner.formatPath("local:///spark.py") === "/spark.py")
    if (Utils.isWindows) {
      assert(PythonRunner.formatPath("file:/C:/a/b/spark.py", testWindows = true) ===
        "C:/a/b/spark.py")
      assert(PythonRunner.formatPath("C:\\a\\b\\spark.py", testWindows = true) ===
        "C:/a/b/spark.py")
      assert(PythonRunner.formatPath("C:\\a b\\spark.py", testWindows = true) ===
        "C:/a b/spark.py")
    }
    intercept[IllegalArgumentException] { PythonRunner.formatPath("one:two") }
    intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:s3:xtremeFS") }
    intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:/path/to/some.py") }
  }

  // Test formatting multiple comma-separated paths to be added to the PYTHONPATH
  test("format paths") {
    assert(PythonRunner.formatPaths("spark.py") === Array("spark.py"))
    assert(PythonRunner.formatPaths("file:/spark.py") === Array("/spark.py"))
    assert(PythonRunner.formatPaths("file:/app.py,local:/spark.py") ===
      Array("/app.py", "/spark.py"))
    assert(PythonRunner.formatPaths("me.py,file:/you.py,local:/we.py") ===
      Array("me.py", "/you.py", "/we.py"))
    if (Utils.isWindows) {
      assert(PythonRunner.formatPaths("C:\\a\\b\\spark.py", testWindows = true) ===
        Array("C:/a/b/spark.py"))
      assert(PythonRunner.formatPaths("C:\\free.py,pie.py", testWindows = true) ===
        Array("C:/free.py", "pie.py"))
      assert(PythonRunner.formatPaths("lovely.py,C:\\free.py,file:/d:/fry.py",
        testWindows = true) ===
        Array("lovely.py", "C:/free.py", "d:/fry.py"))
    }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("one:two,three") }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("two,three,four:five:six") }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("hdfs:/some.py,foo.py") }
    intercept[IllegalArgumentException] { PythonRunner.formatPaths("foo.py,hdfs:/some.py") }
  }
} 
Example 191
Source File: WorkerWatcherSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.SecurityManager
import org.apache.spark.rpc.{RpcAddress, RpcEnv}

class WorkerWatcherSuite extends SparkFunSuite {
  test("WorkerWatcher shuts down on valid disassociation") {//工作节点关闭有效分离
    val conf = new SparkConf()
    val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
    val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker")
    val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl)
    workerWatcher.setTesting(testing = true)
    rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
    workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234))
    assert(workerWatcher.isShutDown)
    rpcEnv.shutdown()
  }

  test("WorkerWatcher stays alive on invalid disassociation") {//无效断开连接
    val conf = new SparkConf()
    val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
    val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker")
    val otherRpcAddress = RpcAddress("4.3.2.1", 1234)
    val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl)
    workerWatcher.setTesting(testing = true)
    rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
    workerWatcher.onDisconnected(otherRpcAddress)
    assert(!workerWatcher.isShutDown)
    rpcEnv.shutdown()
  }
} 
Example 192
Source File: WorkerArgumentsTest.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker

import org.apache.spark.{SparkConf, SparkFunSuite}


class WorkerArgumentsTest extends SparkFunSuite {
//内存不能被设置为0时,命令行参数离开M或G
  test("Memory can't be set to 0 when cmd line args leave off M or G") {
    val conf = new SparkConf
    val args = Array("-m", "10000", "spark://localhost:0000  ")
    intercept[IllegalStateException] {
      new WorkerArguments(args, conf)
    }
  }

  //内存不能被设置为0时,spark_worker_memory env参数离开M或G
  test("Memory can't be set to 0 when SPARK_WORKER_MEMORY env property leaves off M or G") {
    val args = Array("spark://localhost:0000  ")

    class MySparkConf extends SparkConf(false) {
      override def getenv(name: String): String = {
        if (name == "SPARK_WORKER_MEMORY") "50000"
        else super.getenv(name)
      }

      override def clone: SparkConf = {
        new MySparkConf().setAll(getAll)
      }
    }
    val conf = new MySparkConf()
    intercept[IllegalStateException] {
      new WorkerArguments(args, conf)
    }
  }
  //当SPARK_WORKER_MEMORY env属性追加G时,内存正确设置
  test("Memory correctly set when SPARK_WORKER_MEMORY env property appends G") {
    val args = Array("spark://localhost:0000  ")

    class MySparkConf extends SparkConf(false) {
      override def getenv(name: String): String = {
        if (name == "SPARK_WORKER_MEMORY") "5G"
        else super.getenv(name)
      }

      override def clone: SparkConf = {
        new MySparkConf().setAll(getAll)
      }
    }
    val conf = new MySparkConf()
    val workerArgs = new WorkerArguments(args, conf)
    assert(workerArgs.memory === 5120)
  }
  //从附加到内存值的M的args正确设置内存
  test("Memory correctly set from args with M appended to memory value") {
    val conf = new SparkConf
    val args = Array("-m", "10000M", "spark://localhost:0000  ")

    val workerArgs = new WorkerArguments(args, conf)
    assert(workerArgs.memory === 10000)

  }

} 
Example 193
Source File: ExecutorRunnerTest.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.worker

import java.io.File

import scala.collection.JavaConversions._

import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState}
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
//Worker 通过持有 ExecutorRunner 对象来控制 CoarseGrainedExecutorBackend 的启停
class ExecutorRunnerTest extends SparkFunSuite {
  test("command includes appId") {//命令包括AppID
    val appId = "12345-worker321-9876"
    val conf = new SparkConf
    //System.getenv()和System.getProperties()的区别
    //System.getenv() 返回系统环境变量值 设置系统环境变量:当前登录用户主目录下的".bashrc"文件中可以设置系统环境变量
    //System.getProperties() 返回Java进程变量值 通过命令行参数的"-D"选项
    val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
    val appDesc = new ApplicationDescription("app name", Some(8), 500,
      Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl")
    //Worker 通过持有 ExecutorRunner 对象来控制 CoarseGrainedExecutorBackend 的启停
    val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", 123,
      "publicAddr", new File(sparkHome), new File("ooga"), "blah", conf, Seq("localDir"),
      ExecutorState.RUNNING)
  val builder = CommandUtils.buildProcessBuilder(
      appDesc.command, new SecurityManager(conf), 512, sparkHome, er.substituteVariables)
    assert(builder.command().last === appId)
  }
} 
Example 194
Source File: PagedTableSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import scala.xml.Node

import org.apache.spark.SparkFunSuite

class PagedDataSourceSuite extends SparkFunSuite {

  test("basic") {   
    val dataSource1 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
    assert(dataSource1.pageData(1) === PageData(3, (1 to 2)))

    val dataSource2 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
    assert(dataSource2.pageData(2) === PageData(3, (3 to 4)))

    val dataSource3 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
    assert(dataSource3.pageData(3) === PageData(3, Seq(5)))

    val dataSource4 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
    val e1 = intercept[IndexOutOfBoundsException] {
      dataSource4.pageData(4)
    }
    assert(e1.getMessage === "Page 4 is out of range. Please select a page number between 1 and 3.")

    val dataSource5 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
    val e2 = intercept[IndexOutOfBoundsException] {
      dataSource5.pageData(0)
    }
    assert(e2.getMessage === "Page 0 is out of range. Please select a page number between 1 and 3.")

  }
}

class PagedTableSuite extends SparkFunSuite {
  test("pageNavigation") {//页面导航
    // Create a fake PagedTable to test pageNavigation
    //创建一个假的PagedTable来测试pageNavigation
    val pagedTable = new PagedTable[Int] {
      override def tableId: String = ""

      override def tableCssClass: String = ""

      override def dataSource: PagedDataSource[Int] = null

      override def pageLink(page: Int): String = page.toString

      override def headers: Seq[Node] = Nil

      override def row(t: Int): Seq[Node] = Nil

      override def goButtonJavascriptFunction: (String, String) = ("", "")
    }

    assert(pagedTable.pageNavigation(1, 10, 1) === Nil)
    assert(
      (pagedTable.pageNavigation(1, 10, 2).head \\ "li").map(_.text.trim) === Seq("1", "2", ">"))
    assert(
      (pagedTable.pageNavigation(2, 10, 2).head \\ "li").map(_.text.trim) === Seq("<", "1", "2"))

    assert((pagedTable.pageNavigation(1, 10, 100).head \\ "li").map(_.text.trim) ===
      (1 to 10).map(_.toString) ++ Seq(">", ">>"))
    assert((pagedTable.pageNavigation(2, 10, 100).head \\ "li").map(_.text.trim) ===
      Seq("<") ++ (1 to 10).map(_.toString) ++ Seq(">", ">>"))

    assert((pagedTable.pageNavigation(100, 10, 100).head \\ "li").map(_.text.trim) ===
      Seq("<<", "<") ++ (91 to 100).map(_.toString))
    assert((pagedTable.pageNavigation(99, 10, 100).head \\ "li").map(_.text.trim) ===
      Seq("<<", "<") ++ (91 to 100).map(_.toString) ++ Seq(">"))

    assert((pagedTable.pageNavigation(11, 10, 100).head \\ "li").map(_.text.trim) ===
      Seq("<<", "<") ++ (11 to 20).map(_.toString) ++ Seq(">", ">>"))
    assert((pagedTable.pageNavigation(93, 10, 97).head \\ "li").map(_.text.trim) ===
      Seq("<<", "<") ++ (91 to 97).map(_.toString) ++ Seq(">"))
  }
}

private[spark] class SeqPagedDataSource[T](seq: Seq[T], pageSize: Int)
  extends PagedDataSource[T](pageSize) {

  override protected def dataSize: Int = seq.size

  override protected def sliceData(from: Int, to: Int): Seq[T] = seq.slice(from, to)
} 
Example 195
Source File: UIUtilsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ui

import scala.xml.Elem

import org.apache.spark.SparkFunSuite

class UIUtilsSuite extends SparkFunSuite {
  import UIUtils._

  test("makeDescription") {//标记描述
    verify(
      """test <a href="/link"> text </a>""",
      <span class="description-input">test <a href="/link"> text </a></span>,
      "Correctly formatted text with only anchors and relative links should generate HTML"
    )

    verify(
      """test <a href="/link" text </a>""",
      <span class="description-input">{"""test <a href="/link" text </a>"""}</span>,
      "Badly formatted text should make the description be treated as a streaming instead of HTML"
    )

    verify(
      """test <a href="link"> text </a>""",
      <span class="description-input">{"""test <a href="link"> text </a>"""}</span>,
      "Non-relative links should make the description be treated as a string instead of HTML"
    )

    verify(
      """test<a><img></img></a>""",
      <span class="description-input">{"""test<a><img></img></a>"""}</span>,
      "Non-anchor elements should make the description be treated as a string instead of HTML"
    )

    verify(
      """test <a href="/link"> text </a>""",
      <span class="description-input">test <a href="base/link"> text </a></span>,
      baseUrl = "base",
      errorMsg = "Base URL should be prepended to html links"
    )
  }

  private def verify(
      desc: String, expected: Elem, errorMsg: String = "", baseUrl: String = ""): Unit = {
    val generated = makeDescription(desc, baseUrl)
    assert(generated.sameElements(expected),
      s"\n$errorMsg\n\nExpected:\n$expected\nGenerated:\n$generated")
  }
} 
Example 196
Source File: GenericAvroSerializerSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.ByteBuffer

import com.esotericsoftware.kryo.io.{Output, Input}
import org.apache.avro.{SchemaBuilder, Schema}
import org.apache.avro.generic.GenericData.Record

import org.apache.spark.{SparkFunSuite, SharedSparkContext}

class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext {
  conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

  val schema : Schema = SchemaBuilder
    .record("testRecord").fields()
    .requiredString("data")
    .endRecord()
  val record = new Record(schema)
  record.put("data", "test data")

  test("schema compression and decompression") {//模式压缩与解压缩
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema))))
  }

  test("record serialization and deserialization") {//记录序列化和反序列化
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)

    val outputStream = new ByteArrayOutputStream()
    val output = new Output(outputStream)
    genericSer.serializeDatum(record, output)
    output.flush()
    output.close()

    val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
    assert(genericSer.deserializeDatum(input) === record)
  }
  //使用模式指纹以减少信息大小
  test("uses schema fingerprint to decrease message size") {
    val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema)

    val output = new Output(new ByteArrayOutputStream())

    val beginningNormalPosition = output.total()
    genericSerFull.serializeDatum(record, output)
    output.flush()
    val normalLength = output.total - beginningNormalPosition

    conf.registerAvroSchemas(schema)
    val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema)
    val beginningFingerprintPosition = output.total()
    genericSerFinger.serializeDatum(record, output)
    val fingerprintLength = output.total - beginningFingerprintPosition

    assert(fingerprintLength < normalLength)
  }

  test("caches previously seen schemas") {//缓存之前模式
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    val compressedSchema = genericSer.compress(schema)
    val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema))

    assert(compressedSchema.eq(genericSer.compress(schema)))
    assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema))))
  }
} 
Example 197
Source File: KryoSerializerResizableOutputSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.SparkContext
import org.apache.spark.LocalSparkContext
import org.apache.spark.SparkException


class KryoSerializerResizableOutputSuite extends SparkFunSuite {

  // trial and error showed this will not serialize with 1mb buffer
  //试验和错误不会序列化使用1MB的缓冲
  val x = (1 to 400000).toArray
  //kryo不可调整大小的输出缓冲区,应该在大数组失败
  test("kryo without resizable output buffer should fail on large array") {
    val conf = new SparkConf(false)
    conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    conf.set("spark.kryoserializer.buffer", "1m")
    conf.set("spark.kryoserializer.buffer.max", "1m")
    val sc = new SparkContext("local", "test", conf)
    intercept[SparkException](sc.parallelize(x).collect())
    LocalSparkContext.stop(sc)
  }
 //kryo不可调整大小的输出缓冲区,应该在大数组成功
  test("kryo with resizable output buffer should succeed on large array") {
    val conf = new SparkConf(false)
    conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    conf.set("spark.kryoserializer.buffer", "1m")
    conf.set("spark.kryoserializer.buffer.max", "2m")
    val sc = new SparkContext("local", "test", conf)
    assert(sc.parallelize(x).collect() === x)
    LocalSparkContext.stop(sc)
  }
} 
Example 198
Source File: ProactiveClosureSerializationSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import org.apache.spark.{SharedSparkContext, SparkException, SparkFunSuite}
import org.apache.spark.rdd.RDD


class UnserializableClass {
  def op[T](x: T): String = x.toString

  def pred[T](x: T): Boolean = x.toString.length % 2 == 0
}

class ProactiveClosureSerializationSuite extends SparkFunSuite with SharedSparkContext {

  def fixture: (RDD[String], UnserializableClass) = {
    (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass)
  }
//在一个活动的序列化异常,抛出预期的序列化异常
  test("throws expected serialization exceptions on actions") {
    val (data, uc) = fixture
    val ex = intercept[SparkException] {
      data.map(uc.op(_)).count()
    }
    assert(ex.getMessage.contains("Task not serializable"))
  }

  // There is probably a cleaner way to eliminate boilerplate here, but we're
  // iterating over a map from transformation names to functions that perform that
  // transformation on a given RDD, creating one test case for each
  //有可能是一个更清洁的方式来消除样板,
  for (transformation <-
      Map("map" -> xmap _,
          "flatMap" -> xflatMap _,
          "filter" -> xfilter _,
          "mapPartitions" -> xmapPartitions _,
          "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _)) {
    val (name, xf) = transformation

    test(s"$name transformations throw proactive serialization exceptions") {
      val (data, uc) = fixture
      val ex = intercept[SparkException] {
        xf(data, uc)
      }
      assert(ex.getMessage.contains("Task not serializable"),
        s"RDD.$name doesn't proactively throw NotSerializableException")
    }
  }

  private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] =
    x.map(y => uc.op(y))

  private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
    x.flatMap(y => Seq(uc.op(y)))

  private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] =
    x.filter(y => uc.pred(y))

  private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
    x.mapPartitions(_.map(y => uc.op(y)))

  private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
    x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y)))

} 
Example 199
Source File: MapStatusSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.spark.storage.BlockManagerId

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer

import scala.util.Random

class MapStatusSuite extends SparkFunSuite {

  test("compressSize") {//compress 压缩大小
    assert(MapStatus.compressSize(0L) === 0)
    assert(MapStatus.compressSize(1L) === 1)
    assert(MapStatus.compressSize(2L) === 8)
    assert(MapStatus.compressSize(10L) === 25)
    assert((MapStatus.compressSize(1000000L) & 0xFF) === 145)
    assert((MapStatus.compressSize(1000000000L) & 0xFF) === 218)
    // This last size is bigger than we can encode in a byte, so check that we just return 255
    //这最后一个大小字节编码,所以检查返回255
    assert((MapStatus.compressSize(1000000000000000000L) & 0xFF) === 255)
  }

  test("decompressSize") {//解压缩的大小
    assert(MapStatus.decompressSize(0) === 0)
    for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) {
      val size2 = MapStatus.decompressSize(MapStatus.compressSize(size))
      assert(size2 >= 0.99 * size && size2 <= 1.11 * size,
        "size " + size + " decompressed to " + size2 + ", which is out of range")
    }
  }
  //MapStatus 不应该报告非空块的大小为0
  test("MapStatus should never report non-empty blocks' sizes as 0") {
    import Math._
    for (
      numSizes <- Seq(1, 10, 100, 1000, 10000);
      mean <- Seq(0L, 100L, 10000L, Int.MaxValue.toLong);
      stddev <- Seq(0.0, 0.01, 0.5, 1.0)
    ) {
      val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean)
      val status = MapStatus(BlockManagerId("a", "b", 10), sizes)
      val status1 = compressAndDecompressMapStatus(status)
      for (i <- 0 until numSizes) {
        if (sizes(i) != 0) {
          val failureMessage = s"Failed with $numSizes sizes with mean=$mean, stddev=$stddev"
          assert(status.getSizeForBlock(i) !== 0, failureMessage)
          assert(status1.getSizeForBlock(i) !== 0, failureMessage)
        }
      }
    }
  }
//大型任务应该使用
  test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) {
    val sizes = Array.fill[Long](2001)(150L)
    val status = MapStatus(null, sizes)
    assert(status.isInstanceOf[HighlyCompressedMapStatus])
    assert(status.getSizeForBlock(10) === 150L)
    assert(status.getSizeForBlock(50) === 150L)
    assert(status.getSizeForBlock(99) === 150L)
    assert(status.getSizeForBlock(2000) === 150L)
  }
  //高度压缩的Map状态:估计的大小应该是平均非空块大小
  test("HighlyCompressedMapStatus: estimated size should be the average non-empty block size") {
    val sizes = Array.tabulate[Long](3000) { i => i.toLong }
    val avg = sizes.sum / sizes.filter(_ != 0).length
    val loc = BlockManagerId("a", "b", 10)
    val status = MapStatus(loc, sizes)
    val status1 = compressAndDecompressMapStatus(status)
    assert(status1.isInstanceOf[HighlyCompressedMapStatus])
    assert(status1.location == loc)
    for (i <- 0 until 3000) {
      val estimate = status1.getSizeForBlock(i)
      if (sizes(i) > 0) {
        assert(estimate === avg)
      }
    }
  }

  def compressAndDecompressMapStatus(status: MapStatus): MapStatus = {
    val ser = new JavaSerializer(new SparkConf)
    val buf = ser.newInstance().serialize(status)
    ser.newInstance().deserialize[MapStatus](buf)
  }
} 
Example 200
Source File: CoarseGrainedSchedulerBackendSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite}
import org.apache.spark.util.{SerializableBuffer, AkkaUtils}

class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext {
  //序列化任务大于Akka框架大小

  ignore("serialized task larger than akka frame size") {
    val conf = new SparkConf
    //以MB为单位的driver和executor之间通信信息的大小,设置值越大,driver可以接受越大的计算结果
    conf.set("spark.akka.frameSize", "1")
    //设置并发数
    conf.set("spark.default.parallelism", "1")
    //sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf)
    sc = new SparkContext("local[*]", "test", conf)
    //获得Akka传递值大小 1048576默认10M
    val frameSize = AkkaUtils.maxFrameSizeBytes(sc.conf)
   //创建一个序列化缓存

   //ByteBuffer.allocate在能够读和写之前,必须有一个缓冲区,用静态方法 allocate() 来分配缓冲区
    //allocate 分配20M
   val buffer = new SerializableBuffer(java.nio.ByteBuffer.allocate(2 * frameSize))

   val larger = sc.parallelize(Seq(buffer))
  val thrown = intercept[SparkException] {
     larger.collect()
   }
   //抛出异常:使用大的值广播变量
   assert(thrown.getMessage.contains("using broadcast variables for large values"))
   val smaller = sc.parallelize(1 to 4).collect()
   assert(smaller.size === 4)/**/
  }

}