org.apache.spark.SparkFunSuite Scala Examples
The following examples show how to use org.apache.spark.SparkFunSuite.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: MulticlassMetricsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Matrices import org.apache.spark.mllib.util.MLlibTestSparkContext class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multiclass evaluation metrics") { val confusionMatrix = Matrices.dense(3, 3, Array(2, 1, 0, 1, 3, 0, 1, 0, 1)) val labels = Array(0.0, 1.0, 2.0) val predictionAndLabels = sc.parallelize( Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2) val metrics = new MulticlassMetrics(predictionAndLabels) val delta = 0.0000001 val fpRate0 = 1.0 / (9 - 4) val fpRate1 = 1.0 / (9 - 4) val fpRate2 = 1.0 / (9 - 1) val precision0 = 2.0 / (2 + 1) val precision1 = 3.0 / (3 + 1) val precision2 = 1.0 / (1 + 1) val recall0 = 2.0 / (2 + 2) val recall1 = 3.0 / (3 + 1) val recall2 = 1.0 / (1 + 0) val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0) val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) assert(math.abs(metrics.precision(0.0) - precision0) < delta) assert(math.abs(metrics.precision(1.0) - precision1) < delta) assert(math.abs(metrics.precision(2.0) - precision2) < delta) assert(math.abs(metrics.recall(0.0) - recall0) < delta) assert(math.abs(metrics.recall(1.0) - recall1) < delta) assert(math.abs(metrics.recall(2.0) - recall2) < delta) assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta) assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta) assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) assert(math.abs(metrics.recall - (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta) assert(math.abs(metrics.recall - metrics.precision) < delta) assert(math.abs(metrics.recall - metrics.fMeasure) < delta) assert(math.abs(metrics.recall - metrics.weightedRecall) < delta) assert(math.abs(metrics.weightedFalsePositiveRate - ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta) assert(math.abs(metrics.weightedPrecision - ((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta) assert(math.abs(metrics.weightedRecall - ((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta) assert(math.abs(metrics.weightedFMeasure - ((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta) assert(math.abs(metrics.weightedFMeasure(2.0) - ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta) assert(metrics.labels.sameElements(labels)) } }
Example 2
Source File: PythonMLLibAPISuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.api.python import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.recommendation.Rating class PythonMLLibAPISuite extends SparkFunSuite { SerDe.initialize() test("pickle vector") { val vectors = Seq( Vectors.dense(Array.empty[Double]), Vectors.dense(0.0), Vectors.dense(0.0, -2.0), Vectors.sparse(0, Array.empty[Int], Array.empty[Double]), Vectors.sparse(1, Array.empty[Int], Array.empty[Double]), Vectors.sparse(2, Array(1), Array(-2.0))) vectors.foreach { v => val u = SerDe.loads(SerDe.dumps(v)) assert(u.getClass === v.getClass) assert(u === v) } } test("pickle labeled point") { val points = Seq( LabeledPoint(0.0, Vectors.dense(Array.empty[Double])), LabeledPoint(1.0, Vectors.dense(0.0)), LabeledPoint(-0.5, Vectors.dense(0.0, -2.0)), LabeledPoint(0.0, Vectors.sparse(0, Array.empty[Int], Array.empty[Double])), LabeledPoint(1.0, Vectors.sparse(1, Array.empty[Int], Array.empty[Double])), LabeledPoint(-0.5, Vectors.sparse(2, Array(1), Array(-2.0)))) points.foreach { p => val q = SerDe.loads(SerDe.dumps(p)).asInstanceOf[LabeledPoint] assert(q.label === p.label) assert(q.features.getClass === p.features.getClass) assert(q.features === p.features) } } test("pickle double") { for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) { val deser = SerDe.loads(SerDe.dumps(x.asInstanceOf[AnyRef])).asInstanceOf[Double] // We use `equals` here for comparison because we cannot use `==` for NaN assert(x.equals(deser)) } } test("pickle matrix") { val values = Array[Double](0, 1.2, 3, 4.56, 7, 8) val matrix = Matrices.dense(2, 3, values) val nm = SerDe.loads(SerDe.dumps(matrix)).asInstanceOf[DenseMatrix] assert(matrix === nm) // Test conversion for empty matrix val empty = Array[Double]() val emptyMatrix = Matrices.dense(0, 0, empty) val ne = SerDe.loads(SerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix] assert(emptyMatrix == ne) val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4)) val nsm = SerDe.loads(SerDe.dumps(sm)).asInstanceOf[SparseMatrix] assert(sm.toArray === nsm.toArray) val smt = new SparseMatrix( 3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9), isTransposed = true) val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix] assert(smt.toArray === nsmt.toArray) } test("pickle rating") { val rat = new Rating(1, 2, 3.0) val rat2 = SerDe.loads(SerDe.dumps(rat)).asInstanceOf[Rating] assert(rat == rat2) // Test name of class only occur once val rats = (1 to 10).map(x => new Rating(x, x + 1, x + 3.0)).toArray val bytes = SerDe.dumps(rats) assert(bytes.toString.split("Rating").length == 1) assert(bytes.length / 10 < 25) // 25 bytes per rating } }
Example 3
Source File: FPTreeSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.fpm import scala.language.existentials import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("add transaction") { val tree = new FPTree[String] .add(Seq("a", "b", "c")) .add(Seq("a", "b", "y")) .add(Seq("b")) assert(tree.root.children.size == 2) assert(tree.root.children.contains("a")) assert(tree.root.children("a").item.equals("a")) assert(tree.root.children("a").count == 2) assert(tree.root.children.contains("b")) assert(tree.root.children("b").item.equals("b")) assert(tree.root.children("b").count == 1) var child = tree.root.children("a") assert(child.children.size == 1) assert(child.children.contains("b")) assert(child.children("b").item.equals("b")) assert(child.children("b").count == 2) child = child.children("b") assert(child.children.size == 2) assert(child.children.contains("c")) assert(child.children.contains("y")) assert(child.children("c").item.equals("c")) assert(child.children("y").item.equals("y")) assert(child.children("c").count == 1) assert(child.children("y").count == 1) } test("merge tree") { val tree1 = new FPTree[String] .add(Seq("a", "b", "c")) .add(Seq("a", "b", "y")) .add(Seq("b")) val tree2 = new FPTree[String] .add(Seq("a", "b")) .add(Seq("a", "b", "c")) .add(Seq("a", "b", "c", "d")) .add(Seq("a", "x")) .add(Seq("a", "x", "y")) .add(Seq("c", "n")) .add(Seq("c", "m")) val tree3 = tree1.merge(tree2) assert(tree3.root.children.size == 3) assert(tree3.root.children("a").count == 7) assert(tree3.root.children("b").count == 1) assert(tree3.root.children("c").count == 2) val child1 = tree3.root.children("a") assert(child1.children.size == 2) assert(child1.children("b").count == 5) assert(child1.children("x").count == 2) val child2 = child1.children("b") assert(child2.children.size == 2) assert(child2.children("y").count == 1) assert(child2.children("c").count == 3) val child3 = child2.children("c") assert(child3.children.size == 1) assert(child3.children("d").count == 1) val child4 = child1.children("x") assert(child4.children.size == 1) assert(child4.children("y").count == 1) val child5 = tree3.root.children("c") assert(child5.children.size == 2) assert(child5.children("n").count == 1) assert(child5.children("m").count == 1) } test("extract freq itemsets") { val tree = new FPTree[String] .add(Seq("a", "b", "c")) .add(Seq("a", "b", "y")) .add(Seq("a", "b")) .add(Seq("a")) .add(Seq("b")) .add(Seq("b", "n")) val freqItemsets = tree.extract(3L).map { case (items, count) => (items.toSet, count) }.toSet val expected = Set( (Set("a"), 4L), (Set("b"), 5L), (Set("a", "b"), 3L)) assert(freqItemsets === expected) } }
Example 4
Source File: AssociationRulesSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { test("association rules using String type") { val freqItemsets = sc.parallelize(Seq( (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), (Set("r"), 3L), (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L), (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L), (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L), (Set("t", "y", "x"), 3L), (Set("t", "y", "x", "z"), 3L) ).map { case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq) }) val ar = new AssociationRules() val results1 = ar .setMinConfidence(0.9) .run(freqItemsets) .collect() assert(results2.size === 30) assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23) } }
Example 5
Source File: KernelDensitySuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.stat import org.apache.commons.math3.distribution.NormalDistribution import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext { test("kernel density single sample") { val rdd = sc.parallelize(Array(5.0)) val evaluationPoints = Array(5.0, 6.0) val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) val normal = new NormalDistribution(5.0, 3.0) val acceptableErr = 1e-6 assert(math.abs(densities(0) - normal.density(5.0)) < acceptableErr) assert(math.abs(densities(1) - normal.density(6.0)) < acceptableErr) } test("kernel density multiple samples") { val rdd = sc.parallelize(Array(5.0, 10.0)) val evaluationPoints = Array(5.0, 6.0) val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) val normal1 = new NormalDistribution(5.0, 3.0) val normal2 = new NormalDistribution(10.0, 3.0) val acceptableErr = 1e-6 assert(math.abs( densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2) < acceptableErr) assert(math.abs( densities(1) - (normal1.density(6.0) + normal2.density(6.0)) / 2) < acceptableErr) } }
Example 6
Source File: MultivariateGaussianSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.stat.distribution import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{ Vectors, Matrices } import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext { test("univariate") { val x1 = Vectors.dense(0.0) val x2 = Vectors.dense(1.5) val mu = Vectors.dense(0.0) val sigma1 = Matrices.dense(1, 1, Array(1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5) val sigma2 = Matrices.dense(1, 1, Array(4.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5) } test("multivariate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) val mu = Vectors.dense(0.0, 0.0) val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5) val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5) } test("multivariate degenerate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) val mu = Vectors.dense(0.0, 0.0) val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0)) val dist = new MultivariateGaussian(mu, sigma) assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5) assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5) } test("SPARK-11302") { val x = Vectors.dense(629, 640, 1.7188, 618.19) val mu = Vectors.dense( 1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697) val sigma = Matrices.dense(4, 4, Array( 166769.00466698944, 169336.6705268059, 12.820670788921873, 164243.93314092053, 169336.6705268059, 172041.5670061245, 21.62590020524533, 166678.01075856484, 12.820670788921873, 21.62590020524533, 0.872524191943962, 4.283255814732373, 164243.93314092053, 166678.01075856484, 4.283255814732373, 161848.9196719207)) val dist = new MultivariateGaussian(mu, sigma) // Agrees with R's dmvnorm: 7.154782e-05 assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9) } }
Example 7
Source File: KMeansPMMLModelExportSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.pmml.export import org.dmg.pmml.ClusteringModel import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors class KMeansPMMLModelExportSuite extends SparkFunSuite { test("KMeansPMMLModelExport generate PMML format") { val clusterCenters = Array( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0)) val kmeansModel = new KMeansModel(clusterCenters) val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) // assert that the PMML format is as expected assert(modelExport.isInstanceOf[PMMLModelExport]) val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml assert(pmml.getHeader.getDescription === "k-means clustering") // check that the number of fields match the single vector size assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) // This verify that there is a model attached to the pmml object and the model is a clustering // one. It also verifies that the pmml model has the same number of clusters of the spark model. val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) } }
Example 8
Source File: PMMLModelExportFactorySuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.pmml.export import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel} import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator class PMMLModelExportFactorySuite extends SparkFunSuite { test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { val clusterCenters = Array( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0)) val kmeansModel = new KMeansModel(clusterCenters) val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) assert(modelExport.isInstanceOf[KMeansPMMLModelExport]) } test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a " + "LinearRegressionModel, RidgeRegressionModel or LassoModel") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label) val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label) val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) } test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport " + "when passing a LogisticRegressionModel or SVMModel") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) val logisticRegressionModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) } test("PMMLModelExportFactory throw IllegalArgumentException " + "when passing a Multinomial Logistic Regression") { val multiclassLogisticRegressionModel = new LogisticRegressionModel( weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3) intercept[IllegalArgumentException] { PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel) } } test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") { val invalidModel = new Object intercept[IllegalArgumentException] { PMMLModelExportFactory.createPMMLModelExport(invalidModel) } } }
Example 9
Source File: NumericParserSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.util import org.apache.spark.{SparkException, SparkFunSuite} class NumericParserSuite extends SparkFunSuite { test("parser") { val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)" val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]] assert(parsed(0).asInstanceOf[Seq[_]] === Seq(1.0, 2.0e3)) assert(parsed(1).asInstanceOf[Double] === -4.0) assert(parsed(2).asInstanceOf[Array[Double]] === Array(5.0e-6, 7.0e8)) assert(parsed(3).asInstanceOf[Double] === 9.0) val malformatted = Seq("a", "[1,,]", "0.123.4", "1 2", "3+4") malformatted.foreach { s => intercept[SparkException] { NumericParser.parse(s) throw new RuntimeException(s"Didn't detect malformatted string $s.") } } } test("parser with whitespaces") { val s = "(0.0, [1.0, 2.0])" val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]] assert(parsed(0).asInstanceOf[Double] === 0.0) assert(parsed(1).asInstanceOf[Array[Double]] === Array(1.0, 2.0)) } }
Example 10
Source File: BreezeMatrixConversionSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.linalg import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM} import org.apache.spark.SparkFunSuite class BreezeMatrixConversionSuite extends SparkFunSuite { test("dense matrix to breeze") { val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) val breeze = mat.toBreeze.asInstanceOf[BDM[Double]] assert(breeze.rows === mat.numRows) assert(breeze.cols === mat.numCols) assert(breeze.data.eq(mat.asInstanceOf[DenseMatrix].values), "should not copy data") } test("dense breeze matrix to matrix") { val breeze = new BDM[Double](3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) val mat = Matrices.fromBreeze(breeze).asInstanceOf[DenseMatrix] assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") // transposed matrix val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[DenseMatrix] assert(matTransposed.numRows === breeze.cols) assert(matTransposed.numCols === breeze.rows) assert(matTransposed.values.eq(breeze.data), "should not copy data") } test("sparse matrix to breeze") { val values = Array(1.0, 2.0, 4.0, 5.0) val colPtrs = Array(0, 2, 4) val rowIndices = Array(1, 2, 1, 2) val mat = Matrices.sparse(3, 2, colPtrs, rowIndices, values) val breeze = mat.toBreeze.asInstanceOf[BSM[Double]] assert(breeze.rows === mat.numRows) assert(breeze.cols === mat.numCols) assert(breeze.data.eq(mat.asInstanceOf[SparseMatrix].values), "should not copy data") } test("sparse breeze matrix to sparse matrix") { val values = Array(1.0, 2.0, 4.0, 5.0) val colPtrs = Array(0, 2, 4) val rowIndices = Array(1, 2, 1, 2) val breeze = new BSM[Double](values, 3, 2, colPtrs, rowIndices) val mat = Matrices.fromBreeze(breeze).asInstanceOf[SparseMatrix] assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[SparseMatrix] assert(matTransposed.numRows === breeze.cols) assert(matTransposed.numCols === breeze.rows) assert(!matTransposed.values.eq(breeze.data), "has to copy data") } }
Example 11
Source File: CoordinateMatrixSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 5 val n = 4 var mat: CoordinateMatrix = _ override def beforeAll() { super.beforeAll() val entries = sc.parallelize(Seq( (0, 0, 1.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 2, 5.0), (2, 3, 6.0), (3, 0, 7.0), (3, 3, 8.0), (4, 1, 9.0)), 3).map { case (i, j, value) => MatrixEntry(i, j, value) } mat = new CoordinateMatrix(entries) } test("size") { assert(mat.numRows() === m) assert(mat.numCols() === n) } test("empty entries") { val entries = sc.parallelize(Seq[MatrixEntry](), 1) val emptyMat = new CoordinateMatrix(entries) intercept[RuntimeException] { emptyMat.numCols() } intercept[RuntimeException] { emptyMat.numRows() } } test("toBreeze") { val expected = BDM( (1.0, 2.0, 0.0, 0.0), (0.0, 3.0, 4.0, 0.0), (0.0, 0.0, 5.0, 6.0), (7.0, 0.0, 0.0, 8.0), (0.0, 9.0, 0.0, 0.0)) assert(mat.toBreeze() === expected) } test("transpose") { val transposed = mat.transpose() assert(mat.toBreeze().t === transposed.toBreeze()) } test("toIndexedRowMatrix") { val indexedRowMatrix = mat.toIndexedRowMatrix() val expected = BDM( (1.0, 2.0, 0.0, 0.0), (0.0, 3.0, 4.0, 0.0), (0.0, 0.0, 5.0, 6.0), (7.0, 0.0, 0.0, 8.0), (0.0, 9.0, 0.0, 0.0)) assert(indexedRowMatrix.toBreeze() === expected) } test("toRowMatrix") { val rowMatrix = mat.toRowMatrix() val rows = rowMatrix.rows.collect().toSet val expected = Set( Vectors.dense(1.0, 2.0, 0.0, 0.0), Vectors.dense(0.0, 3.0, 4.0, 0.0), Vectors.dense(0.0, 0.0, 5.0, 6.0), Vectors.dense(7.0, 0.0, 0.0, 8.0), Vectors.dense(0.0, 9.0, 0.0, 0.0)) assert(rows === expected) } test("toBlockMatrix") { val blockMat = mat.toBlockMatrix(2, 2) assert(blockMat.numRows() === m) assert(blockMat.numCols() === n) assert(blockMat.toBreeze() === mat.toBreeze()) intercept[IllegalArgumentException] { mat.toBlockMatrix(-1, 2) } intercept[IllegalArgumentException] { mat.toBlockMatrix(2, 0) } } }
Example 12
Source File: BreezeVectorConversionSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.linalg import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} import org.apache.spark.SparkFunSuite class BreezeVectorConversionSuite extends SparkFunSuite { val arr = Array(0.1, 0.2, 0.3, 0.4) val n = 20 val indices = Array(0, 3, 5, 10, 13) val values = Array(0.1, 0.5, 0.3, -0.8, -1.0) test("dense to breeze") { val vec = Vectors.dense(arr) assert(vec.toBreeze === new BDV[Double](arr)) } test("sparse to breeze") { val vec = Vectors.sparse(n, indices, values) assert(vec.toBreeze === new BSV[Double](indices, values, n)) } test("dense breeze to vector") { val breeze = new BDV[Double](arr) val vec = Vectors.fromBreeze(breeze).asInstanceOf[DenseVector] assert(vec.size === arr.length) assert(vec.values.eq(arr), "should not copy data") } test("sparse breeze to vector") { val breeze = new BSV[Double](indices, values, n) val vec = Vectors.fromBreeze(breeze).asInstanceOf[SparseVector] assert(vec.size === n) assert(vec.indices.eq(indices), "should not copy data") assert(vec.values.eq(values), "should not copy data") } test("sparse breeze with partially-used arrays to vector") { val activeSize = 3 val breeze = new BSV[Double](indices, values, activeSize, n) val vec = Vectors.fromBreeze(breeze).asInstanceOf[SparseVector] assert(vec.size === n) assert(vec.indices === indices.slice(0, activeSize)) assert(vec.values === values.slice(0, activeSize)) } }
Example 13
Source File: LabeledPointSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.regression import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors class LabeledPointSuite extends SparkFunSuite { test("parse labeled points") { val points = Seq( LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0)))) points.foreach { p => assert(p === LabeledPoint.parse(p.toString)) } } test("parse labeled points with whitespaces") { val point = LabeledPoint.parse("(0.0, [1.0, 2.0])") assert(point === LabeledPoint(0.0, Vectors.dense(1.0, 2.0))) } test("parse labeled points with v0.9 format") { val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0") assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0))) } }
Example 14
Source File: RidgeRegressionSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.regression import scala.util.Random import org.jblas.DoubleMatrix import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.util.Utils private object RidgeRegressionSuite { val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = { predictions.zip(input).map { case (prediction, expected) => (prediction - expected.label) * (prediction - expected.label) }.reduceLeft(_ + _) / predictions.size } test("ridge regression can help avoid overfitting") { // For small number of examples and large variance of error distribution, // ridge regression should give smaller generalization error that linear regression. val numExamples = 50 val numFeatures = 20 org.jblas.util.Random.seed(42) // Pick weights as random values distributed uniformly in [-0.5, 0.5] val w = DoubleMatrix.rand(numFeatures, 1).subi(0.5) // Use half of data for training and other half for validation val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2 * numExamples, 42, 10.0) val testData = data.take(numExamples) val validationData = data.takeRight(numExamples) val testRDD = sc.parallelize(testData, 2).cache() val validationRDD = sc.parallelize(validationData, 2).cache() // First run without regularization. val linearReg = new LinearRegressionWithSGD() linearReg.optimizer.setNumIterations(200) .setStepSize(1.0) val linearModel = linearReg.run(testRDD) val linearErr = predictionError( linearModel.predict(validationRDD.map(_.features)).collect(), validationData) val ridgeReg = new RidgeRegressionWithSGD() ridgeReg.optimizer.setNumIterations(200) .setRegParam(0.1) .setStepSize(1.0) val ridgeModel = ridgeReg.run(testRDD) val ridgeErr = predictionError( ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData) // Ridge validation error should be lower than linear regression. assert(ridgeErr < linearErr, "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") } test("model save/load") { val model = RidgeRegressionSuite.model val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString // Save model, load it back, and compare. try { model.save(sc, path) val sameModel = RidgeRegressionModel.load(sc, path) assert(model.weights == sameModel.weights) assert(model.intercept == sameModel.intercept) } finally { Utils.deleteRecursively(tempDir) } } } class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 val n = 200000 val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => val random = new Random(idx) iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble())))) }.cache() // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. val model = RidgeRegressionWithSGD.train(points, 2) val predictions = model.predict(points.map(_.features)) } }
Example 15
Source File: MLPairRDDFunctionsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.rdd import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("topByKey") { val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5, 1), (3, 5)), 2) .topByKey(5) .collectAsMap() assert(topMap.size === 3) assert(topMap(1) === Array(7, 6, 3, 2, 1)) assert(topMap(3) === Array(7, 5, 2)) assert(topMap(5) === Array(1)) } }
Example 16
Source File: RDDFunctionsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.rdd import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("sliding") { val data = 0 until 6 for (numPartitions <- 1 to 8) { val rdd = sc.parallelize(data, numPartitions) for (windowSize <- 1 to 6) { for (step <- 1 to 3) { val sliding = rdd.sliding(windowSize, step).collect().map(_.toList).toList val expected = data.sliding(windowSize, step) .map(_.toList).toList.filter(l => l.size == windowSize) assert(sliding === expected) } } assert(rdd.sliding(7).collect().isEmpty, "Should return an empty RDD if the window size is greater than the number of items.") } } test("sliding with empty partitions") { val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) val rdd = sc.parallelize(data, data.length).flatMap(s => s) assert(rdd.partitions.length === data.length) val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq) val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) assert(sliding === expected) } }
Example 17
Source File: TwitterStreamSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.twitter import org.scalatest.BeforeAndAfter import twitter4j.Status import twitter4j.auth.{NullAuthorization, Authorization} import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { val batchDuration = Seconds(1) private val master: String = "local[2]" private val framework: String = this.getClass.getSimpleName test("twitter input stream") { val ssc = new StreamingContext(master, framework, batchDuration) val filters = Seq("filter1", "filter2") val authorization: Authorization = NullAuthorization.getInstance() // tests the API, does not actually test data receiving val test1: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, None) val test2: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, None, filters) val test3: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_AND_DISK_SER_2) val test4: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, Some(authorization)) val test5: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, Some(authorization), filters) val test6: ReceiverInputDStream[Status] = TwitterUtils.createStream( ssc, Some(authorization), filters, StorageLevel.MEMORY_AND_DISK_SER_2) // Note that actually testing the data receiving is hard as authentication keys are // necessary for accessing Twitter live stream ssc.stop() } }
Example 18
Source File: FlumeStreamSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.flume import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.base.Charsets import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory import org.jboss.netty.handler.codec.compression._ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") var ssc: StreamingContext = null test("flume input stream") { testFlumeStream(testCompression = false) } test("flume input compressed stream") { testFlumeStream(testCompression = true) } private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { override def newChannel(pipeline: ChannelPipeline): SocketChannel = { val encoder = new ZlibEncoder(compressionLevel) pipeline.addFirst("deflater", encoder) pipeline.addFirst("inflater", new ZlibDecoder()) super.newChannel(pipeline) } } }
Example 19
Source File: ZeroMQStreamSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.zeromq import akka.actor.SupervisorStrategy import akka.util.ByteString import akka.zeromq.Subscribe import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream class ZeroMQStreamSuite extends SparkFunSuite { val batchDuration = Seconds(1) private val master: String = "local[2]" private val framework: String = this.getClass.getSimpleName test("zeromq input stream") { val ssc = new StreamingContext(master, framework, batchDuration) val publishUrl = "abc" val subscribe = new Subscribe(null.asInstanceOf[ByteString]) val bytesToObjects = (bytes: Seq[ByteString]) => null.asInstanceOf[Iterator[String]] // tests the API, does not actually test data receiving val test1: ReceiverInputDStream[String] = ZeroMQUtils.createStream(ssc, publishUrl, subscribe, bytesToObjects) val test2: ReceiverInputDStream[String] = ZeroMQUtils.createStream( ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2) val test3: ReceiverInputDStream[String] = ZeroMQUtils.createStream( ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2, SupervisorStrategy.defaultStrategy) // TODO: Actually test data receiving ssc.stop() } }
Example 20
Source File: MQTTStreamSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.mqtt import scala.concurrent.duration._ import scala.language.postfixOps import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" private val framework = this.getClass.getSimpleName private val topic = "def" private var ssc: StreamingContext = _ private var mqttTestUtils: MQTTTestUtils = _ before { ssc = new StreamingContext(master, framework, batchDuration) mqttTestUtils = new MQTTTestUtils mqttTestUtils.setup() } after { if (ssc != null) { ssc.stop() ssc = null } if (mqttTestUtils != null) { mqttTestUtils.teardown() mqttTestUtils = null } } test("mqtt input stream") { val sendMessage = "MQTT demo for spark streaming" val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic, StorageLevel.MEMORY_ONLY) @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => if (rdd.collect.length > 0) { receiveMessage = receiveMessage ::: List(rdd.first) receiveMessage } } ssc.start() // Retry it because we don't know when the receiver will start. eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { mqttTestUtils.publishData(topic, sendMessage) assert(sendMessage.equals(receiveMessage(0))) } ssc.stop() } }
Example 21
Source File: KafkaStreamSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.kafka import scala.collection.mutable import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random import kafka.serializer.StringDecoder import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll { private var ssc: StreamingContext = _ private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll(): Unit = { kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } override def afterAll(): Unit = { if (ssc != null) { ssc.stop() ssc = null } if (kafkaTestUtils != null) { kafkaTestUtils.teardown() kafkaTestUtils = null } } test("Kafka input stream") { val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) ssc = new StreamingContext(sparkConf, Milliseconds(500)) val topic = "topic1" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) kafkaTestUtils.createTopic(topic) kafkaTestUtils.sendMessages(topic, sent) val kafkaParams = Map("zookeeper.connect" -> kafkaTestUtils.zkAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}", "auto.offset.reset" -> "smallest") val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) val result = new mutable.HashMap[String, Long]() with mutable.SynchronizedMap[String, Long] stream.map(_._2).countByValue().foreachRDD { r => val ret = r.collect() ret.toMap.foreach { kv => val count = result.getOrElseUpdate(kv._1, 0) + kv._2 result.put(kv._1, count) } } ssc.start() eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { assert(sent === result) } } }
Example 22
Source File: KafkaClusterSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.kafka import scala.util.Random import kafka.common.TopicAndPartition import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll { private val topic = "kcsuitetopic" + Random.nextInt(10000) private val topicAndPartition = TopicAndPartition(topic, 0) private var kc: KafkaCluster = null private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll() { kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() kafkaTestUtils.createTopic(topic) kafkaTestUtils.sendMessages(topic, Map("a" -> 1)) kc = new KafkaCluster(Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress)) } override def afterAll() { if (kafkaTestUtils != null) { kafkaTestUtils.teardown() kafkaTestUtils = null } } test("metadata apis") { val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition) val leaderAddress = s"${leader._1}:${leader._2}" assert(leaderAddress === kafkaTestUtils.brokerAddress, "didn't get leader") val parts = kc.getPartitions(Set(topic)).right.get assert(parts(topicAndPartition), "didn't get partitions") val err = kc.getPartitions(Set(topic + "BAD")) assert(err.isLeft, "getPartitions for a nonexistant topic should be an error") } test("leader offset apis") { val earliest = kc.getEarliestLeaderOffsets(Set(topicAndPartition)).right.get assert(earliest(topicAndPartition).offset === 0, "didn't get earliest") val latest = kc.getLatestLeaderOffsets(Set(topicAndPartition)).right.get assert(latest(topicAndPartition).offset === 1, "didn't get latest") } test("consumer offset apis") { val group = "kcsuitegroup" + Random.nextInt(10000) val offset = Random.nextInt(10000) val set = kc.setConsumerOffsets(group, Map(topicAndPartition -> offset)) assert(set.isRight, "didn't set consumer offsets") val get = kc.getConsumerOffsets(group, Set(topicAndPartition)).right.get assert(get(topicAndPartition) === offset, "didn't get consumer offsets") } }
Example 23
Source File: ConcurrentHiveSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.execution import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext import org.scalatest.BeforeAndAfterAll class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { test("Multiple Hive Instances") { (1 to 10).map { i => val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val ts = new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", conf)) ts.executeSql("SHOW TABLES").toRdd.collect() ts.executeSql("SELECT * FROM src").toRdd.collect() ts.executeSql("SHOW TABLES").toRdd.collect() } } } }
Example 24
Source File: FiltersSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.client import java.util.Collections import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.serde.serdeConstants import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ class FiltersSuite extends SparkFunSuite with Logging { private val shim = new Shim_v0_13 private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test") private val varCharCol = new FieldSchema() varCharCol.setName("varchar") varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME) testTable.setPartCols(Collections.singletonList(varCharCol)) filterTest("string filter", (a("stringcol", StringType) > Literal("test")) :: Nil, "stringcol > \"test\"") filterTest("string filter backwards", (Literal("test") > a("stringcol", StringType)) :: Nil, "\"test\" > stringcol") filterTest("int filter", (a("intcol", IntegerType) === Literal(1)) :: Nil, "intcol = 1") filterTest("int filter backwards", (Literal(1) === a("intcol", IntegerType)) :: Nil, "1 = intcol") filterTest("int and string filter", (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil, "1 = intcol and \"a\" = strcol") filterTest("skip varchar", (Literal("") === a("varchar", StringType)) :: Nil, "") private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name){ val converted = shim.convertFilters(testTable, filters) if (converted != result) { fail( s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'") } } } private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() }
Example 25
Source File: ClasspathDependenciesSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive import java.net.URL import org.apache.spark.SparkFunSuite class ClasspathDependenciesSuite extends SparkFunSuite { private val classloader = this.getClass.getClassLoader private def assertLoads(classname: String): Unit = { val resourceURL: URL = Option(findResource(classname)).getOrElse { fail(s"Class $classname not found as ${resourceName(classname)}") } logInfo(s"Class $classname at $resourceURL") classloader.loadClass(classname) } private def assertLoads(classes: String*): Unit = { classes.foreach(assertLoads) } private def findResource(classname: String): URL = { val resource = resourceName(classname) classloader.getResource(resource) } private def resourceName(classname: String): String = { classname.replace(".", "/") + ".class" } private def assertClassNotFound(classname: String): Unit = { Option(findResource(classname)).foreach { resourceURL => fail(s"Class $classname found at $resourceURL") } intercept[ClassNotFoundException] { classloader.loadClass(classname) } } private def assertClassNotFound(classes: String*): Unit = { classes.foreach(assertClassNotFound) } private val KRYO = "com.esotericsoftware.kryo.Kryo" private val SPARK_HIVE = "org.apache.hive." private val SPARK_SHADED = "org.spark-project.hive.shaded." test("shaded Protobuf") { assertLoads(SPARK_SHADED + "com.google.protobuf.ServiceException") } test("hive-common") { assertLoads("org.apache.hadoop.hive.conf.HiveConf") } test("hive-exec") { assertLoads("org.apache.hadoop.hive.ql.CommandNeedRetryException") } private val STD_INSTANTIATOR = "org.objenesis.strategy.StdInstantiatorStrategy" test("unshaded kryo") { assertLoads(KRYO, STD_INSTANTIATOR) } test("Forbidden Dependencies") { assertClassNotFound( SPARK_HIVE + KRYO, SPARK_SHADED + KRYO, "org.apache.hive." + KRYO, "com.esotericsoftware.shaded." + STD_INSTANTIATOR, SPARK_HIVE + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, "org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR ) } test("parquet-hadoop-bundle") { assertLoads( "parquet.hadoop.ParquetOutputFormat", "parquet.hadoop.ParquetInputFormat" ) } }
Example 26
Source File: RandomDataGeneratorSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types._ def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = { val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType) val generator = RandomDataGenerator.forType(dataType, nullable, Some(33)).getOrElse { fail(s"Random data generator was not defined for $dataType") } if (nullable) { assert(Iterator.fill(100)(generator()).contains(null)) } else { assert(Iterator.fill(100)(generator()).forall(_ != null)) } for (_ <- 1 to 10) { val generatedValue = generator() toCatalyst(generatedValue) } } // Basic types: for ( dataType <- DataTypeTestUtils.atomicTypes; nullable <- Seq(true, false) if !dataType.isInstanceOf[DecimalType]) { test(s"$dataType (nullable=$nullable)") { testRandomDataGeneration(dataType) } } for ( arrayType <- DataTypeTestUtils.atomicArrayTypes if RandomDataGenerator.forType(arrayType.elementType, arrayType.containsNull).isDefined ) { test(s"$arrayType") { testRandomDataGeneration(arrayType) } } val atomicTypesWithDataGenerators = DataTypeTestUtils.atomicTypes.filter(RandomDataGenerator.forType(_).isDefined) // Complex types: for ( keyType <- atomicTypesWithDataGenerators; valueType <- atomicTypesWithDataGenerators // Scala's BigDecimal.hashCode can lead to OutOfMemoryError on Scala 2.10 (see SI-6173) and // Spark can hit NumberFormatException errors when converting certain BigDecimals (SPARK-8802). // For these reasons, we don't support generation of maps with decimal keys. if !keyType.isInstanceOf[DecimalType] ) { val mapType = MapType(keyType, valueType) test(s"$mapType") { testRandomDataGeneration(mapType) } } for ( colOneType <- atomicTypesWithDataGenerators; colTwoType <- atomicTypesWithDataGenerators ) { val structType = StructType(StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil) test(s"$structType") { testRandomDataGeneration(structType) } } }
Example 27
Source File: LogicalPlanSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 private val function: PartialFunction[LogicalPlan, LogicalPlan] = { case p: Project => invocationCount += 1 p } private val testRelation = LocalRelation() test("resolveOperator runs on operators") { invocationCount = 0 val plan = Project(Nil, testRelation) plan resolveOperators function assert(invocationCount === 1) } test("resolveOperator runs on operators recursively") { invocationCount = 0 val plan = Project(Nil, Project(Nil, testRelation)) plan resolveOperators function assert(invocationCount === 2) } test("resolveOperator skips all ready resolved plans") { invocationCount = 0 val plan = Project(Nil, Project(Nil, testRelation)) plan.foreach(_.setAnalyzed()) plan resolveOperators function assert(invocationCount === 0) } test("resolveOperator skips partially resolved plans") { invocationCount = 0 val plan1 = Project(Nil, testRelation) val plan2 = Project(Nil, plan1) plan1.foreach(_.setAnalyzed()) plan2 resolveOperators function assert(invocationCount === 1) } }
Example 28
Source File: SameResultSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ class SameResultSuite extends SparkFunSuite { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = { val aAnalyzed = a.analyze val bAnalyzed = b.analyze if (aAnalyzed.sameResult(bAnalyzed) != result) { val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n") fail(s"Plans should return sameResult = $result\n$comparison") } } test("relations") { assertSameResult(testRelation, testRelation2) } test("projections") { assertSameResult(testRelation.select('a), testRelation2.select('a)) assertSameResult(testRelation.select('b), testRelation2.select('b)) assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b)) assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a)) assertSameResult(testRelation, testRelation2.select('a), result = false) assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), result = false) } test("filters") { assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b)) } test("sorts") { assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc)) } }
Example 29
Source File: NondeterministicSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { test("MonotonicallyIncreasingID") { checkEvaluation(MonotonicallyIncreasingID(), 0L) } test("SparkPartitionID") { checkEvaluation(SparkPartitionID(), 0) } test("InputFileName") { checkEvaluation(InputFileName(), "") } }
Example 30
Source File: NullFunctionsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = { testFunc(false, BooleanType) testFunc(1.toByte, ByteType) testFunc(1.toShort, ShortType) testFunc(1, IntegerType) testFunc(1L, LongType) testFunc(1.0F, FloatType) testFunc(1.0, DoubleType) testFunc(Decimal(1.5), DecimalType(2, 1)) testFunc(new java.sql.Date(10), DateType) testFunc(new java.sql.Timestamp(10), TimestampType) testFunc("abcd", StringType) } test("isnull and isnotnull") { testAllTypes { (value: Any, tpe: DataType) => checkEvaluation(IsNull(Literal.create(value, tpe)), false) checkEvaluation(IsNotNull(Literal.create(value, tpe)), true) checkEvaluation(IsNull(Literal.create(null, tpe)), true) checkEvaluation(IsNotNull(Literal.create(null, tpe)), false) } } test("IsNaN") { checkEvaluation(IsNaN(Literal(Double.NaN)), true) checkEvaluation(IsNaN(Literal(Float.NaN)), true) checkEvaluation(IsNaN(Literal(math.log(-3))), true) checkEvaluation(IsNaN(Literal.create(null, DoubleType)), false) checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false) checkEvaluation(IsNaN(Literal(Float.MaxValue)), false) checkEvaluation(IsNaN(Literal(5.5f)), false) } test("nanvl") { checkEvaluation(NaNvl(Literal(5.0), Literal.create(null, DoubleType)), 5.0) checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(5.0)), null) checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(Double.NaN)), null) checkEvaluation(NaNvl(Literal(Double.NaN), Literal(5.0)), 5.0) checkEvaluation(NaNvl(Literal(Double.NaN), Literal.create(null, DoubleType)), null) assert(NaNvl(Literal(Double.NaN), Literal(Double.NaN)). eval(EmptyRow).asInstanceOf[Double].isNaN) } test("coalesce") { testAllTypes { (value: Any, tpe: DataType) => val lit = Literal.create(value, tpe) val nullLit = Literal.create(null, tpe) checkEvaluation(Coalesce(Seq(nullLit)), null) checkEvaluation(Coalesce(Seq(lit)), value) checkEvaluation(Coalesce(Seq(nullLit, lit)), value) checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value) checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value) } } test("AtLeastNNonNulls") { val mix = Seq(Literal("x"), Literal.create(null, StringType), Literal.create(null, DoubleType), Literal(Double.NaN), Literal(5f)) val nanOnly = Seq(Literal("x"), Literal(10.0), Literal(Float.NaN), Literal(math.log(-2)), Literal(Double.MaxValue)) val nullOnly = Seq(Literal("x"), Literal.create(null, DoubleType), Literal.create(null, DecimalType.USER_DEFAULT), Literal(Float.MaxValue), Literal(false)) checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow) checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow) checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow) checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow) checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) } }
Example 31
Source File: DecimalExpressionSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.{LongType, DecimalType, Decimal} class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("UnscaledValue") { val d1 = Decimal("10.1") checkEvaluation(UnscaledValue(Literal(d1)), 101L) val d2 = Decimal(101, 3, 1) checkEvaluation(UnscaledValue(Literal(d2)), 101L) checkEvaluation(UnscaledValue(Literal.create(null, DecimalType(2, 1))), null) } test("MakeDecimal") { checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1")) checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null) } test("PromotePrecision") { val d1 = Decimal("10.1") checkEvaluation(PromotePrecision(Literal(d1)), d1) val d2 = Decimal(101, 3, 1) checkEvaluation(PromotePrecision(Literal(d2)), d2) checkEvaluation(PromotePrecision(Literal.create(null, DecimalType(2, 1))), null) } test("CheckOverflow") { val d1 = Decimal("10.1") checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10")) checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1) checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1) checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null) val d2 = Decimal(101, 3, 1) checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10")) checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2) checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2) checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null) checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null) } }
Example 32
Source File: CollectionFunctionsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Array and Map Size") { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) checkEvaluation(Size(a0), 3) checkEvaluation(Size(a1), 0) checkEvaluation(Size(a2), 2) val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) checkEvaluation(Size(m0), 2) checkEvaluation(Size(m1), 0) checkEvaluation(Size(m2), 1) checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) } test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) checkEvaluation(new SortArray(a2), Seq("a", "b")) checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b")) checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1)) checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a")) checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) checkEvaluation(new SortArray(a4), Seq(null, null)) val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2))) } test("Array contains") { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) val a3 = Literal.create(null, ArrayType(StringType)) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) checkEvaluation(ArrayContains(a1, Literal("")), true) checkEvaluation(ArrayContains(a1, Literal("a")), null) checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null) checkEvaluation(ArrayContains(a2, Literal(1L)), null) checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null) checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } }
Example 33
Source File: RandomSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { test("random") { checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001) checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001) } test("SPARK-9127 codegen with long seed") { checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001) checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001) } }
Example 34
Source File: MiscFunctionsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType} class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("md5") { checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932") checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "6ac1e56bc78f031059be7be854522c4c") checkEvaluation(Md5(Literal.create(null, BinaryType)), null) checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType) } test("sha1") { checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "5d211bad8f4ee70e16c7d343a838fc344a1ed961") checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType) } test("sha2") { checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC")) checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6))) // unsupported bit length checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null) checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null) checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null) checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null) } test("crc32") { checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L) checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), 2180413220L) checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } }
Example 35
Source File: AttributeSetSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.IntegerType class AttributeSetSuite extends SparkFunSuite { val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) val aSet = AttributeSet(aLower :: Nil) val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) val bSet = AttributeSet(bUpper :: Nil) val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) test("sanity check") { assert(aUpper != aLower) assert(bUpper != bLower) } test("checks by id not name") { assert(aSet.contains(aUpper) === true) assert(aSet.contains(aLower) === true) assert(aSet.contains(fakeA) === false) assert(aSet.contains(bUpper) === false) assert(aSet.contains(bLower) === false) } test("++ preserves AttributeSet") { assert((aSet ++ bSet).contains(aUpper) === true) assert((aSet ++ bSet).contains(aLower) === true) } test("extracts all references references") { val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil) assert(addSet.contains(aUpper)) assert(addSet.contains(aLower)) assert(addSet.contains(bUpper)) assert(addSet.contains(bLower)) } test("dedups attributes") { assert(AttributeSet(aUpper :: aLower :: Nil).size === 1) } test("subset") { assert(aSet.subsetOf(aAndBSet) === true) assert(aAndBSet.subsetOf(aSet) === false) } test("equality") { assert(aSet != aAndBSet) assert(aAndBSet != aSet) assert(aSet != bSet) assert(bSet != aSet) assert(aSet == aSet) assert(aSet == AttributeSet(aUpper :: Nil)) } }
Example 36
Source File: CodeFormatterSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite class CodeFormatterSuite extends SparkFunSuite { def testCase(name: String)(input: String)(expected: String): Unit = { test(name) { assert(CodeFormatter.format(input).trim === expected.trim) } } testCase("basic example") { """class A { |blahblah; |}""".stripMargin }{ """ | c) """.stripMargin } }
Example 37
Source File: GenerateUnsafeRowJoinerSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions.codegen import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.types._ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { private val fixed = Seq(IntegerType) private val variable = Seq(IntegerType, StringType) test("simple fixed width types") { testConcat(0, 0, fixed) testConcat(0, 1, fixed) testConcat(1, 0, fixed) testConcat(64, 0, fixed) testConcat(0, 64, fixed) testConcat(64, 64, fixed) } test("randomized fix width types") { for (i <- 0 until 20) { testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed) } } test("simple variable width types") { testConcat(0, 0, variable) testConcat(0, 1, variable) testConcat(1, 0, variable) testConcat(64, 0, variable) testConcat(0, 64, variable) testConcat(64, 64, variable) } test("randomized variable width types") { for (i <- 0 until 10) { testConcatOnce(Random.nextInt(100), Random.nextInt(100), variable) } } private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]): Unit = { for (i <- 0 until 10) { testConcatOnce(numFields1, numFields2, candidateTypes) } } private def testConcatOnce(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]) { info(s"schema size $numFields1, $numFields2") val schema1 = RandomDataGenerator.randomSchema(numFields1, candidateTypes) val schema2 = RandomDataGenerator.randomSchema(numFields2, candidateTypes) // Create the converters needed to convert from external row to internal row and to UnsafeRows. val internalConverter1 = CatalystTypeConverters.createToCatalystConverter(schema1) val internalConverter2 = CatalystTypeConverters.createToCatalystConverter(schema2) val converter1 = UnsafeProjection.create(schema1) val converter2 = UnsafeProjection.create(schema2) // Create the input rows, convert them into UnsafeRows. val extRow1 = RandomDataGenerator.forType(schema1, nullable = false).get.apply() val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply() val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow]) val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow]) // Run the joiner. val mergedSchema = StructType(schema1 ++ schema2) val concater = GenerateUnsafeRowJoiner.create(schema1, schema2) val output = concater.join(row1, row2) // Test everything equals ... for (i <- mergedSchema.indices) { if (i < schema1.size) { assert(output.isNullAt(i) === row1.isNullAt(i)) if (!output.isNullAt(i)) { assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType)) } } else { assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size)) if (!output.isNullAt(i)) { assert(output.get(i, mergedSchema(i).dataType) === row2.get(i - schema1.size, mergedSchema(i).dataType)) } } } } }
Example 38
Source File: EncoderErrorMessageSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.encoders import scala.reflect.ClassTag import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders class NonEncodable(i: Int) case class ComplexNonEncodable1(name1: NonEncodable) case class ComplexNonEncodable2(name2: ComplexNonEncodable1) case class ComplexNonEncodable3(name3: Option[NonEncodable]) case class ComplexNonEncodable4(name4: Array[NonEncodable]) case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]]) class EncoderErrorMessageSuite extends SparkFunSuite { // Note: we also test error messages for encoders for private classes in JavaDatasetSuite. // That is done in Java because Scala cannot create truly private classes. test("primitive types in encoders using Kryo serialization") { intercept[UnsupportedOperationException] { Encoders.kryo[Int] } intercept[UnsupportedOperationException] { Encoders.kryo[Long] } intercept[UnsupportedOperationException] { Encoders.kryo[Char] } } test("primitive types in encoders using Java serialization") { intercept[UnsupportedOperationException] { Encoders.javaSerialization[Int] } intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] } intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] } } test("nice error message for missing encoder") { val errorMsg1 = intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage assert(errorMsg1.contains( s"""root class: "${clsName[ComplexNonEncodable1]}"""")) assert(errorMsg1.contains( s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) val errorMsg2 = intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage assert(errorMsg2.contains( s"""root class: "${clsName[ComplexNonEncodable2]}"""")) assert(errorMsg2.contains( s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")""")) assert(errorMsg1.contains( s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) val errorMsg3 = intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage assert(errorMsg3.contains( s"""root class: "${clsName[ComplexNonEncodable3]}"""")) assert(errorMsg3.contains( s"""field (class: "scala.Option", name: "name3")""")) assert(errorMsg3.contains( s"""option value class: "${clsName[NonEncodable]}"""")) val errorMsg4 = intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage assert(errorMsg4.contains( s"""root class: "${clsName[ComplexNonEncodable4]}"""")) assert(errorMsg4.contains( s"""field (class: "scala.Array", name: "name4")""")) assert(errorMsg4.contains( s"""array element class: "${clsName[NonEncodable]}"""")) val errorMsg5 = intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage assert(errorMsg5.contains( s"""root class: "${clsName[ComplexNonEncodable5]}"""")) assert(errorMsg5.contains( s"""field (class: "scala.Option", name: "name5")""")) assert(errorMsg5.contains( s"""option value class: "scala.Array"""")) assert(errorMsg5.contains( s"""array element class: "${clsName[NonEncodable]}"""")) } private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName }
Example 39
Source File: RuleExecutorSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.trees import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) } } test("only once") { object ApplyOnce extends RuleExecutor[Expression] { val batches = Batch("once", Once, DecrementLiterals) :: Nil } assert(ApplyOnce.execute(Literal(10)) === Literal(9)) } test("to fixed point") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(10)) === Literal(0)) } test("to maxIterations") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(100)) === Literal(90)) } }
Example 40
Source File: PartitioningSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} class PartitioningSuite extends SparkFunSuite { test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { val expressions = Seq(Literal(2), Literal(3)) // Consider two HashPartitionings that have the same _set_ of hash expressions but which are // created with different orderings of those expressions: val partitioningA = HashPartitioning(expressions, 100) val partitioningB = HashPartitioning(expressions.reverse, 100) // These partitionings are not considered equal: assert(partitioningA != partitioningB) // However, they both satisfy the same clustered distribution: val distribution = ClusteredDistribution(expressions) assert(partitioningA.satisfies(distribution)) assert(partitioningB.satisfies(distribution)) // These partitionings compute different hashcodes for the same input row: def computeHashCode(partitioning: HashPartitioning): Int = { val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) hashExprProj.apply(InternalRow.empty).hashCode() } assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) // Thus, these partitionings are incompatible: assert(!partitioningA.compatibleWith(partitioningB)) assert(!partitioningB.compatibleWith(partitioningA)) assert(!partitioningA.guarantees(partitioningB)) assert(!partitioningB.guarantees(partitioningA)) // Just to be sure that we haven't cheated by having these methods always return false, // check that identical partitionings are still compatible with and guarantee each other: assert(partitioningA === partitioningA) assert(partitioningA.guarantees(partitioningA)) assert(partitioningA.compatibleWith(partitioningA)) } }
Example 41
Source File: NumberConverterSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.NumberConverter.convert import org.apache.spark.unsafe.types.UTF8String class NumberConverterSuite extends SparkFunSuite { private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = { assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) === UTF8String.fromString(expected)) } test("convert") { checkConv("3", 10, 2, "11") checkConv("-15", 10, -16, "-F") checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1") checkConv("big", 36, 16, "3A48") checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF") checkConv("11abc", 10, 16, "B") } }
Example 42
Source File: MetadataSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.json4s.jackson.JsonMethods.parse import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.{MetadataBuilder, Metadata} class MetadataSuite extends SparkFunSuite { val baseMetadata = new MetadataBuilder() .putString("purpose", "ml") .putBoolean("isBase", true) .build() val summary = new MetadataBuilder() .putLong("numFeatures", 10L) .build() val age = new MetadataBuilder() .putString("name", "age") .putLong("index", 1L) .putBoolean("categorical", false) .putDouble("average", 45.0) .build() val gender = new MetadataBuilder() .putString("name", "gender") .putLong("index", 5) .putBoolean("categorical", true) .putStringArray("categories", Array("male", "female")) .build() val metadata = new MetadataBuilder() .withMetadata(baseMetadata) .putBoolean("isBase", false) // overwrite an existing key .putMetadata("summary", summary) .putLongArray("long[]", Array(0L, 1L)) .putDoubleArray("double[]", Array(3.0, 4.0)) .putBooleanArray("boolean[]", Array(true, false)) .putMetadataArray("features", Array(age, gender)) .build() test("metadata builder and getters") { assert(age.contains("summary") === false) assert(age.contains("index") === true) assert(age.getLong("index") === 1L) assert(age.contains("average") === true) assert(age.getDouble("average") === 45.0) assert(age.contains("categorical") === true) assert(age.getBoolean("categorical") === false) assert(age.contains("name") === true) assert(age.getString("name") === "age") assert(metadata.contains("purpose") === true) assert(metadata.getString("purpose") === "ml") assert(metadata.contains("isBase") === true) assert(metadata.getBoolean("isBase") === false) assert(metadata.contains("summary") === true) assert(metadata.getMetadata("summary") === summary) assert(metadata.contains("long[]") === true) assert(metadata.getLongArray("long[]").toSeq === Seq(0L, 1L)) assert(metadata.contains("double[]") === true) assert(metadata.getDoubleArray("double[]").toSeq === Seq(3.0, 4.0)) assert(metadata.contains("boolean[]") === true) assert(metadata.getBooleanArray("boolean[]").toSeq === Seq(true, false)) assert(gender.contains("categories") === true) assert(gender.getStringArray("categories").toSeq === Seq("male", "female")) assert(metadata.contains("features") === true) assert(metadata.getMetadataArray("features").toSeq === Seq(age, gender)) } test("metadata json conversion") { val json = metadata.json withClue("toJson must produce a valid JSON string") { parse(json) } val parsed = Metadata.fromJson(json) assert(parsed === metadata) assert(parsed.## === metadata.##) } }
Example 43
Source File: StringUtilsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.StringUtils._ class StringUtilsSuite extends SparkFunSuite { test("escapeLikeRegex") { assert(escapeLikeRegex("abdef") === "(?s)\\Qa\\E\\Qb\\E\\Qd\\E\\Qe\\E\\Qf\\E") assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E_.\\Qb\\E") assert(escapeLikeRegex("a_%b") === "(?s)\\Qa\\E..*\\Qb\\E") assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*%\\Qb\\E") assert(escapeLikeRegex("a%") === "(?s)\\Qa\\E.*") assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") } }
Example 44
Source File: CatalystTypeConvertersSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.types._ class CatalystTypeConvertersSuite extends SparkFunSuite { private val simpleTypes: Seq[DataType] = Seq( StringType, DateType, BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType.USER_DEFAULT) test("null handling in rows") { val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t))) val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema) val convertToScala = CatalystTypeConverters.createToScalaConverter(schema) val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null)) assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow) } test("null handling for individual values") { for (dataType <- simpleTypes) { assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null) } } test("option handling in convertToCatalyst") { // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with // createToCatalystConverter but it may not actually matter as this is only called internally // in a handful of places where we don't expect to receive Options. assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123)) } test("option handling in createToCatalystConverter") { assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123) } }
Example 45
Source File: SQLContextSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.{SharedSparkContext, SparkFunSuite} class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ test("getOrCreate instantiates SQLContext") { val sqlContext = SQLContext.getOrCreate(sc) assert(sqlContext != null, "SQLContext.getOrCreate returned null") assert(SQLContext.getOrCreate(sc).eq(sqlContext), "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") } test("getOrCreate return the original SQLContext") { val sqlContext = SQLContext.getOrCreate(sc) val newSession = sqlContext.newSession() assert(SQLContext.getOrCreate(sc).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") SQLContext.setActive(newSession) assert(SQLContext.getOrCreate(sc).eq(newSession), "SQLContext.getOrCreate after explicitly setActive() did not return the active context") } test("Sessions of SQLContext") { val sqlContext = SQLContext.getOrCreate(sc) val session1 = sqlContext.newSession() val session2 = sqlContext.newSession() // all have the default configurations val key = SQLConf.SHUFFLE_PARTITIONS.key assert(session1.getConf(key) === session2.getConf(key)) session1.setConf(key, "1") session2.setConf(key, "2") assert(session1.getConf(key) === "1") assert(session2.getConf(key) === "2") // temporary table should not be shared val df = session1.range(10) df.registerTempTable("test1") assert(session1.tableNames().contains("test1")) assert(!session2.tableNames().contains("test1")) // UDF should not be shared def myadd(a: Int, b: Int): Int = a + b session1.udf.register[Int, Int, Int]("myadd", myadd) session1.sql("select myadd(1, 2)").explain() intercept[AnalysisException] { session2.sql("select myadd(1, 2)").explain() } } test("SPARK-13390: createDataFrame(java.util.List[_],Class[_]) NotSerializableException") { val rows = new java.util.ArrayList[IntJavaBean]() rows.add(new IntJavaBean(1)) val sqlContext = SQLContext.getOrCreate(sc) // Without the fix for SPARK-13390, this will throw NotSerializableException sqlContext.createDataFrame(rows, classOf[IntJavaBean]).groupBy("int").count().collect() } } class IntJavaBean(private var i: Int) extends Serializable { def getInt(): Int = i def setInt(i: Int): Unit = { this.i = i } }
Example 46
Source File: LocalNodeTest.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.local import org.apache.spark.SparkFunSuite import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference} import org.apache.spark.sql.types.{IntegerType, StringType} class LocalNodeTest extends SparkFunSuite { protected val conf: SQLConf = new SQLConf protected val kvIntAttributes = Seq( AttributeReference("k", IntegerType)(), AttributeReference("v", IntegerType)()) protected val joinNameAttributes = Seq( AttributeReference("id1", IntegerType)(), AttributeReference("name", StringType)()) protected val joinNicknameAttributes = Seq( AttributeReference("id2", IntegerType)(), AttributeReference("nickname", StringType)()) protected def resolveExpressions( expressions: Seq[Expression], localNode: LocalNode): Seq[Expression] = { require(localNode.expressions.forall(_.resolved)) val inputMap = localNode.output.map { a => (a.name, a) }.toMap expressions.map { expression => expression.transformUp { case UnresolvedAttribute(Seq(u)) => inputMap.getOrElse(u, sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) } } } }
Example 47
Source File: CoGroupedIteratorSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper { test("basic") { val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) val result = cogrouped.map { case (key, leftData, rightData) => assert(key.numFields == 1) (key.getInt(0), leftData.toSeq, rightData.toSeq) }.toSeq assert(result == (1, Seq(create_row(1, "a"), create_row(1, "b")), Seq(create_row(1, 2L))) :: (2, Seq(create_row(2, "c")), Seq(create_row(2, 3L))) :: (3, Seq.empty, Seq(create_row(3, 4L))) :: Nil ) } test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") { val leftInput = Seq(create_row(2, "a")).iterator val rightInput = Seq(create_row(1, 2L)).iterator val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) val result = cogrouped.map { case (key, leftData, rightData) => assert(key.numFields == 1) (key.getInt(0), leftData.toSeq, rightData.toSeq) }.toSeq assert(result == (1, Seq.empty, Seq(create_row(1, 2L))) :: (2, Seq(create_row(2, "a")), Seq.empty) :: Nil ) } }
Example 48
Source File: SQLExecutionSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import java.util.Properties import scala.collection.parallel.CompositeThrowable import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.SQLContext class SQLExecutionSuite extends SparkFunSuite { test("concurrent query execution (SPARK-10548)") { // Try to reproduce the issue with the old SparkContext val conf = new SparkConf() .setMaster("local[*]") .setAppName("test") val badSparkContext = new BadSparkContext(conf) try { testConcurrentQueryExecution(badSparkContext) fail("unable to reproduce SPARK-10548") } catch { case e: IllegalArgumentException => assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) } finally { badSparkContext.stop() } // Verify that the issue is fixed with the latest SparkContext val goodSparkContext = new SparkContext(conf) try { testConcurrentQueryExecution(goodSparkContext) } finally { goodSparkContext.stop() } } private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { override protected def childValue(parent: Properties): Properties = new Properties(parent) override protected def initialValue(): Properties = new Properties() } }
Example 49
Source File: NullableColumnBuilderSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.types._ class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) extends BasicColumnBuilder[JvmType](new NoopColumnStats, columnType) with NullableColumnBuilder object TestNullableColumnBuilder { def apply[JvmType](columnType: ColumnType[JvmType], initialSize: Int = 0) : TestNullableColumnBuilder[JvmType] = { val builder = new TestNullableColumnBuilder(columnType) builder.initialize(initialSize) builder } } class NullableColumnBuilderSuite extends SparkFunSuite { import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10), STRUCT(StructType(StructField("a", StringType) :: Nil)), ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType))) .foreach { testNullableColumnBuilder(_) } def testNullableColumnBuilder[JvmType]( columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") val dataType = columnType.dataType val proj = UnsafeProjection.create(Array[DataType](dataType)) val converter = CatalystTypeConverters.createToScalaConverter(dataType) test(s"$typeName column builder: empty column") { val columnBuilder = TestNullableColumnBuilder(columnType) val buffer = columnBuilder.build() assertResult(0, "Wrong null count")(buffer.getInt()) assert(!buffer.hasRemaining) } test(s"$typeName column builder: buffer size auto growth") { val columnBuilder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) (0 until 4).foreach { _ => columnBuilder.appendFrom(proj(randomRow), 0) } val buffer = columnBuilder.build() assertResult(0, "Wrong null count")(buffer.getInt()) } test(s"$typeName column builder: null values") { val columnBuilder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) val nullRow = makeNullRow(1) (0 until 4).foreach { _ => columnBuilder.appendFrom(proj(randomRow), 0) columnBuilder.appendFrom(proj(nullRow), 0) } val buffer = columnBuilder.build() assertResult(4, "Wrong null count")(buffer.getInt()) // For null positions (1 to 7 by 2).foreach(assertResult(_, "Wrong null position")(buffer.getInt())) // For non-null values val actual = new GenericMutableRow(new Array[Any](1)) (0 until 4).foreach { _ => columnType.extract(buffer, actual, 0) assert(converter(actual.get(0, dataType)) === converter(randomRow.get(0, dataType)), "Extracted value didn't equal to the original one") } assert(!buffer.hasRemaining) } } }
Example 50
Source File: ColumnStatsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, createRow(Double.MaxValue, Double.MinValue, 0)) testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) testDecimalColumnStats(createRow(null, null, 0)) def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], initialStatistics: GenericInternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } test(s"$columnStatsName: non-empty") { import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) assertResult(10, "Wrong null count")(stats.values(2)) assertResult(20, "Wrong row count")(stats.values(3)) assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum } } } def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( initialStatistics: GenericInternalRow): Unit = { val columnStatsName = classOf[DecimalColumnStats].getSimpleName val columnType = COMPACT_DECIMAL(15, 10) test(s"$columnStatsName: empty") { val columnStats = new DecimalColumnStats(15, 10) columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } test(s"$columnStatsName: non-empty") { import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ val columnStats = new DecimalColumnStats(15, 10) val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) assertResult(10, "Wrong null count")(stats.values(2)) assertResult(20, "Wrong row count")(stats.values(3)) assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum } } } }
Example 51
Source File: NullableColumnAccessorSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution.columnar import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.types._ class TestNullableColumnAccessor[JvmType]( buffer: ByteBuffer, columnType: ColumnType[JvmType]) extends BasicColumnAccessor(buffer, columnType) with NullableColumnAccessor object TestNullableColumnAccessor { def apply[JvmType](buffer: ByteBuffer, columnType: ColumnType[JvmType]) : TestNullableColumnAccessor[JvmType] = { new TestNullableColumnAccessor(buffer, columnType) } } class NullableColumnAccessorSuite extends SparkFunSuite { import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10), STRUCT(StructType(StructField("a", StringType) :: Nil)), ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType))) .foreach { testNullableColumnAccessor(_) } def testNullableColumnAccessor[JvmType]( columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") val nullRow = makeNullRow(1) test(s"Nullable $typeName column accessor: empty column") { val builder = TestNullableColumnBuilder(columnType) val accessor = TestNullableColumnAccessor(builder.build(), columnType) assert(!accessor.hasNext) } test(s"Nullable $typeName column accessor: access null values") { val builder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) (0 until 4).foreach { _ => builder.appendFrom(proj(randomRow), 0) builder.appendFrom(proj(nullRow), 0) } val accessor = TestNullableColumnAccessor(builder.build(), columnType) val row = new GenericMutableRow(1) val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) (0 until 4).foreach { _ => assert(accessor.hasNext) accessor.extractTo(row, 0) assert(converter(row.get(0, columnType.dataType)) === converter(randomRow.get(0, columnType.dataType))) assert(accessor.hasNext) accessor.extractTo(row, 0) assert(row.isNullAt(0)) } assert(!accessor.hasNext) } } }
Example 52
Source File: GroupedIteratorSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType} class GroupedIteratorSuite extends SparkFunSuite { test("basic") { val schema = new StructType().add("i", IntegerType).add("s", StringType) val encoder = RowEncoder(schema) val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(encoder.toRow), Seq('i.int.at(0)), schema.toAttributes) val result = grouped.map { case (key, data) => assert(key.numFields == 1) key.getInt(0) -> data.map(encoder.fromRow).toSeq }.toSeq assert(result == 1 -> Seq(input(0), input(1)) :: 2 -> Seq(input(2)) :: Nil) } test("group by 2 columns") { val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType) val encoder = RowEncoder(schema) val input = Seq( Row(1, 2L, "a"), Row(1, 2L, "b"), Row(1, 3L, "c"), Row(2, 1L, "d"), Row(3, 2L, "e")) val grouped = GroupedIterator(input.iterator.map(encoder.toRow), Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes) val result = grouped.map { case (key, data) => assert(key.numFields == 2) (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq) }.toSeq assert(result == (1, 2L, Seq(input(0), input(1))) :: (1, 3L, Seq(input(2))) :: (2, 1L, Seq(input(3))) :: (3, 2L, Seq(input(4))) :: Nil) } test("do nothing to the value iterator") { val schema = new StructType().add("i", IntegerType).add("s", StringType) val encoder = RowEncoder(schema) val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(encoder.toRow), Seq('i.int.at(0)), schema.toAttributes) assert(grouped.length == 2) } }
Example 53
Source File: RowSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class RowSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("create row") { val expected = new GenericMutableRow(4) expected.setInt(0, 2147483647) expected.update(1, UTF8String.fromString("this is a string")) expected.setBoolean(2, false) expected.setNullAt(3) val actual1 = Row(2147483647, "this is a string", false, null) assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) assert(expected.isNullAt(3) === actual1.isNullAt(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) assert(expected.isNullAt(3) === actual2.isNullAt(3)) } test("SpecificMutableRow.update with null") { val row = new SpecificMutableRow(Seq(IntegerType)) row(0) = null assert(row.isNullAt(0)) } test("serialize w/ kryo") { val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() val serializer = new SparkSqlSerializer(sparkContext.getConf) val instance = serializer.newInstance() val ser = instance.serialize(row) val de = instance.deserialize(ser).asInstanceOf[Row] assert(de === row) } test("get values by field name on Row created via .toDF") { val row = Seq((1, Seq(1))).toDF("a", "b").first() assert(row.getAs[Int]("a") === 1) assert(row.getAs[Seq[Int]]("b") === Seq(1)) intercept[IllegalArgumentException]{ row.getAs[Int]("c") } } test("float NaN == NaN") { val r1 = Row(Float.NaN) val r2 = Row(Float.NaN) assert(r1 === r2) } test("double NaN == NaN") { val r1 = Row(Double.NaN) val r2 = Row(Double.NaN) assert(r1 === r2) } test("equals and hashCode") { val r1 = Row("Hello") val r2 = Row("Hello") assert(r1 === r2) assert(r1.hashCode() === r2.hashCode()) val r3 = Row("World") assert(r3.hashCode() != r1.hashCode()) } }
Example 54
Source File: ResolvedDataSourceSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.sources import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.datasources.ResolvedDataSource class ResolvedDataSourceSuite extends SparkFunSuite { test("jdbc") { assert( ResolvedDataSource.lookupDataSource("jdbc") === classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) assert( ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.jdbc") === classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) assert( ResolvedDataSource.lookupDataSource("org.apache.spark.sql.jdbc") === classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) } test("json") { assert( ResolvedDataSource.lookupDataSource("json") === classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) assert( ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.json") === classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) assert( ResolvedDataSource.lookupDataSource("org.apache.spark.sql.json") === classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) } test("parquet") { assert( ResolvedDataSource.lookupDataSource("parquet") === classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) assert( ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.parquet") === classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) assert( ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") === classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) } test("error message for unknown data sources") { val error1 = intercept[ClassNotFoundException] { ResolvedDataSource.lookupDataSource("avro") } assert(error1.getMessage.contains("spark-packages")) val error2 = intercept[ClassNotFoundException] { ResolvedDataSource.lookupDataSource("com.databricks.spark.avro") } assert(error2.getMessage.contains("spark-packages")) val error3 = intercept[ClassNotFoundException] { ResolvedDataSource.lookupDataSource("asfdwefasdfasdf") } assert(error3.getMessage.contains("spark-packages")) } }
Example 55
Source File: RateLimiterSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.receiver import org.apache.spark.SparkConf import org.apache.spark.SparkFunSuite class RateLimiterSuite extends SparkFunSuite { test("rate limiter initializes even without a maxRate set") { val conf = new SparkConf() val rateLimiter = new RateLimiter(conf){} rateLimiter.updateRate(105) assert(rateLimiter.getCurrentLimit == 105) } test("rate limiter updates when below maxRate") { val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "110") val rateLimiter = new RateLimiter(conf){} rateLimiter.updateRate(105) assert(rateLimiter.getCurrentLimit == 105) } test("rate limiter stays below maxRate despite large updates") { val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "100") val rateLimiter = new RateLimiter(conf){} rateLimiter.updateRate(105) assert(rateLimiter.getCurrentLimit === 100) } }
Example 56
Source File: FailureSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming import java.io.File import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkFunSuite, Logging} import org.apache.spark.util.Utils class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging { private val batchDuration: Duration = Milliseconds(1000) private val numBatches = 30 private var directory: File = null before { directory = Utils.createTempDir() } after { if (directory != null) { Utils.deleteRecursively(directory) } StreamingContext.getActive().foreach { _.stop() } } test("multiple failures with map") { MasterFailureTest.testMap(directory.getAbsolutePath, numBatches, batchDuration) } test("multiple failures with updateStateByKey") { MasterFailureTest.testUpdateStateByKey(directory.getAbsolutePath, numBatches, batchDuration) } }
Example 57
Source File: UIUtilsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.ui import java.util.TimeZone import java.util.concurrent.TimeUnit import org.scalatest.Matchers import org.apache.spark.SparkFunSuite class UIUtilsSuite extends SparkFunSuite with Matchers{ test("shortTimeUnitString") { assert("ns" === UIUtils.shortTimeUnitString(TimeUnit.NANOSECONDS)) assert("us" === UIUtils.shortTimeUnitString(TimeUnit.MICROSECONDS)) assert("ms" === UIUtils.shortTimeUnitString(TimeUnit.MILLISECONDS)) assert("sec" === UIUtils.shortTimeUnitString(TimeUnit.SECONDS)) assert("min" === UIUtils.shortTimeUnitString(TimeUnit.MINUTES)) assert("hrs" === UIUtils.shortTimeUnitString(TimeUnit.HOURS)) assert("days" === UIUtils.shortTimeUnitString(TimeUnit.DAYS)) } test("normalizeDuration") { verifyNormalizedTime(900, TimeUnit.MILLISECONDS, 900) verifyNormalizedTime(1.0, TimeUnit.SECONDS, 1000) verifyNormalizedTime(1.0, TimeUnit.MINUTES, 60 * 1000) verifyNormalizedTime(1.0, TimeUnit.HOURS, 60 * 60 * 1000) verifyNormalizedTime(1.0, TimeUnit.DAYS, 24 * 60 * 60 * 1000) } private def verifyNormalizedTime( expectedTime: Double, expectedUnit: TimeUnit, input: Long): Unit = { val (time, unit) = UIUtils.normalizeDuration(input) time should be (expectedTime +- 1E-6) unit should be (expectedUnit) } test("convertToTimeUnit") { verifyConvertToTimeUnit(60.0 * 1000 * 1000 * 1000, 60 * 1000, TimeUnit.NANOSECONDS) verifyConvertToTimeUnit(60.0 * 1000 * 1000, 60 * 1000, TimeUnit.MICROSECONDS) verifyConvertToTimeUnit(60 * 1000, 60 * 1000, TimeUnit.MILLISECONDS) verifyConvertToTimeUnit(60, 60 * 1000, TimeUnit.SECONDS) verifyConvertToTimeUnit(1, 60 * 1000, TimeUnit.MINUTES) verifyConvertToTimeUnit(1.0 / 60, 60 * 1000, TimeUnit.HOURS) verifyConvertToTimeUnit(1.0 / 60 / 24, 60 * 1000, TimeUnit.DAYS) } private def verifyConvertToTimeUnit( expectedTime: Double, milliseconds: Long, unit: TimeUnit): Unit = { val convertedTime = UIUtils.convertToTimeUnit(milliseconds, unit) convertedTime should be (expectedTime +- 1E-6) } test("formatBatchTime") { val tzForTest = TimeZone.getTimeZone("America/Los_Angeles") val batchTime = 1431637480452L // Thu May 14 14:04:40 PDT 2015 assert("2015/05/14 14:04:40" === UIUtils.formatBatchTime(batchTime, 1000, timezone = tzForTest)) assert("2015/05/14 14:04:40.452" === UIUtils.formatBatchTime(batchTime, 999, timezone = tzForTest)) assert("14:04:40" === UIUtils.formatBatchTime(batchTime, 1000, false, timezone = tzForTest)) assert("14:04:40.452" === UIUtils.formatBatchTime(batchTime, 999, false, timezone = tzForTest)) } }
Example 58
Source File: InputInfoTrackerSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.scheduler import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.streaming.{Time, Duration, StreamingContext} class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { private var ssc: StreamingContext = _ before { val conf = new SparkConf().setMaster("local[2]").setAppName("DirectStreamTacker") if (ssc == null) { ssc = new StreamingContext(conf, Duration(1000)) } } after { if (ssc != null) { ssc.stop() ssc = null } } test("test report and get InputInfo from InputInfoTracker") { val inputInfoTracker = new InputInfoTracker(ssc) val streamId1 = 0 val streamId2 = 1 val time = Time(0L) val inputInfo1 = StreamInputInfo(streamId1, 100L) val inputInfo2 = StreamInputInfo(streamId2, 300L) inputInfoTracker.reportInfo(time, inputInfo1) inputInfoTracker.reportInfo(time, inputInfo2) val batchTimeToInputInfos = inputInfoTracker.getInfo(time) assert(batchTimeToInputInfos.size == 2) assert(batchTimeToInputInfos.keys === Set(streamId1, streamId2)) assert(batchTimeToInputInfos(streamId1) === inputInfo1) assert(batchTimeToInputInfos(streamId2) === inputInfo2) assert(inputInfoTracker.getInfo(time)(streamId1) === inputInfo1) } test("test cleanup InputInfo from InputInfoTracker") { val inputInfoTracker = new InputInfoTracker(ssc) val streamId1 = 0 val inputInfo1 = StreamInputInfo(streamId1, 100L) val inputInfo2 = StreamInputInfo(streamId1, 300L) inputInfoTracker.reportInfo(Time(0), inputInfo1) inputInfoTracker.reportInfo(Time(1), inputInfo2) inputInfoTracker.cleanup(Time(0)) assert(inputInfoTracker.getInfo(Time(0))(streamId1) === inputInfo1) assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2) inputInfoTracker.cleanup(Time(1)) assert(inputInfoTracker.getInfo(Time(0)).get(streamId1) === None) assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2) } }
Example 59
Source File: RecurringTimerSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.util import scala.collection.mutable import scala.concurrent.duration._ import org.scalatest.PrivateMethodTester import org.scalatest.concurrent.Eventually._ import org.apache.spark.SparkFunSuite import org.apache.spark.util.ManualClock class RecurringTimerSuite extends SparkFunSuite with PrivateMethodTester { test("basic") { val clock = new ManualClock() val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long] val timer = new RecurringTimer(clock, 100, time => { results += time }, "RecurringTimerSuite-basic") timer.start(0) eventually(timeout(10.seconds), interval(10.millis)) { assert(results === Seq(0L)) } clock.advance(100) eventually(timeout(10.seconds), interval(10.millis)) { assert(results === Seq(0L, 100L)) } clock.advance(200) eventually(timeout(10.seconds), interval(10.millis)) { assert(results === Seq(0L, 100L, 200L, 300L)) } assert(timer.stop(interruptTimer = true) === 300L) } test("SPARK-10224: call 'callback' after stopping") { val clock = new ManualClock() val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long] val timer = new RecurringTimer(clock, 100, time => { results += time }, "RecurringTimerSuite-SPARK-10224") timer.start(0) eventually(timeout(10.seconds), interval(10.millis)) { assert(results === Seq(0L)) } @volatile var lastTime = -1L // Now RecurringTimer is waiting for the next interval val thread = new Thread { override def run(): Unit = { lastTime = timer.stop(interruptTimer = false) } } thread.start() val stopped = PrivateMethod[RecurringTimer]('stopped) // Make sure the `stopped` field has been changed eventually(timeout(10.seconds), interval(10.millis)) { assert(timer.invokePrivate(stopped()) === true) } clock.advance(200) // When RecurringTimer is awake from clock.waitTillTime, it will call `callback` once. // Then it will find `stopped` is true and exit the loop, but it should call `callback` again // before exiting its internal thread. thread.join() assert(results === Seq(0L, 100L, 200L)) assert(lastTime === 200L) } }
Example 60
Source File: RateLimitedOutputStreamSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.util import java.io.ByteArrayOutputStream import java.util.concurrent.TimeUnit._ import org.apache.spark.SparkFunSuite class RateLimitedOutputStreamSuite extends SparkFunSuite { private def benchmark[U](f: => U): Long = { val start = System.nanoTime f System.nanoTime - start } test("write") { val underlying = new ByteArrayOutputStream val data = "X" * 41000 val stream = new RateLimitedOutputStream(underlying, desiredBytesPerSec = 10000) val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } val seconds = SECONDS.convert(elapsedNs, NANOSECONDS) assert(seconds >= 4, s"Seconds value ($seconds) is less than 4.") assert(seconds <= 30, s"Took more than 30 seconds ($seconds) to write data.") assert(underlying.toString("UTF-8") === data) } }
Example 61
Source File: NettyRpcHandlerSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.rpc.netty import java.net.InetSocketAddress import java.nio.ByteBuffer import io.netty.channel.Channel import org.mockito.Mockito._ import org.mockito.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc._ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) val sm = mock(classOf[StreamManager]) when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any())) .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) nettyRpcHandler.receive(client, null, null) verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) } test("connectionTerminated") { val dispatcher = mock(classOf[Dispatcher]) val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) nettyRpcHandler.receive(client, null, null) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) nettyRpcHandler.connectionTerminated(client) verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) verify(dispatcher, times(1)).postToAll( RemoteProcessDisconnected(RpcAddress("localhost", 40000))) } }
Example 62
Source File: RpcAddressSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.rpc import org.apache.spark.{SparkException, SparkFunSuite} class RpcAddressSuite extends SparkFunSuite { test("hostPort") { val address = RpcAddress("1.2.3.4", 1234) assert(address.host == "1.2.3.4") assert(address.port == 1234) assert(address.hostPort == "1.2.3.4:1234") } test("fromSparkURL") { val address = RpcAddress.fromSparkURL("spark://1.2.3.4:1234") assert(address.host == "1.2.3.4") assert(address.port == 1234) } test("fromSparkURL: a typo url") { val e = intercept[SparkException] { RpcAddress.fromSparkURL("spark://1.2. 3.4:1234") } assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) } test("fromSparkURL: invalid scheme") { val e = intercept[SparkException] { RpcAddress.fromSparkURL("invalid://1.2.3.4:1234") } assert("Invalid master URL: invalid://1.2.3.4:1234" === e.getMessage) } test("toSparkURL") { val address = RpcAddress("1.2.3.4", 1234) assert(address.toSparkURL == "spark://1.2.3.4:1234") } }
Example 63
Source File: NettyBlockTransferServiceSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.network.netty import org.apache.spark.network.BlockDataManager import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito.mock import org.scalatest._ class NettyBlockTransferServiceSuite extends SparkFunSuite with BeforeAndAfterEach with ShouldMatchers { private var service0: NettyBlockTransferService = _ private var service1: NettyBlockTransferService = _ override def afterEach() { if (service0 != null) { service0.close() service0 = null } if (service1 != null) { service1.close() service1 = null } } test("can bind to a random port") { service0 = createService(port = 0) service0.port should not be 0 } test("can bind to two random ports") { service0 = createService(port = 0) service1 = createService(port = 0) service0.port should not be service1.port } test("can bind to a specific port") { val port = 17634 service0 = createService(port) service0.port should be >= port service0.port should be <= (port + 10) // avoid testing equality in case of simultaneous tests } test("can bind to a specific port twice and the second increments") { val port = 17634 service0 = createService(port) service1 = createService(port) service0.port should be >= port service0.port should be <= (port + 10) service1.port should be (service0.port + 1) } private def createService(port: Int): NettyBlockTransferService = { val conf = new SparkConf() .set("spark.app.id", s"test-${getClass.getName}") .set("spark.blockManager.port", port.toString) val securityManager = new SecurityManager(conf) val blockDataManager = mock(classOf[BlockDataManager]) val service = new NettyBlockTransferService(conf, securityManager, numCores = 1) service.init(blockDataManager) service } }
Example 64
Source File: PythonBroadcastSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.api.python import scala.io.Source import java.io.{PrintWriter, File} import org.scalatest.Matchers import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils // This test suite uses SharedSparkContext because we need a SparkEnv in order to deserialize // a PythonBroadcast: class PythonBroadcastSuite extends SparkFunSuite with Matchers with SharedSparkContext { test("PythonBroadcast can be serialized with Kryo (SPARK-4882)") { val tempDir = Utils.createTempDir() val broadcastedString = "Hello, world!" def assertBroadcastIsValid(broadcast: PythonBroadcast): Unit = { val source = Source.fromFile(broadcast.path) val contents = source.mkString source.close() contents should be (broadcastedString) } try { val broadcastDataFile: File = { val file = new File(tempDir, "broadcastData") val printWriter = new PrintWriter(file) printWriter.write(broadcastedString) printWriter.close() file } val broadcast = new PythonBroadcast(broadcastDataFile.getAbsolutePath) assertBroadcastIsValid(broadcast) val conf = new SparkConf().set("spark.kryo.registrationRequired", "true") val deserializedBroadcast = Utils.clone[PythonBroadcast](broadcast, new KryoSerializer(conf).newInstance()) assertBroadcastIsValid(deserializedBroadcast) } finally { Utils.deleteRecursively(tempDir) } } }
Example 65
Source File: PythonRDDSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.api.python import java.io.{ByteArrayOutputStream, DataOutputStream} import org.apache.spark.SparkFunSuite class PythonRDDSuite extends SparkFunSuite { test("Writing large strings to the worker") { val input: List[String] = List("a"*100000) val buffer = new DataOutputStream(new ByteArrayOutputStream) PythonRDD.writeIteratorToStream(input.iterator, buffer) } test("Handle nulls gracefully") { val buffer = new DataOutputStream(new ByteArrayOutputStream) // Should not have NPE when write an Iterator with null in it // The correctness will be tested in Python PythonRDD.writeIteratorToStream(Iterator("a", null), buffer) PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer) PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer) PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer) PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer) PythonRDD.writeIteratorToStream( Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer) } }
Example 66
Source File: SerDeUtilSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.api.python import org.apache.spark.{SharedSparkContext, SparkFunSuite} class SerDeUtilSuite extends SparkFunSuite with SharedSparkContext { test("Converting an empty pair RDD to python does not throw an exception (SPARK-5441)") { val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) SerDeUtil.pairRDDToPython(emptyRdd, 10) } test("Converting an empty python RDD to pair RDD does not throw an exception (SPARK-5441)") { val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) val javaRdd = emptyRdd.toJavaRDD() val pythonRdd = SerDeUtil.javaToPython(javaRdd) SerDeUtil.pythonToPairRDD(pythonRdd, false) } }
Example 67
Source File: PythonRunnerSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy import org.apache.spark.SparkFunSuite import org.apache.spark.util.Utils class PythonRunnerSuite extends SparkFunSuite { // Test formatting a single path to be added to the PYTHONPATH test("format path") { assert(PythonRunner.formatPath("spark.py") === "spark.py") assert(PythonRunner.formatPath("file:/spark.py") === "/spark.py") assert(PythonRunner.formatPath("file:///spark.py") === "/spark.py") assert(PythonRunner.formatPath("local:/spark.py") === "/spark.py") assert(PythonRunner.formatPath("local:///spark.py") === "/spark.py") if (Utils.isWindows) { assert(PythonRunner.formatPath("file:/C:/a/b/spark.py", testWindows = true) === "C:/a/b/spark.py") assert(PythonRunner.formatPath("C:\\a\\b\\spark.py", testWindows = true) === "C:/a/b/spark.py") assert(PythonRunner.formatPath("C:\\a b\\spark.py", testWindows = true) === "C:/a b/spark.py") } intercept[IllegalArgumentException] { PythonRunner.formatPath("one:two") } intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:s3:xtremeFS") } intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:/path/to/some.py") } } // Test formatting multiple comma-separated paths to be added to the PYTHONPATH test("format paths") { assert(PythonRunner.formatPaths("spark.py") === Array("spark.py")) assert(PythonRunner.formatPaths("file:/spark.py") === Array("/spark.py")) assert(PythonRunner.formatPaths("file:/app.py,local:/spark.py") === Array("/app.py", "/spark.py")) assert(PythonRunner.formatPaths("me.py,file:/you.py,local:/we.py") === Array("me.py", "/you.py", "/we.py")) if (Utils.isWindows) { assert(PythonRunner.formatPaths("C:\\a\\b\\spark.py", testWindows = true) === Array("C:/a/b/spark.py")) assert(PythonRunner.formatPaths("C:\\free.py,pie.py", testWindows = true) === Array("C:/free.py", "pie.py")) assert(PythonRunner.formatPaths("lovely.py,C:\\free.py,file:/d:/fry.py", testWindows = true) === Array("lovely.py", "C:/free.py", "d:/fry.py")) } intercept[IllegalArgumentException] { PythonRunner.formatPaths("one:two,three") } intercept[IllegalArgumentException] { PythonRunner.formatPaths("two,three,four:five:six") } intercept[IllegalArgumentException] { PythonRunner.formatPaths("hdfs:/some.py,foo.py") } intercept[IllegalArgumentException] { PythonRunner.formatPaths("foo.py,hdfs:/some.py") } } }
Example 68
Source File: LogUrlsStandaloneSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy import java.net.URL import scala.collection.mutable import scala.io.Source import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListener} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.util.SparkConfWithEnv class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { private val WAIT_TIMEOUT_MILLIS = 10000 test("verify that correct log urls get propagated from workers") { sc = new SparkContext("local-cluster[2,1,1024]", "test") val listener = new SaveExecutorInfo sc.addSparkListener(listener) // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.addedExecutorInfos.values.foreach { info => assert(info.logUrlMap.nonEmpty) // Browse to each URL to check that it's valid info.logUrlMap.foreach { case (logType, logUrl) => val html = Source.fromURL(logUrl).mkString assert(html.contains(s"$logType log page")) } } } test("verify that log urls reflect SPARK_PUBLIC_DNS (SPARK-6175)") { val SPARK_PUBLIC_DNS = "public_dns" val conf = new SparkConfWithEnv(Map("SPARK_PUBLIC_DNS" -> SPARK_PUBLIC_DNS)).set( "spark.extraListeners", classOf[SaveExecutorInfo].getName) sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo] assert(listeners.size === 1) val listener = listeners(0) listener.addedExecutorInfos.values.foreach { info => assert(info.logUrlMap.nonEmpty) info.logUrlMap.values.foreach { logUrl => assert(new URL(logUrl).getHost === SPARK_PUBLIC_DNS) } } } } private[spark] class SaveExecutorInfo extends SparkListener { val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() override def onExecutorAdded(executor: SparkListenerExecutorAdded) { addedExecutorInfos(executor.executorId) = executor.executorInfo } }
Example 69
Source File: ClientSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy import org.scalatest.Matchers import org.apache.spark.SparkFunSuite class ClientSuite extends SparkFunSuite with Matchers { test("correctly validates driver jar URL's") { ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true) ClientArguments.isValidJarUrl("https://someHost:8080/foo.jar") should be (true) // file scheme with authority and path is valid. ClientArguments.isValidJarUrl("file://somehost/path/to/a/jarFile.jar") should be (true) // file scheme without path is not valid. // In this case, jarFile.jar is recognized as authority. ClientArguments.isValidJarUrl("file://jarFile.jar") should be (false) // file scheme without authority but with triple slash is valid. ClientArguments.isValidJarUrl("file:///some/path/to/a/jarFile.jar") should be (true) ClientArguments.isValidJarUrl("hdfs://someHost:1234/foo.jar") should be (true) ClientArguments.isValidJarUrl("hdfs://someHost:1234/foo") should be (false) ClientArguments.isValidJarUrl("/missing/a/protocol/jarfile.jar") should be (false) ClientArguments.isValidJarUrl("not-even-a-path.jar") should be (false) // This URI doesn't have authority and path. ClientArguments.isValidJarUrl("hdfs:someHost:1234/jarfile.jar") should be (false) // Invalid syntax. ClientArguments.isValidJarUrl("hdfs:") should be (false) } }
Example 70
Source File: MasterWebUISuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.master.ui import java.util.Date import scala.io.Source import scala.language.postfixOps import org.json4s.jackson.JsonMethods._ import org.json4s.JsonAST.{JNothing, JString, JInt} import org.mockito.Mockito.{mock, when} import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SecurityManager, SparkFunSuite} import org.apache.spark.deploy.DeployMessages.MasterStateResponse import org.apache.spark.deploy.DeployTestUtils._ import org.apache.spark.deploy.master._ import org.apache.spark.rpc.RpcEnv class MasterWebUISuite extends SparkFunSuite with BeforeAndAfter { val masterPage = mock(classOf[MasterPage]) val master = { val conf = new SparkConf val securityMgr = new SecurityManager(conf) val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) master } val masterWebUI = new MasterWebUI(master, 0, customMasterPage = Some(masterPage)) before { masterWebUI.bind() } after { masterWebUI.stop() } test("list applications") { val worker = createWorkerInfo() val appDesc = createAppDesc() // use new start date so it isn't filtered by UI val activeApp = new ApplicationInfo( new Date().getTime, "id", appDesc, new Date(), null, Int.MaxValue) activeApp.addExecutor(worker, 2) val workers = Array[WorkerInfo](worker) val activeApps = Array(activeApp) val completedApps = Array[ApplicationInfo]() val activeDrivers = Array[DriverInfo]() val completedDrivers = Array[DriverInfo]() val stateResponse = new MasterStateResponse( "host", 8080, None, workers, activeApps, completedApps, activeDrivers, completedDrivers, RecoveryState.ALIVE) when(masterPage.getMasterState).thenReturn(stateResponse) val resultJson = Source.fromURL( s"http://localhost:${masterWebUI.boundPort}/api/v1/applications") .mkString val parsedJson = parse(resultJson) val firstApp = parsedJson(0) assert(firstApp \ "id" === JString(activeApp.id)) assert(firstApp \ "name" === JString(activeApp.desc.name)) assert(firstApp \ "coresGranted" === JInt(2)) assert(firstApp \ "maxCores" === JInt(4)) assert(firstApp \ "memoryPerExecutorMB" === JInt(1234)) assert(firstApp \ "coresPerExecutor" === JNothing) } }
Example 71
Source File: WorkerWatcherSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SecurityManager import org.apache.spark.rpc.{RpcAddress, RpcEnv} class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) assert(workerWatcher.isShutDown) rpcEnv.shutdown() } test("WorkerWatcher stays alive on invalid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") val otherRpcAddress = RpcAddress("4.3.2.1", 1234) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(otherRpcAddress) assert(!workerWatcher.isShutDown) rpcEnv.shutdown() } }
Example 72
Source File: WorkerArgumentsTest.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.SparkConfWithEnv class WorkerArgumentsTest extends SparkFunSuite { test("Memory can't be set to 0 when cmd line args leave off M or G") { val conf = new SparkConf val args = Array("-m", "10000", "spark://localhost:0000 ") intercept[IllegalStateException] { new WorkerArguments(args, conf) } } test("Memory can't be set to 0 when SPARK_WORKER_MEMORY env property leaves off M or G") { val args = Array("spark://localhost:0000 ") val conf = new SparkConfWithEnv(Map("SPARK_WORKER_MEMORY" -> "50000")) intercept[IllegalStateException] { new WorkerArguments(args, conf) } } test("Memory correctly set when SPARK_WORKER_MEMORY env property appends G") { val args = Array("spark://localhost:0000 ") val conf = new SparkConfWithEnv(Map("SPARK_WORKER_MEMORY" -> "5G")) val workerArgs = new WorkerArguments(args, conf) assert(workerArgs.memory === 5120) } test("Memory correctly set from args with M appended to memory value") { val conf = new SparkConf val args = Array("-m", "10000M", "spark://localhost:0000 ") val workerArgs = new WorkerArguments(args, conf) assert(workerArgs.memory === 10000) } }
Example 73
Source File: ExecutorRunnerTest.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker import java.io.File import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} class ExecutorRunnerTest extends SparkFunSuite { test("command includes appId") { val appId = "12345-worker321-9876" val conf = new SparkConf val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val appDesc = new ApplicationDescription("app name", Some(8), 500, Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl") val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", 123, "publicAddr", new File(sparkHome), new File("ooga"), "blah", conf, Seq("localDir"), ExecutorState.RUNNING) val builder = CommandUtils.buildProcessBuilder( appDesc.command, new SecurityManager(conf), 512, sparkHome, er.substituteVariables) val builderCommand = builder.command() assert(builderCommand.get(builderCommand.size() - 1) === appId) } }
Example 74
Source File: CommandUtilsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.Command import org.apache.spark.util.Utils import org.scalatest.{Matchers, PrivateMethodTester} class CommandUtilsSuite extends SparkFunSuite with Matchers with PrivateMethodTester { test("set libraryPath correctly") { val appId = "12345-worker321-9876" val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val cmd = new Command("mainClass", Seq(), Map(), Seq(), Seq("libraryPathToB"), Seq()) val builder = CommandUtils.buildProcessBuilder( cmd, new SecurityManager(new SparkConf), 512, sparkHome, t => t) val libraryPath = Utils.libraryPathEnvName val env = builder.environment env.keySet should contain(libraryPath) assert(env.get(libraryPath).startsWith("libraryPathToB")) } test("auth secret shouldn't appear in java opts") { val buildLocalCommand = PrivateMethod[Command]('buildLocalCommand) val conf = new SparkConf val secret = "This is the secret sauce" // set auth secret conf.set(SecurityManager.SPARK_AUTH_SECRET_CONF, secret) val command = new Command("mainClass", Seq(), Map(), Seq(), Seq("lib"), Seq("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF + "=" + secret)) // auth is not set var cmd = CommandUtils invokePrivate buildLocalCommand( command, new SecurityManager(conf), (t: String) => t, Seq(), Map()) assert(!cmd.javaOpts.exists(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF))) assert(!cmd.environment.contains(SecurityManager.ENV_AUTH_SECRET)) // auth is set to false conf.set(SecurityManager.SPARK_AUTH_CONF, "false") cmd = CommandUtils invokePrivate buildLocalCommand( command, new SecurityManager(conf), (t: String) => t, Seq(), Map()) assert(!cmd.javaOpts.exists(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF))) assert(!cmd.environment.contains(SecurityManager.ENV_AUTH_SECRET)) // auth is set to true conf.set(SecurityManager.SPARK_AUTH_CONF, "true") cmd = CommandUtils invokePrivate buildLocalCommand( command, new SecurityManager(conf), (t: String) => t, Seq(), Map()) assert(!cmd.javaOpts.exists(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF))) assert(cmd.environment(SecurityManager.ENV_AUTH_SECRET) === secret) } }
Example 75
Source File: LogPageSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker.ui import java.io.{File, FileWriter} import org.mockito.Mockito.{mock, when} import org.scalatest.PrivateMethodTester import org.apache.spark.SparkFunSuite class LogPageSuite extends SparkFunSuite with PrivateMethodTester { test("get logs simple") { val webui = mock(classOf[WorkerWebUI]) val tmpDir = new File(sys.props("java.io.tmpdir")) val workDir = new File(tmpDir, "work-dir") workDir.mkdir() when(webui.workDir).thenReturn(workDir) val logPage = new LogPage(webui) // Prepare some fake log files to read later val out = "some stdout here" val err = "some stderr here" val tmpOut = new File(workDir, "stdout") val tmpErr = new File(workDir, "stderr") val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory val tmpOutBad = new File(tmpDir, "stdout") val tmpRand = new File(workDir, "random") write(tmpOut, out) write(tmpErr, err) write(tmpOutBad, out) write(tmpErrBad, err) write(tmpRand, "1 6 4 5 2 7 8") // Get the logs. All log types other than "stderr" or "stdout" will be rejected val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog) val (stdout, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100) val (stderr, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100) val (error1, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100) val (error2, _, _, _) = logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100) // These files exist, but live outside the working directory val (error3, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) val (error4, _, _, _) = logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) assert(stdout === out) assert(stderr === err) assert(error1.startsWith("Error: Log type must be one of ")) assert(error2.startsWith("Error: Log type must be one of ")) assert(error3.startsWith("Error: invalid log directory")) assert(error4.startsWith("Error: invalid log directory")) } private def write(f: File, s: String): Unit = { val writer = new FileWriter(f) try { writer.write(s) } finally { writer.close() } } }
Example 76
Source File: PagedTableSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.ui import scala.xml.Node import org.apache.spark.SparkFunSuite class PagedDataSourceSuite extends SparkFunSuite { test("basic") { val dataSource1 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) assert(dataSource1.pageData(1) === PageData(3, (1 to 2))) val dataSource2 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) assert(dataSource2.pageData(2) === PageData(3, (3 to 4))) val dataSource3 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) assert(dataSource3.pageData(3) === PageData(3, Seq(5))) val dataSource4 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) val e1 = intercept[IndexOutOfBoundsException] { dataSource4.pageData(4) } assert(e1.getMessage === "Page 4 is out of range. Please select a page number between 1 and 3.") val dataSource5 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) val e2 = intercept[IndexOutOfBoundsException] { dataSource5.pageData(0) } assert(e2.getMessage === "Page 0 is out of range. Please select a page number between 1 and 3.") } } class PagedTableSuite extends SparkFunSuite { test("pageNavigation") { // Create a fake PagedTable to test pageNavigation val pagedTable = new PagedTable[Int] { override def tableId: String = "" override def tableCssClass: String = "" override def dataSource: PagedDataSource[Int] = null override def pageLink(page: Int): String = page.toString override def headers: Seq[Node] = Nil override def row(t: Int): Seq[Node] = Nil override def goButtonJavascriptFunction: (String, String) = ("", "") } assert(pagedTable.pageNavigation(1, 10, 1) === Nil) assert( (pagedTable.pageNavigation(1, 10, 2).head \\ "li").map(_.text.trim) === Seq("1", "2", ">")) assert( (pagedTable.pageNavigation(2, 10, 2).head \\ "li").map(_.text.trim) === Seq("<", "1", "2")) assert((pagedTable.pageNavigation(1, 10, 100).head \\ "li").map(_.text.trim) === (1 to 10).map(_.toString) ++ Seq(">", ">>")) assert((pagedTable.pageNavigation(2, 10, 100).head \\ "li").map(_.text.trim) === Seq("<") ++ (1 to 10).map(_.toString) ++ Seq(">", ">>")) assert((pagedTable.pageNavigation(100, 10, 100).head \\ "li").map(_.text.trim) === Seq("<<", "<") ++ (91 to 100).map(_.toString)) assert((pagedTable.pageNavigation(99, 10, 100).head \\ "li").map(_.text.trim) === Seq("<<", "<") ++ (91 to 100).map(_.toString) ++ Seq(">")) assert((pagedTable.pageNavigation(11, 10, 100).head \\ "li").map(_.text.trim) === Seq("<<", "<") ++ (11 to 20).map(_.toString) ++ Seq(">", ">>")) assert((pagedTable.pageNavigation(93, 10, 97).head \\ "li").map(_.text.trim) === Seq("<<", "<") ++ (91 to 97).map(_.toString) ++ Seq(">")) } } private[spark] class SeqPagedDataSource[T](seq: Seq[T], pageSize: Int) extends PagedDataSource[T](pageSize) { override protected def dataSize: Int = seq.size override protected def sliceData(from: Int, to: Int): Seq[T] = seq.slice(from, to) }
Example 77
Source File: UISuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.ui import java.net.ServerSocket import scala.io.Source import scala.util.{Failure, Success, Try} import org.eclipse.jetty.servlet.ServletContextHandler import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} class UISuite extends SparkFunSuite { private def newSparkContext(): SparkContext = { val conf = new SparkConf() .setMaster("local") .setAppName("test") .set("spark.ui.enabled", "true") val sc = new SparkContext(conf) assert(sc.ui.isDefined) sc } ignore("basic ui visibility") { withSpark(newSparkContext()) { sc => // test if the ui is visible, and all the expected tabs are visible eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL(sc.ui.get.appUIAddress).mkString assert(!html.contains("random data that should not be present")) assert(html.toLowerCase.contains("stages")) assert(html.toLowerCase.contains("storage")) assert(html.toLowerCase.contains("environment")) assert(html.toLowerCase.contains("executors")) } } } ignore("visibility at localhost:4040") { withSpark(newSparkContext()) { sc => // test if visible from http://localhost:4040 eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL("http://localhost:4040").mkString assert(html.toLowerCase.contains("stages")) } } } test("jetty selects different port under contention") { val server = new ServerSocket(0) val startPort = server.getLocalPort val serverInfo1 = JettyUtils.startJettyServer( "0.0.0.0", startPort, Seq[ServletContextHandler](), new SparkConf) val serverInfo2 = JettyUtils.startJettyServer( "0.0.0.0", startPort, Seq[ServletContextHandler](), new SparkConf) // Allow some wiggle room in case ports on the machine are under contention val boundPort1 = serverInfo1.boundPort val boundPort2 = serverInfo2.boundPort assert(boundPort1 != startPort) assert(boundPort2 != startPort) assert(boundPort1 != boundPort2) serverInfo1.server.stop() serverInfo2.server.stop() server.close() } test("jetty binds to port 0 correctly") { val serverInfo = JettyUtils.startJettyServer( "0.0.0.0", 0, Seq[ServletContextHandler](), new SparkConf) val server = serverInfo.server val boundPort = serverInfo.boundPort assert(server.getState === "STARTED") assert(boundPort != 0) Try { new ServerSocket(boundPort) } match { case Success(s) => fail("Port %s doesn't seem used by jetty server".format(boundPort)) case Failure(e) => } } test("verify appUIAddress contains the scheme") { withSpark(newSparkContext()) { sc => val ui = sc.ui.get val uiAddress = ui.appUIAddress val uiHostPort = ui.appUIHostPort assert(uiAddress.equals("http://" + uiHostPort)) } } test("verify appUIAddress contains the port") { withSpark(newSparkContext()) { sc => val ui = sc.ui.get val splitUIAddress = ui.appUIAddress.split(':') val boundPort = ui.boundPort assert(splitUIAddress(2).toInt == boundPort) } } }
Example 78
Source File: UIUtilsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.ui import scala.xml.Elem import org.apache.spark.SparkFunSuite class UIUtilsSuite extends SparkFunSuite { import UIUtils._ test("makeDescription") { verify( """test <a href="/link"> text </a>""", <span class="description-input">test <a href="/link"> text </a></span>, "Correctly formatted text with only anchors and relative links should generate HTML" ) verify( """test <a href="/link" text </a>""", <span class="description-input">{"""test <a href="/link" text </a>"""}</span>, "Badly formatted text should make the description be treated as a streaming instead of HTML" ) verify( """test <a href="link"> text </a>""", <span class="description-input">{"""test <a href="link"> text </a>"""}</span>, "Non-relative links should make the description be treated as a string instead of HTML" ) verify( """test<a><img></img></a>""", <span class="description-input">{"""test<a><img></img></a>"""}</span>, "Non-anchor elements should make the description be treated as a string instead of HTML" ) verify( """test <a href="/link"> text </a>""", <span class="description-input">test <a href="base/link"> text </a></span>, baseUrl = "base", errorMsg = "Base URL should be prepended to html links" ) } test("SPARK-11906: Progress bar should not overflow because of speculative tasks") { val generated = makeProgressBar(2, 3, 0, 0, 4).head.child.filter(_.label == "div") val expected = Seq( <div class="bar bar-completed" style="width: 75.0%"></div>, <div class="bar bar-running" style="width: 25.0%"></div> ) assert(generated.sameElements(expected), s"\nRunning progress bar should round down\n\nExpected:\n$expected\nGenerated:\n$generated") } test("decodeURLParameter (SPARK-12708: Sorting task error in Stages Page when yarn mode.)") { val encoded1 = "%252F" val decoded1 = "/" val encoded2 = "%253Cdriver%253E" val decoded2 = "<driver>" assert(decoded1 === decodeURLParameter(encoded1)) assert(decoded2 === decodeURLParameter(encoded2)) // verify that no affect to decoded URL. assert(decoded1 === decodeURLParameter(decoded1)) assert(decoded2 === decodeURLParameter(decoded2)) } private def verify( desc: String, expected: Elem, errorMsg: String = "", baseUrl: String = ""): Unit = { val generated = makeDescription(desc, baseUrl) assert(generated.sameElements(expected), s"\n$errorMsg\n\nExpected:\n$expected\nGenerated:\n$generated") } }
Example 79
Source File: GenericAvroSerializerSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.ByteBuffer import com.esotericsoftware.kryo.io.{Output, Input} import org.apache.avro.{SchemaBuilder, Schema} import org.apache.avro.generic.GenericData.Record import org.apache.spark.{SparkFunSuite, SharedSparkContext} class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") val schema : Schema = SchemaBuilder .record("testRecord").fields() .requiredString("data") .endRecord() val record = new Record(schema) record.put("data", "test data") test("schema compression and decompression") { val genericSer = new GenericAvroSerializer(conf.getAvroSchema) assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema)))) } test("record serialization and deserialization") { val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val outputStream = new ByteArrayOutputStream() val output = new Output(outputStream) genericSer.serializeDatum(record, output) output.flush() output.close() val input = new Input(new ByteArrayInputStream(outputStream.toByteArray)) assert(genericSer.deserializeDatum(input) === record) } test("uses schema fingerprint to decrease message size") { val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema) val output = new Output(new ByteArrayOutputStream()) val beginningNormalPosition = output.total() genericSerFull.serializeDatum(record, output) output.flush() val normalLength = output.total - beginningNormalPosition conf.registerAvroSchemas(schema) val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema) val beginningFingerprintPosition = output.total() genericSerFinger.serializeDatum(record, output) val fingerprintLength = output.total - beginningFingerprintPosition assert(fingerprintLength < normalLength) } test("caches previously seen schemas") { val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val compressedSchema = genericSer.compress(schema) val decompressedSchema = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) assert(compressedSchema.eq(genericSer.compress(schema))) assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) } }
Example 80
Source File: KryoSerializerResizableOutputSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.serializer import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SparkContext import org.apache.spark.LocalSparkContext import org.apache.spark.SparkException class KryoSerializerResizableOutputSuite extends SparkFunSuite { // trial and error showed this will not serialize with 1mb buffer val x = (1 to 400000).toArray test("kryo without resizable output buffer should fail on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryoserializer.buffer", "1m") conf.set("spark.kryoserializer.buffer.max", "1m") val sc = new SparkContext("local", "test", conf) intercept[SparkException](sc.parallelize(x).collect()) LocalSparkContext.stop(sc) } test("kryo with resizable output buffer should succeed on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryoserializer.buffer", "1m") conf.set("spark.kryoserializer.buffer.max", "2m") val sc = new SparkContext("local", "test", conf) assert(sc.parallelize(x).collect() === x) LocalSparkContext.stop(sc) } }
Example 81
Source File: JavaSerializerSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.serializer import org.apache.spark.{SparkConf, SparkFunSuite} class JavaSerializerSuite extends SparkFunSuite { test("JavaSerializer instances are serializable") { val serializer = new JavaSerializer(new SparkConf()) val instance = serializer.newInstance() instance.deserialize[JavaSerializer](instance.serialize(serializer)) } test("Deserialize object containing a primitive Class as attribute") { val serializer = new JavaSerializer(new SparkConf()) val instance = serializer.newInstance() instance.deserialize[JavaSerializer](instance.serialize(new ContainsPrimitiveClass())) } } private class ContainsPrimitiveClass extends Serializable { val intClass = classOf[Int] val longClass = classOf[Long] val shortClass = classOf[Short] val charClass = classOf[Char] val doubleClass = classOf[Double] val floatClass = classOf[Float] val booleanClass = classOf[Boolean] val byteClass = classOf[Byte] val voidClass = classOf[Void] }
Example 82
Source File: ProactiveClosureSerializationSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.serializer import org.apache.spark.{SharedSparkContext, SparkException, SparkFunSuite} import org.apache.spark.rdd.RDD class UnserializableClass { def op[T](x: T): String = x.toString def pred[T](x: T): Boolean = x.toString.length % 2 == 0 } class ProactiveClosureSerializationSuite extends SparkFunSuite with SharedSparkContext { def fixture: (RDD[String], UnserializableClass) = { (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) } test("throws expected serialization exceptions on actions") { val (data, uc) = fixture val ex = intercept[SparkException] { data.map(uc.op(_)).count() } assert(ex.getMessage.contains("Task not serializable")) } // There is probably a cleaner way to eliminate boilerplate here, but we're // iterating over a map from transformation names to functions that perform that // transformation on a given RDD, creating one test case for each for (transformation <- Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _, "mapPartitions" -> xmapPartitions _, "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _)) { val (name, xf) = transformation test(s"$name transformations throw proactive serialization exceptions") { val (data, uc) = fixture val ex = intercept[SparkException] { xf(data, uc) } assert(ex.getMessage.contains("Task not serializable"), s"RDD.$name doesn't proactively throw NotSerializableException") } } private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.map(y => uc.op(y)) private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.flatMap(y => Seq(uc.op(y))) private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = x.filter(y => uc.pred(y)) private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitions(_.map(y => uc.op(y))) private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y))) }
Example 83
Source File: CoarseGrainedSchedulerBackendSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.scheduler import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.util.{SerializableBuffer, AkkaUtils} class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { test("serialized task larger than akka frame size") { val conf = new SparkConf conf.set("spark.akka.frameSize", "1") conf.set("spark.default.parallelism", "1") sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf) val frameSize = AkkaUtils.maxFrameSizeBytes(sc.conf) val buffer = new SerializableBuffer(java.nio.ByteBuffer.allocate(2 * frameSize)) val larger = sc.parallelize(Seq(buffer)) val thrown = intercept[SparkException] { larger.collect() } assert(thrown.getMessage.contains("using broadcast variables for large values")) val smaller = sc.parallelize(1 to 4).collect() assert(smaller.size === 4) } }
Example 84
Source File: MesosClusterSchedulerSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.scheduler.mesos import java.util.Date import org.scalatest.mock.MockitoSugar import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.scheduler.cluster.mesos._ import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { private val command = new Command("mainClass", Seq("arg"), null, null, null, null) test("can queue drivers") { val conf = new SparkConf() conf.setMaster("mesos://localhost:5050") conf.setAppName("spark mesos") val scheduler = new MesosClusterScheduler( new BlackHoleMesosClusterPersistenceEngineFactory, conf) { override def start(): Unit = { ready = true } } scheduler.start() val response = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 1000, 1, true, command, Map[String, String](), "s1", new Date())) assert(response.success) val response2 = scheduler.submitDriver(new MesosDriverDescription( "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) assert(response2.success) val state = scheduler.getSchedulerState() val queuedDrivers = state.queuedDrivers.toList assert(queuedDrivers(0).submissionId == response.submissionId) assert(queuedDrivers(1).submissionId == response2.submissionId) } test("can kill queued drivers") { val conf = new SparkConf() conf.setMaster("mesos://localhost:5050") conf.setAppName("spark mesos") val scheduler = new MesosClusterScheduler( new BlackHoleMesosClusterPersistenceEngineFactory, conf) { override def start(): Unit = { ready = true } } scheduler.start() val response = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 1000, 1, true, command, Map[String, String](), "s1", new Date())) assert(response.success) val killResponse = scheduler.killDriver(response.submissionId) assert(killResponse.success) val state = scheduler.getSchedulerState() assert(state.queuedDrivers.isEmpty) } }
Example 85
Source File: MesosTaskLaunchDataSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.scheduler.cluster.mesos import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite class MesosTaskLaunchDataSuite extends SparkFunSuite { test("serialize and deserialize data must be same") { val serializedTask = ByteBuffer.allocate(40) (Range(100, 110).map(serializedTask.putInt(_))) serializedTask.rewind val attemptNumber = 100 val byteString = MesosTaskLaunchData(serializedTask, attemptNumber).toByteString serializedTask.rewind val mesosTaskLaunchData = MesosTaskLaunchData.fromByteString(byteString) assert(mesosTaskLaunchData.attemptNumber == attemptNumber) assert(mesosTaskLaunchData.serializedTask.equals(serializedTask)) } }
Example 86
Source File: SparkListenerWithClusterSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.scheduler import scala.collection.mutable import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.cluster.ExecutorInfo val WAIT_TIMEOUT_MILLIS = 10000 before { sc = new SparkContext("local-cluster[2,1,1024]", "SparkListenerSuite") } test("SparkListener sends executor added message") { val listener = new SaveExecutorInfo sc.addSparkListener(listener) // This test will check if the number of executors received by "SparkListener" is same as the // number of all executors, so we need to wait until all executors are up sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) rdd2.setName("Target RDD") rdd2.count() sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(listener.addedExecutorInfo.size == 2) assert(listener.addedExecutorInfo("0").totalCores == 1) assert(listener.addedExecutorInfo("1").totalCores == 1) } private class SaveExecutorInfo extends SparkListener { val addedExecutorInfo = mutable.Map[String, ExecutorInfo]() override def onExecutorAdded(executor: SparkListenerExecutorAdded) { addedExecutorInfo(executor.executorId) = executor.executorInfo } } }
Example 87
Source File: OutputCommitCoordinatorIntegrationSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.scheduler import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} import org.scalatest.concurrent.Timeouts import org.scalatest.time.{Span, Seconds} import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext} import org.apache.spark.util.Utils class OutputCommitCoordinatorIntegrationSuite extends SparkFunSuite with LocalSparkContext with Timeouts { override def beforeAll(): Unit = { super.beforeAll() val conf = new SparkConf() .set("master", "local[2,4]") .set("spark.speculation", "true") .set("spark.hadoop.mapred.output.committer.class", classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName) sc = new SparkContext("local[2, 4]", "test", conf) } test("exception thrown in OutputCommitter.commitTask()") { // Regression test for SPARK-10381 failAfter(Span(60, Seconds)) { val tempDir = Utils.createTempDir() try { sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out") } finally { Utils.deleteRecursively(tempDir) } } } } private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter { override def commitTask(context: TaskAttemptContext): Unit = { val ctx = TaskContext.get() if (ctx.attemptNumber < 1) { throw new java.io.FileNotFoundException("Intentional exception") } super.commitTask(context) } }
Example 88
Source File: CompletionIteratorSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.util import org.apache.spark.SparkFunSuite class CompletionIteratorSuite extends SparkFunSuite { test("basic test") { var numTimesCompleted = 0 val iter = List(1, 2, 3).iterator val completionIter = CompletionIterator[Int, Iterator[Int]](iter, { numTimesCompleted += 1 }) assert(completionIter.hasNext) assert(completionIter.next() === 1) assert(numTimesCompleted === 0) assert(completionIter.hasNext) assert(completionIter.next() === 2) assert(numTimesCompleted === 0) assert(completionIter.hasNext) assert(completionIter.next() === 3) assert(numTimesCompleted === 0) assert(!completionIter.hasNext) assert(numTimesCompleted === 1) // SPARK-4264: Calling hasNext should not trigger the completion callback again. assert(!completionIter.hasNext) assert(numTimesCompleted === 1) } }
Example 89
Source File: NextIteratorSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.util import java.util.NoSuchElementException import scala.collection.mutable.Buffer import org.scalatest.Matchers import org.apache.spark.SparkFunSuite class NextIteratorSuite extends SparkFunSuite with Matchers { test("one iteration") { val i = new StubIterator(Buffer(1)) i.hasNext should be (true) i.next should be (1) i.hasNext should be (false) intercept[NoSuchElementException] { i.next() } } test("two iterations") { val i = new StubIterator(Buffer(1, 2)) i.hasNext should be (true) i.next should be (1) i.hasNext should be (true) i.next should be (2) i.hasNext should be (false) intercept[NoSuchElementException] { i.next() } } test("empty iteration") { val i = new StubIterator(Buffer()) i.hasNext should be (false) intercept[NoSuchElementException] { i.next() } } test("close is called once for empty iterations") { val i = new StubIterator(Buffer()) i.hasNext should be (false) i.hasNext should be (false) i.closeCalled should be (1) } test("close is called once for non-empty iterations") { val i = new StubIterator(Buffer(1, 2)) i.next should be (1) i.next should be (2) // close isn't called until we check for the next element i.closeCalled should be (0) i.hasNext should be (false) i.closeCalled should be (1) i.hasNext should be (false) i.closeCalled should be (1) } class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] { var closeCalled = 0 override def getNext(): Int = { if (ints.size == 0) { finished = true 0 } else { ints.remove(0) } } override def close() { closeCalled += 1 } } }
Example 90
Source File: PrimitiveVectorSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.util.collection import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator class PrimitiveVectorSuite extends SparkFunSuite { test("primitive value") { val vector = new PrimitiveVector[Int] for (i <- 0 until 1000) { vector += i assert(vector(i) === i) } assert(vector.size === 1000) assert(vector.size == vector.length) intercept[IllegalArgumentException] { vector(1000) } for (i <- 0 until 1000) { assert(vector(i) == i) } } test("non-primitive value") { val vector = new PrimitiveVector[String] for (i <- 0 until 1000) { vector += i.toString assert(vector(i) === i.toString) } assert(vector.size === 1000) assert(vector.size == vector.length) intercept[IllegalArgumentException] { vector(1000) } for (i <- 0 until 1000) { assert(vector(i) == i.toString) } } test("ideal growth") { val vector = new PrimitiveVector[Long](initialSize = 1) vector += 1 for (i <- 1 until 1024) { vector += i assert(vector.size === i + 1) assert(vector.capacity === Integer.highestOneBit(i) * 2) } assert(vector.capacity === 1024) vector += 1024 assert(vector.capacity === 2048) } test("ideal size") { val vector = new PrimitiveVector[Long](8192) for (i <- 0 until 8192) { vector += i } assert(vector.size === 8192) assert(vector.capacity === 8192) val actualSize = SizeEstimator.estimate(vector) val expectedSize = 8192 * 8 // Make sure we are not allocating a significant amount of memory beyond our expected. // Due to specialization wonkiness, we need to ensure we don't have 2 copies of the array. assert(actualSize < expectedSize * 1.1) } test("resizing") { val vector = new PrimitiveVector[Long] for (i <- 0 until 4097) { vector += i } assert(vector.size === 4097) assert(vector.capacity === 8192) vector.trim() assert(vector.size === 4097) assert(vector.capacity === 4097) vector.resize(5000) assert(vector.size === 4097) assert(vector.capacity === 5000) vector.resize(4000) assert(vector.size === 4000) assert(vector.capacity === 4000) vector.resize(5000) assert(vector.size === 4000) assert(vector.capacity === 5000) for (i <- 0 until 4000) { assert(vector(i) == i) } intercept[IllegalArgumentException] { vector(4000) } } }
Example 91
Source File: CompactBufferSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.util.collection import org.apache.spark.SparkFunSuite class CompactBufferSuite extends SparkFunSuite { test("empty buffer") { val b = new CompactBuffer[Int] assert(b.size === 0) assert(b.iterator.toList === Nil) assert(b.size === 0) assert(b.iterator.toList === Nil) intercept[IndexOutOfBoundsException] { b(0) } intercept[IndexOutOfBoundsException] { b(1) } intercept[IndexOutOfBoundsException] { b(2) } intercept[IndexOutOfBoundsException] { b(-1) } } test("basic inserts") { val b = new CompactBuffer[Int] assert(b.size === 0) assert(b.iterator.toList === Nil) for (i <- 0 until 1000) { b += i assert(b.size === i + 1) assert(b(i) === i) } assert(b.iterator.toList === (0 until 1000).toList) assert(b.iterator.toList === (0 until 1000).toList) assert(b.size === 1000) } test("adding sequences") { val b = new CompactBuffer[Int] assert(b.size === 0) assert(b.iterator.toList === Nil) // Add some simple lists and iterators b ++= List(0) assert(b.size === 1) assert(b.iterator.toList === List(0)) b ++= Iterator(1) assert(b.size === 2) assert(b.iterator.toList === List(0, 1)) b ++= List(2) assert(b.size === 3) assert(b.iterator.toList === List(0, 1, 2)) b ++= Iterator(3, 4, 5, 6, 7, 8, 9) assert(b.size === 10) assert(b.iterator.toList === (0 until 10).toList) // Add CompactBuffers val b2 = new CompactBuffer[Int] b2 ++= 0 until 10 b ++= b2 assert(b.iterator.toList === (1 to 2).flatMap(i => 0 until 10).toList) b ++= b2 assert(b.iterator.toList === (1 to 3).flatMap(i => 0 until 10).toList) b ++= b2 assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList) // Add some small CompactBuffers as well val b3 = new CompactBuffer[Int] b ++= b3 assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList) b3 += 0 b ++= b3 assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0)) b3 += 1 b ++= b3 assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0, 0, 1)) b3 += 2 b ++= b3 assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0, 0, 1, 0, 1, 2)) } test("adding the same buffer to itself") { val b = new CompactBuffer[Int] assert(b.size === 0) assert(b.iterator.toList === Nil) b += 1 assert(b.toList === List(1)) for (j <- 1 until 8) { b ++= b assert(b.size === (1 << j)) assert(b.iterator.toList === (1 to (1 << j)).map(i => 1).toList) } } }
Example 92
Source File: PrefixComparatorsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.util.collection.unsafe.sort import com.google.common.primitives.UnsignedBytes import org.scalatest.prop.PropertyChecks import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.types.UTF8String class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { test("String prefix comparator") { def testPrefixComparison(s1: String, s2: String): Unit = { val utf8string1 = UTF8String.fromString(s1) val utf8string2 = UTF8String.fromString(s2) val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1) val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2) val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) val cmp = UnsignedBytes.lexicographicalComparator().compare( utf8string1.getBytes.take(8), utf8string2.getBytes.take(8)) assert( (prefixComparisonResult == 0 && cmp == 0) || (prefixComparisonResult < 0 && s1.compareTo(s2) < 0) || (prefixComparisonResult > 0 && s1.compareTo(s2) > 0)) } // scalastyle:off val regressionTests = Table( ("s1", "s2"), ("abc", "世界"), ("你好", "世界"), ("你好123", "你好122") ) // scalastyle:on forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) } forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } test("Binary prefix comparator") { def compareBinary(x: Array[Byte], y: Array[Byte]): Int = { for (i <- 0 until x.length; if i < y.length) { val res = x(i).compare(y(i)) if (res != 0) return res } x.length - y.length } def testPrefixComparison(x: Array[Byte], y: Array[Byte]): Unit = { val s1Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(x) val s2Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(y) val prefixComparisonResult = PrefixComparators.BINARY.compare(s1Prefix, s2Prefix) assert( (prefixComparisonResult == 0) || (prefixComparisonResult < 0 && compareBinary(x, y) < 0) || (prefixComparisonResult > 0 && compareBinary(x, y) > 0)) } // scalastyle:off val regressionTests = Table( ("s1", "s2"), ("abc", "世界"), ("你好", "世界"), ("你好123", "你好122") ) // scalastyle:on forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8")) } forAll { (s1: String, s2: String) => testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8")) } } test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) assert(nan1.isNaN) assert(nan2.isNaN) val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1) val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2) assert(nan1Prefix === nan2Prefix) val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) } }
Example 93
Source File: ResetSystemProperties.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.util import java.util.Properties import org.apache.commons.lang3.SerializationUtils import org.scalatest.{BeforeAndAfterEach, Suite} import org.apache.spark.SparkFunSuite private[spark] trait ResetSystemProperties extends BeforeAndAfterEach { this: Suite => var oldProperties: Properties = null override def beforeEach(): Unit = { // we need SerializationUtils.clone instead of `new Properties(System.getProperties()` because // the later way of creating a copy does not copy the properties but it initializes a new // Properties object with the given properties as defaults. They are not recognized at all // by standard Scala wrapper over Java Properties then. oldProperties = SerializationUtils.clone(System.getProperties) super.beforeEach() } override def afterEach(): Unit = { try { super.afterEach() } finally { System.setProperties(oldProperties) oldProperties = null } } }
Example 94
Source File: SamplingUtilsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.util.random import scala.util.Random import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} import org.apache.spark.SparkFunSuite class SamplingUtilsSuite extends SparkFunSuite { test("reservoirSampleAndCount") { val input = Seq.fill(100)(Random.nextInt()) // input size < k val (sample1, count1) = SamplingUtils.reservoirSampleAndCount(input.iterator, 150) assert(count1 === 100) assert(input === sample1.toSeq) // input size == k val (sample2, count2) = SamplingUtils.reservoirSampleAndCount(input.iterator, 100) assert(count2 === 100) assert(input === sample2.toSeq) // input size > k val (sample3, count3) = SamplingUtils.reservoirSampleAndCount(input.iterator, 10) assert(count3 === 100) assert(sample3.length === 10) } test("computeFraction") { // test that the computed fraction guarantees enough data points // in the sample with a failure rate <= 0.0001 val n = 100000 for (s <- 1 to 15) { val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) val poisson = new PoissonDistribution(frac * n) assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") } for (s <- List(20, 100, 1000)) { val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) val poisson = new PoissonDistribution(frac * n) assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") } for (s <- List(1, 10, 100, 1000)) { val frac = SamplingUtils.computeFractionForSampleSize(s, n, false) val binomial = new BinomialDistribution(n, frac) assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low") } } }
Example 95
Source File: XORShiftRandomSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.util.random import org.scalatest.Matchers import org.apache.commons.math3.stat.inference.ChiSquareTest import org.apache.spark.SparkFunSuite import org.apache.spark.util.Utils.times import scala.language.reflectiveCalls class XORShiftRandomSuite extends SparkFunSuite with Matchers { private def fixture = new { val seed = 1L val xorRand = new XORShiftRandom(seed) val hundMil = 1e8.toInt } val chiTest = new ChiSquareTest assert(chiTest.chiSquareTest(bins, 0.05) === false) } test ("XORShift with zero seed") { val random = new XORShiftRandom(0L) assert(random.nextInt() != 0) } test ("hashSeed has random bits throughout") { val totalBitCount = (0 until 10).map { seed => val hashed = XORShiftRandom.hashSeed(seed) val bitCount = java.lang.Long.bitCount(hashed) // make sure we have roughly equal numbers of 0s and 1s. Mostly just check that we // don't have all 0s or 1s in the high bits bitCount should be > 20 bitCount should be < 44 bitCount }.sum // and over all the seeds, very close to equal numbers of 0s & 1s totalBitCount should be > (32 * 10 - 30) totalBitCount should be < (32 * 10 + 30) } }
Example 96
Source File: DistributionSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.util import org.scalatest.Matchers import org.apache.spark.SparkFunSuite class DistributionSuite extends SparkFunSuite with Matchers { test("summary") { val d = new Distribution((1 to 100).toArray.map{_.toDouble}) val stats = d.statCounter stats.count should be (100) stats.mean should be (50.5) stats.sum should be (50 * 101) val quantiles = d.getQuantiles() quantiles(0) should be (1) quantiles(1) should be (26) quantiles(2) should be (51) quantiles(3) should be (76) quantiles(4) should be (100) } }
Example 97
Source File: VectorSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.util import scala.util.Random import org.apache.spark.SparkFunSuite @deprecated("suppress compile time deprecation warning", "1.0.0") class VectorSuite extends SparkFunSuite { def verifyVector(vector: Vector, expectedLength: Int): Unit = { assert(vector.length == expectedLength) assert(vector.elements.min > 0.0) assert(vector.elements.max < 1.0) } test("random with default random number generator") { val vector100 = Vector.random(100) verifyVector(vector100, 100) } test("random with given random number generator") { val vector100 = Vector.random(100, new Random(100)) verifyVector(vector100, 100) } }
Example 98
Source File: ByteArrayChunkOutputStreamSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.util.io import scala.util.Random import org.apache.spark.SparkFunSuite class ByteArrayChunkOutputStreamSuite extends SparkFunSuite { test("empty output") { val o = new ByteArrayChunkOutputStream(1024) assert(o.toArrays.length === 0) } test("write a single byte") { val o = new ByteArrayChunkOutputStream(1024) o.write(10) assert(o.toArrays.length === 1) assert(o.toArrays.head.toSeq === Seq(10.toByte)) } test("write a single near boundary") { val o = new ByteArrayChunkOutputStream(10) o.write(new Array[Byte](9)) o.write(99) assert(o.toArrays.length === 1) assert(o.toArrays.head(9) === 99.toByte) } test("write a single at boundary") { val o = new ByteArrayChunkOutputStream(10) o.write(new Array[Byte](10)) o.write(99) assert(o.toArrays.length === 2) assert(o.toArrays(1).length === 1) assert(o.toArrays(1)(0) === 99.toByte) } test("single chunk output") { val ref = new Array[Byte](8) Random.nextBytes(ref) val o = new ByteArrayChunkOutputStream(10) o.write(ref) val arrays = o.toArrays assert(arrays.length === 1) assert(arrays.head.length === ref.length) assert(arrays.head.toSeq === ref.toSeq) } test("single chunk output at boundary size") { val ref = new Array[Byte](10) Random.nextBytes(ref) val o = new ByteArrayChunkOutputStream(10) o.write(ref) val arrays = o.toArrays assert(arrays.length === 1) assert(arrays.head.length === ref.length) assert(arrays.head.toSeq === ref.toSeq) } test("multiple chunk output") { val ref = new Array[Byte](26) Random.nextBytes(ref) val o = new ByteArrayChunkOutputStream(10) o.write(ref) val arrays = o.toArrays assert(arrays.length === 3) assert(arrays(0).length === 10) assert(arrays(1).length === 10) assert(arrays(2).length === 6) assert(arrays(0).toSeq === ref.slice(0, 10)) assert(arrays(1).toSeq === ref.slice(10, 20)) assert(arrays(2).toSeq === ref.slice(20, 26)) } test("multiple chunk output at boundary size") { val ref = new Array[Byte](30) Random.nextBytes(ref) val o = new ByteArrayChunkOutputStream(10) o.write(ref) val arrays = o.toArrays assert(arrays.length === 3) assert(arrays(0).length === 10) assert(arrays(1).length === 10) assert(arrays(2).length === 10) assert(arrays(0).toSeq === ref.slice(0, 10)) assert(arrays(1).toSeq === ref.slice(10, 20)) assert(arrays(2).toSeq === ref.slice(20, 30)) } }
Example 99
Source File: DiskBlockManagerSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.storage import java.io.{File, FileWriter} import scala.language.reflectiveCalls import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) private var rootDir0: File = _ private var rootDir1: File = _ private var rootDirs: String = _ val blockManager = mock(classOf[BlockManager]) when(blockManager.conf).thenReturn(testConf) var diskBlockManager: DiskBlockManager = _ override def beforeAll() { super.beforeAll() rootDir0 = Utils.createTempDir() rootDir1 = Utils.createTempDir() rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath } override def afterAll() { super.afterAll() Utils.deleteRecursively(rootDir0) Utils.deleteRecursively(rootDir1) } override def beforeEach() { val conf = testConf.clone conf.set("spark.local.dir", rootDirs) diskBlockManager = new DiskBlockManager(blockManager, conf) } override def afterEach() { diskBlockManager.stop() } test("basic block creation") { val blockId = new TestBlockId("test") val newFile = diskBlockManager.getFile(blockId) writeToFile(newFile, 10) assert(diskBlockManager.containsBlock(blockId)) newFile.delete() assert(!diskBlockManager.containsBlock(blockId)) } test("enumerating blocks") { val ids = (1 to 100).map(i => TestBlockId("test_" + i)) val files = ids.map(id => diskBlockManager.getFile(id)) files.foreach(file => writeToFile(file, 10)) assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } def writeToFile(file: File, numBytes: Int) { val writer = new FileWriter(file, true) for (i <- 0 until numBytes) writer.write(i) writer.close() } }
Example 100
Source File: BlockIdSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.storage import org.apache.spark.SparkFunSuite class BlockIdSuite extends SparkFunSuite { def assertSame(id1: BlockId, id2: BlockId) { assert(id1.name === id2.name) assert(id1.hashCode === id2.hashCode) assert(id1 === id2) } def assertDifferent(id1: BlockId, id2: BlockId) { assert(id1.name != id2.name) assert(id1.hashCode != id2.hashCode) assert(id1 != id2) } test("test-bad-deserialization") { try { // Try to deserialize an invalid block id. BlockId("myblock") fail() } catch { case e: IllegalStateException => // OK case _: Throwable => fail() } } test("rdd") { val id = RDDBlockId(1, 2) assertSame(id, RDDBlockId(1, 2)) assertDifferent(id, RDDBlockId(1, 1)) assert(id.name === "rdd_1_2") assert(id.asRDDId.get.rddId === 1) assert(id.asRDDId.get.splitIndex === 2) assert(id.isRDD) assertSame(id, BlockId(id.toString)) } test("shuffle") { val id = ShuffleBlockId(1, 2, 3) assertSame(id, ShuffleBlockId(1, 2, 3)) assertDifferent(id, ShuffleBlockId(3, 2, 3)) assert(id.name === "shuffle_1_2_3") assert(id.asRDDId === None) assert(id.shuffleId === 1) assert(id.mapId === 2) assert(id.reduceId === 3) assert(id.isShuffle) assertSame(id, BlockId(id.toString)) } test("broadcast") { val id = BroadcastBlockId(42) assertSame(id, BroadcastBlockId(42)) assertDifferent(id, BroadcastBlockId(123)) assert(id.name === "broadcast_42") assert(id.asRDDId === None) assert(id.broadcastId === 42) assert(id.isBroadcast) assertSame(id, BlockId(id.toString)) } test("taskresult") { val id = TaskResultBlockId(60) assertSame(id, TaskResultBlockId(60)) assertDifferent(id, TaskResultBlockId(61)) assert(id.name === "taskresult_60") assert(id.asRDDId === None) assert(id.taskId === 60) assert(!id.isRDD) assertSame(id, BlockId(id.toString)) } test("stream") { val id = StreamBlockId(1, 100) assertSame(id, StreamBlockId(1, 100)) assertDifferent(id, StreamBlockId(2, 101)) assert(id.name === "input-1-100") assert(id.asRDDId === None) assert(id.streamId === 1) assert(id.uniqueId === 100) assert(!id.isBroadcast) assertSame(id, BlockId(id.toString)) } test("test") { val id = TestBlockId("abc") assertSame(id, TestBlockId("abc")) assertDifferent(id, TestBlockId("ab")) assert(id.name === "test_abc") assert(id.asRDDId === None) assert(id.id === "abc") assert(!id.isShuffle) assertSame(id, BlockId(id.toString)) } }
Example 101
Source File: LocalDirsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.storage import java.io.File import org.apache.spark.util.Utils import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.SparkConfWithEnv class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { before { Utils.clearLocalRootDirs() } test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") { // Regression test for SPARK-2974 assert(!new File("/NONEXISTENT_DIR").exists()) val conf = new SparkConf(false) .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}") assert(new File(Utils.getLocalDir(conf)).exists()) } test("SPARK_LOCAL_DIRS override also affects driver") { // Regression test for SPARK-2975 assert(!new File("/NONEXISTENT_DIR").exists()) // spark.local.dir only contains invalid directories, but that's not a problem since // SPARK_LOCAL_DIRS will override it on both the driver and workers: val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir"))) .set("spark.local.dir", "/NONEXISTENT_PATH") assert(new File(Utils.getLocalDir(conf)).exists()) } }
Example 102
Source File: PartitionwiseSampledRDDSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.rdd import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler} class MockSampler extends RandomSampler[Long, Long] { private var s: Long = _ override def setSeed(seed: Long) { s = seed } override def sample(items: Iterator[Long]): Iterator[Long] = { Iterator(s) } override def clone: MockSampler = new MockSampler } class PartitionwiseSampledRDDSuite extends SparkFunSuite with SharedSparkContext { test("seed distribution") { val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2) val sampler = new MockSampler val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, false, 0L) assert(sample.distinct().count == 2, "Seeds must be different.") } test("concurrency") { // SPARK-2251: zip with self computes each partition twice. // We want to make sure there are no concurrency issues. val rdd = sc.parallelize(0 until 111, 10) for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) { val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler, true) sampled.zip(sampled).count() } } }
Example 103
Source File: PartitionPruningRDDSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.rdd import org.apache.spark.{Partition, SharedSparkContext, SparkFunSuite, TaskContext} class PartitionPruningRDDSuite extends SparkFunSuite with SharedSparkContext { test("Pruned Partitions inherit locality prefs correctly") { val rdd = new RDD[Int](sc, Nil) { override protected def getPartitions = { Array[Partition]( new TestPartition(0, 1), new TestPartition(1, 1), new TestPartition(2, 1)) } def compute(split: Partition, context: TaskContext) = { Iterator() } } val prunedRDD = PartitionPruningRDD.create(rdd, _ == 2) assert(prunedRDD.partitions.length == 1) val p = prunedRDD.partitions(0) assert(p.index == 0) assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2) } test("Pruned Partitions can be unioned ") { val rdd = new RDD[Int](sc, Nil) { override protected def getPartitions = { Array[Partition]( new TestPartition(0, 4), new TestPartition(1, 5), new TestPartition(2, 6)) } def compute(split: Partition, context: TaskContext) = { List(split.asInstanceOf[TestPartition].testValue).iterator } } val prunedRDD1 = PartitionPruningRDD.create(rdd, _ == 0) val prunedRDD2 = PartitionPruningRDD.create(rdd, _ == 2) val merged = prunedRDD1 ++ prunedRDD2 assert(merged.count() == 2) val take = merged.take(2) assert(take.apply(0) == 4) assert(take.apply(1) == 6) } } class TestPartition(i: Int, value: Int) extends Partition with Serializable { def index: Int = i def testValue: Int = this.value }
Example 104
Source File: JdbcRDDSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.rdd import java.sql._ import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.util.Utils class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { before { Utils.classForName("org.apache.derby.jdbc.EmbeddedDriver") val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") try { try { val create = conn.createStatement create.execute(""" CREATE TABLE FOO( ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), DATA INTEGER )""") create.close() val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") (1 to 100).foreach { i => insert.setInt(1, i * 2) insert.executeUpdate } insert.close() } catch { case e: SQLException if e.getSQLState == "X0Y32" => // table exists } try { val create = conn.createStatement create.execute("CREATE TABLE BIGINT_TEST(ID BIGINT NOT NULL, DATA INTEGER)") create.close() val insert = conn.prepareStatement("INSERT INTO BIGINT_TEST VALUES(?,?)") (1 to 100).foreach { i => insert.setLong(1, 100000000000000000L + 4000000000000000L * i) insert.setInt(2, i) insert.executeUpdate } insert.close() } catch { case e: SQLException if e.getSQLState == "X0Y32" => // table exists } } finally { conn.close() } } test("basic functionality") { sc = new SparkContext("local", "test") val rdd = new JdbcRDD( sc, () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", 1, 100, 3, (r: ResultSet) => { r.getInt(1) } ).cache() assert(rdd.count === 100) assert(rdd.reduce(_ + _) === 10100) } test("large id overflow") { sc = new SparkContext("local", "test") val rdd = new JdbcRDD( sc, () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, "SELECT DATA FROM BIGINT_TEST WHERE ? <= ID AND ID <= ?", 1131544775L, 567279358897692673L, 20, (r: ResultSet) => { r.getInt(1) } ).cache() assert(rdd.count === 100) assert(rdd.reduce(_ + _) === 5050) } after { try { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;shutdown=true") } catch { case se: SQLException if se.getSQLState == "08006" => // Normal single database shutdown // https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html } } }
Example 105
Source File: ZippedPartitionsSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark.rdd import org.apache.spark.{SharedSparkContext, SparkFunSuite} object ZippedPartitionsSuite { def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { Iterator(i.toArray.size, s.toArray.size, d.toArray.size) } } class ZippedPartitionsSuite extends SparkFunSuite with SharedSparkContext { test("print sizes") { val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) val data3 = sc.makeRDD(Array(1.0, 2.0), 2) val zippedRDD = data1.zipPartitions(data2, data3)(ZippedPartitionsSuite.procZippedData) val obtainedSizes = zippedRDD.collect() val expectedSizes = Array(2, 3, 1, 2, 3, 1) assert(obtainedSizes.size == 6) assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2)) } }
Example 106
Source File: HiveTestTrait.scala From cloud-integration with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.sources import java.io.File import com.cloudera.spark.cloud.ObjectStoreConfigurations import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.{SparkSession, SQLContext, SQLImplicits} import org.apache.spark.sql.hive.test.TestHiveContext import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils trait HiveTestTrait extends SparkFunSuite with BeforeAndAfterAll { // override protected val enableAutoThreadAudit = false protected var hiveContext: HiveInstanceForTests = _ protected var spark: SparkSession = _ protected override def beforeAll(): Unit = { super.beforeAll() // set up spark and hive context hiveContext = new HiveInstanceForTests() spark = hiveContext.sparkSession } protected override def afterAll(): Unit = { try { SparkSession.clearActiveSession() if (hiveContext != null) { hiveContext.reset() hiveContext = null } if (spark != null) { spark.close() spark = null } } finally { super.afterAll() } } } class HiveInstanceForTests extends TestHiveContext( new SparkContext( System.getProperty("spark.sql.test.master", "local[1]"), "TestSQLContext", new SparkConf() .setAll(ObjectStoreConfigurations.RW_TEST_OPTIONS) .set("spark.sql.warehouse.dir", TestSetup.makeWarehouseDir().toURI.getPath) ) ) { } object TestSetup { def makeWarehouseDir(): File = { val warehouseDir = Utils.createTempDir(namePrefix = "warehouse") warehouseDir.delete() warehouseDir } }
Example 107
Source File: PulsarConfigUpdaterSuite.scala From pulsar-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.pulsar import org.apache.spark.SparkFunSuite import org.scalatest.BeforeAndAfterEach class PulsarConfigUpdaterSuite extends SparkFunSuite with BeforeAndAfterEach { private val testModule = "testModule" private val testKey = "testKey" private val testValue = "testValue" private val otherTestValue = "otherTestValue" test("set should always set value") { val params = Map.empty[String, String] val updatedParams = PulsarConfigUpdater(testModule, params) .set(testKey, testValue) .build() assert(updatedParams.size() === 1) assert(updatedParams.get(testKey) === testValue) } test("setIfUnset without existing key should set value") { val params = Map.empty[String, String] val updatedParams = PulsarConfigUpdater(testModule, params) .setIfUnset(testKey, testValue) .build() assert(updatedParams.size() === 1) assert(updatedParams.get(testKey) === testValue) } test("setIfUnset with existing key should not set value") { val params = Map[String, String](testKey -> testValue) val updatedParams = PulsarConfigUpdater(testModule, params) .setIfUnset(testKey, otherTestValue) .build() assert(updatedParams.size() === 1) assert(updatedParams.get(testKey) === testValue) } }
Example 108
Source File: MergeClauseSuite.scala From spark-acid with Apache License 2.0 | 5 votes |
package com.qubole.spark.hiveacid.merge import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{AnalysisException, functions} class MergeClauseSuite extends SparkFunSuite { def insertClause(addCondition : Boolean = true): MergeWhenNotInsert = { if (addCondition) { MergeWhenNotInsert(Some(functions.expr("x > 2").expr), Seq(functions.col("x").expr, functions.col("y").expr)) } else { MergeWhenNotInsert(None, Seq(functions.col("x").expr, functions.col("y").expr)) } } def updateClause(addCondition : Boolean = true): MergeWhenUpdateClause = { if (addCondition) { val updateCondition = Some(functions.expr("a > 2").expr) MergeWhenUpdateClause(updateCondition, Map("b" -> functions.lit(3).expr), isStar = false) } else { MergeWhenUpdateClause(None, Map("b" -> functions.lit(3).expr), isStar = false) } } def deleteClause(addCondition : Boolean = true): MergeWhenDelete = { if (addCondition) { MergeWhenDelete(Some(functions.expr("a < 1").expr)) } else { MergeWhenDelete(None) } } test("Validate MergeClauses") { val clauses = Seq(insertClause(), updateClause(), deleteClause()) MergeWhenClause.validate(clauses) } test("Invalid MergeClause cases") { val invalidMerge = "MERGE Validation Error: " //empty clauses checkInvalidMergeClause(invalidMerge + MergeWhenClause.atleastOneClauseError, Seq()) // multi update or insert clauses val multiUpdateClauses = Seq(updateClause(), updateClause(), insertClause()) checkInvalidMergeClause(invalidMerge + MergeWhenClause.justOneClausePerTypeError, multiUpdateClauses) // multi match clauses with first clause without condition val invalidMultiMatch = Seq(updateClause(false), deleteClause()) checkInvalidMergeClause(invalidMerge + MergeWhenClause.matchClauseConditionError, invalidMultiMatch) // invalid Update Clause val invalidUpdateClause = MergeWhenUpdateClause(None, Map(), isStar = false) val thrown = intercept[IllegalArgumentException] { MergeWhenClause.validate(Seq(invalidUpdateClause)) } assert(thrown.getMessage === "UPDATE Clause in MERGE should have one or more SET Values") } private def checkInvalidMergeClause(invalidMessage: String, multiUpdateClauses: Seq[MergeWhenClause]) = { val thrown = intercept[AnalysisException] { MergeWhenClause.validate(multiUpdateClauses) } assert(thrown.message === invalidMessage) } }
Example 109
Source File: ClickThroughRatePrediction.scala From click-through-rate-prediction with Apache License 2.0 | 5 votes |
package org.apache.spark.examples.kaggle import org.apache.spark.SparkFunSuite import org.apache.spark.util.MLlibTestSparkContext class ClickThroughRatePredictionSuite extends SparkFunSuite with MLlibTestSparkContext { test("run") { // Logger.getLogger("org").setLevel(Level.OFF) // Logger.getLogger("akka").setLevel(Level.OFF) val trainPath = this.getClass.getResource("/train.part-10000").getPath val testPath = this.getClass.getResource("/test.part-10000").getPath val resultPath = "./tmp/result/" ClickThroughRatePrediction.run(sc, sqlContext, trainPath, testPath, resultPath) } }
Example 110
Source File: LabeledPointSuite.scala From sona with Apache License 2.0 | 5 votes |
package com.tencent.angel.sona.ml.feature import org.apache.spark.linalg.Vectors import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer class LabeledPointSuite extends SparkFunSuite { test("Kryo class register") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.registerKryoClasses( Array(classOf[scala.collection.mutable.WrappedArray.ofRef[_]], classOf[LabeledPoint])) // conf.set("spark.kryo.registrationRequired", "true") val ser = new KryoSerializer(conf).newInstance() val labeled1 = LabeledPoint(1.0, Vectors.dense(Array(1.0, 2.0))) val labeled2 = LabeledPoint(1.0, Vectors.sparse(10, Array(5, 7), Array(1.0, 2.0))) Seq(labeled1, labeled2).foreach { l => val l2 = ser.deserialize[LabeledPoint](ser.serialize(l)) assert(l === l2) } } }
Example 111
Source File: InstanceSuite.scala From sona with Apache License 2.0 | 5 votes |
package com.tencent.angel.sona.ml.feature import org.apache.spark.linalg.Vectors import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer class InstanceSuite extends SparkFunSuite{ test("Kryo class register") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.registerKryoClasses( Array(classOf[scala.collection.mutable.WrappedArray.ofRef[_]], classOf[Instance])) // conf.set("spark.kryo.registrationRequired", "true") val ser = new KryoSerializer(conf).newInstance() val instance1 = Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)) val instance2 = Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse) Seq(instance1, instance2).foreach { i => val i2 = ser.deserialize[Instance](ser.serialize(i)) assert(i === i2) } val oInstance1 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)) val oInstance2 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0).toSparse) Seq(oInstance1, oInstance2).foreach { o => val o2 = ser.deserialize[OffsetInstance](ser.serialize(o)) assert(o === o2) } } }
Example 112
Source File: LibFFMRelationSuite.scala From sona with Apache License 2.0 | 5 votes |
package com.tencent.angel.sona.ml.source.libffm import java.io.File import java.nio.charset.StandardCharsets import com.google.common.io.Files import org.apache.spark.SparkFunSuite import com.tencent.angel.sona.ml.util.MLlibTestSparkContext import org.apache.spark.util.SparkUtil class LibFFMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { // Path for dataset var path: String = _ override def beforeAll(): Unit = { super.beforeAll() val lines0 = """ |1 0:1:1.0 1:3:2.0 2:5:3.0 |0 """.stripMargin val lines1 = """ |0 0:2:4.0 1:4:5.0 2:6:6.0 """.stripMargin val dir = SparkUtil.createTempDir() val succ = new File(dir, "_SUCCESS") val file0 = new File(dir, "part-00000") val file1 = new File(dir, "part-00001") Files.write("", succ, StandardCharsets.UTF_8) Files.write(lines0, file0, StandardCharsets.UTF_8) Files.write(lines1, file1, StandardCharsets.UTF_8) path = dir.getPath } override def afterAll(): Unit = { try { val prefix = "C:\\Users\\fitzwang\\AppData\\Local\\Temp\\" if (path.startsWith(prefix)) { SparkUtil.deleteRecursively(new File(path)) } } finally { super.afterAll() } } test("ffmIO"){ val df = spark.read.format("libffm").load(path) val metadata = df.schema(1).metadata val fieldSet = MetaSummary.getFieldSet(metadata) println(fieldSet.mkString("[", ",", "]")) val keyFieldMap = MetaSummary.getKeyFieldMap(metadata) println(keyFieldMap.mkString("[", ",", "]")) df.write.format("libffm").save("temp.libffm") } test("read_ffm"){ val df = spark.read.format("libffm").load(path) val metadata = df.schema(1).metadata val fieldSet = MetaSummary.getFieldSet(metadata) println(fieldSet.mkString("[", ",", "]")) val keyFieldMap = MetaSummary.getKeyFieldMap(metadata) println(keyFieldMap.mkString("[", ",", "]")) } }
Example 113
Source File: AngelTestUtils.scala From sona with Apache License 2.0 | 5 votes |
package com.tencent.angel.sona.ml.util import com.tencent.angel.sona.core.DriverContext import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.sql.{DataFrameReader, SparkSession} class AngelTestUtils extends SparkFunSuite { protected var spark: SparkSession = _ protected var libsvm: DataFrameReader = _ protected var dummy: DataFrameReader = _ protected var sparkConf: SparkConf = _ protected var driverCtx: DriverContext = _ protected override def beforeAll(): Unit = { super.beforeAll() spark = SparkSession.builder() .master("local[2]") .appName("AngelClassification") .getOrCreate() libsvm = spark.read.format("libsvmex") dummy = spark.read.format("dummy") sparkConf = spark.sparkContext.getConf driverCtx = DriverContext.get(sparkConf) driverCtx.startAngelAndPSAgent() } protected override def afterAll(): Unit = { super.afterAll() driverCtx.stopAngelAndPSAgent() } }
Example 114
Source File: PgWireProtocolSuite.scala From spark-sql-server with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.server.service.postgresql.protocol.v3 import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import java.sql.SQLException import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String class PgWireProtocolSuite extends SparkFunSuite { val conf = new SQLConf() test("DataRow") { val v3Protocol = new PgWireProtocol(65536) val row = new GenericInternalRow(2) row.update(0, 8) row.update(1, UTF8String.fromString("abcdefghij")) val schema = StructType.fromDDL("a INT, b STRING") val rowConverters = PgRowConverters(conf, schema, Seq(true, false)) val data = v3Protocol.DataRow(row, rowConverters) val bytes = ByteBuffer.wrap(data) assert(bytes.get() === 'D'.toByte) assert(bytes.getInt === 28) assert(bytes.getShort === 2) assert(bytes.getInt === 4) assert(bytes.getInt === 8) assert(bytes.getInt === 10) assert(data.slice(19, 30) === "abcdefghij".getBytes(StandardCharsets.UTF_8)) } test("Fails when message buffer overflowed") { val v3Protocol = new PgWireProtocol(4) val row = new GenericInternalRow(1) row.update(0, UTF8String.fromString("abcdefghijk")) val schema = StructType.fromDDL("a STRING") val rowConverters = PgRowConverters(conf, schema, Seq(false)) val errMsg = intercept[SQLException] { v3Protocol.DataRow(row, rowConverters) }.getMessage assert(errMsg.contains( "Cannot generate a V3 protocol message because buffer is not enough for the message. " + "To avoid this exception, you might set higher value at " + "'spark.sql.server.messageBufferSizeInBytes'") ) } }
Example 115
Source File: ExtensionBuilderSuite.scala From spark-sql-server with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.server import java.net.URL import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.util.Utils class ExtensionBuilderSuite extends SparkFunSuite with BeforeAndAfterAll { var _sqlContext: SQLContext = _ // TODO: This method works only in Java8 private def addJarInClassPath(jarURLString: String): Unit = { // val cl = ClassLoader.getSystemClassLoader val cl = Utils.getSparkClassLoader val clazz = cl.getClass val method = clazz.getSuperclass.getDeclaredMethod("addURL", Seq(classOf[URL]): _*) method.setAccessible(true) method.invoke(cl, Seq[Object](new URL(jarURLString)): _*) } protected override def beforeAll(): Unit = { super.beforeAll() // Adds a jar for an extension builder val jarPath = "src/test/resources/extensions_2.12_3.0.0-preview2_0.1.7-spark3.0-SNAPSHOT.jar" val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" // sqlContext.sparkContext.addJar(jarURL) addJarInClassPath(jarURL) val conf = new SparkConf(loadDefaults = true) .setMaster("local[1]") .setAppName("spark-sql-server-test") .set("spark.sql.server.extensions.builder", "org.apache.spark.ExtensionBuilderExample") _sqlContext = SQLServerEnv.newSQLContext(conf) } protected override def afterAll(): Unit = { try { super.afterAll() } finally { try { if (_sqlContext != null) { _sqlContext.sparkContext.stop() _sqlContext = null } } finally { SparkSession.clearActiveSession() SparkSession.clearDefaultSession() } } } test("user-defined optimizer rules") { val rules = Seq("org.apache.spark.catalyst.EmptyRule1", "org.apache.spark.catalyst.EmptyRule2") val optimizerRuleNames = _sqlContext.sessionState.optimizer .extendedOperatorOptimizationRules.map(_.ruleName) rules.foreach { expectedRuleName => assert(optimizerRuleNames.contains(expectedRuleName)) } } }
Example 116
Source File: NormalizerSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var data: Array[Vector] = _ @transient var dataFrame: DataFrame = _ @transient var normalizer: Normalizer = _ @transient var l1Normalized: Array[Vector] = _ @transient var l2Normalized: Array[Vector] = _ override def beforeAll(): Unit = { super.beforeAll() data = Array( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), Vectors.dense(0.0, 0.0, 0.0), Vectors.dense(0.6, -1.1, -3.0), Vectors.sparse(3, Seq((1, 0.91), (2, 3.2))), Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))), Vectors.sparse(3, Seq()) ) l1Normalized = Array( Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))), Vectors.dense(0.0, 0.0, 0.0), Vectors.dense(0.12765957, -0.23404255, -0.63829787), Vectors.sparse(3, Seq((1, 0.22141119), (2, 0.7785888))), Vectors.dense(0.625, 0.07894737, 0.29605263), Vectors.sparse(3, Seq()) ) l2Normalized = Array( Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))), Vectors.dense(0.0, 0.0, 0.0), Vectors.dense(0.184549876, -0.3383414, -0.922749378), Vectors.sparse(3, Seq((1, 0.27352993), (2, 0.96186349))), Vectors.dense(0.897906166, 0.113419726, 0.42532397), Vectors.sparse(3, Seq()) ) val sqlContext = new SQLContext(sc) dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized_features") } //收集的结果 def collectResult(result: DataFrame): Array[Vector] = { result.select("normalized_features").collect().map { case Row(features: Vector) => features } } //向量的断言类型 def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = { assert((lhs, rhs).zipped.forall { case (v1: DenseVector, v2: DenseVector) => true case (v1: SparseVector, v2: SparseVector) => true case _ => false }, "The vector type should be preserved after normalization.") } //断言值 def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = { assert((lhs, rhs).zipped.forall { (vector1, vector2) => vector1 ~== vector2 absTol 1E-5 }, "The vector value is not correct after normalization.") } test("Normalization with default parameter") {//默认参数的正常化 //transform()方法将DataFrame转化为另外一个DataFrame的算法 normalizer.transform(dataFrame).show() val result = collectResult(normalizer.transform(dataFrame)) assertTypeOfVector(data, result) assertValues(result, l2Normalized) } test("Normalization with setter") {//规范化设置 normalizer.setP(1) //transform()方法将DataFrame转化为另外一个DataFrame的算法 normalizer.transform(dataFrame).show() val result = collectResult(normalizer.transform(dataFrame)) assertTypeOfVector(data, result) assertValues(result, l1Normalized) } } private object NormalizerSuite { case class FeatureData(features: Vector) }
Example 117
Source File: AttributeGroupSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.attribute import org.apache.spark.SparkFunSuite class AttributeGroupSuite extends SparkFunSuite { test("attribute group") {//属性分组 val attrs = Array( NumericAttribute.defaultAttr, NominalAttribute.defaultAttr, BinaryAttribute.defaultAttr.withIndex(0), NumericAttribute.defaultAttr.withName("age").withSparsity(0.8), NominalAttribute.defaultAttr.withName("size").withValues("small", "medium", "large"), BinaryAttribute.defaultAttr.withName("clicked").withValues("no", "yes"), NumericAttribute.defaultAttr, NumericAttribute.defaultAttr) val group = new AttributeGroup("user", attrs) assert(group.size === 8) assert(group.name === "user") assert(group(0) === NumericAttribute.defaultAttr.withIndex(0)) assert(group(2) === BinaryAttribute.defaultAttr.withIndex(2)) assert(group.indexOf("age") === 3) assert(group.indexOf("size") === 4) assert(group.indexOf("clicked") === 5) assert(!group.hasAttr("abc")) intercept[NoSuchElementException] { group("abc") } assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name)) assert(group === AttributeGroup.fromStructField(group.toStructField())) } test("attribute group without attributes") {//没有属性的属性组 val group0 = new AttributeGroup("user", 10) assert(group0.name === "user") assert(group0.numAttributes === Some(10)) assert(group0.size === 10) assert(group0.attributes.isEmpty) assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name)) assert(group0 === AttributeGroup.fromStructField(group0.toStructField())) val group1 = new AttributeGroup("item") assert(group1.name === "item") assert(group1.numAttributes.isEmpty) assert(group1.attributes.isEmpty) assert(group1.size === -1) } }
Example 118
Source File: RegressionEvaluatorSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new RegressionEvaluator) } test("Regression Evaluator: default params") {//评估回归:默认参数 val trainer = new LinearRegression //fit()方法将DataFrame转化为一个Transformer的算法 val model = trainer.fit(dataset) //转换 //Prediction 预测 //transform()方法将DataFrame转化为另外一个DataFrame的算法 val predictions = model.transform(dataset) predictions.collect() // default = rmse //默认rmse均方根误差说明样本的离散程度 val evaluator = new RegressionEvaluator() println("==MetricName="+evaluator.getMetricName+"=LabelCol="+evaluator.getLabelCol+"=PredictionCol="+evaluator.getPredictionCol) //==MetricName=rmse=LabelCol=label=PredictionCol=prediction,默认rmse均方根误差说明样本的离散程度 assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001) // r2 score 评分 //R2平方系统也称判定系数,用来评估模型拟合数据的好坏 evaluator.setMetricName("r2") assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001) //MAE平均绝对误差是所有单个观测值与算术平均值的偏差的绝对值的平均 evaluator.setMetricName("mae") assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) } }
Example 119
Source File: ParamGridBuilderSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.tuning import scala.collection.mutable import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.{ParamMap, TestParams} class ParamGridBuilderSuite extends SparkFunSuite { val solver = new TestParams() import solver.{inputCol, maxIter} test("param grid builder") {//参数网格生成器 def validateGrid(maps: Array[ParamMap], expected: mutable.Set[(Int, String)]): Unit = { assert(maps.size === expected.size) maps.foreach { m =>//m:ParamMap类型 //(10,input0)(10,input1) val tuple = (m(maxIter), m(inputCol)) assert(expected.contains(tuple)) expected.remove(tuple) } assert(expected.isEmpty) } //通过addGrid添加我们需要寻找的最佳参数 //ParamGridBuilder构建待选参数(如:logistic regression的regParam) val maps0 = new ParamGridBuilder() .baseOn(maxIter -> 10) .addGrid(inputCol, Array("input0", "input1")) .build() //期望值 val expected0 = mutable.Set( (10, "input0"), (10, "input1")) validateGrid(maps0, expected0) val maps1 = new ParamGridBuilder() .baseOn(ParamMap(maxIter -> 5, inputCol -> "input")) // will be overwritten 将被覆盖 .addGrid(maxIter, Array(10, 20))//重载 .addGrid(inputCol, Array("input0", "input1")) .build() val expected1 = mutable.Set( (10, "input0"), (20, "input0"), (10, "input1"), (20, "input1")) validateGrid(maps1, expected1) } }
Example 120
Source File: ProbabilisticClassifierSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} final class TestProbabilisticClassificationModel( override val uid: String, override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] { override def copy(extra: org.apache.spark.ml.param.ParamMap): this.type = defaultCopy(extra) override protected def predictRaw(input: Vector): Vector = { input } override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { rawPrediction } def friendlyPredict(input: Vector): Double = { predict(input) } } //概率分类器套件 class ProbabilisticClassifierSuite extends SparkFunSuite { test("test thresholding") {//测试阈值 val thresholds = Array(0.5, 0.2) //在二进制分类中设置阈值,范围为[0,1],如果类标签1的估计概率>Threshold,则预测1,否则0 val testModel = new TestProbabilisticClassificationModel("myuid", 2).setThresholds(thresholds) assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0) assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0) } test("test thresholding not required") {//测试不需要阈值 val testModel = new TestProbabilisticClassificationModel("myuid", 2) assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) } }
Example 121
Source File: ANNSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.ann import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { // TODO: test for weights comparison with Weka MLP //人工神经网络与乙状结肠学习LBFGS优化器异或函数 test("ANN with Sigmoid learns XOR function with LBFGS optimizer") { val inputs = Array( Array(0.0, 0.0), Array(0.0, 1.0), Array(1.0, 0.0), Array(1.0, 1.0) ) val outputs = Array(0.0, 1.0, 1.0, 0.0) val data = inputs.zip(outputs).map { case (features, label) => (Vectors.dense(features), Vectors.dense(label)) } val rddData = sc.parallelize(data, 1) val hiddenLayersTopology = Array(5) val dataSample = rddData.first() val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) val initialWeights = FeedForwardModel(topology, 23124).weights() val trainer = new FeedForwardTrainer(topology, 2, 1) //initialWeights初始取值,默认是0向量 trainer.setWeights(initialWeights) trainer.LBFGSOptimizer.setNumIterations(20) val model = trainer.train(rddData) val predictionAndLabels = rddData.map { case (input, label) => (model.predict(input)(0), label(0)) }.collect() predictionAndLabels.foreach { case (p, l) => assert(math.round(p) === l) } } //人工神经网络与学习两输出和批量SoftMax GD优化器异或函数 test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") { val inputs = Array( Array(0.0, 0.0), Array(0.0, 1.0), Array(1.0, 0.0), Array(1.0, 1.0) ) val outputs = Array( Array(1.0, 0.0), Array(0.0, 1.0), Array(0.0, 1.0), Array(1.0, 0.0) ) val data = inputs.zip(outputs).map { case (features, label) => (Vectors.dense(features), Vectors.dense(label)) } val rddData = sc.parallelize(data, 1) val hiddenLayersTopology = Array(5) val dataSample = rddData.first() val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) val initialWeights = FeedForwardModel(topology, 23124).weights() val trainer = new FeedForwardTrainer(topology, 2, 2) //(SGD随机梯度下降) trainer.SGDOptimizer.setNumIterations(2000) //initialWeights初始取值,默认是0向量 trainer.setWeights(initialWeights) val model = trainer.train(rddData) val predictionAndLabels = rddData.map { case (input, label) => (model.predict(input), label) }.collect() predictionAndLabels.foreach { case (p, l) => assert(p ~== l absTol 0.5) } } }
Example 122
Source File: ChiSqSelectorSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext //特征提取和转换 卡方选择(ChiSqSelector)稀疏和稠密向量 test("ChiSqSelector transform test (sparse & dense vector)") { val labeledDiscreteData = sc.parallelize(//标记的离散数据 Seq(LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), //LabeledPoint标记点是局部向量,向量可以是密集型或者稀疏型,每个向量会关联了一个标签(label) LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2) val preFilteredData =//预过滤数据 //LabeledPoint标记点是局部向量,向量可以是密集型或者稀疏型,每个向量会关联了一个标签(label) Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), LabeledPoint(1.0, Vectors.dense(Array(6.0))), LabeledPoint(1.0, Vectors.dense(Array(8.0))), LabeledPoint(2.0, Vectors.dense(Array(5.0)))) //fit()方法将DataFrame转化为一个Transformer的算法 val model = new ChiSqSelector(1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => //transform()方法将DataFrame转化为另外一个DataFrame的算法 LabeledPoint(lp.label, model.transform(lp.features)) }.collect().toSet assert(filteredData == preFilteredData) } }
Example 123
Source File: ElementwiseProductSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext { //产品应适用于数据集在一个密集的矢量 test("elementwise (hadamard) product should properly apply vector to dense data set") { val denseData = Array( Vectors.dense(1.0, 4.0, 1.9, -9.0) ) val scalingVec = Vectors.dense(2.0, 0.5, 0.0, 0.25) val transformer = new ElementwiseProduct(scalingVec) //批理变换和每个变换,得到相同的结果 //transform()方法将DataFrame转化为另外一个DataFrame的算法 val transformedData = transformer.transform(sc.makeRDD(denseData)) val transformedVecs = transformedData.collect() val transformedVec = transformedVecs(0) val expectedVec = Vectors.dense(2.0, 2.0, 0.0, -2.25) assert(transformedVec ~== expectedVec absTol 1E-5, s"Expected transformed vector $expectedVec but found $transformedVec") } //元素(Hadamard)产品应正确运用向量的稀疏数据集 test("elementwise (hadamard) product should properly apply vector to sparse data set") { val sparseData = Array( Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))) ) val dataRDD = sc.parallelize(sparseData, 3) val scalingVec = Vectors.dense(1.0, 0.0, 0.5) val transformer = new ElementwiseProduct(scalingVec) val data2 = sparseData.map(transformer.transform) //transform()方法将DataFrame转化为另外一个DataFrame的算法 val data2RDD = transformer.transform(dataRDD) assert((sparseData, data2, data2RDD.collect()).zipped.forall { case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true case _ => false }, "The vector type should be preserved after hadamard product") assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) assert(data2(0) ~== Vectors.sparse(3, Seq((1, 0.0), (2, -1.5))) absTol 1E-5) } }
Example 124
Source File: PCASuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.util.MLlibTestSparkContext class PCASuite extends SparkFunSuite with MLlibTestSparkContext { private val data = Array( Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) ) private lazy val dataRDD = sc.parallelize(data, 2) test("Correct computing use a PCA wrapper") {//正确的计算使用一个主成分分析包装 val k = dataRDD.count().toInt //fit()方法将DataFrame转化为一个Transformer的算法 val pca = new PCA(k).fit(dataRDD) //转换分布式矩阵分 val mat = new RowMatrix(dataRDD) //计算主成分析,将维度降为K val pc = mat.computePrincipalComponents(k) //PCA变换 //transform()方法将DataFrame转化为另外一个DataFrame的算法 val pca_transform = pca.transform(dataRDD).collect() //Mat _相乘 val mat_multiply = mat.multiply(pc).rows.collect() assert(pca_transform.toSet === mat_multiply.toSet) } }
Example 125
Source File: HashingTFSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext //把字符转换特征哈希值,返回词的频率 class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("hashing tf on a single doc") {//散列在一个单一的文件 val hashingTF = new HashingTF(1000) val doc = "a a b b c d".split(" ") val n = hashingTF.numFeatures //词的频率 val termFreqs = Seq( (hashingTF.indexOf("a"), 2.0), (hashingTF.indexOf("b"), 2.0), (hashingTF.indexOf("c"), 1.0), (hashingTF.indexOf("d"), 1.0)) //termFreqs: Seq[(Int, Double)] = List((97,2.0), (98,2.0), (99,1.0), (100,1.0)) assert(termFreqs.map(_._1).forall(i => i >= 0 && i < n), "index must be in range [0, #features)")//索引必须在范围内 assert(termFreqs.map(_._1).toSet.size === 4, "expecting perfect hashing")//期待完美的哈希 val expected = Vectors.sparse(n, termFreqs) //transform 把每个输入文档映射到一个Vector对象 //transform()方法将DataFrame转化为另外一个DataFrame的算法 assert(hashingTF.transform(doc) === expected) } test("hashing tf on an RDD") {//散列TF在RDD val hashingTF = new HashingTF val localDocs: Seq[Seq[String]] = Seq( "a a b b b c d".split(" "), "a b c d a b c".split(" "), "c b a c b a a".split(" ")) val docs = sc.parallelize(localDocs, 2) //transform()方法将DataFrame转化为另外一个DataFrame的算法 assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet) } }
Example 126
Source File: ImpuritySuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.tree import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} import org.apache.spark.mllib.util.MLlibTestSparkContext class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext { test("Gini impurity does not support negative labels") {//基尼杂质不支持负标签 val gini = new GiniAggregator(2) intercept[IllegalArgumentException] { gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) } } test("Entropy does not support negative labels") {//熵不支持负标签 val entropy = new EntropyAggregator(2) intercept[IllegalArgumentException] { entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) } } }
Example 127
Source File: BaggedPointSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.tree.impl import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.EnsembleTestHelper import org.apache.spark.mllib.util.MLlibTestSparkContext class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) baggedRDD.collect().foreach { baggedPoint => assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) } } test("BaggedPoint RDD: with subsampling with replacement (fraction = 1.0)") { val numSubsamples = 100 val (expectedMean, expectedStddev) = (1.0, 1.0) val seeds = Array(123, 5354, 230, 349867, 23987) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, //epsilon代收敛的阀值 expectedStddev, epsilon = 0.01) } } test("BaggedPoint RDD: with subsampling with replacement (fraction = 0.5)") { val numSubsamples = 100 val subsample = 0.5 //math.abs返回数的绝对值 val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample)) val seeds = Array(123, 5354, 230, 349867, 23987) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) } } test("BaggedPoint RDD: with subsampling without replacement (fraction = 1.0)") { val numSubsamples = 100 val (expectedMean, expectedStddev) = (1.0, 0) val seeds = Array(123, 5354, 230, 349867, 23987) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) } } test("BaggedPoint RDD: with subsampling without replacement (fraction = 0.5)") { val numSubsamples = 100 val subsample = 0.5 //math.abs返回数的绝对值 val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample * (1 - subsample))) val seeds = Array(123, 5354, 230, 349867, 23987) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) } } }
Example 128
Source File: MatrixFactorizationModelSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.recommendation import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils sqlContext.createDataFrame(prodFeatures).show() } test("constructor") {//构造函数 val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) //预测得分,用户ID,产品ID println("========"+model.predict(0, 2)) //17.0 assert(model.predict(0, 2) ~== 17.0 relTol 1e-14) intercept[IllegalArgumentException] { new MatrixFactorizationModel(1, userFeatures, prodFeatures) } //userFeatures 用户特征 val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0)))) intercept[IllegalArgumentException] { new MatrixFactorizationModel(rank, userFeatures1, prodFeatures) } //prodFeatures 产品特征 val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0)))) intercept[IllegalArgumentException] { new MatrixFactorizationModel(rank, userFeatures, prodFeatures1) } } test("save/load") {//保存/加载 val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = { features.mapValues(_.toSeq).collect().toSet } try { model.save(sc, path) val newModel = MatrixFactorizationModel.load(sc, path) assert(newModel.rank === rank) //用户特征 assert(collect(newModel.userFeatures) === collect(userFeatures)) //产品特征 assert(collect(newModel.productFeatures) === collect(prodFeatures)) } finally { Utils.deleteRecursively(tempDir) } } test("batch predict API recommendProductsForUsers") {//批量预测API recommendproductsforusers val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) val topK = 10 //为用户推荐个数为num的商品 val recommendations = model.recommendProductsForUsers(topK).collectAsMap() assert(recommendations(0)(0).rating ~== 17.0 relTol 1e-14) assert(recommendations(1)(0).rating ~== 39.0 relTol 1e-14) } test("batch predict API recommendUsersForProducts") { //userFeatures用户因子,prodFeatures商品因子,rank因子个数,因子个数一般越多越好,普通取值10到200 val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) val topK = 10 //为用户推荐个数为num的商品 val recommendations = model.recommendUsersForProducts(topK).collectAsMap() assert(recommendations(2)(0).user == 1) assert(recommendations(2)(0).rating ~== 39.0 relTol 1e-14) assert(recommendations(2)(1).user == 0) assert(recommendations(2)(1).rating ~== 17.0 relTol 1e-14) } }
Example 129
Source File: AreaUnderCurveSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ class AreaUnderCurveSuite extends SparkFunSuite with MLlibTestSparkContext { test("auc computation") {//AUC计算 //曲线 val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 assert(AreaUnderCurve.of(curve) ~== auc absTol 1E-5) //1e-5的意思就是1乘以10的负5次幂.就是0.000001 val rddCurve = sc.parallelize(curve, 2) assert(AreaUnderCurve.of(rddCurve) ~== auc absTol 1E-5) } test("auc of an empty curve") {//AUC空曲线 //曲线 val curve = Seq.empty[(Double, Double)] assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5) } test("auc of a curve with a single point") {//单点与曲线的AUC val curve = Seq((1.0, 1.0)) assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5) } }
Example 130
Source File: PythonMLLibAPISuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.api.python import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.recommendation.Rating class PythonMLLibAPISuite extends SparkFunSuite { SerDe.initialize() test("pickle vector") { val vectors = Seq( Vectors.dense(Array.empty[Double]), Vectors.dense(0.0), Vectors.dense(0.0, -2.0), Vectors.sparse(0, Array.empty[Int], Array.empty[Double]), Vectors.sparse(1, Array.empty[Int], Array.empty[Double]), Vectors.sparse(2, Array(1), Array(-2.0))) vectors.foreach { v => val u = SerDe.loads(SerDe.dumps(v)) assert(u.getClass === v.getClass) assert(u === v) } } test("pickle labeled point") { val points = Seq( LabeledPoint(0.0, Vectors.dense(Array.empty[Double])), LabeledPoint(1.0, Vectors.dense(0.0)), LabeledPoint(-0.5, Vectors.dense(0.0, -2.0)), LabeledPoint(0.0, Vectors.sparse(0, Array.empty[Int], Array.empty[Double])), LabeledPoint(1.0, Vectors.sparse(1, Array.empty[Int], Array.empty[Double])), LabeledPoint(-0.5, Vectors.sparse(2, Array(1), Array(-2.0)))) points.foreach { p => val q = SerDe.loads(SerDe.dumps(p)).asInstanceOf[LabeledPoint] assert(q.label === p.label) assert(q.features.getClass === p.features.getClass) assert(q.features === p.features) } } test("pickle double") { for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) { val deser = SerDe.loads(SerDe.dumps(x.asInstanceOf[AnyRef])).asInstanceOf[Double] // We use `equals` here for comparison because we cannot use `==` for NaN assert(x.equals(deser)) } } test("pickle matrix") { val values = Array[Double](0, 1.2, 3, 4.56, 7, 8) val matrix = Matrices.dense(2, 3, values) val nm = SerDe.loads(SerDe.dumps(matrix)).asInstanceOf[DenseMatrix] assert(matrix === nm) // Test conversion for empty matrix val empty = Array[Double]() val emptyMatrix = Matrices.dense(0, 0, empty) val ne = SerDe.loads(SerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix] assert(emptyMatrix == ne) val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4)) val nsm = SerDe.loads(SerDe.dumps(sm)).asInstanceOf[SparseMatrix] assert(sm.toArray === nsm.toArray) val smt = new SparseMatrix( 3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9), isTransposed = true) val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix] assert(smt.toArray === nsmt.toArray) } test("pickle rating") { val rat = new Rating(1, 2, 3.0) val rat2 = SerDe.loads(SerDe.dumps(rat)).asInstanceOf[Rating] assert(rat == rat2) // Test name of class only occur once val rats = (1 to 10).map(x => new Rating(x, x + 1, x + 3.0)).toArray val bytes = SerDe.dumps(rats) assert(bytes.toString.split("Rating").length == 1) assert(bytes.length / 10 < 25) // 25 bytes per rating } }
Example 131
Source File: FPTreeSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.fpm import scala.language.existentials import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("add transaction") {//增加转换 val tree = new FPTree[String] .add(Seq("a", "b", "c")) .add(Seq("a", "b", "y")) .add(Seq("b")) assert(tree.root.children.size == 2) assert(tree.root.children.contains("a")) assert(tree.root.children("a").item.equals("a")) assert(tree.root.children("a").count == 2) assert(tree.root.children.contains("b")) assert(tree.root.children("b").item.equals("b")) assert(tree.root.children("b").count == 1) var child = tree.root.children("a") assert(child.children.size == 1) assert(child.children.contains("b")) assert(child.children("b").item.equals("b")) assert(child.children("b").count == 2) child = child.children("b") assert(child.children.size == 2) assert(child.children.contains("c")) assert(child.children.contains("y")) assert(child.children("c").item.equals("c")) assert(child.children("y").item.equals("y")) assert(child.children("c").count == 1) assert(child.children("y").count == 1) } test("merge tree") {//合并树 val tree1 = new FPTree[String] .add(Seq("a", "b", "c")) .add(Seq("a", "b", "y")) .add(Seq("b")) val tree2 = new FPTree[String] .add(Seq("a", "b")) .add(Seq("a", "b", "c")) .add(Seq("a", "b", "c", "d")) .add(Seq("a", "x")) .add(Seq("a", "x", "y")) .add(Seq("c", "n")) .add(Seq("c", "m")) val tree3 = tree1.merge(tree2) assert(tree3.root.children.size == 3) assert(tree3.root.children("a").count == 7) assert(tree3.root.children("b").count == 1) assert(tree3.root.children("c").count == 2) val child1 = tree3.root.children("a") assert(child1.children.size == 2) assert(child1.children("b").count == 5) assert(child1.children("x").count == 2) val child2 = child1.children("b") assert(child2.children.size == 2) assert(child2.children("y").count == 1) assert(child2.children("c").count == 3) val child3 = child2.children("c") assert(child3.children.size == 1) assert(child3.children("d").count == 1) val child4 = child1.children("x") assert(child4.children.size == 1) assert(child4.children("y").count == 1) val child5 = tree3.root.children("c") assert(child5.children.size == 2) assert(child5.children("n").count == 1) assert(child5.children("m").count == 1) } test("extract freq itemsets") {//频繁项集的提取物 val tree = new FPTree[String] .add(Seq("a", "b", "c")) .add(Seq("a", "b", "y")) .add(Seq("a", "b")) .add(Seq("a")) .add(Seq("b")) .add(Seq("b", "n")) val freqItemsets = tree.extract(3L).map { case (items, count) => (items.toSet, count) }.toSet val expected = Set( (Set("a"), 4L), (Set("b"), 5L), (Set("a", "b"), 3L)) assert(freqItemsets === expected) } }
Example 132
Source File: AssociationRulesSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext //频繁模式挖掘-Association Rules class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { test("association rules using String type") {//使用字符串类型的关联规则 val freqItemsets = sc.parallelize(Seq(//频繁项集 (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), (Set("r"), 3L), (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L), (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L), (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L), (Set("t", "y", "x"), 3L), (Set("t", "y", "x", "z"), 3L) ).map { case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq) }) //频繁模式挖掘-Association Rules val ar = new AssociationRules() val results1 = ar .setMinConfidence(0.9) .run(freqItemsets) .collect() assert(results2.size === 30) //math.abs返回数的绝对值 assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23) } }
Example 133
Source File: KernelDensitySuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.stat import org.apache.commons.math3.distribution.NormalDistribution import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext { test("kernel density single sample") {//核密度单样本 val rdd = sc.parallelize(Array(5.0)) val evaluationPoints = Array(5.0, 6.0) val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) val normal = new NormalDistribution(5.0, 3.0) val acceptableErr = 1e-6 //math.abs返回数的绝对值 assert(math.abs(densities(0) - normal.density(5.0)) < acceptableErr) assert(math.abs(densities(1) - normal.density(6.0)) < acceptableErr) } test("kernel density multiple samples") {//核密度多样本 val rdd = sc.parallelize(Array(5.0, 10.0)) val evaluationPoints = Array(5.0, 6.0) val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) val normal1 = new NormalDistribution(5.0, 3.0) val normal2 = new NormalDistribution(10.0, 3.0) val acceptableErr = 1e-6 //math.abs返回数的绝对值 assert(math.abs( densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2) < acceptableErr) assert(math.abs( densities(1) - (normal1.density(6.0) + normal2.density(6.0)) / 2) < acceptableErr) } }
Example 134
Source File: MultivariateGaussianSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.stat.distribution import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{ Vectors, Matrices } import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext { test("univariate") {//单变量 val x1 = Vectors.dense(0.0) val x2 = Vectors.dense(1.5) val mu = Vectors.dense(0.0) //密集矩阵 val sigma1 = Matrices.dense(1, 1, Array(1.0)) //多元高斯 val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5) val sigma2 = Matrices.dense(1, 1, Array(4.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5) } test("multivariate") {//多变量 val x1 = Vectors.dense(0.0, 0.0)//创建密集向量 val x2 = Vectors.dense(1.0, 1.0)//创建密集向量 val mu = Vectors.dense(0.0, 0.0)//创建密集向量 val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5) val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5) } test("multivariate degenerate") {//多元退化 val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) val mu = Vectors.dense(0.0, 0.0) val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0)) val dist = new MultivariateGaussian(mu, sigma) assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5) assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5) } test("SPARK-11302") { val x = Vectors.dense(629, 640, 1.7188, 618.19) val mu = Vectors.dense( 1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697) val sigma = Matrices.dense(4, 4, Array( 166769.00466698944, 169336.6705268059, 12.820670788921873, 164243.93314092053, 169336.6705268059, 172041.5670061245, 21.62590020524533, 166678.01075856484, 12.820670788921873, 21.62590020524533, 0.872524191943962, 4.283255814732373, 164243.93314092053, 166678.01075856484, 4.283255814732373, 161848.9196719207)) val dist = new MultivariateGaussian(mu, sigma) // Agrees with R's dmvnorm: 7.154782e-05 assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9) } }
Example 135
Source File: KMeansPMMLModelExportSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.pmml.export import org.dmg.pmml.ClusteringModel import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors class KMeansPMMLModelExportSuite extends SparkFunSuite { test("KMeansPMMLModelExport generate PMML format") { val clusterCenters = Array( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0)) val kmeansModel = new KMeansModel(clusterCenters) val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) // assert that the PMML format is as expected assert(modelExport.isInstanceOf[PMMLModelExport]) val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml assert(pmml.getHeader.getDescription === "k-means clustering") // check that the number of fields match the single vector size //clusterCenters聚类中心点 assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) // This verify that there is a model attached to the pmml object and the model is a clustering // one. It also verifies that the pmml model has the same number of clusters of the spark model. val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) } }
Example 136
Source File: PMMLModelExportFactorySuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.pmml.export import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel} import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator val multiclassLogisticRegressionModel = new LogisticRegressionModel( weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, //numClasses 分类数 numFeatures = 2, numClasses = 3) intercept[IllegalArgumentException] { PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel) } } test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") { val invalidModel = new Object intercept[IllegalArgumentException] { PMMLModelExportFactory.createPMMLModelExport(invalidModel) } } }
Example 137
Source File: NumericParserSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.util import org.apache.spark.{SparkException, SparkFunSuite} class NumericParserSuite extends SparkFunSuite { test("parser") {//解析 val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)" val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]] assert(parsed(0).asInstanceOf[Seq[_]] === Seq(1.0, 2.0e3)) assert(parsed(1).asInstanceOf[Double] === -4.0) assert(parsed(2).asInstanceOf[Array[Double]] === Array(5.0e-6, 7.0e8)) assert(parsed(3).asInstanceOf[Double] === 9.0) val malformatted = Seq("a", "[1,,]", "0.123.4", "1 2", "3+4") malformatted.foreach { s => intercept[SparkException] { NumericParser.parse(s) throw new RuntimeException(s"Didn't detect malformatted string $s.") } } } test("parser with whitespaces") {//空格的解析 val s = "(0.0, [1.0, 2.0])" //数字解析 val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]] assert(parsed(0).asInstanceOf[Double] === 0.0) assert(parsed(1).asInstanceOf[Array[Double]] === Array(1.0, 2.0)) } }
Example 138
Source File: BreezeMatrixConversionSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.linalg import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM} import org.apache.spark.SparkFunSuite class BreezeMatrixConversionSuite extends SparkFunSuite { test("dense matrix to breeze") {//稠密矩阵转换breeze矩阵 val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) val breeze = mat.toBreeze.asInstanceOf[BDM[Double]] assert(breeze.rows === mat.numRows) assert(breeze.cols === mat.numCols) assert(breeze.data.eq(mat.asInstanceOf[DenseMatrix].values), "should not copy data") } test("dense breeze matrix to matrix") {//稠密breeze矩阵转换矩阵 val breeze = new BDM[Double](3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) val mat = Matrices.fromBreeze(breeze).asInstanceOf[DenseMatrix] assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") // transposed matrix 转置矩阵 val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[DenseMatrix] assert(matTransposed.numRows === breeze.cols) assert(matTransposed.numCols === breeze.rows) assert(matTransposed.values.eq(breeze.data), "should not copy data") } test("sparse matrix to breeze") {//稀疏矩阵转换breeze矩阵 val values = Array(1.0, 2.0, 4.0, 5.0) val colPtrs = Array(0, 2, 4) val rowIndices = Array(1, 2, 1, 2) val mat = Matrices.sparse(3, 2, colPtrs, rowIndices, values) val breeze = mat.toBreeze.asInstanceOf[BSM[Double]] assert(breeze.rows === mat.numRows) assert(breeze.cols === mat.numCols) assert(breeze.data.eq(mat.asInstanceOf[SparseMatrix].values), "should not copy data") } test("sparse breeze matrix to sparse matrix") {//稀疏breeze矩阵转换稀疏矩阵 val values = Array(1.0, 2.0, 4.0, 5.0) val colPtrs = Array(0, 2, 4) val rowIndices = Array(1, 2, 1, 2) val breeze = new BSM[Double](values, 3, 2, colPtrs, rowIndices) val mat = Matrices.fromBreeze(breeze).asInstanceOf[SparseMatrix] assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[SparseMatrix] assert(matTransposed.numRows === breeze.cols) assert(matTransposed.numCols === breeze.rows) assert(!matTransposed.values.eq(breeze.data), "has to copy data") } }
Example 139
Source File: CoordinateMatrixSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors val blockMat = mat.toBlockMatrix(2, 2) assert(blockMat.numRows() === m) assert(blockMat.numCols() === n) assert(blockMat.toBreeze() === mat.toBreeze()) intercept[IllegalArgumentException] { mat.toBlockMatrix(-1, 2) } intercept[IllegalArgumentException] { mat.toBlockMatrix(2, 0) } } }
Example 140
Source File: BreezeVectorConversionSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.linalg import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} import org.apache.spark.SparkFunSuite class BreezeVectorConversionSuite extends SparkFunSuite { val arr = Array(0.1, 0.2, 0.3, 0.4) val n = 20 val indices = Array(0, 3, 5, 10, 13) val values = Array(0.1, 0.5, 0.3, -0.8, -1.0) test("dense to breeze") {//密集矩阵转换breeze矩阵 val vec = Vectors.dense(arr) assert(vec.toBreeze === new BDV[Double](arr)) } test("sparse to breeze") {//稀疏矩阵转换breeze矩阵 val vec = Vectors.sparse(n, indices, values) assert(vec.toBreeze === new BSV[Double](indices, values, n)) } test("dense breeze to vector") {//密集breeze矩阵转换向量 val breeze = new BDV[Double](arr) val vec = Vectors.fromBreeze(breeze).asInstanceOf[DenseVector] assert(vec.size === arr.length) assert(vec.values.eq(arr), "should not copy data") } test("sparse breeze to vector") {//稀疏breeze矩阵转换向量 val breeze = new BSV[Double](indices, values, n) val vec = Vectors.fromBreeze(breeze).asInstanceOf[SparseVector] assert(vec.size === n) assert(vec.indices.eq(indices), "should not copy data") assert(vec.values.eq(values), "should not copy data") } test("sparse breeze with partially-used arrays to vector") { val activeSize = 3 val breeze = new BSV[Double](indices, values, activeSize, n) val vec = Vectors.fromBreeze(breeze).asInstanceOf[SparseVector] assert(vec.size === n) assert(vec.indices === indices.slice(0, activeSize)) assert(vec.values === values.slice(0, activeSize)) } }
Example 141
Source File: LabeledPointSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.regression import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors class LabeledPointSuite extends SparkFunSuite { println(p+"||||"+LabeledPoint.parse(p.toString)) assert(p === LabeledPoint.parse(p.toString))} } } test("parse labeled points with whitespaces") {//解析标记点的空格 //标记点字符串的解析 //LabeledPoint标记点是局部向量,向量可以是密集型或者稀疏型,每个向量会关联了一个标签(label) val point = LabeledPoint.parse("(0.0, [1.0, 2.0])") assert(point === LabeledPoint(0.0, Vectors.dense(1.0, 2.0))) } test("parse labeled points with v0.9 format") {//解析标记点的V0.9格式 //默认密集向量,未指定标记默认1 val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0") //LabeledPoint标记点是局部向量,向量可以是密集型或者稀疏型,每个向量会关联了一个标签(label) assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0))) } }
Example 142
Source File: MLPairRDDFunctionsSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.rdd import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("topByKey") { val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5, 1), (3, 5)), 2) // .topByKey(5) //以k转换map数组 .collectAsMap() assert(topMap.size === 3) assert(topMap(1) === Array(7, 6, 3, 2, 1)) assert(topMap(3) === Array(7, 5, 2)) assert(topMap(5) === Array(1)) } }
Example 143
Source File: RDDFunctionsSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.mllib.rdd import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("sliding") {//滑动 val data = 0 until 6 for (numPartitions <- 1 to 8) { val rdd = sc.parallelize(data, numPartitions) for (windowSize <- 1 to 6) { val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList val expected = data.sliding(windowSize).map(_.toList).toList assert(sliding === expected) } assert(rdd.sliding(7).collect().isEmpty, //应该返回一个空盘如果窗口大小大于物品的数量 "Should return an empty RDD if the window size is greater than the number of items.") } } test("sliding with empty partitions") {//带空分区的滑动 val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) // Array(1, 2, 3, 4, 5, 6, 7) val rdd = sc.parallelize(data, data.length).flatMap(s => s) //data.length = 5 assert(rdd.partitions.size === data.length) //设置数据平滑窗口 val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq) //expected: Seq[Seq[Int]] = Stream(List(1, 2, 3), ?) val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) assert(sliding === expected) } }
Example 144
Source File: TwitterStreamSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.twitter import org.scalatest.BeforeAndAfter import twitter4j.Status import twitter4j.auth.{NullAuthorization, Authorization} import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { val batchDuration = Seconds(1) private val master: String = "local[2]" private val framework: String = this.getClass.getSimpleName test("twitter input stream") { val ssc = new StreamingContext(master, framework, batchDuration) val filters = Seq("filter1", "filter2") val authorization: Authorization = NullAuthorization.getInstance() // tests the API, does not actually test data receiving val test1: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, None) val test2: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, None, filters) val test3: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_AND_DISK_SER_2) val test4: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, Some(authorization)) val test5: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, Some(authorization), filters) val test6: ReceiverInputDStream[Status] = TwitterUtils.createStream( ssc, Some(authorization), filters, StorageLevel.MEMORY_AND_DISK_SER_2) // Note that actually testing the data receiving is hard as authentication keys are // necessary for accessing Twitter live stream ssc.stop() } }
Example 145
Source File: FlumePollingStreamSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.flume import java.net.InetSocketAddress import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.util.{ManualClock, Utils} private def testMultipleTimes(test: () => Unit): Unit = { var testPassed = false var attempt = 0 while (!testPassed && attempt < maxAttempts) { try { test() testPassed = true } catch { case e: Exception if Utils.isBindCollision(e) => logWarning("Exception when running flume polling test: " + e) attempt += 1 } } assert(testPassed, s"Test failed after $attempt attempts!") } private def testFlumePolling(): Unit = { try { val port = utils.startSingleSink() writeAndVerify(Seq(port)) utils.assertChannelsAreEmpty() } finally { utils.close() } } private def testFlumePollingMultipleHost(): Unit = { try { val ports = utils.startMultipleSinks() writeAndVerify(ports) utils.assertChannelsAreEmpty() } finally { utils.close() } } def writeAndVerify(sinkPorts: Seq[Int]): Unit = { // Set up the streaming context and input streams //设置流上下文和输入流 val ssc = new StreamingContext(conf, batchDuration) val addresses = sinkPorts.map(port => new InetSocketAddress("localhost", port)) val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, utils.eventsPerBatch, 5) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) outputStream.register() ssc.start() try { utils.sendDatAndEnsureAllDataHasBeenReceived() val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] clock.advance(batchDuration.milliseconds) // The eventually is required to ensure that all data in the batch has been processed. //最终需要确保批处理中的所有数据已被处理 eventually(timeout(10 seconds), interval(100 milliseconds)) { val flattenOutputBuffer = outputBuffer.flatten val headers = flattenOutputBuffer.map(_.event.getHeaders.map { case kv => (kv._1.toString, kv._2.toString) }).map(mapAsJavaMap) val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) utils.assertOutput(headers, bodies) } } finally { ssc.stop() } } }
Example 146
Source File: FlumeStreamSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.flume import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.base.Charsets import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory import org.jboss.netty.handler.codec.compression._ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { override def newChannel(pipeline: ChannelPipeline): SocketChannel = { val encoder = new ZlibEncoder(compressionLevel) pipeline.addFirst("deflater", encoder) pipeline.addFirst("inflater", new ZlibDecoder()) super.newChannel(pipeline) } } }
Example 147
Source File: ZeroMQStreamSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.zeromq import akka.actor.SupervisorStrategy import akka.util.ByteString import akka.zeromq.Subscribe import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream class ZeroMQStreamSuite extends SparkFunSuite { val batchDuration = Seconds(1) private val master: String = "local[2]" private val framework: String = this.getClass.getSimpleName test("zeromq input stream") { val ssc = new StreamingContext(master, framework, batchDuration) val publishUrl = "abc" val subscribe = new Subscribe(null.asInstanceOf[ByteString]) val bytesToObjects = (bytes: Seq[ByteString]) => null.asInstanceOf[Iterator[String]] // tests the API, does not actually test data receiving val test1: ReceiverInputDStream[String] = ZeroMQUtils.createStream(ssc, publishUrl, subscribe, bytesToObjects) val test2: ReceiverInputDStream[String] = ZeroMQUtils.createStream( ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2) val test3: ReceiverInputDStream[String] = ZeroMQUtils.createStream( ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2, SupervisorStrategy.defaultStrategy) // TODO: Actually test data receiving ssc.stop() } }
Example 148
Source File: MQTTStreamSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.mqtt import scala.concurrent.duration._ import scala.language.postfixOps import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" private val framework = this.getClass.getSimpleName private val topic = "def" private var ssc: StreamingContext = _ private var mqttTestUtils: MQTTTestUtils = _ before { ssc = new StreamingContext(master, framework, batchDuration) mqttTestUtils = new MQTTTestUtils mqttTestUtils.setup() } after { if (ssc != null) { ssc.stop() ssc = null } if (mqttTestUtils != null) { mqttTestUtils.teardown() mqttTestUtils = null } } test("mqtt input stream") { val sendMessage = "MQTT demo for spark streaming" val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic, StorageLevel.MEMORY_ONLY) @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => if (rdd.collect.length > 0) { receiveMessage = receiveMessage ::: List(rdd.first) receiveMessage } } ssc.start() // Retry it because we don't know when the receiver will start. eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { mqttTestUtils.publishData(topic, sendMessage) assert(sendMessage.equals(receiveMessage(0))) } ssc.stop() } }
Example 149
Source File: KafkaStreamSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.kafka import scala.collection.mutable import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random import kafka.serializer.StringDecoder import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll { private var ssc: StreamingContext = _ private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll(): Unit = { kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } override def afterAll(): Unit = { if (ssc != null) { ssc.stop() ssc = null } if (kafkaTestUtils != null) { kafkaTestUtils.teardown() kafkaTestUtils = null } } test("Kafka input stream") {//Kafka输入流 val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) ssc = new StreamingContext(sparkConf, Milliseconds(500)) val topic = "topic1" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) kafkaTestUtils.createTopic(topic) kafkaTestUtils.sendMessages(topic, sent) val kafkaParams = Map("zookeeper.connect" -> kafkaTestUtils.zkAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}", "auto.offset.reset" -> "smallest") val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) val result = new mutable.HashMap[String, Long]() with mutable.SynchronizedMap[String, Long] stream.map(_._2).countByValue().foreachRDD { r => val ret = r.collect() ret.toMap.foreach { kv => val count = result.getOrElseUpdate(kv._1, 0) + kv._2 result.put(kv._1, count) } } ssc.start() eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { assert(sent === result) } } }
Example 150
Source File: KafkaClusterSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.kafka import scala.util.Random import kafka.common.TopicAndPartition import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll { private val topic = "kcsuitetopic" + Random.nextInt(10000) private val topicAndPartition = TopicAndPartition(topic, 0) // private var kc: KafkaCluster = null private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll() { kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() kafkaTestUtils.createTopic(topic) kafkaTestUtils.sendMessages(topic, Map("a" -> 1)) kc = new KafkaCluster(Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress)) } override def afterAll() { if (kafkaTestUtils != null) { kafkaTestUtils.teardown() kafkaTestUtils = null } } test("metadata apis") {//元数据API val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition) val leaderAddress = s"${leader._1}:${leader._2}" assert(leaderAddress === kafkaTestUtils.brokerAddress, "didn't get leader") val parts = kc.getPartitions(Set(topic)).right.get assert(parts(topicAndPartition), "didn't get partitions") val err = kc.getPartitions(Set(topic + "BAD")) assert(err.isLeft, "getPartitions for a nonexistant topic should be an error") } test("leader offset apis") {//指挥者偏移API val earliest = kc.getEarliestLeaderOffsets(Set(topicAndPartition)).right.get assert(earliest(topicAndPartition).offset === 0, "didn't get earliest") val latest = kc.getLatestLeaderOffsets(Set(topicAndPartition)).right.get assert(latest(topicAndPartition).offset === 1, "didn't get latest") } test("consumer offset apis") {//消费者偏移API val group = "kcsuitegroup" + Random.nextInt(10000) val offset = Random.nextInt(10000) val set = kc.setConsumerOffsets(group, Map(topicAndPartition -> offset)) assert(set.isRight, "didn't set consumer offsets") val get = kc.getConsumerOffsets(group, Set(topicAndPartition)).right.get assert(get(topicAndPartition) === offset, "didn't get consumer offsets") } }
Example 151
Source File: OrcTest.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.orc import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.test.SQLTestUtils private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive protected val sqlContext = _sqlContext import sqlContext.implicits._ import sqlContext.sparkContext protected def withOrcTable[T <: Product: ClassTag: TypeTag] (data: Seq[T], tableName: String) (f: => Unit): Unit = { withOrcDataFrame(data) { df => sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { data.toDF().write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath) } protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { df.write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath) } }
Example 152
Source File: ConcurrentHiveSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.execution import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext import org.scalatest.BeforeAndAfterAll class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") {//不支持多个实例 test("Multiple Hive Instances") { (1 to 10).map { i => val ts = new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", new SparkConf())) ts.executeSql("SHOW TABLES").toRdd.collect() ts.executeSql("SELECT * FROM src").toRdd.collect() ts.executeSql("SHOW TABLES").toRdd.collect() } } } }
Example 153
Source File: FiltersSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive.client import scala.collection.JavaConversions._ import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.serde.serdeConstants import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ class FiltersSuite extends SparkFunSuite with Logging { private val shim = new Shim_v0_13 private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test") private val varCharCol = new FieldSchema() varCharCol.setName("varchar") varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME) testTable.setPartCols(varCharCol :: Nil) //字符串过滤器 filterTest("string filter", (a("stringcol", StringType) > Literal("test")) :: Nil, "stringcol > \"test\"") //字符串过滤器向后 filterTest("string filter backwards", (Literal("test") > a("stringcol", StringType)) :: Nil, "\"test\" > stringcol") //int过滤器 filterTest("int filter", (a("intcol", IntegerType) === Literal(1)) :: Nil, "intcol = 1") //int向后过滤 filterTest("int filter backwards", (Literal(1) === a("intcol", IntegerType)) :: Nil, "1 = intcol") filterTest("int and string filter", (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil, "1 = intcol and \"a\" = strcol") filterTest("skip varchar", (Literal("") === a("varchar", StringType)) :: Nil, "") private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name){ val converted = shim.convertFilters(testTable, filters) if (converted != result) { fail( s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'") } } } private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() }
Example 154
Source File: SerializationSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer class SerializationSuite extends SparkFunSuite { //HiveContext应该是可序列化的 test("[SPARK-5840] HiveContext should be serializable") { val hiveContext = org.apache.spark.sql.hive.test.TestHive hiveContext.hiveconf val serializer = new JavaSerializer(new SparkConf()).newInstance() val bytes = serializer.serialize(hiveContext) val deSer = serializer.deserialize[AnyRef](bytes) } }
Example 155
Source File: ClasspathDependenciesSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.hive import java.net.URL import org.apache.spark.SparkFunSuite class ClasspathDependenciesSuite extends SparkFunSuite { //ClassLoader就是用来动态加载class文件到内存当中用 private val classloader = this.getClass.getClassLoader private def assertLoads(classname: String): Unit = { val resourceURL: URL = Option(findResource(classname)).getOrElse { fail(s"Class $classname not found as ${resourceName(classname)}") } logInfo(s"Class $classname at $resourceURL") classloader.loadClass(classname) } private def assertLoads(classes: String*): Unit = { classes.foreach(assertLoads) } private def findResource(classname: String): URL = { val resource = resourceName(classname) classloader.getResource(resource) } private def resourceName(classname: String): String = { classname.replace(".", "/") + ".class" } private def assertClassNotFound(classname: String): Unit = { Option(findResource(classname)).foreach { resourceURL => fail(s"Class $classname found at $resourceURL") } intercept[ClassNotFoundException] { classloader.loadClass(classname) } } private def assertClassNotFound(classes: String*): Unit = { classes.foreach(assertClassNotFound) } private val KRYO = "com.esotericsoftware.kryo.Kryo" private val SPARK_HIVE = "org.apache.hive." private val SPARK_SHADED = "org.spark-project.hive.shaded." test("shaded Protobuf") { assertLoads(SPARK_SHADED + "com.google.protobuf.ServiceException") } test("hive-common") { assertLoads("org.apache.hadoop.hive.conf.HiveConf") } test("hive-exec") { assertLoads("org.apache.hadoop.hive.ql.CommandNeedRetryException") } private val STD_INSTANTIATOR = "org.objenesis.strategy.StdInstantiatorStrategy" test("unshaded kryo") { assertLoads(KRYO, STD_INSTANTIATOR) } test("Forbidden Dependencies") { assertClassNotFound( SPARK_HIVE + KRYO, SPARK_SHADED + KRYO, "org.apache.hive." + KRYO, "com.esotericsoftware.shaded." + STD_INSTANTIATOR, SPARK_HIVE + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, "org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR ) } test("parquet-hadoop-bundle") { assertLoads( "parquet.hadoop.ParquetOutputFormat", "parquet.hadoop.ParquetInputFormat" ) } }
Example 156
Source File: CommitFailureTestRelationSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { override def _sqlContext: SQLContext = TestHive private val sqlContext = _sqlContext // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. //提交任务时,“CommitFailureTestSource”会为测试目的引发异常 val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName //commitTask()失败应该回退到abortTask() test("SPARK-7684: commitTask() failure should fallback to abortTask()") { withTempPath { file => // Here we coalesce partition number to 1 to ensure that only a single task is issued. This // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` // directory while committing/aborting the job. See SPARK-8513 for more details. //这里我们将分区号合并为1,以确保只发出一个任务, 这个防止当FileOutputCommitter尝试删除`_temporary`时发生竞争条件 //目录提交/中止作业, 有关详细信息,请参阅SPARK-8513 val df = sqlContext.range(0, 10).coalesce(1) intercept[SparkException] { df.write.format(dataSourceName).save(file.getCanonicalPath) } val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) } } }
Example 157
Source File: RandomDataGeneratorSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types._ def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = { val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType) val generator = RandomDataGenerator.forType(dataType, nullable, Some(33)).getOrElse { fail(s"Random data generator was not defined for $dataType") } if (nullable) { assert(Iterator.fill(100)(generator()).contains(null)) } else { assert(Iterator.fill(100)(generator()).forall(_ != null)) } for (_ <- 1 to 10) { val generatedValue = generator() toCatalyst(generatedValue) } } // Basic types: for ( dataType <- DataTypeTestUtils.atomicTypes; nullable <- Seq(true, false) if !dataType.isInstanceOf[DecimalType]) { test(s"$dataType (nullable=$nullable)") { testRandomDataGeneration(dataType) } } for ( arrayType <- DataTypeTestUtils.atomicArrayTypes if RandomDataGenerator.forType(arrayType.elementType, arrayType.containsNull).isDefined ) { test(s"$arrayType") { testRandomDataGeneration(arrayType) } } val atomicTypesWithDataGenerators = DataTypeTestUtils.atomicTypes.filter(RandomDataGenerator.forType(_).isDefined) // Complex types: for ( keyType <- atomicTypesWithDataGenerators; valueType <- atomicTypesWithDataGenerators // Scala's BigDecimal.hashCode can lead to OutOfMemoryError on Scala 2.10 (see SI-6173) and // Spark can hit NumberFormatException errors when converting certain BigDecimals (SPARK-8802). // For these reasons, we don't support generation of maps with decimal keys. if !keyType.isInstanceOf[DecimalType] ) { val mapType = MapType(keyType, valueType) test(s"$mapType") { testRandomDataGeneration(mapType) } } for ( colOneType <- atomicTypesWithDataGenerators; colTwoType <- atomicTypesWithDataGenerators ) { val structType = StructType(StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil) test(s"$structType") { testRandomDataGeneration(structType) } } }
Example 158
Source File: SqlParserSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.Command private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { override def output: Seq[Attribute] = Seq.empty override def children: Seq[LogicalPlan] = Seq.empty } private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") override protected lazy val start: Parser[LogicalPlan] = set private lazy val set: Parser[LogicalPlan] = EXECUTE ~> ident ^^ { case fileName => TestCommand(fileName) } } private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { protected val EXECUTE = Keyword("EXECUTE") override protected lazy val start: Parser[LogicalPlan] = set private lazy val set: Parser[LogicalPlan] = EXECUTE ~> ident ^^ { case fileName => TestCommand(fileName) } } class SqlParserSuite extends SparkFunSuite { test("test long keyword") { val parser = new SuperLongKeywordTestParser assert(TestCommand("NotRealCommand") === parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand")) } test("test case insensitive") { val parser = new CaseInsensitiveTestParser assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand")) assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand")) assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand")) } }
Example 159
Source File: LogicalPlanSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 private val function: PartialFunction[LogicalPlan, LogicalPlan] = { case p: Project => invocationCount += 1 p } private val testRelation = LocalRelation() test("resolveOperator runs on operators") { invocationCount = 0 val plan = Project(Nil, testRelation) plan resolveOperators function assert(invocationCount === 1) } test("resolveOperator runs on operators recursively") { invocationCount = 0 val plan = Project(Nil, Project(Nil, testRelation)) plan resolveOperators function assert(invocationCount === 2) } test("resolveOperator skips all ready resolved plans") { invocationCount = 0 val plan = Project(Nil, Project(Nil, testRelation)) plan.foreach(_.setAnalyzed()) plan resolveOperators function assert(invocationCount === 0) } test("resolveOperator skips partially resolved plans") { invocationCount = 0 val plan1 = Project(Nil, testRelation) val plan2 = Project(Nil, plan1) plan1.foreach(_.setAnalyzed()) plan2 resolveOperators function assert(invocationCount === 1) } }
Example 160
Source File: SameResultSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ class SameResultSuite extends SparkFunSuite { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = { val aAnalyzed = a.analyze val bAnalyzed = b.analyze if (aAnalyzed.sameResult(bAnalyzed) != result) { val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n") fail(s"Plans should return sameResult = $result\n$comparison") } } test("relations") { assertSameResult(testRelation, testRelation2) } test("projections") { assertSameResult(testRelation.select('a), testRelation2.select('a)) assertSameResult(testRelation.select('b), testRelation2.select('b)) assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b)) assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a)) assertSameResult(testRelation, testRelation2.select('a), result = false) assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), result = false) } test("filters") { assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b)) } test("sorts") { assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc)) } }
Example 161
Source File: NondeterministicSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { test("MonotonicallyIncreasingID") { checkEvaluation(MonotonicallyIncreasingID(), 0L) } test("SparkPartitionID") { checkEvaluation(SparkPartitionID(), 0) } test("InputFileName") { checkEvaluation(InputFileName(), "") } }
Example 162
Source File: NullFunctionsSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ test("coalesce") { testAllTypes { (value: Any, tpe: DataType) => val lit = Literal.create(value, tpe) val nullLit = Literal.create(null, tpe) checkEvaluation(Coalesce(Seq(nullLit)), null) checkEvaluation(Coalesce(Seq(lit)), value) checkEvaluation(Coalesce(Seq(nullLit, lit)), value) checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value) checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value) } } test("AtLeastNNonNulls") { val mix = Seq(Literal("x"), Literal.create(null, StringType), Literal.create(null, DoubleType), Literal(Double.NaN), Literal(5f)) val nanOnly = Seq(Literal("x"), Literal(10.0), Literal(Float.NaN), Literal(math.log(-2)), Literal(Double.MaxValue)) val nullOnly = Seq(Literal("x"), Literal.create(null, DoubleType), Literal.create(null, DecimalType.USER_DEFAULT), Literal(Float.MaxValue), Literal(false)) checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow) checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow) checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow) checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow) checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) } }
Example 163
Source File: DecimalExpressionSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.{LongType, DecimalType, Decimal} class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { //非标准的值 test("UnscaledValue") { val d1 = Decimal("10.1") checkEvaluation(UnscaledValue(Literal(d1)), 101L) val d2 = Decimal(101, 3, 1) checkEvaluation(UnscaledValue(Literal(d2)), 101L) checkEvaluation(UnscaledValue(Literal.create(null, DecimalType(2, 1))), null) } //十进制 test("MakeDecimal") { checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1")) checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null) } //提高精度 test("PromotePrecision") { val d1 = Decimal("10.1") checkEvaluation(PromotePrecision(Literal(d1)), d1) val d2 = Decimal(101, 3, 1) checkEvaluation(PromotePrecision(Literal(d2)), d2) checkEvaluation(PromotePrecision(Literal.create(null, DecimalType(2, 1))), null) } //检查溢出 test("CheckOverflow") { val d1 = Decimal("10.1") checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10")) checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1) checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1) checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null) val d2 = Decimal(101, 3, 1) checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10")) checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2) checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2) checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null) checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null) } }
Example 164
Source File: LiteralExpressionSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("null") { checkEvaluation(Literal.create(null, BooleanType), null) checkEvaluation(Literal.create(null, ByteType), null) checkEvaluation(Literal.create(null, ShortType), null) checkEvaluation(Literal.create(null, IntegerType), null) checkEvaluation(Literal.create(null, LongType), null) checkEvaluation(Literal.create(null, FloatType), null) checkEvaluation(Literal.create(null, LongType), null) checkEvaluation(Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, BinaryType), null) checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null) checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) } test("boolean literals") { checkEvaluation(Literal(true), true) checkEvaluation(Literal(false), false) } test("int literals") { List(0, 1, Int.MinValue, Int.MaxValue).foreach { d => checkEvaluation(Literal(d), d) checkEvaluation(Literal(d.toLong), d.toLong) checkEvaluation(Literal(d.toShort), d.toShort) checkEvaluation(Literal(d.toByte), d.toByte) } checkEvaluation(Literal(Long.MinValue), Long.MinValue) checkEvaluation(Literal(Long.MaxValue), Long.MaxValue) } test("double literals") { List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d => checkEvaluation(Literal(d), d) checkEvaluation(Literal(d.toFloat), d.toFloat) } checkEvaluation(Literal(Double.MinValue), Double.MinValue) checkEvaluation(Literal(Double.MaxValue), Double.MaxValue) checkEvaluation(Literal(Float.MinValue), Float.MinValue) checkEvaluation(Literal(Float.MaxValue), Float.MaxValue) } test("string literals") { checkEvaluation(Literal(""), "") checkEvaluation(Literal("test"), "test") checkEvaluation(Literal("\0"), "\0") } test("sum two literals") { checkEvaluation(Add(Literal(1), Literal(1)), 2) } test("binary literals") { checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0)) checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2)) } test("decimal") { List(-0.0001, 0.0, 0.001, 1.2, 1.1111, 5).foreach { d => checkEvaluation(Literal(Decimal(d)), Decimal(d)) checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt)) checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong)) checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 3)), Decimal((d * 1000L).toLong, 10, 3)) checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d)) checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d)) } } // TODO(davies): add tests for ArrayType, MapType and StructType }
Example 165
Source File: CollectionFunctionsSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { //数组和映射大小 test("Array and Map Size") { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) checkEvaluation(Size(a0), 3) checkEvaluation(Size(a1), 0) checkEvaluation(Size(a2), 2) val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) checkEvaluation(Size(m0), 2) checkEvaluation(Size(m1), 0) checkEvaluation(Size(m2), 1) checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) } //数组排序 test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) checkEvaluation(new SortArray(a2), Seq("a", "b")) checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b")) checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1)) checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a")) checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) } //数组包含 test("Array contains") { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) val a3 = Literal.create(null, ArrayType(StringType)) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) checkEvaluation(ArrayContains(a1, Literal("")), true) checkEvaluation(ArrayContains(a1, Literal("a")), null) checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null) checkEvaluation(ArrayContains(a2, Literal(1L)), null) checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null) checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } }
Example 166
Source File: RandomSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite //随机测试套件 class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { test("random") { checkDoubleEvaluation(Rand(30), 0.7363714192755834 +- 0.001) checkDoubleEvaluation(Randn(30), 0.5181478766595276 +- 0.001) } //代码生成长种子 test("SPARK-9127 codegen with long seed") { checkDoubleEvaluation(Rand(5419823303878592871L), 0.4061913198963727 +- 0.001) checkDoubleEvaluation(Randn(5419823303878592871L), -0.24417152005343168 +- 0.001) } }
Example 167
Source File: MiscFunctionsSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType} class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("md5") { checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932") checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "6ac1e56bc78f031059be7be854522c4c") checkEvaluation(Md5(Literal.create(null, BinaryType)), null) checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType) } test("sha1") { checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "5d211bad8f4ee70e16c7d343a838fc344a1ed961") checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType) } test("sha2") { checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC")) checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6))) // unsupported bit length checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null) checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null) checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null) checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null) } test("crc32") { checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L) checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), 2180413220L) checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } }
Example 168
Source File: AttributeSetSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.IntegerType class AttributeSetSuite extends SparkFunSuite { val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) val aSet = AttributeSet(aLower :: Nil) val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) val bSet = AttributeSet(bUpper :: Nil) val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) test("sanity check") { assert(aUpper != aLower) assert(bUpper != bLower) } //按ID检查而不是名称 test("checks by id not name") { assert(aSet.contains(aUpper) === true) assert(aSet.contains(aLower) === true) assert(aSet.contains(fakeA) === false) assert(aSet.contains(bUpper) === false) assert(aSet.contains(bLower) === false) } test("++ preserves AttributeSet") { assert((aSet ++ bSet).contains(aUpper) === true) assert((aSet ++ bSet).contains(aLower) === true) } test("extracts all references references") { val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil) assert(addSet.contains(aUpper)) assert(addSet.contains(aLower)) assert(addSet.contains(bUpper)) assert(addSet.contains(bLower)) } test("dedups attributes") { assert(AttributeSet(aUpper :: aLower :: Nil).size === 1) } test("subset") { assert(aSet.subsetOf(aAndBSet) === true) assert(aAndBSet.subsetOf(aSet) === false) } test("equality") { assert(aSet != aAndBSet) assert(aAndBSet != aSet) assert(aSet != bSet) assert(bSet != aSet) assert(aSet == aSet) assert(aSet == AttributeSet(aUpper :: Nil)) } }
Example 169
Source File: GenerateUnsafeRowJoinerSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions.codegen import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.types._ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { private val fixed = Seq(IntegerType) private val variable = Seq(IntegerType, StringType) //简单的固定宽度类型 test("simple fixed width types") { testConcat(0, 0, fixed) testConcat(0, 1, fixed) testConcat(1, 0, fixed) testConcat(64, 0, fixed) testConcat(0, 64, fixed) testConcat(64, 64, fixed) } //随机化的固定宽度类型 test("randomized fix width types") { for (i <- 0 until 20) { testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed) } } //简单变量宽度类型 test("simple variable width types") { testConcat(0, 0, variable) testConcat(0, 1, variable) testConcat(1, 0, variable) testConcat(64, 0, variable) testConcat(0, 64, variable) testConcat(64, 64, variable) } //随机变量宽度类型 test("randomized variable width types") { for (i <- 0 until 10) { testConcatOnce(Random.nextInt(100), Random.nextInt(100), variable) } } private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]): Unit = { for (i <- 0 until 10) { testConcatOnce(numFields1, numFields2, candidateTypes) } } private def testConcatOnce(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]) { info(s"schema size $numFields1, $numFields2") val schema1 = RandomDataGenerator.randomSchema(numFields1, candidateTypes) val schema2 = RandomDataGenerator.randomSchema(numFields2, candidateTypes) // Create the converters needed to convert from external row to internal row and to UnsafeRows. //创建从外部行转换为内部行和UnsafeRows所需的转换器 val internalConverter1 = CatalystTypeConverters.createToCatalystConverter(schema1) val internalConverter2 = CatalystTypeConverters.createToCatalystConverter(schema2) val converter1 = UnsafeProjection.create(schema1) val converter2 = UnsafeProjection.create(schema2) // Create the input rows, convert them into UnsafeRows. //创建输入行,将它们转换成UnsafeRows val extRow1 = RandomDataGenerator.forType(schema1, nullable = false).get.apply() val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply() val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow]) val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow]) // Run the joiner. val mergedSchema = StructType(schema1 ++ schema2) val concater = GenerateUnsafeRowJoiner.create(schema1, schema2) val output = concater.join(row1, row2) // Test everything equals ... for (i <- mergedSchema.indices) { if (i < schema1.size) { assert(output.isNullAt(i) === row1.isNullAt(i)) if (!output.isNullAt(i)) { assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType)) } } else { assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size)) if (!output.isNullAt(i)) { assert(output.get(i, mergedSchema(i).dataType) === row2.get(i - schema1.size, mergedSchema(i).dataType)) } } } } }
Example 170
Source File: GeneratedProjectionSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class GeneratedProjectionSuite extends SparkFunSuite { //在更宽的桌子上产生预测 test("generated projections on wider table") { val N = 1000 val wideRow1 = new GenericInternalRow((1 to N).toArray[Any]) val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) val wideRow2 = new GenericInternalRow( (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) val schema2 = StructType((1 to N).map(i => StructField("", StringType))) val joined = new JoinedRow(wideRow1, wideRow2) val joinedSchema = StructType(schema1 ++ schema2) val nested = new JoinedRow(InternalRow(joined, joined), joined) val nestedSchema = StructType( Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) // test generated UnsafeProjection val unsafeProj = UnsafeProjection.create(nestedSchema) val unsafe: UnsafeRow = unsafeProj(nested) (0 until N).foreach { i => val s = UTF8String.fromString((i + 1).toString) assert(i + 1 === unsafe.getInt(i + 2)) assert(s === unsafe.getUTF8String(i + 2 + N)) assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i)) assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i)) assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) } // test generated SafeProjection val safeProj = FromUnsafeProjection(nestedSchema) val result = safeProj(unsafe) // Can't compare GenericInternalRow with JoinedRow directly (0 until N).foreach { i => val r = i + 1 val s = UTF8String.fromString((i + 1).toString) assert(r === result.getInt(i + 2)) assert(s === result.getUTF8String(i + 2 + N)) assert(r === result.getStruct(0, N * 2).getInt(i)) assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) assert(r === result.getStruct(1, N * 2).getInt(i)) assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) } // test generated MutableProjection val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => BoundReference(i, f.dataType, true) } val mutableProj = GenerateMutableProjection.generate(exprs)() val row1 = mutableProj(result) assert(result === row1) val row2 = mutableProj(result) assert(result === row2) } test("generated unsafe projection with array of binary") { val row = InternalRow( Array[Byte](1, 2), new GenericArrayData(Array(Array[Byte](1, 2), null, Array[Byte](3, 4)))) val fields = (BinaryType :: ArrayType(BinaryType) :: Nil).toArray[DataType] val unsafeProj = UnsafeProjection.create(fields) val unsafeRow: UnsafeRow = unsafeProj(row) assert(java.util.Arrays.equals(unsafeRow.getBinary(0), Array[Byte](1, 2))) assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(0), Array[Byte](1, 2))) assert(unsafeRow.getArray(1).isNullAt(1)) assert(unsafeRow.getArray(1).getBinary(1) === null) assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(2), Array[Byte](3, 4))) val safeProj = FromUnsafeProjection(fields) val row2 = safeProj(unsafeRow) assert(row2 === row) } }
Example 171
Source File: RuleExecutorSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.trees import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) } } test("only once") { object ApplyOnce extends RuleExecutor[Expression] { val batches = Batch("once", Once, DecrementLiterals) :: Nil } assert(ApplyOnce.execute(Literal(10)) === Literal(9)) } test("to fixed point") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(10)) === Literal(0)) } test("to maxIterations") { object ToFixedPoint extends RuleExecutor[Expression] { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } assert(ToFixedPoint.execute(Literal(100)) === Literal(90)) } }
Example 172
Source File: NumberConverterSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.NumberConverter.convert import org.apache.spark.unsafe.types.UTF8String class NumberConverterSuite extends SparkFunSuite { private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = { assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) === UTF8String.fromString(expected)) } //转换 test("convert") { checkConv("3", 10, 2, "11") checkConv("-15", 10, -16, "-F") checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1") checkConv("big", 36, 16, "3A48") checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF") checkConv("11abc", 10, 16, "B") } }
Example 173
Source File: MetadataSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.json4s.jackson.JsonMethods.parse import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.{MetadataBuilder, Metadata} class MetadataSuite extends SparkFunSuite { val baseMetadata = new MetadataBuilder() .putString("purpose", "ml") .putBoolean("isBase", true) .build() val summary = new MetadataBuilder() .putLong("numFeatures", 10L) .build() val age = new MetadataBuilder() .putString("name", "age") .putLong("index", 1L) .putBoolean("categorical", false) .putDouble("average", 45.0) .build() val gender = new MetadataBuilder() .putString("name", "gender") .putLong("index", 5) .putBoolean("categorical", true) .putStringArray("categories", Array("male", "female")) .build() val metadata = new MetadataBuilder() .withMetadata(baseMetadata) .putBoolean("isBase", false) // overwrite an existing key .putMetadata("summary", summary) .putLongArray("long[]", Array(0L, 1L)) .putDoubleArray("double[]", Array(3.0, 4.0)) .putBooleanArray("boolean[]", Array(true, false)) .putMetadataArray("features", Array(age, gender)) .build() //元数据构建器和getter test("metadata builder and getters") { assert(age.contains("summary") === false) assert(age.contains("index") === true) assert(age.getLong("index") === 1L) assert(age.contains("average") === true) assert(age.getDouble("average") === 45.0) assert(age.contains("categorical") === true) assert(age.getBoolean("categorical") === false) assert(age.contains("name") === true) assert(age.getString("name") === "age") assert(metadata.contains("purpose") === true) assert(metadata.getString("purpose") === "ml") assert(metadata.contains("isBase") === true) assert(metadata.getBoolean("isBase") === false) assert(metadata.contains("summary") === true) assert(metadata.getMetadata("summary") === summary) assert(metadata.contains("long[]") === true) assert(metadata.getLongArray("long[]").toSeq === Seq(0L, 1L)) assert(metadata.contains("double[]") === true) assert(metadata.getDoubleArray("double[]").toSeq === Seq(3.0, 4.0)) assert(metadata.contains("boolean[]") === true) assert(metadata.getBooleanArray("boolean[]").toSeq === Seq(true, false)) assert(gender.contains("categories") === true) assert(gender.getStringArray("categories").toSeq === Seq("male", "female")) assert(metadata.contains("features") === true) assert(metadata.getMetadataArray("features").toSeq === Seq(age, gender)) } //元数据的JSON转换 test("metadata json conversion") { val json = metadata.json withClue("toJson must produce a valid JSON string") { parse(json) } val parsed = Metadata.fromJson(json) assert(parsed === metadata) assert(parsed.## === metadata.##) } }
Example 174
Source File: StringUtilsSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.StringUtils._ class StringUtilsSuite extends SparkFunSuite { //像正则表达式一样转义 test("escapeLikeRegex") { assert(escapeLikeRegex("abdef") === "(?s)\\Qa\\E\\Qb\\E\\Qd\\E\\Qe\\E\\Qf\\E") assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E_.\\Qb\\E") assert(escapeLikeRegex("a_%b") === "(?s)\\Qa\\E..*\\Qb\\E") assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*%\\Qb\\E") assert(escapeLikeRegex("a%") === "(?s)\\Qa\\E.*") assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") } }
Example 175
Source File: CatalystTypeConvertersSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.types._ class CatalystTypeConvertersSuite extends SparkFunSuite { private val simpleTypes: Seq[DataType] = Seq( StringType, DateType, BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType.USER_DEFAULT) //行空处理 test("null handling in rows") { val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t))) val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema) val convertToScala = CatalystTypeConverters.createToScalaConverter(schema) val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null)) assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow) } //个别值的空处理 test("null handling for individual values") { for (dataType <- simpleTypes) { assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null) } } //convertToCatalyst中的选项处理 test("option handling in convertToCatalyst") { // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with // createToCatalystConverter but it may not actually matter as this is only called internally // in a handful of places where we don't expect to receive Options. assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123)) } test("option handling in createToCatalystConverter") { assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123) } }
Example 176
Source File: SQLContextSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.SharedSQLContext //SQL上下文测试套件 class SQLContextSuite extends SparkFunSuite with SharedSQLContext { override def afterAll(): Unit = { try { SQLContext.setLastInstantiatedContext(ctx) } finally { super.afterAll() } } test("getOrCreate instantiates SQLContext") {//获取或创建实例化SQL上下文 SQLContext.clearLastInstantiatedContext() val sqlContext = SQLContext.getOrCreate(ctx.sparkContext) assert(sqlContext != null, "SQLContext.getOrCreate returned null") assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") } test("getOrCreate gets last explicitly instantiated SQLContext") {//获得或创造获取最后的显式实例化SQL上下文 SQLContext.clearLastInstantiatedContext() val sqlContext = new SQLContext(ctx.sparkContext) assert(SQLContext.getOrCreate(ctx.sparkContext) != null, "SQLContext.getOrCreate after explicitly created SQLContext returned null") assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") } }
Example 177
Source File: SQLExecutionSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import java.util.Properties import scala.collection.parallel.CompositeThrowable import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.SQLContext //SQL执行测试套件 class SQLExecutionSuite extends SparkFunSuite { //并发查询执行 test("concurrent query execution (SPARK-10548)") { // Try to reproduce the issue with the old SparkContext //尝试旧sparkcontext再现问题 val conf = new SparkConf() .setMaster("local[*]") .setAppName("test") val badSparkContext = new BadSparkContext(conf) try { testConcurrentQueryExecution(badSparkContext) fail("unable to reproduce SPARK-10548") } catch { case e: IllegalArgumentException => assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) } finally { badSparkContext.stop() } // Verify that the issue is fixed with the latest SparkContext //验证问题是固定的最新sparkcontext val goodSparkContext = new SparkContext(conf) try { testConcurrentQueryExecution(goodSparkContext) } finally { goodSparkContext.stop() } } private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { override protected def childValue(parent: Properties): Properties = new Properties(parent) override protected def initialValue(): Properties = new Properties() } }
Example 178
Source File: NullableColumnBuilderSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ //测试空列生成器 class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) extends BasicColumnBuilder[JvmType](new NoopColumnStats, columnType) with NullableColumnBuilder object TestNullableColumnBuilder { def apply[JvmType](columnType: ColumnType[JvmType], initialSize: Int = 0) : TestNullableColumnBuilder[JvmType] = { val builder = new TestNullableColumnBuilder(columnType) builder.initialize(initialSize) builder } } //空列生成器套件 class NullableColumnBuilderSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) .foreach { testNullableColumnBuilder(_) } def testNullableColumnBuilder[JvmType]( columnType: ColumnType[JvmType]): Unit = { //stripSuffix去掉<string>字串中结尾的字符 val typeName = columnType.getClass.getSimpleName.stripSuffix("$") test(s"$typeName column builder: empty column") {//空列 val columnBuilder = TestNullableColumnBuilder(columnType) val buffer = columnBuilder.build() assertResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) assertResult(0, "Wrong null count")(buffer.getInt()) assert(!buffer.hasRemaining) } test(s"$typeName column builder: buffer size auto growth") {//缓存大小自动增长 val columnBuilder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) (0 until 4).foreach { _ => columnBuilder.appendFrom(randomRow, 0) } val buffer = columnBuilder.build() assertResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) assertResult(0, "Wrong null count")(buffer.getInt()) } //列生成器:空值 test(s"$typeName column builder: null values") { val columnBuilder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) val nullRow = makeNullRow(1) (0 until 4).foreach { _ => columnBuilder.appendFrom(randomRow, 0) columnBuilder.appendFrom(nullRow, 0) } val buffer = columnBuilder.build() assertResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) assertResult(4, "Wrong null count")(buffer.getInt()) // For null positions 空位置 (1 to 7 by 2).foreach(assertResult(_, "Wrong null position")(buffer.getInt())) // For non-null values 对于非空值 (0 until 4).foreach { _ => val actual = if (columnType.isInstanceOf[GENERIC]) { SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]]) } else { columnType.extract(buffer) } assert(actual === randomRow.get(0, columnType.dataType), "Extracted value didn't equal to the original one") } assert(!buffer.hasRemaining) } } }
Example 179
Source File: ColumnStatsSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ //列统计测试套件 class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0)) testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, createRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, createRow(Double.MaxValue, Double.MinValue, 0)) testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) testDecimalColumnStats(createRow(null, null, 0)) def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) //测试列统计 def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], initialStatistics: GenericInternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } test(s"$columnStatsName: non-empty") {//非空 import org.apache.spark.sql.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) assertResult(10, "Wrong null count")(stats.values(2)) assertResult(20, "Wrong row count")(stats.values(3)) assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum } } } //测试十进制列统计 def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( initialStatistics: GenericInternalRow): Unit = { val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName val columnType = FIXED_DECIMAL(15, 10) test(s"$columnStatsName: empty") { val columnStats = new FixedDecimalColumnStats(15, 10) columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } test(s"$columnStatsName: non-empty") {//非空 import org.apache.spark.sql.columnar.ColumnarTestUtils._ val columnStats = new FixedDecimalColumnStats(15, 10) val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) assertResult(10, "Wrong null count")(stats.values(2)) assertResult(20, "Wrong row count")(stats.values(3)) assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum } } } }
Example 180
Source File: NullableColumnAccessorSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.columnar import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types.{StringType, ArrayType, DataType} //试验可为空的列的访问 class TestNullableColumnAccessor[JvmType]( buffer: ByteBuffer, columnType: ColumnType[JvmType]) extends BasicColumnAccessor(buffer, columnType) with NullableColumnAccessor //试验可为空的列的访问 object TestNullableColumnAccessor { def apply[JvmType](buffer: ByteBuffer, columnType: ColumnType[JvmType]) : TestNullableColumnAccessor[JvmType] = { // Skips the column type ID buffer.getInt() new TestNullableColumnAccessor(buffer, columnType) } } //空列存取器套件 class NullableColumnAccessorSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) .foreach { testNullableColumnAccessor(_) } //试验可为空的列的访问 def testNullableColumnAccessor[JvmType]( columnType: ColumnType[JvmType]): Unit = { //stripSuffix去掉<string>字串中结尾的字符 val typeName = columnType.getClass.getSimpleName.stripSuffix("$") val nullRow = makeNullRow(1) //空值 test(s"Nullable $typeName column accessor: empty column") { val builder = TestNullableColumnBuilder(columnType) val accessor = TestNullableColumnAccessor(builder.build(), columnType) assert(!accessor.hasNext) } //访问空值 test(s"Nullable $typeName column accessor: access null values") { val builder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) (0 until 4).foreach { _ => builder.appendFrom(randomRow, 0) builder.appendFrom(nullRow, 0) } val accessor = TestNullableColumnAccessor(builder.build(), columnType) val row = new GenericMutableRow(1) (0 until 4).foreach { _ => assert(accessor.hasNext) accessor.extractTo(row, 0) assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType)) assert(accessor.hasNext) accessor.extractTo(row, 0) assert(row.isNullAt(0)) } assert(!accessor.hasNext) } } }
Example 181
Source File: ResolvedDataSourceSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.sources import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.datasources.ResolvedDataSource //解析数据源测试套件 class ResolvedDataSourceSuite extends SparkFunSuite { test("jdbc") { assert( //解析JDBC数据源 ResolvedDataSource.lookupDataSource("jdbc") === //默认使用jdbc.DefaultSource类型 classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) assert( ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.jdbc") === classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) assert( ResolvedDataSource.lookupDataSource("org.apache.spark.sql.jdbc") === classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) } test("json") { assert( //解析json数据源 ResolvedDataSource.lookupDataSource("json") === //默认使用json.DefaultSource类型 classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) assert( ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.json") === classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) assert( ResolvedDataSource.lookupDataSource("org.apache.spark.sql.json") === classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) } test("parquet") { assert( //解析parquet数据源 ResolvedDataSource.lookupDataSource("parquet") === //默认使用parquet.DefaultSource类型 classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) assert( ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.parquet") === classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) assert( ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") === classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) } }
Example 182
Source File: FailureSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming import java.io.File import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkFunSuite, Logging} import org.apache.spark.util.Utils class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging { private val batchDuration: Duration = Milliseconds(1000) private val numBatches = 30 private var directory: File = null before { directory = Utils.createTempDir() } after { if (directory != null) { //删除临时目录 Utils.deleteRecursively(directory) } //停止所有活动实时流 StreamingContext.getActive().foreach { _.stop() } } //多次失败map test("multiple failures with map") { MasterFailureTest.testMap(directory.getAbsolutePath, numBatches, batchDuration) } //多次失败updateStateByKey test("multiple failures with updateStateByKey") { MasterFailureTest.testUpdateStateByKey(directory.getAbsolutePath, numBatches, batchDuration) } }
Example 183
Source File: RateLimitedOutputStreamSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.util import java.io.ByteArrayOutputStream import java.util.concurrent.TimeUnit._ import org.apache.spark.SparkFunSuite class RateLimitedOutputStreamSuite extends SparkFunSuite { private def benchmark[U](f: => U): Long = { val start = System.nanoTime f System.nanoTime - start } test("write") {//写 val underlying = new ByteArrayOutputStream val data = "X" * 41000 //desiredBytesPerSec 每秒所需的字节数 val stream = new RateLimitedOutputStream(underlying, desiredBytesPerSec = 10000) val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } val seconds = SECONDS.convert(elapsedNs, NANOSECONDS) assert(seconds >= 4, s"Seconds value ($seconds) is less than 4.") assert(seconds <= 30, s"Took more than 30 seconds ($seconds) to write data.") assert(underlying.toString("UTF-8") === data) } }
Example 184
Source File: RpcAddressSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.rpc import org.apache.spark.{SparkException, SparkFunSuite} class RpcAddressSuite extends SparkFunSuite { test("hostPort") {//主机端口 val address = RpcAddress("1.2.3.4", 1234) assert(address.host == "1.2.3.4") assert(address.port == 1234) assert(address.hostPort == "1.2.3.4:1234") } test("fromSparkURL") {//来自Spark URL val address = RpcAddress.fromSparkURL("spark://1.2.3.4:1234") assert(address.host == "1.2.3.4") assert(address.port == 1234) } test("fromSparkURL: a typo url") {//来自一个错误Spark URL val e = intercept[SparkException] { RpcAddress.fromSparkURL("spark://1.2. 3.4:1234")//中间有空格 } assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) } test("fromSparkURL: invalid scheme") {//来自一个Spark URL无效模式 val e = intercept[SparkException] { RpcAddress.fromSparkURL("invalid://1.2.3.4:1234") } assert("Invalid master URL: invalid://1.2.3.4:1234" === e.getMessage) } test("toSparkURL") {//转换SparkURL格式 val address = RpcAddress("1.2.3.4", 1234) assert(address.toSparkURL == "spark://1.2.3.4:1234") } }
Example 185
Source File: NettyBlockTransferServiceSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.network.netty import org.apache.spark.network.BlockDataManager import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito.mock import org.scalatest._ class NettyBlockTransferServiceSuite extends SparkFunSuite with BeforeAndAfterEach with ShouldMatchers { private var service0: NettyBlockTransferService = _ private var service1: NettyBlockTransferService = _ override def afterEach() { if (service0 != null) { service0.close() service0 = null } if (service1 != null) { service1.close() service1 = null } } test("can bind to a random port") {//可以绑定到一个随机端口 service0 = createService(port = 0) service0.port should not be 0 } test("can bind to two random ports") {//可以绑定到两个随机端口 service0 = createService(port = 0) service1 = createService(port = 0) service0.port should not be service1.port } test("can bind to a specific port") {//可以绑定到一个特定的端口 val port = 17634 service0 = createService(port) service0.port should be >= port //在同时测试的情况下避免测试平等 service0.port should be <= (port + 10) // avoid testing equality in case of simultaneous tests } //可以绑定到一个特定的端口两次和第二个增量 test("can bind to a specific port twice and the second increments") { val port = 17634 service0 = createService(port) service1 = createService(port) service0.port should be >= port service0.port should be <= (port + 10) service1.port should be (service0.port + 1) } private def createService(port: Int): NettyBlockTransferService = { val conf = new SparkConf() .set("spark.app.id", s"test-${getClass.getName}") .set("spark.blockManager.port", port.toString) val securityManager = new SecurityManager(conf) val blockDataManager = mock(classOf[BlockDataManager]) val service = new NettyBlockTransferService(conf, securityManager, numCores = 1) service.init(blockDataManager) service } }
Example 186
Source File: SortShuffleWriterSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.shuffle.sort import org.mockito.Mockito._ import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite} class SortShuffleWriterSuite extends SparkFunSuite { import SortShuffleWriter._ test("conditions for bypassing merge-sort") { val conf = new SparkConf(loadDefaults = false) val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS) val ord = implicitly[Ordering[Int]] // Numbers of partitions that are above and below the default bypassMergeThreshold / //那上面和下面的默认bypassmergethreshold分区数 val FEW_PARTITIONS = 50 val MANY_PARTITIONS = 10000 // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high //没有排序或聚合Shuffle assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None)) assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None)) // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions //Shuffle的排序或聚合, assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord))) assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None)) } }
Example 187
Source File: PythonBroadcastSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.api.python import scala.io.Source import java.io.{PrintWriter, File} import org.scalatest.Matchers import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils // This test suite uses SharedSparkContext because we need a SparkEnv in order to deserialize // a PythonBroadcast: class PythonBroadcastSuite extends SparkFunSuite with Matchers with SharedSparkContext { test("PythonBroadcast can be serialized with Kryo (SPARK-4882)") { val tempDir = Utils.createTempDir() val broadcastedString = "Hello, world!" def assertBroadcastIsValid(broadcast: PythonBroadcast): Unit = { val source = Source.fromFile(broadcast.path) val contents = source.mkString source.close() contents should be (broadcastedString) } try { val broadcastDataFile: File = { val file = new File(tempDir, "broadcastData") val printWriter = new PrintWriter(file) printWriter.write(broadcastedString) printWriter.close() file } val broadcast = new PythonBroadcast(broadcastDataFile.getAbsolutePath) assertBroadcastIsValid(broadcast) val conf = new SparkConf().set("spark.kryo.registrationRequired", "true") val deserializedBroadcast = Utils.clone[PythonBroadcast](broadcast, new KryoSerializer(conf).newInstance()) assertBroadcastIsValid(deserializedBroadcast) } finally { Utils.deleteRecursively(tempDir) } } }
Example 188
Source File: PythonRDDSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.api.python import java.io.{ByteArrayOutputStream, DataOutputStream} import org.apache.spark.SparkFunSuite class PythonRDDSuite extends SparkFunSuite { //写大串给worker test("Writing large strings to the worker") { val input: List[String] = List("a"*100000) val buffer = new DataOutputStream(new ByteArrayOutputStream) PythonRDD.writeIteratorToStream(input.iterator, buffer) } //很好的处理null test("Handle nulls gracefully") { val buffer = new DataOutputStream(new ByteArrayOutputStream) // Should not have NPE when write an Iterator with null in it // The correctness will be tested in Python PythonRDD.writeIteratorToStream(Iterator("a", null), buffer) PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer) PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer) PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer) PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer) PythonRDD.writeIteratorToStream( Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer) } }
Example 189
Source File: SerDeUtilSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.api.python import org.apache.spark.{SharedSparkContext, SparkFunSuite} class SerDeUtilSuite extends SparkFunSuite with SharedSparkContext { //将空对RDD转换为python不会引发异常 test("Converting an empty pair RDD to python does not throw an exception (SPARK-5441)") { val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) SerDeUtil.pairRDDToPython(emptyRdd, 10) } //将空python RDD转换为RDD不会引发异常 test("Converting an empty python RDD to pair RDD does not throw an exception (SPARK-5441)") { val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) val javaRdd = emptyRdd.toJavaRDD() val pythonRdd = SerDeUtil.javaToPython(javaRdd) SerDeUtil.pythonToPairRDD(pythonRdd, false) } }
Example 190
Source File: PythonRunnerSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy import org.apache.spark.SparkFunSuite import org.apache.spark.util.Utils class PythonRunnerSuite extends SparkFunSuite { // Test formatting a single path to be added to the PYTHONPATH //测试格式化要添加到PYTHONPATH的单个路径 test("format path") { assert(PythonRunner.formatPath("spark.py") === "spark.py") assert(PythonRunner.formatPath("file:/spark.py") === "/spark.py") assert(PythonRunner.formatPath("file:///spark.py") === "/spark.py") assert(PythonRunner.formatPath("local:/spark.py") === "/spark.py") assert(PythonRunner.formatPath("local:///spark.py") === "/spark.py") if (Utils.isWindows) { assert(PythonRunner.formatPath("file:/C:/a/b/spark.py", testWindows = true) === "C:/a/b/spark.py") assert(PythonRunner.formatPath("C:\\a\\b\\spark.py", testWindows = true) === "C:/a/b/spark.py") assert(PythonRunner.formatPath("C:\\a b\\spark.py", testWindows = true) === "C:/a b/spark.py") } intercept[IllegalArgumentException] { PythonRunner.formatPath("one:two") } intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:s3:xtremeFS") } intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:/path/to/some.py") } } // Test formatting multiple comma-separated paths to be added to the PYTHONPATH test("format paths") { assert(PythonRunner.formatPaths("spark.py") === Array("spark.py")) assert(PythonRunner.formatPaths("file:/spark.py") === Array("/spark.py")) assert(PythonRunner.formatPaths("file:/app.py,local:/spark.py") === Array("/app.py", "/spark.py")) assert(PythonRunner.formatPaths("me.py,file:/you.py,local:/we.py") === Array("me.py", "/you.py", "/we.py")) if (Utils.isWindows) { assert(PythonRunner.formatPaths("C:\\a\\b\\spark.py", testWindows = true) === Array("C:/a/b/spark.py")) assert(PythonRunner.formatPaths("C:\\free.py,pie.py", testWindows = true) === Array("C:/free.py", "pie.py")) assert(PythonRunner.formatPaths("lovely.py,C:\\free.py,file:/d:/fry.py", testWindows = true) === Array("lovely.py", "C:/free.py", "d:/fry.py")) } intercept[IllegalArgumentException] { PythonRunner.formatPaths("one:two,three") } intercept[IllegalArgumentException] { PythonRunner.formatPaths("two,three,four:five:six") } intercept[IllegalArgumentException] { PythonRunner.formatPaths("hdfs:/some.py,foo.py") } intercept[IllegalArgumentException] { PythonRunner.formatPaths("foo.py,hdfs:/some.py") } } }
Example 191
Source File: WorkerWatcherSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SecurityManager import org.apache.spark.rpc.{RpcAddress, RpcEnv} class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") {//工作节点关闭有效分离 val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) assert(workerWatcher.isShutDown) rpcEnv.shutdown() } test("WorkerWatcher stays alive on invalid disassociation") {//无效断开连接 val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") val otherRpcAddress = RpcAddress("4.3.2.1", 1234) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(otherRpcAddress) assert(!workerWatcher.isShutDown) rpcEnv.shutdown() } }
Example 192
Source File: WorkerArgumentsTest.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker import org.apache.spark.{SparkConf, SparkFunSuite} class WorkerArgumentsTest extends SparkFunSuite { //内存不能被设置为0时,命令行参数离开M或G test("Memory can't be set to 0 when cmd line args leave off M or G") { val conf = new SparkConf val args = Array("-m", "10000", "spark://localhost:0000 ") intercept[IllegalStateException] { new WorkerArguments(args, conf) } } //内存不能被设置为0时,spark_worker_memory env参数离开M或G test("Memory can't be set to 0 when SPARK_WORKER_MEMORY env property leaves off M or G") { val args = Array("spark://localhost:0000 ") class MySparkConf extends SparkConf(false) { override def getenv(name: String): String = { if (name == "SPARK_WORKER_MEMORY") "50000" else super.getenv(name) } override def clone: SparkConf = { new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() intercept[IllegalStateException] { new WorkerArguments(args, conf) } } //当SPARK_WORKER_MEMORY env属性追加G时,内存正确设置 test("Memory correctly set when SPARK_WORKER_MEMORY env property appends G") { val args = Array("spark://localhost:0000 ") class MySparkConf extends SparkConf(false) { override def getenv(name: String): String = { if (name == "SPARK_WORKER_MEMORY") "5G" else super.getenv(name) } override def clone: SparkConf = { new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() val workerArgs = new WorkerArguments(args, conf) assert(workerArgs.memory === 5120) } //从附加到内存值的M的args正确设置内存 test("Memory correctly set from args with M appended to memory value") { val conf = new SparkConf val args = Array("-m", "10000M", "spark://localhost:0000 ") val workerArgs = new WorkerArguments(args, conf) assert(workerArgs.memory === 10000) } }
Example 193
Source File: ExecutorRunnerTest.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.deploy.worker import java.io.File import scala.collection.JavaConversions._ import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} //Worker 通过持有 ExecutorRunner 对象来控制 CoarseGrainedExecutorBackend 的启停 class ExecutorRunnerTest extends SparkFunSuite { test("command includes appId") {//命令包括AppID val appId = "12345-worker321-9876" val conf = new SparkConf //System.getenv()和System.getProperties()的区别 //System.getenv() 返回系统环境变量值 设置系统环境变量:当前登录用户主目录下的".bashrc"文件中可以设置系统环境变量 //System.getProperties() 返回Java进程变量值 通过命令行参数的"-D"选项 val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val appDesc = new ApplicationDescription("app name", Some(8), 500, Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl") //Worker 通过持有 ExecutorRunner 对象来控制 CoarseGrainedExecutorBackend 的启停 val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", 123, "publicAddr", new File(sparkHome), new File("ooga"), "blah", conf, Seq("localDir"), ExecutorState.RUNNING) val builder = CommandUtils.buildProcessBuilder( appDesc.command, new SecurityManager(conf), 512, sparkHome, er.substituteVariables) assert(builder.command().last === appId) } }
Example 194
Source File: PagedTableSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.ui import scala.xml.Node import org.apache.spark.SparkFunSuite class PagedDataSourceSuite extends SparkFunSuite { test("basic") { val dataSource1 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) assert(dataSource1.pageData(1) === PageData(3, (1 to 2))) val dataSource2 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) assert(dataSource2.pageData(2) === PageData(3, (3 to 4))) val dataSource3 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) assert(dataSource3.pageData(3) === PageData(3, Seq(5))) val dataSource4 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) val e1 = intercept[IndexOutOfBoundsException] { dataSource4.pageData(4) } assert(e1.getMessage === "Page 4 is out of range. Please select a page number between 1 and 3.") val dataSource5 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) val e2 = intercept[IndexOutOfBoundsException] { dataSource5.pageData(0) } assert(e2.getMessage === "Page 0 is out of range. Please select a page number between 1 and 3.") } } class PagedTableSuite extends SparkFunSuite { test("pageNavigation") {//页面导航 // Create a fake PagedTable to test pageNavigation //创建一个假的PagedTable来测试pageNavigation val pagedTable = new PagedTable[Int] { override def tableId: String = "" override def tableCssClass: String = "" override def dataSource: PagedDataSource[Int] = null override def pageLink(page: Int): String = page.toString override def headers: Seq[Node] = Nil override def row(t: Int): Seq[Node] = Nil override def goButtonJavascriptFunction: (String, String) = ("", "") } assert(pagedTable.pageNavigation(1, 10, 1) === Nil) assert( (pagedTable.pageNavigation(1, 10, 2).head \\ "li").map(_.text.trim) === Seq("1", "2", ">")) assert( (pagedTable.pageNavigation(2, 10, 2).head \\ "li").map(_.text.trim) === Seq("<", "1", "2")) assert((pagedTable.pageNavigation(1, 10, 100).head \\ "li").map(_.text.trim) === (1 to 10).map(_.toString) ++ Seq(">", ">>")) assert((pagedTable.pageNavigation(2, 10, 100).head \\ "li").map(_.text.trim) === Seq("<") ++ (1 to 10).map(_.toString) ++ Seq(">", ">>")) assert((pagedTable.pageNavigation(100, 10, 100).head \\ "li").map(_.text.trim) === Seq("<<", "<") ++ (91 to 100).map(_.toString)) assert((pagedTable.pageNavigation(99, 10, 100).head \\ "li").map(_.text.trim) === Seq("<<", "<") ++ (91 to 100).map(_.toString) ++ Seq(">")) assert((pagedTable.pageNavigation(11, 10, 100).head \\ "li").map(_.text.trim) === Seq("<<", "<") ++ (11 to 20).map(_.toString) ++ Seq(">", ">>")) assert((pagedTable.pageNavigation(93, 10, 97).head \\ "li").map(_.text.trim) === Seq("<<", "<") ++ (91 to 97).map(_.toString) ++ Seq(">")) } } private[spark] class SeqPagedDataSource[T](seq: Seq[T], pageSize: Int) extends PagedDataSource[T](pageSize) { override protected def dataSize: Int = seq.size override protected def sliceData(from: Int, to: Int): Seq[T] = seq.slice(from, to) }
Example 195
Source File: UIUtilsSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.ui import scala.xml.Elem import org.apache.spark.SparkFunSuite class UIUtilsSuite extends SparkFunSuite { import UIUtils._ test("makeDescription") {//标记描述 verify( """test <a href="/link"> text </a>""", <span class="description-input">test <a href="/link"> text </a></span>, "Correctly formatted text with only anchors and relative links should generate HTML" ) verify( """test <a href="/link" text </a>""", <span class="description-input">{"""test <a href="/link" text </a>"""}</span>, "Badly formatted text should make the description be treated as a streaming instead of HTML" ) verify( """test <a href="link"> text </a>""", <span class="description-input">{"""test <a href="link"> text </a>"""}</span>, "Non-relative links should make the description be treated as a string instead of HTML" ) verify( """test<a><img></img></a>""", <span class="description-input">{"""test<a><img></img></a>"""}</span>, "Non-anchor elements should make the description be treated as a string instead of HTML" ) verify( """test <a href="/link"> text </a>""", <span class="description-input">test <a href="base/link"> text </a></span>, baseUrl = "base", errorMsg = "Base URL should be prepended to html links" ) } private def verify( desc: String, expected: Elem, errorMsg: String = "", baseUrl: String = ""): Unit = { val generated = makeDescription(desc, baseUrl) assert(generated.sameElements(expected), s"\n$errorMsg\n\nExpected:\n$expected\nGenerated:\n$generated") } }
Example 196
Source File: GenericAvroSerializerSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.ByteBuffer import com.esotericsoftware.kryo.io.{Output, Input} import org.apache.avro.{SchemaBuilder, Schema} import org.apache.avro.generic.GenericData.Record import org.apache.spark.{SparkFunSuite, SharedSparkContext} class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") val schema : Schema = SchemaBuilder .record("testRecord").fields() .requiredString("data") .endRecord() val record = new Record(schema) record.put("data", "test data") test("schema compression and decompression") {//模式压缩与解压缩 val genericSer = new GenericAvroSerializer(conf.getAvroSchema) assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema)))) } test("record serialization and deserialization") {//记录序列化和反序列化 val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val outputStream = new ByteArrayOutputStream() val output = new Output(outputStream) genericSer.serializeDatum(record, output) output.flush() output.close() val input = new Input(new ByteArrayInputStream(outputStream.toByteArray)) assert(genericSer.deserializeDatum(input) === record) } //使用模式指纹以减少信息大小 test("uses schema fingerprint to decrease message size") { val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema) val output = new Output(new ByteArrayOutputStream()) val beginningNormalPosition = output.total() genericSerFull.serializeDatum(record, output) output.flush() val normalLength = output.total - beginningNormalPosition conf.registerAvroSchemas(schema) val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema) val beginningFingerprintPosition = output.total() genericSerFinger.serializeDatum(record, output) val fingerprintLength = output.total - beginningFingerprintPosition assert(fingerprintLength < normalLength) } test("caches previously seen schemas") {//缓存之前模式 val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val compressedSchema = genericSer.compress(schema) val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) assert(compressedSchema.eq(genericSer.compress(schema))) assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) } }
Example 197
Source File: KryoSerializerResizableOutputSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.serializer import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SparkContext import org.apache.spark.LocalSparkContext import org.apache.spark.SparkException class KryoSerializerResizableOutputSuite extends SparkFunSuite { // trial and error showed this will not serialize with 1mb buffer //试验和错误不会序列化使用1MB的缓冲 val x = (1 to 400000).toArray //kryo不可调整大小的输出缓冲区,应该在大数组失败 test("kryo without resizable output buffer should fail on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryoserializer.buffer", "1m") conf.set("spark.kryoserializer.buffer.max", "1m") val sc = new SparkContext("local", "test", conf) intercept[SparkException](sc.parallelize(x).collect()) LocalSparkContext.stop(sc) } //kryo不可调整大小的输出缓冲区,应该在大数组成功 test("kryo with resizable output buffer should succeed on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryoserializer.buffer", "1m") conf.set("spark.kryoserializer.buffer.max", "2m") val sc = new SparkContext("local", "test", conf) assert(sc.parallelize(x).collect() === x) LocalSparkContext.stop(sc) } }
Example 198
Source File: ProactiveClosureSerializationSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.serializer import org.apache.spark.{SharedSparkContext, SparkException, SparkFunSuite} import org.apache.spark.rdd.RDD class UnserializableClass { def op[T](x: T): String = x.toString def pred[T](x: T): Boolean = x.toString.length % 2 == 0 } class ProactiveClosureSerializationSuite extends SparkFunSuite with SharedSparkContext { def fixture: (RDD[String], UnserializableClass) = { (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) } //在一个活动的序列化异常,抛出预期的序列化异常 test("throws expected serialization exceptions on actions") { val (data, uc) = fixture val ex = intercept[SparkException] { data.map(uc.op(_)).count() } assert(ex.getMessage.contains("Task not serializable")) } // There is probably a cleaner way to eliminate boilerplate here, but we're // iterating over a map from transformation names to functions that perform that // transformation on a given RDD, creating one test case for each //有可能是一个更清洁的方式来消除样板, for (transformation <- Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _, "mapPartitions" -> xmapPartitions _, "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _)) { val (name, xf) = transformation test(s"$name transformations throw proactive serialization exceptions") { val (data, uc) = fixture val ex = intercept[SparkException] { xf(data, uc) } assert(ex.getMessage.contains("Task not serializable"), s"RDD.$name doesn't proactively throw NotSerializableException") } } private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.map(y => uc.op(y)) private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.flatMap(y => Seq(uc.op(y))) private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = x.filter(y => uc.pred(y)) private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitions(_.map(y => uc.op(y))) private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y))) }
Example 199
Source File: MapStatusSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.scheduler import org.apache.spark.storage.BlockManagerId import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import scala.util.Random class MapStatusSuite extends SparkFunSuite { test("compressSize") {//compress 压缩大小 assert(MapStatus.compressSize(0L) === 0) assert(MapStatus.compressSize(1L) === 1) assert(MapStatus.compressSize(2L) === 8) assert(MapStatus.compressSize(10L) === 25) assert((MapStatus.compressSize(1000000L) & 0xFF) === 145) assert((MapStatus.compressSize(1000000000L) & 0xFF) === 218) // This last size is bigger than we can encode in a byte, so check that we just return 255 //这最后一个大小字节编码,所以检查返回255 assert((MapStatus.compressSize(1000000000000000000L) & 0xFF) === 255) } test("decompressSize") {//解压缩的大小 assert(MapStatus.decompressSize(0) === 0) for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) { val size2 = MapStatus.decompressSize(MapStatus.compressSize(size)) assert(size2 >= 0.99 * size && size2 <= 1.11 * size, "size " + size + " decompressed to " + size2 + ", which is out of range") } } //MapStatus 不应该报告非空块的大小为0 test("MapStatus should never report non-empty blocks' sizes as 0") { import Math._ for ( numSizes <- Seq(1, 10, 100, 1000, 10000); mean <- Seq(0L, 100L, 10000L, Int.MaxValue.toLong); stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) val status = MapStatus(BlockManagerId("a", "b", 10), sizes) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { val failureMessage = s"Failed with $numSizes sizes with mean=$mean, stddev=$stddev" assert(status.getSizeForBlock(i) !== 0, failureMessage) assert(status1.getSizeForBlock(i) !== 0, failureMessage) } } } } //大型任务应该使用 test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) val status = MapStatus(null, sizes) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) assert(status.getSizeForBlock(99) === 150L) assert(status.getSizeForBlock(2000) === 150L) } //高度压缩的Map状态:估计的大小应该是平均非空块大小 test("HighlyCompressedMapStatus: estimated size should be the average non-empty block size") { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.filter(_ != 0).length val loc = BlockManagerId("a", "b", 10) val status = MapStatus(loc, sizes) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) for (i <- 0 until 3000) { val estimate = status1.getSizeForBlock(i) if (sizes(i) > 0) { assert(estimate === avg) } } } def compressAndDecompressMapStatus(status: MapStatus): MapStatus = { val ser = new JavaSerializer(new SparkConf) val buf = ser.newInstance().serialize(status) ser.newInstance().deserialize[MapStatus](buf) } }
Example 200
Source File: CoarseGrainedSchedulerBackendSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.scheduler import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.util.{SerializableBuffer, AkkaUtils} class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { //序列化任务大于Akka框架大小 ignore("serialized task larger than akka frame size") { val conf = new SparkConf //以MB为单位的driver和executor之间通信信息的大小,设置值越大,driver可以接受越大的计算结果 conf.set("spark.akka.frameSize", "1") //设置并发数 conf.set("spark.default.parallelism", "1") //sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf) sc = new SparkContext("local[*]", "test", conf) //获得Akka传递值大小 1048576默认10M val frameSize = AkkaUtils.maxFrameSizeBytes(sc.conf) //创建一个序列化缓存 //ByteBuffer.allocate在能够读和写之前,必须有一个缓冲区,用静态方法 allocate() 来分配缓冲区 //allocate 分配20M val buffer = new SerializableBuffer(java.nio.ByteBuffer.allocate(2 * frameSize)) val larger = sc.parallelize(Seq(buffer)) val thrown = intercept[SparkException] { larger.collect() } //抛出异常:使用大的值广播变量 assert(thrown.getMessage.contains("using broadcast variables for large values")) val smaller = sc.parallelize(1 to 4).collect() assert(smaller.size === 4)/**/ } }