org.scalatest.Assertions Scala Examples
The following examples show how to use org.scalatest.Assertions.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: SparkContextInfoSuite.scala From BigDatalog with Apache License 2.0 | 5 votes |
package org.apache.spark import org.scalatest.Assertions import org.apache.spark.storage.StorageLevel class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { test("getPersistentRDDs only returns RDDs that are marked as cached") { sc = new SparkContext("local", "test") assert(sc.getPersistentRDDs.isEmpty === true) val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(sc.getPersistentRDDs.isEmpty === true) rdd.cache() assert(sc.getPersistentRDDs.size === 1) assert(sc.getPersistentRDDs.values.head === rdd) } test("getPersistentRDDs returns an immutable map") { sc = new SparkContext("local", "test") val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() val myRdds = sc.getPersistentRDDs assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) // myRdds2 should have 2 RDDs, but myRdds should not change val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() val myRdds2 = sc.getPersistentRDDs assert(myRdds2.size === 2) assert(myRdds2(0) === rdd1) assert(myRdds2(1) === rdd2) assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) } test("getRDDStorageInfo only reports on RDDs that actually persist data") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(sc.getRDDStorageInfo.size === 0) rdd.collect() assert(sc.getRDDStorageInfo.size === 1) assert(sc.getRDDStorageInfo.head.isCached) assert(sc.getRDDStorageInfo.head.memSize > 0) assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY) } test("call sites report correct locations") { sc = new SparkContext("local", "test") testPackage.runCallSiteTest(sc) } } package object testPackage extends Assertions { private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r def runCallSiteTest(sc: SparkContext) { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) val rddCreationSite = rdd.getCreationSite val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { case CALL_SITE_REGEX(func, file, line) => { assert(func === "makeRDD") assert(file === "SparkContextInfoSuite.scala") line.toInt } case _ => fail("Did not match expected call site format") } curCallSite match { case CALL_SITE_REGEX(func, file, line) => { assert(func === "getCallSite") // this is correct because we called it from outside of Spark assert(file === "SparkContextInfoSuite.scala") assert(line.toInt === rddCreationLine.toInt + 2) } case _ => fail("Did not match expected call site format") } } }
Example 2
Source File: CompareParamGrid.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl import org.apache.spark.ml.param.ParamMap import org.scalatest.{Assertions, Matchers} trait CompareParamGrid extends Matchers with Assertions { def gridCompare(g1: Array[ParamMap], g2: Array[ParamMap]): Unit = { val g1values = g1.toSet[ParamMap].map(_.toSeq.toSet) val g2values = g2.toSet[ParamMap].map(_.toSeq.toSet) matchTwoSets(g1values, g2values) } private def matchTwoSets[T](actual: Set[T], expected: Set[T]): Unit = { def stringify(set: Set[T]): String = { val list = set.toList val chunk = list take 10 val strings = chunk.map(_.toString).sorted if (list.size > chunk.size) strings.mkString else strings.mkString + ",..." } val missing = stringify(expected -- actual) val extra = stringify(actual -- expected) withClue(s"Missing:\n $missing\nExtra:\n$extra") { actual shouldBe expected } } }
Example 3
Source File: ScalaStyleValidationTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class ScalaStyleValidationTest extends FlatSpec with Matchers with Assertions { import scala.Throwable private def +(x: Int, y: Int) = x + y private def -(x: Int, y: Int) = x - y private def *(x: Int, y: Int) = x * y private def /(x: Int, y: Int) = x / y private def +-(x: Int, y: Int) = x + (-y) private def xx_=(y: Int) = println(s"setting xx to $y") "bad names" should "never happen" in { "def _=abc = ???" shouldNot compile true shouldBe true } "non-ascii" should "not be allowed" in { // "def ⇒ = ???" shouldNot compile // it does not even compile as a string } }
Example 4
Source File: RandomListTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.testkit import java.text.SimpleDateFormat import com.salesforce.op.features.types._ import com.salesforce.op.test.TestCommon import com.salesforce.op.testkit.RandomList.{NormalGeolocation, UniformGeolocation} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec} import scala.language.postfixOps @RunWith(classOf[JUnitRunner]) class RandomListTest extends FlatSpec with TestCommon with Assertions { private val numTries = 10000 private val rngSeed = 314159214142136L private def check[D, T <: OPList[D]]( g: RandomList[D, T], minLen: Int, maxLen: Int, predicate: (D => Boolean) = (_: D) => true ) = { g reset rngSeed def segment = g limit numTries segment count (_.value.length < minLen) shouldBe 0 segment count (_.value.length > maxLen) shouldBe 0 segment foreach (list => list.value foreach { x => predicate(x) shouldBe true }) } private val df = new SimpleDateFormat("dd/MM/yy") Spec[Text, RandomList[String, TextList]] should "generate lists of strings" in { val sut = RandomList.ofTexts(RandomText.countries, 0, 4) check[String, TextList](sut, 0, 4, _.length > 0) (sut limit 7 map (_.value.toList)) shouldBe List( List("Madagascar", "Gondal", "Zephyria"), List("Holy Alliance"), List("North American Union"), List("Guatemala", "Estonia", "Kolechia"), List(), List("Myanmar", "Bhutan"), List("Equatorial Guinea") ) } Spec[Date, RandomList[Long, DateList]] should "generate lists of dates" in { val dates = RandomIntegral.dates(df.parse("01/01/2017"), 1000, 1000000) val sut = RandomList.ofDates(dates, 11, 22) var d0 = 0L check[Long, DateList](sut, 11, 22, d => { val d1 = d0 d0 = d d > d1 }) } Spec[DateTimeList, RandomList[Long, DateTimeList]] should "generate lists of datetimes" in { val datetimes = RandomIntegral.datetimes(df.parse("01/01/2017"), 1000, 1000000) val sut = RandomList.ofDateTimes(datetimes, 11, 22) var d0 = 0L check[Long, DateTimeList](sut, 11, 22, d => { val d1 = d0 d0 = d d > d1 }) } Spec[UniformGeolocation] should "generate uniformly distributed geolocations" in { val sut = RandomList.ofGeolocations val segment = sut limit numTries segment foreach (_.value.length shouldBe 3) } Spec[NormalGeolocation] should "generate geolocations around given point" in { for {accuracy <- GeolocationAccuracy.values} { val geolocation = RandomList.ofGeolocationsNear(37.444136, 122.163160, accuracy) val segment = geolocation limit numTries segment foreach (_.value.length shouldBe 3) } } }
Example 5
Source File: RandomIntegralTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.testkit import java.text.SimpleDateFormat import com.salesforce.op.features.types._ import com.salesforce.op.test.TestCommon import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec} import scala.language.postfixOps @RunWith(classOf[JUnitRunner]) class RandomIntegralTest extends FlatSpec with TestCommon with Assertions { private val numTries = 10000 private val rngSeed = 314159214142135L private def check[T <: Integral]( g: RandomIntegral[T], predicate: Long => Boolean = _ => true ) = { g reset rngSeed def segment = g limit numTries val numberOfEmpties = segment count (_.isEmpty) val expectedNumberOfEmpties = g.probabilityOfEmpty * numTries withClue(s"numEmpties = $numberOfEmpties, expected $expectedNumberOfEmpties") { math.abs(numberOfEmpties - expectedNumberOfEmpties) < 2 * math.sqrt(numTries) shouldBe true } val maybeValues = segment filterNot (_.isEmpty) map (_.value) val values = maybeValues collect { case Some(s) => s } values foreach (x => predicate(x) shouldBe true) withClue(s"number of distinct values = ${values.size}, expected:") { math.abs(maybeValues.size - values.toSet.size) < maybeValues.size / 20 } } private val df = new SimpleDateFormat("dd/MM/yy") Spec[RandomIntegral[Integral]] should "generate empties and distinct numbers" in { val sut0 = RandomIntegral.integrals val sut = sut0.withProbabilityOfEmpty(0.3) check(sut) sut.probabilityOfEmpty shouldBe 0.3 } Spec[RandomIntegral[Integral]] should "generate empties and distinct numbers in some range" in { val sut0 = RandomIntegral.integrals(100, 200) val sut = sut0.withProbabilityOfEmpty(0.3) check(sut, i => i >= 100 && i < 200) sut.probabilityOfEmpty shouldBe 0.3 } Spec[RandomIntegral[Date]] should "generate dates" in { val sut = RandomIntegral.dates(df.parse("01/01/2017"), 1000, 1000000) var d0 = 0L check(sut withProbabilityOfEmpty 0.01, d => { val d1 = d0 d0 = d d0 > d1 }) } Spec[RandomIntegral[DateTime]] should "generate dates with times" in { val sut = RandomIntegral.datetimes(df.parse("08/24/2017"), 1000, 1000000) var d0 = 0L check(sut withProbabilityOfEmpty 0.001, d => { val d1 = d0 d0 = d d0 > d1 }) } }
Example 6
Source File: RandomSetTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.testkit import com.salesforce.op.features.types._ import com.salesforce.op.test.TestCommon import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec} import scala.language.postfixOps @RunWith(classOf[JUnitRunner]) class RandomSetTest extends FlatSpec with TestCommon with Assertions { private val numTries = 10000 private val rngSeed = 314159214142136L private def check[D, T <: OPSet[D]]( g: RandomSet[D, T], minLen: Int, maxLen: Int, predicate: (D => Boolean) = (_: D) => true ) = { g reset rngSeed def segment = g limit numTries segment count (_.value.size < minLen) shouldBe 0 segment count (_.value.size > maxLen) shouldBe 0 segment foreach (Set => Set.value foreach { x => predicate(x) shouldBe true }) } Spec[MultiPickList] should "generate multipicklists" in { val sut = RandomMultiPickList.of(RandomText.countries, maxLen = 5) check[String, MultiPickList](sut, 0, 5, _.nonEmpty) val expected = List( Set(), Set("Aldorria", "Palau", "Glubbdubdrib"), Set(), Set(), Set("Sweden", "Wuhu Islands", "Tuvalu") ) {sut reset 42; sut limit 5 map (_.value)} shouldBe expected {sut reset 42; sut limit 5 map (_.value)} shouldBe expected } }
Example 7
Source File: AssertingSyntax.scala From cats-effect-testing with Apache License 2.0 | 5 votes |
package cats.effect.testing.scalatest import cats.Functor import cats.effect.Sync import org.scalatest.{Assertion, Assertions, Succeeded} import cats.implicits._ def assertThrows[E <: Throwable](implicit F: Sync[F], ct: reflect.ClassTag[E]): F[Assertion] = self.attempt.flatMap { case Left(t: E) => F.pure(Succeeded: Assertion) case Left(t) => F.delay( fail( s"Expected an exception of type ${ct.runtimeClass.getName} but got an exception: $t" ) ) case Right(a) => F.delay( fail(s"Expected an exception of type ${ct.runtimeClass.getName} but got a result: $a") ) } } }
Example 8
Source File: BagelSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark.bagel import org.scalatest.{BeforeAndAfter, Assertions} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.storage.StorageLevel class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { var sc: SparkContext = _ after { if (sc != null) { sc.stop() sc = null } } test("halting by voting") { sc = new SparkContext("local", "test") val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) val msgs = sc.parallelize(Array[(String, TestMessage)]()) val numSupersteps = 5 val result = Bagel.run(sc, verts, msgs, sc.defaultParallelism) { (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) } for ((id, vert) <- result.collect) { assert(vert.age === numSupersteps) } } test("halting by message silence") { sc = new SparkContext("local", "test") val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) val numSupersteps = 5 val result = Bagel.run(sc, verts, msgs, sc.defaultParallelism) { (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => val msgsOut = msgs match { case Some(ms) if (superstep < numSupersteps - 1) => ms case _ => Array[TestMessage]() } (new TestVertex(self.active, self.age + 1), msgsOut) } for ((id, vert) <- result.collect) { assert(vert.age === numSupersteps) } } test("large number of iterations") { // This tests whether jobs with a large number of iterations finish in a reasonable time, // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang failAfter(30 seconds) { sc = new SparkContext("local", "test") val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) val msgs = sc.parallelize(Array[(String, TestMessage)]()) val numSupersteps = 50 val result = Bagel.run(sc, verts, msgs, sc.defaultParallelism) { (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) } for ((id, vert) <- result.collect) { assert(vert.age === numSupersteps) } } } test("using non-default persistence level") { failAfter(10 seconds) { sc = new SparkContext("local", "test") val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) val msgs = sc.parallelize(Array[(String, TestMessage)]()) val numSupersteps = 20 val result = Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) } for ((id, vert) <- result.collect) { assert(vert.age === numSupersteps) } } } }
Example 9
Source File: SparkContextInfoSuite.scala From spark1.52 with Apache License 2.0 | 5 votes |
package org.apache.spark import org.scalatest.Assertions import org.apache.spark.storage.StorageLevel class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { //只返回RDDS被标记为缓存 test("getPersistentRDDs only returns RDDs that are marked as cached") { sc = new SparkContext("local", "test") assert(sc.getPersistentRDDs.isEmpty === true) val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) //获得持久化RDD空值 assert(sc.getPersistentRDDs.isEmpty === true) rdd.cache()//RDD持久化缓存 assert(sc.getPersistentRDDs.size === 1) //返回列表第一个RDD的值 assert(sc.getPersistentRDDs.values.head === rdd) } //返回一个不可变的Map test("getPersistentRDDs returns an immutable map") { sc = new SparkContext("local", "test") val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() val myRdds = sc.getPersistentRDDs //返回已标记的持久化 assert(myRdds.size === 1) assert(myRdds(0) === rdd1) //获得持久化存储级别,默认内存 assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) // myRdds2 should have 2 RDDs, but myRdds should not change val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() val myRdds2 = sc.getPersistentRDDs assert(myRdds2.size === 2) assert(myRdds2(0) === rdd1) assert(myRdds2(1) === rdd2) assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) } //报告RDDS实际持久化RDDInfo数据 test("getRDDStorageInfo only reports on RDDs that actually persist data") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(sc.getRDDStorageInfo.size === 0) rdd.collect() assert(sc.getRDDStorageInfo.size === 1)//RDDInfo assert(sc.getRDDStorageInfo.head.isCached)//判断是否缓存 assert(sc.getRDDStorageInfo.head.memSize > 0)//内存大小 assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY) } test("call sites report correct locations") {//报告正确的位置 sc = new SparkContext("local", "test") testPackage.runCallSiteTest(sc) } } package object testPackage extends Assertions { private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r def runCallSiteTest(sc: SparkContext) { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) val rddCreationSite = rdd.getCreationSite println("===="+rddCreationSite) //注意:2行后定义“rdd” val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { case CALL_SITE_REGEX(func, file, line) => { assert(func === "makeRDD") assert(file === "SparkContextInfoSuite.scala") line.toInt } case _ => fail("Did not match expected call site format") } curCallSite match { case CALL_SITE_REGEX(func, file, line) => { //这是正确的,因为我们从Spark的外部称它为正确的 assert(func === "getCallSite") // this is correct because we called it from outside of Spark assert(file === "SparkContextInfoSuite.scala") println("==line==="+line.toInt ) //assert(line.toInt === rddCreationLine.toInt + 2) } case _ => fail("Did not match expected call site format") } } }
Example 10
Source File: SparkContextInfoSuite.scala From Spark-2.3.1 with Apache License 2.0 | 5 votes |
package org.apache.spark import org.scalatest.Assertions import org.apache.spark.storage.StorageLevel class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { test("getPersistentRDDs only returns RDDs that are marked as cached") { sc = new SparkContext("local", "test") assert(sc.getPersistentRDDs.isEmpty === true) val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(sc.getPersistentRDDs.isEmpty === true) rdd.cache() assert(sc.getPersistentRDDs.size === 1) assert(sc.getPersistentRDDs.values.head === rdd) } test("getPersistentRDDs returns an immutable map") { sc = new SparkContext("local", "test") val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() val myRdds = sc.getPersistentRDDs assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) // myRdds2 should have 2 RDDs, but myRdds should not change val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() val myRdds2 = sc.getPersistentRDDs assert(myRdds2.size === 2) assert(myRdds2(0) === rdd1) assert(myRdds2(1) === rdd2) assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) } test("getRDDStorageInfo only reports on RDDs that actually persist data") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(sc.getRDDStorageInfo.size === 0) rdd.collect() assert(sc.getRDDStorageInfo.size === 1) assert(sc.getRDDStorageInfo.head.isCached) assert(sc.getRDDStorageInfo.head.memSize > 0) assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY) } test("call sites report correct locations") { sc = new SparkContext("local", "test") testPackage.runCallSiteTest(sc) } } package object testPackage extends Assertions { private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r def runCallSiteTest(sc: SparkContext) { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) val rddCreationSite = rdd.getCreationSite val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { case CALL_SITE_REGEX(func, file, line) => assert(func === "makeRDD") assert(file === "SparkContextInfoSuite.scala") line.toInt case _ => fail("Did not match expected call site format") } curCallSite match { case CALL_SITE_REGEX(func, file, line) => assert(func === "getCallSite") // this is correct because we called it from outside of Spark assert(file === "SparkContextInfoSuite.scala") assert(line.toInt === rddCreationLine.toInt + 2) case _ => fail("Did not match expected call site format") } } }
Example 11
Source File: OpCountVectorizerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op._ import com.salesforce.op.features.types._ import com.salesforce.op.test.TestOpVectorColumnType.{IndCol, IndVal} import com.salesforce.op.test.{TestFeatureBuilder, TestOpVectorMetadataBuilder, TestSparkContext} import com.salesforce.op.utils.spark.OpVectorMetadata import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class OpCountVectorizerTest extends FlatSpec with TestSparkContext { val data = Seq[(Real, TextList)]( (Real(0), Seq("a", "b", "c").toTextList), (Real(1), Seq("a", "b", "b", "b", "a", "c").toTextList) ) lazy val (ds, f1, f2) = TestFeatureBuilder(data) lazy val expected = Array[(Real, OPVector)]( (Real(0), Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 1.0, 1.0)).toOPVector), (Real(1), Vectors.sparse(3, Array(0, 1, 2), Array(3.0, 2.0, 1.0)).toOPVector) ) val f2vec = new OpCountVectorizer().setInput(f2).setVocabSize(3).setMinDF(2) Spec[OpCountVectorizerTest] should "convert array of strings into count vector" in { val transformedData = f2vec.fit(ds).transform(ds) val output = f2vec.getOutput() transformedData.orderBy(f1.name).collect(f1, output) should contain theSameElementsInOrderAs expected } it should "return the a fitted vectorizer with the correct parameters" in { val fitted = f2vec.fit(ds) val vectorMetadata = fitted.getMetadata() val expectedMeta = TestOpVectorMetadataBuilder( f2vec, f2 -> List(IndVal(Some("b")), IndVal(Some("a")), IndVal(Some("c"))) ) // cannot just do equals because fitting is nondeterministic OpVectorMetadata(f2vec.getOutputFeatureName, vectorMetadata).columns should contain theSameElementsAs expectedMeta.columns } it should "convert array of strings into count vector (shortcut version)" in { val output = f2.countVec(minDF = 2, vocabSize = 3) val f2vec = output.originStage.asInstanceOf[OpCountVectorizer] val transformedData = f2vec.fit(ds).transform(ds) transformedData.orderBy(f1.name).collect(f1, output) should contain theSameElementsInOrderAs expected } }
Example 12
Source File: TimerImplTest.scala From datadog4s with MIT License | 5 votes |
package com.avast.datadog4s.statsd import java.util.concurrent.TimeUnit import cats.effect.{ Clock, IO } import com.avast.datadog4s.api.Tag import com.avast.datadog4s.statsd.metric.TimerImpl import com.timgroup.statsd.{ StatsDClient => JStatsDClient } import org.mockito.scalatest.MockitoSugar import org.scalatest.{ Assertions, BeforeAndAfter } import org.scalatest.flatspec.AnyFlatSpec class TimerImplTest extends AnyFlatSpec with MockitoSugar with BeforeAndAfter with Assertions { trait Fixtures { val aspect: String = "metric" val sampleRate = 1.0 val statsD: JStatsDClient = mock[JStatsDClient] val clock: Clock[IO] = mock[Clock[IO]] val timer = new TimerImpl[IO](clock, statsD, aspect, sampleRate, Vector.empty) when(clock.monotonic(TimeUnit.NANOSECONDS)).thenReturn(IO.pure(10 * 1000 * 1000), IO.pure(30 * 1000 * 1000)) } "time F[A]" should "report success with label success:true" in new Fixtures { private val res = timer.time(IO.delay("hello world")).unsafeRunSync() verify(statsD, times(1)).recordExecutionTime(aspect, 20, sampleRate, Tag.of("success", "true")) assertResult(res)("hello world") } it should "report failure with label failure:true and exception name" in new Fixtures { private val res = timer.time(IO.raiseError(new NoSuchElementException("fail"))) assertThrows[NoSuchElementException](res.unsafeRunSync()) verify(statsD, times(1)) .recordExecutionTime( aspect, 20, sampleRate, Tag.of("exception", "java.util.NoSuchElementException"), Tag.of("success", "false") ) } }
Example 13
Source File: VideoDisplay.scala From jvm-toxcore-c with GNU General Public License v3.0 | 5 votes |
package im.tox.tox4j.av.callbacks.video import java.io.Closeable import im.tox.tox4j.av.data.{ Height, Width } import im.tox.tox4j.testing.autotest.AutoTestSuite.timed import org.scalatest.Assertions import scala.util.Try abstract class VideoDisplay[Parsed, Canvas] extends Assertions with Closeable { def width: Width def height: Height protected def canvas: Try[Canvas] protected def parse( y: Array[Byte], u: Array[Byte], v: Array[Byte], yStride: Int, uStride: Int, vStride: Int ): Parsed protected def displaySent(canvas: Canvas, frameNumber: Int, parsed: Parsed): Unit protected def displayReceived(canvas: Canvas, frameNumber: Int, parsed: Parsed): Unit final def displaySent(frameNumber: Int, y: Array[Byte], u: Array[Byte], v: Array[Byte]): Unit = { val width = this.width.value canvas.foreach(displaySent(_, frameNumber, parse(y, u, v, width, width / 2, width / 2))) } final def displayReceived( frameNumber: Int, y: Array[Byte], u: Array[Byte], v: Array[Byte], yStride: Int, uStride: Int, vStride: Int ): Option[(Int, Int)] = { canvas.toOption.map { canvas => val (parseTime, parsed) = timed(parse(y, u, v, yStride, uStride, vStride)) val displayTime = timed(displayReceived(canvas, frameNumber, parsed)) (parseTime, displayTime) } } }
Example 14
Source File: ToxCoreTestBase.scala From jvm-toxcore-c with GNU General Public License v3.0 | 5 votes |
package im.tox.tox4j import java.io.IOException import java.net.{ InetAddress, Socket } import java.util.Random import org.jetbrains.annotations.NotNull import org.scalatest.Assertions object ToxCoreTestBase extends Assertions { private[tox4j] val nodeCandidates = Seq( new DhtNode("tox.initramfs.io", "tox.initramfs.io", 33445, "3F0A45A268367C1BEA652F258C85F4A66DA76BCAA667A49E770BCC4917AB6A25"), new DhtNode("tox.verdict.gg", null, 33445, "1C5293AEF2114717547B39DA8EA6F1E331E5E358B35F9B6B5F19317911C5F976") ) @NotNull def randomBytes(length: Int): Array[Byte] = { val array = new Array[Byte](length) new Random().nextBytes(array) array } @NotNull def readablePublicKey(@NotNull id: Array[Byte]): String = { val str = new StringBuilder id foreach { c => str.append(f"$c%02X") } str.toString() } @NotNull def parsePublicKey(@NotNull id: String): Array[Byte] = { val publicKey = new Array[Byte](id.length / 2) publicKey.indices foreach { i => publicKey(i) = ((fromHexDigit(id.charAt(i * 2)) << 4) + fromHexDigit(id.charAt(i * 2 + 1))).toByte } publicKey } private def fromHexDigit(c: Char): Byte = { val digit = if (false) { 0 } else if ('0' to '9' contains c) { c - '0' } else if ('A' to 'F' contains c) { c - 'A' + 10 } else if ('a' to 'f' contains c) { c - 'a' + 10 } else { throw new IllegalArgumentException(s"Non-hex digit character: $c") } digit.toByte } @SuppressWarnings(Array("org.wartremover.warts.Equals")) private def hasConnection(ip: String, port: Int): Option[String] = { var socket: Socket = null try { socket = new Socket(InetAddress.getByName(ip), port) if (socket.getInputStream == null) { Some("Socket input stream is null") } else { None } } catch { case e: IOException => Some(s"A network connection can't be established to $ip:$port: ${e.getMessage}") } finally { if (socket != null) { socket.close() } } } def checkIPv4: Option[String] = { hasConnection("8.8.8.8", 53) } def checkIPv6: Option[String] = { hasConnection("2001:4860:4860::8888", 53) } protected[tox4j] def assumeIPv4(): Unit = { assume(checkIPv4.isEmpty) } protected[tox4j] def assumeIPv6(): Unit = { assume(checkIPv6.isEmpty) } }
Example 15
Source File: VideoGenerator.scala From jvm-toxcore-c with GNU General Public License v3.0 | 5 votes |
package im.tox.tox4j.av.callbacks.video import im.tox.tox4j.av.data.{ Height, Width } import org.scalatest.Assertions abstract class VideoGenerator { def width: Width def height: Height def length: Int def yuv(t: Int): (Array[Byte], Array[Byte], Array[Byte]) def resize(width: Width, height: Height): VideoGenerator final def size: Int = width.value * height.value protected final def w: Int = width.value protected final def h: Int = height.value } object VideoGenerator extends Assertions { private def resizeNearestNeighbour( pixels: Array[Byte], oldWidth: Int, oldHeight: Int, newWidth: Int, newHeight: Int ): Array[Byte] = { val result = Array.ofDim[Byte](newWidth * newHeight) val xRatio = oldWidth / newWidth.toDouble val yRatio = oldHeight / newHeight.toDouble for { yPos <- 0 until newHeight xPos <- 0 until newWidth } { val px = Math.floor(xPos * xRatio) val py = Math.floor(yPos * yRatio) result((yPos * newWidth) + xPos) = pixels(((py * oldWidth) + px).toInt) } result } @SuppressWarnings(Array("org.wartremover.warts.Equals")) def resizeNearestNeighbour(newWidth: Width, newHeight: Height, gen: VideoGenerator): VideoGenerator = { if (newWidth == gen.width && newHeight == gen.height) { gen } else { new VideoGenerator { override def toString: String = s"resizeNearestNeighbour($width, $height, $gen)" override def resize(width: Width, height: Height): VideoGenerator = gen.resize(width, height) override def yuv(t: Int): (Array[Byte], Array[Byte], Array[Byte]) = { val yuv = gen.yuv(t) ( resizeNearestNeighbour(yuv._1, gen.width.value, gen.height.value, width.value, height.value), resizeNearestNeighbour(yuv._2, gen.width.value / 2, gen.height.value / 2, width.value / 2, height.value / 2), resizeNearestNeighbour(yuv._3, gen.width.value / 2, gen.height.value / 2, width.value / 2, height.value / 2) ) } override def width: Width = newWidth override def height: Height = newHeight override def length: Int = gen.length } } } }
Example 16
Source File: DhtNodeSelector.scala From jvm-toxcore-c with GNU General Public License v3.0 | 5 votes |
package im.tox.tox4j import java.io.IOException import java.net.{ InetAddress, Socket } import com.typesafe.scalalogging.Logger import im.tox.tox4j.core.ToxCore import im.tox.tox4j.impl.jni.ToxCoreImplFactory import org.scalatest.Assertions import org.slf4j.LoggerFactory object DhtNodeSelector extends Assertions { private val logger = Logger(LoggerFactory.getLogger(this.getClass)) private var selectedNode: Option[DhtNode] = Some(ToxCoreTestBase.nodeCandidates(0)) @SuppressWarnings(Array("org.wartremover.warts.Equals")) private def tryConnect(node: DhtNode): Option[DhtNode] = { var socket: Socket = null try { socket = new Socket(InetAddress.getByName(node.ipv4), node.udpPort.value) assume(socket.getInputStream != null) Some(node) } catch { case e: IOException => logger.info(s"TCP connection failed (${e.getMessage})") None } finally { if (socket != null) { socket.close() } } } private def tryBootstrap( withTox: (Boolean, Boolean) => (ToxCore => Option[DhtNode]) => Option[DhtNode], node: DhtNode, udpEnabled: Boolean ): Option[DhtNode] = { val protocol = if (udpEnabled) "UDP" else "TCP" val port = if (udpEnabled) node.udpPort else node.tcpPort logger.info(s"Trying to bootstrap with ${node.ipv4}:$port using $protocol") withTox(true, udpEnabled) { tox => val status = new ConnectedListener if (!udpEnabled) { tox.addTcpRelay(node.ipv4, port, node.dhtId) } tox.bootstrap(node.ipv4, port, node.dhtId) // Try bootstrapping for 10 seconds. (0 to 10000 / tox.iterationInterval) find { _ => tox.iterate(status)(()) Thread.sleep(tox.iterationInterval) status.isConnected } match { case Some(time) => logger.info(s"Bootstrapped successfully after ${time * tox.iterationInterval}ms using $protocol") Some(node) case None => logger.info(s"Unable to bootstrap with $protocol") None } } } private def findNode(withTox: (Boolean, Boolean) => (ToxCore => Option[DhtNode]) => Option[DhtNode]): DhtNode = { DhtNodeSelector.selectedNode match { case Some(node) => node case None => logger.info("Looking for a working bootstrap node") DhtNodeSelector.selectedNode = ToxCoreTestBase.nodeCandidates find { node => logger.info(s"Trying to establish a TCP connection to ${node.ipv4}") (for { node <- tryConnect(node) node <- tryBootstrap(withTox, node, udpEnabled = true) node <- tryBootstrap(withTox, node, udpEnabled = false) } yield node).isDefined } assume(DhtNodeSelector.selectedNode.nonEmpty, "No viable nodes for bootstrap found; cannot test") DhtNodeSelector.selectedNode.get } } def node: DhtNode = findNode(ToxCoreImplFactory.withToxUnit[Option[DhtNode]]) }
Example 17
Source File: CheckedOrdering.scala From jvm-toxcore-c with GNU General Public License v3.0 | 5 votes |
package im.tox.tox4j.testing import org.scalatest.Assertions object CheckedOrdering extends Assertions { @SuppressWarnings(Array("org.wartremover.warts.Equals")) def apply[A](ord: Ordering[A]): Ordering[A] = { new Ordering[A] { override def compare(x: A, y: A): Int = { val result = ord.compare(x, y) if (result == 0) { assert(x == y) } result } } } }
Example 18
Source File: CheckedOrderingEq.scala From jvm-toxcore-c with GNU General Public License v3.0 | 5 votes |
package im.tox.tox4j.testing import org.scalatest.Assertions object CheckedOrderingEq extends Assertions { @SuppressWarnings(Array("org.wartremover.warts.Equals")) def apply[A <: AnyRef](ord: Ordering[A]): Ordering[A] = { new Ordering[A] { override def compare(x: A, y: A): Int = { val result = ord.compare(x, y) if (result == 0) { assert(x eq y) } result } } } }
Example 19
Source File: GetDisjunction.scala From jvm-toxcore-c with GNU General Public License v3.0 | 5 votes |
package im.tox.tox4j.testing import im.tox.core.error.CoreError import im.tox.core.typesafe.{ -\/, \/, \/- } import org.scalatest.Assertions import scala.language.implicitConversions final case class GetDisjunction[T] private (disjunction: CoreError \/ T) extends Assertions { def get: T = { disjunction match { case -\/(error) => fail(error.toString) case \/-(success) => success } } } object GetDisjunction { @SuppressWarnings(Array("org.wartremover.warts.ImplicitConversion")) implicit def toGetDisjunction[T](disjunction: CoreError \/ T): GetDisjunction[T] = GetDisjunction(disjunction) }
Example 20
Source File: ToxExceptionChecks.scala From jvm-toxcore-c with GNU General Public License v3.0 | 5 votes |
package im.tox.tox4j.testing import im.tox.tox4j.exceptions.ToxException import org.scalatest.Assertions trait ToxExceptionChecks extends Assertions { @SuppressWarnings(Array("org.wartremover.warts.Equals")) protected def intercept[E <: Enum[E]](code: E)(f: => Unit) = { try { f fail(s"Expected exception with code ${code.name}") } catch { case e: ToxException[_] => assert(e.code eq code) } } }
Example 21
Source File: BagelSuite.scala From iolap with Apache License 2.0 | 5 votes |
package org.apache.spark.bagel import org.scalatest.{BeforeAndAfter, Assertions} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.storage.StorageLevel class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { var sc: SparkContext = _ after { if (sc != null) { sc.stop() sc = null } } test("halting by voting") { sc = new SparkContext("local", "test") val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) val msgs = sc.parallelize(Array[(String, TestMessage)]()) val numSupersteps = 5 val result = Bagel.run(sc, verts, msgs, sc.defaultParallelism) { (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) } for ((id, vert) <- result.collect) { assert(vert.age === numSupersteps) } } test("halting by message silence") { sc = new SparkContext("local", "test") val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) val numSupersteps = 5 val result = Bagel.run(sc, verts, msgs, sc.defaultParallelism) { (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => val msgsOut = msgs match { case Some(ms) if (superstep < numSupersteps - 1) => ms case _ => Array[TestMessage]() } (new TestVertex(self.active, self.age + 1), msgsOut) } for ((id, vert) <- result.collect) { assert(vert.age === numSupersteps) } } test("large number of iterations") { // This tests whether jobs with a large number of iterations finish in a reasonable time, // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang failAfter(30 seconds) { sc = new SparkContext("local", "test") val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) val msgs = sc.parallelize(Array[(String, TestMessage)]()) val numSupersteps = 50 val result = Bagel.run(sc, verts, msgs, sc.defaultParallelism) { (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) } for ((id, vert) <- result.collect) { assert(vert.age === numSupersteps) } } } test("using non-default persistence level") { failAfter(10 seconds) { sc = new SparkContext("local", "test") val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) val msgs = sc.parallelize(Array[(String, TestMessage)]()) val numSupersteps = 20 val result = Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) } for ((id, vert) <- result.collect) { assert(vert.age === numSupersteps) } } } }
Example 22
Source File: SerializerPropertiesSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.util.Random import org.scalatest.Assertions import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset class SerializerPropertiesSuite extends SparkFunSuite { import SerializerPropertiesSuite._ test("JavaSerializer does not support relocation") { // Per a comment on the SPARK-4550 JIRA ticket, Java serialization appears to write out the // full class name the first time an object is written to an output stream, but subsequent // references to the class write a more compact identifier; this prevents relocation. val ser = new JavaSerializer(new SparkConf()) testSupportsRelocationOfSerializedObjects(ser, generateRandomItem) } test("KryoSerializer supports relocation when auto-reset is enabled") { val ser = new KryoSerializer(new SparkConf) assert(ser.newInstance().asInstanceOf[KryoSerializerInstance].getAutoReset()) testSupportsRelocationOfSerializedObjects(ser, generateRandomItem) } test("KryoSerializer does not support relocation when auto-reset is disabled") { val conf = new SparkConf().set("spark.kryo.registrator", classOf[RegistratorWithoutAutoReset].getName) val ser = new KryoSerializer(conf) assert(!ser.newInstance().asInstanceOf[KryoSerializerInstance].getAutoReset()) testSupportsRelocationOfSerializedObjects(ser, generateRandomItem) } } object SerializerPropertiesSuite extends Assertions { def generateRandomItem(rand: Random): Any = { val randomFunctions: Seq[() => Any] = Seq( () => rand.nextInt(), () => rand.nextString(rand.nextInt(10)), () => rand.nextDouble(), () => rand.nextBoolean(), () => (rand.nextInt(), rand.nextString(rand.nextInt(10))), () => MyCaseClass(rand.nextInt(), rand.nextString(rand.nextInt(10))), () => { val x = MyCaseClass(rand.nextInt(), rand.nextString(rand.nextInt(10))) (x, x) } ) randomFunctions(rand.nextInt(randomFunctions.size)).apply() } def testSupportsRelocationOfSerializedObjects( serializer: Serializer, generateRandomItem: Random => Any): Unit = { if (!serializer.supportsRelocationOfSerializedObjects) { return } val NUM_TRIALS = 5 val rand = new Random(42) for (_ <- 1 to NUM_TRIALS) { val items = { // Make sure that we have duplicate occurrences of the same object in the stream: val randomItems = Seq.fill(10)(generateRandomItem(rand)) randomItems ++ randomItems.take(5) } val baos = new ByteArrayOutputStream() val serStream = serializer.newInstance().serializeStream(baos) def serializeItem(item: Any): Array[Byte] = { val itemStartOffset = baos.toByteArray.length serStream.writeObject(item) serStream.flush() val itemEndOffset = baos.toByteArray.length baos.toByteArray.slice(itemStartOffset, itemEndOffset).clone() } val itemsAndSerializedItems: Seq[(Any, Array[Byte])] = { val serItems = items.map { item => (item, serializeItem(item)) } serStream.close() rand.shuffle(serItems) } val reorderedSerializedData: Array[Byte] = itemsAndSerializedItems.flatMap(_._2).toArray val deserializedItemsStream = serializer.newInstance().deserializeStream( new ByteArrayInputStream(reorderedSerializedData)) assert(deserializedItemsStream.asIterator.toSeq === itemsAndSerializedItems.map(_._1)) deserializedItemsStream.close() } } } private case class MyCaseClass(foo: Int, bar: String)
Example 23
Source File: BufferHolderSparkSubmitSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.catalyst.expressions.codegen import org.scalatest.{Assertions, BeforeAndAfterEach, Matchers} import org.apache.spark.{SparkFunSuite, TestUtils} import org.apache.spark.deploy.SparkSubmitSuite import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.ResetSystemProperties // A test for growing the buffer holder to nearly 2GB. Due to the heap size limitation of the Spark // unit tests JVM, the actually test code is running as a submit job. class BufferHolderSparkSubmitSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with ResetSystemProperties { test("SPARK-22222: Buffer holder should be able to allocate memory larger than 1GB") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val argsForSparkSubmit = Seq( "--class", BufferHolderSparkSubmitSuite.getClass.getName.stripSuffix("$"), "--name", "SPARK-22222", "--master", "local-cluster[1,1,4096]", "--driver-memory", "4g", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", "--conf", "spark.driver.extraJavaOptions=-ea", unusedJar.toString) SparkSubmitSuite.runSparkSubmit(argsForSparkSubmit, "../..") } } object BufferHolderSparkSubmitSuite extends Assertions { def main(args: Array[String]): Unit = { val ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH val unsafeRow = new UnsafeRow(1000) val holder = new BufferHolder(unsafeRow) holder.reset() assert(intercept[IllegalArgumentException] { holder.grow(-1) }.getMessage.contains("because the size is negative")) // while to reuse a buffer may happen, this test checks whether the buffer can be grown holder.grow(ARRAY_MAX / 2) assert(unsafeRow.getSizeInBytes % 8 == 0) holder.grow(ARRAY_MAX / 2 + 7) assert(unsafeRow.getSizeInBytes % 8 == 0) holder.grow(Integer.MAX_VALUE / 2) assert(unsafeRow.getSizeInBytes % 8 == 0) holder.grow(ARRAY_MAX - holder.totalSize()) assert(unsafeRow.getSizeInBytes % 8 == 0) assert(intercept[IllegalArgumentException] { holder.grow(ARRAY_MAX + 1 - holder.totalSize()) }.getMessage.contains("because the size after growing")) } }
Example 24
Source File: WholeStageCodegenSparkSubmitSuite.scala From XSQL with Apache License 2.0 | 5 votes |
package org.apache.spark.sql.execution import org.scalatest.{Assertions, BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.TimeLimits import org.apache.spark.{SparkFunSuite, TestUtils} import org.apache.spark.deploy.SparkSubmitSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.{LocalSparkSession, QueryTest, Row, SparkSession} import org.apache.spark.sql.functions.{array, col, count, lit} import org.apache.spark.sql.types.IntegerType import org.apache.spark.unsafe.Platform import org.apache.spark.util.ResetSystemProperties // Due to the need to set driver's extraJavaOptions, this test needs to use actual SparkSubmit. class WholeStageCodegenSparkSubmitSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with ResetSystemProperties { test("Generated code on driver should not embed platform-specific constant") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) // HotSpot JVM specific: Set up a local cluster with the driver/executor using mismatched // settings of UseCompressedOops JVM option. val argsForSparkSubmit = Seq( "--class", WholeStageCodegenSparkSubmitSuite.getClass.getName.stripSuffix("$"), "--master", "local-cluster[1,1,1024]", "--driver-memory", "1g", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", "--conf", "spark.driver.extraJavaOptions=-XX:-UseCompressedOops", "--conf", "spark.executor.extraJavaOptions=-XX:+UseCompressedOops", unusedJar.toString) SparkSubmitSuite.runSparkSubmit(argsForSparkSubmit, "../..") } } object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging { var spark: SparkSession = _ def main(args: Array[String]): Unit = { TestUtils.configTestLog4j("INFO") spark = SparkSession.builder().getOrCreate() // Make sure the test is run where the driver and the executors uses different object layouts val driverArrayHeaderSize = Platform.BYTE_ARRAY_OFFSET val executorArrayHeaderSize = spark.sparkContext.range(0, 1).map(_ => Platform.BYTE_ARRAY_OFFSET).collect.head.toInt assert(driverArrayHeaderSize > executorArrayHeaderSize) val df = spark.range(71773).select((col("id") % lit(10)).cast(IntegerType) as "v") .groupBy(array(col("v"))).agg(count(col("*"))) val plan = df.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) val expectedAnswer = Row(Array(0), 7178) :: Row(Array(1), 7178) :: Row(Array(2), 7178) :: Row(Array(3), 7177) :: Row(Array(4), 7177) :: Row(Array(5), 7177) :: Row(Array(6), 7177) :: Row(Array(7), 7177) :: Row(Array(8), 7177) :: Row(Array(9), 7177) :: Nil val result = df.collect QueryTest.sameRows(result.toSeq, expectedAnswer) match { case Some(errMsg) => fail(errMsg) case _ => } } }
Example 25
Source File: SparkContextInfoSuite.scala From sparkoscope with Apache License 2.0 | 5 votes |
package org.apache.spark import org.scalatest.Assertions import org.apache.spark.storage.StorageLevel class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { test("getPersistentRDDs only returns RDDs that are marked as cached") { sc = new SparkContext("local", "test") assert(sc.getPersistentRDDs.isEmpty === true) val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(sc.getPersistentRDDs.isEmpty === true) rdd.cache() assert(sc.getPersistentRDDs.size === 1) assert(sc.getPersistentRDDs.values.head === rdd) } test("getPersistentRDDs returns an immutable map") { sc = new SparkContext("local", "test") val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() val myRdds = sc.getPersistentRDDs assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) // myRdds2 should have 2 RDDs, but myRdds should not change val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() val myRdds2 = sc.getPersistentRDDs assert(myRdds2.size === 2) assert(myRdds2(0) === rdd1) assert(myRdds2(1) === rdd2) assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) } test("getRDDStorageInfo only reports on RDDs that actually persist data") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(sc.getRDDStorageInfo.size === 0) rdd.collect() assert(sc.getRDDStorageInfo.size === 1) assert(sc.getRDDStorageInfo.head.isCached) assert(sc.getRDDStorageInfo.head.memSize > 0) assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY) } test("call sites report correct locations") { sc = new SparkContext("local", "test") testPackage.runCallSiteTest(sc) } } package object testPackage extends Assertions { private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r def runCallSiteTest(sc: SparkContext) { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) val rddCreationSite = rdd.getCreationSite val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { case CALL_SITE_REGEX(func, file, line) => assert(func === "makeRDD") assert(file === "SparkContextInfoSuite.scala") line.toInt case _ => fail("Did not match expected call site format") } curCallSite match { case CALL_SITE_REGEX(func, file, line) => assert(func === "getCallSite") // this is correct because we called it from outside of Spark assert(file === "SparkContextInfoSuite.scala") assert(line.toInt === rddCreationLine.toInt + 2) case _ => fail("Did not match expected call site format") } } }
Example 26
Source File: SparkContextInfoSuite.scala From SparkCore with Apache License 2.0 | 5 votes |
package org.apache.spark import org.scalatest.{Assertions, FunSuite} import org.apache.spark.storage.StorageLevel class SparkContextInfoSuite extends FunSuite with LocalSparkContext { test("getPersistentRDDs only returns RDDs that are marked as cached") { sc = new SparkContext("local", "test") assert(sc.getPersistentRDDs.isEmpty === true) val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(sc.getPersistentRDDs.isEmpty === true) rdd.cache() assert(sc.getPersistentRDDs.size === 1) assert(sc.getPersistentRDDs.values.head === rdd) } test("getPersistentRDDs returns an immutable map") { sc = new SparkContext("local", "test") val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() val myRdds = sc.getPersistentRDDs assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) // myRdds2 should have 2 RDDs, but myRdds should not change val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() val myRdds2 = sc.getPersistentRDDs assert(myRdds2.size === 2) assert(myRdds2(0) === rdd1) assert(myRdds2(1) === rdd2) assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) } test("getRDDStorageInfo only reports on RDDs that actually persist data") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(sc.getRDDStorageInfo.size === 0) rdd.collect() assert(sc.getRDDStorageInfo.size === 1) assert(sc.getRDDStorageInfo.head.isCached) assert(sc.getRDDStorageInfo.head.memSize > 0) assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY) } test("call sites report correct locations") { sc = new SparkContext("local", "test") testPackage.runCallSiteTest(sc) } } package object testPackage extends Assertions { private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r def runCallSiteTest(sc: SparkContext) { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) val rddCreationSite = rdd.getCreationSite val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { case CALL_SITE_REGEX(func, file, line) => { assert(func === "makeRDD") assert(file === "SparkContextInfoSuite.scala") line.toInt } case _ => fail("Did not match expected call site format") } curCallSite match { case CALL_SITE_REGEX(func, file, line) => { assert(func === "getCallSite") // this is correct because we called it from outside of Spark assert(file === "SparkContextInfoSuite.scala") assert(line.toInt === rddCreationLine.toInt + 2) } case _ => fail("Did not match expected call site format") } } }
Example 27
Source File: AnalyzerTest.scala From scala-commons with MIT License | 5 votes |
package com.avsystem.commons package analyzer import org.scalactic.source.Position import org.scalatest.Assertions import scala.reflect.internal.util.BatchSourceFile import scala.tools.nsc.plugins.Plugin import scala.tools.nsc.{Global, Settings} trait AnalyzerTest { this: Assertions => val settings = new Settings settings.usejavacp.value = true settings.pluginOptions.value ++= List("AVSystemAnalyzer:+_") val compiler: Global = new Global(settings) { global => override protected def loadRoughPluginsList(): List[Plugin] = new AnalyzerPlugin(global) :: super.loadRoughPluginsList() } def compile(source: String): Unit = { compiler.reporter.reset() val run = new compiler.Run run.compileSources(List(new BatchSourceFile("test.scala", source))) } def assertErrors(source: String)(implicit pos: Position): Unit = { compile(source) assert(compiler.reporter.hasErrors) } def assertErrors(errors: Int, source: String)(implicit pos: Position): Unit = { compile(source) assert(compiler.reporter.errorCount == errors) } def assertNoErrors(source: String)(implicit pos: Position): Unit = { compile(source) assert(!compiler.reporter.hasErrors) } }
Example 28
Source File: ResultAssertions.scala From wartremover-contrib with Apache License 2.0 | 5 votes |
package org.wartremover package contrib.test import org.scalatest.Assertions import org.wartremover.test.WartTestTraverser trait ResultAssertions extends Assertions { def assertEmpty(result: WartTestTraverser.Result) = { assertResult(List.empty, "result.errors")(result.errors) assertResult(List.empty, "result.warnings")(result.warnings) } def assertError(result: WartTestTraverser.Result)(message: String) = assertErrors(result)(message, 1) def assertErrors(result: WartTestTraverser.Result)(message: String, times: Int) = { assertResult(List.fill(times)(message), "result.errors")(result.errors.map(skipTraverserPrefix)) assertResult(List.empty, "result.warnings")(result.warnings.map(skipTraverserPrefix)) } def assertWarnings(result: WartTestTraverser.Result)(message: String, times: Int) = { assertResult(List.empty, "result.errors")(result.errors.map(skipTraverserPrefix)) assertResult(List.fill(times)(message), "result.warnings")(result.warnings.map(skipTraverserPrefix)) } private val messageFormat = """\[wartremover:\S+\] ([\s\S]+)""".r private def skipTraverserPrefix(msg: String) = msg match { case messageFormat(rest) => rest case s => s } }
Example 29
Source File: ResponseAccessors.scala From ScalaWebTest with Apache License 2.0 | 5 votes |
package org.scalawebtest.core import org.scalatest.Assertions def responseHeaderValue(name: String): String = responseHeaders.get(name) match { case Some(v) => v case None => val headerNames = responseHeaders.keys.mkString("'", "', '", "'") fail( s"""The current web response for did not contain the expected response header with field-name: '$name' It contained the following header field-names: $headerNames""") } }
Example 30
Source File: SparkContextInfoSuite.scala From multi-tenancy-spark with Apache License 2.0 | 5 votes |
package org.apache.spark import org.scalatest.Assertions import org.apache.spark.storage.StorageLevel class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { test("getPersistentRDDs only returns RDDs that are marked as cached") { sc = new SparkContext("local", "test") assert(sc.getPersistentRDDs.isEmpty === true) val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(sc.getPersistentRDDs.isEmpty === true) rdd.cache() assert(sc.getPersistentRDDs.size === 1) assert(sc.getPersistentRDDs.values.head === rdd) } test("getPersistentRDDs returns an immutable map") { sc = new SparkContext("local", "test") val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() val myRdds = sc.getPersistentRDDs assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) // myRdds2 should have 2 RDDs, but myRdds should not change val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() val myRdds2 = sc.getPersistentRDDs assert(myRdds2.size === 2) assert(myRdds2(0) === rdd1) assert(myRdds2(1) === rdd2) assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) } test("getRDDStorageInfo only reports on RDDs that actually persist data") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(sc.getRDDStorageInfo.size === 0) rdd.collect() assert(sc.getRDDStorageInfo.size === 1) assert(sc.getRDDStorageInfo.head.isCached) assert(sc.getRDDStorageInfo.head.memSize > 0) assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY) } test("call sites report correct locations") { sc = new SparkContext("local", "test") testPackage.runCallSiteTest(sc) } } package object testPackage extends Assertions { private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r def runCallSiteTest(sc: SparkContext) { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) val rddCreationSite = rdd.getCreationSite val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { case CALL_SITE_REGEX(func, file, line) => assert(func === "makeRDD") assert(file === "SparkContextInfoSuite.scala") line.toInt case _ => fail("Did not match expected call site format") } curCallSite match { case CALL_SITE_REGEX(func, file, line) => assert(func === "getCallSite") // this is correct because we called it from outside of Spark assert(file === "SparkContextInfoSuite.scala") assert(line.toInt === rddCreationLine.toInt + 2) case _ => fail("Did not match expected call site format") } } }
Example 31
Source File: SparkContextInfoSuite.scala From drizzle-spark with Apache License 2.0 | 5 votes |
package org.apache.spark import org.scalatest.Assertions import org.apache.spark.storage.StorageLevel class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { test("getPersistentRDDs only returns RDDs that are marked as cached") { sc = new SparkContext("local", "test") assert(sc.getPersistentRDDs.isEmpty === true) val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(sc.getPersistentRDDs.isEmpty === true) rdd.cache() assert(sc.getPersistentRDDs.size === 1) assert(sc.getPersistentRDDs.values.head === rdd) } test("getPersistentRDDs returns an immutable map") { sc = new SparkContext("local", "test") val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() val myRdds = sc.getPersistentRDDs assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) // myRdds2 should have 2 RDDs, but myRdds should not change val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() val myRdds2 = sc.getPersistentRDDs assert(myRdds2.size === 2) assert(myRdds2(0) === rdd1) assert(myRdds2(1) === rdd2) assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) } test("getRDDStorageInfo only reports on RDDs that actually persist data") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(sc.getRDDStorageInfo.size === 0) rdd.collect() assert(sc.getRDDStorageInfo.size === 1) assert(sc.getRDDStorageInfo.head.isCached) assert(sc.getRDDStorageInfo.head.memSize > 0) assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY) } test("call sites report correct locations") { sc = new SparkContext("local", "test") testPackage.runCallSiteTest(sc) } } package object testPackage extends Assertions { private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r def runCallSiteTest(sc: SparkContext) { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) val rddCreationSite = rdd.getCreationSite val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { case CALL_SITE_REGEX(func, file, line) => assert(func === "makeRDD") assert(file === "SparkContextInfoSuite.scala") line.toInt case _ => fail("Did not match expected call site format") } curCallSite match { case CALL_SITE_REGEX(func, file, line) => assert(func === "getCallSite") // this is correct because we called it from outside of Spark assert(file === "SparkContextInfoSuite.scala") assert(line.toInt === rddCreationLine.toInt + 2) case _ => fail("Did not match expected call site format") } } }
Example 32
Source File: SparkContextInfoSuite.scala From iolap with Apache License 2.0 | 5 votes |
package org.apache.spark import org.scalatest.Assertions import org.apache.spark.storage.StorageLevel class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { test("getPersistentRDDs only returns RDDs that are marked as cached") { sc = new SparkContext("local", "test") assert(sc.getPersistentRDDs.isEmpty === true) val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(sc.getPersistentRDDs.isEmpty === true) rdd.cache() assert(sc.getPersistentRDDs.size === 1) assert(sc.getPersistentRDDs.values.head === rdd) } test("getPersistentRDDs returns an immutable map") { sc = new SparkContext("local", "test") val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() val myRdds = sc.getPersistentRDDs assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) // myRdds2 should have 2 RDDs, but myRdds should not change val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() val myRdds2 = sc.getPersistentRDDs assert(myRdds2.size === 2) assert(myRdds2(0) === rdd1) assert(myRdds2(1) === rdd2) assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) } test("getRDDStorageInfo only reports on RDDs that actually persist data") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(sc.getRDDStorageInfo.size === 0) rdd.collect() assert(sc.getRDDStorageInfo.size === 1) assert(sc.getRDDStorageInfo.head.isCached) assert(sc.getRDDStorageInfo.head.memSize > 0) assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY) } test("call sites report correct locations") { sc = new SparkContext("local", "test") testPackage.runCallSiteTest(sc) } } package object testPackage extends Assertions { private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r def runCallSiteTest(sc: SparkContext) { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) val rddCreationSite = rdd.getCreationSite val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { case CALL_SITE_REGEX(func, file, line) => { assert(func === "makeRDD") assert(file === "SparkContextInfoSuite.scala") line.toInt } case _ => fail("Did not match expected call site format") } curCallSite match { case CALL_SITE_REGEX(func, file, line) => { assert(func === "getCallSite") // this is correct because we called it from outside of Spark assert(file === "SparkContextInfoSuite.scala") assert(line.toInt === rddCreationLine.toInt + 2) } case _ => fail("Did not match expected call site format") } } }
Example 33
Source File: MockHelpers.scala From guardrail with MIT License | 5 votes |
package helpers import com.fasterxml.jackson.databind.ObjectMapper import io.netty.handler.codec.http.EmptyHttpHeaders import java.io.ByteArrayInputStream import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import java.util.Collections import java.util.concurrent.CompletableFuture import javax.ws.rs.container.AsyncResponse import org.asynchttpclient.Response import org.asynchttpclient.uri.Uri import org.mockito.{ ArgumentMatchersSugar, MockitoSugar } import org.scalatest.Assertions import scala.reflect.ClassTag object MockHelpers extends Assertions with MockitoSugar with ArgumentMatchersSugar { def mockAsyncResponse[T](future: CompletableFuture[T])(implicit cls: ClassTag[T]): AsyncResponse = { val asyncResponse = mock[AsyncResponse] when(asyncResponse.resume(any[T])) thenAnswer [AnyRef] { response => response match { case t: Throwable => future.completeExceptionally(t) case other: T => future.complete(other) case other => fail(s"AsyncResponse.resume expected an object of type ${cls.runtimeClass.getName}, but got ${other.getClass.getName} instead") } } asyncResponse } def mockAHCResponse[T](uri: String, status: Int, maybeBody: Option[T] = None)(implicit mapper: ObjectMapper): Response = { val response = mock[Response] when(response.getUri) thenReturn Uri.create(uri) when(response.hasResponseStatus) thenReturn true when(response.getStatusCode) thenReturn status when(response.getStatusText) thenReturn "Some Status" when(response.hasResponseHeaders) thenReturn true when(response.getHeaders) thenReturn EmptyHttpHeaders.INSTANCE when(response.getHeader(any)) thenReturn null when(response.getHeaders(any)) thenReturn Collections.emptyList() maybeBody match { case None => when(response.hasResponseBody) thenReturn true case Some(body) => val responseBytes = mapper.writeValueAsBytes(body) val responseStr = new String(responseBytes, StandardCharsets.UTF_8) when(response.hasResponseBody) thenReturn true when(response.getResponseBody(any)) thenReturn responseStr when(response.getResponseBody) thenReturn responseStr when(response.getResponseBodyAsStream) thenReturn new ByteArrayInputStream(responseBytes) when(response.getResponseBodyAsByteBuffer) thenReturn ByteBuffer.wrap(responseBytes) when(response.getResponseBodyAsBytes) thenReturn responseBytes } response } }
Example 34
Source File: ListTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.features.types import com.salesforce.op.test.TestCommon import org.apache.lucene.spatial3d.geom.{GeoPoint, PlanetModel} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class ListTest extends FlatSpec with TestCommon { Spec[DateTimeList] should "extend OPList[Long]" in { val myDateTimeList = new DateTimeList(List.empty[Long]) myDateTimeList shouldBe a[FeatureType] myDateTimeList shouldBe a[OPCollection] myDateTimeList shouldBe a[OPList[_]] myDateTimeList shouldBe a[DateList] } it should "compare values correctly" in { new DateTimeList(List(456L, 13L)) shouldBe new DateTimeList(List(456L, 13L)) new DateTimeList(List(13L, 456L)) should not be new DateTimeList(List(456L, 13L)) FeatureTypeDefaults.DateTimeList should not be new DateTimeList(List(456L, 13L)) FeatureTypeDefaults.DateTimeList shouldBe DateTimeList(List.empty[Long]) List(12237834L, 4890489839L).toDateTimeList shouldBe a[DateTimeList] } }
Example 35
Source File: AvroFieldTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.cli.gen import com.salesforce.op.cli.gen.AvroField._ import com.salesforce.op.test.TestCommon import org.apache.avro.Schema import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec} import scala.collection.JavaConverters._ import scala.language.postfixOps @RunWith(classOf[JUnitRunner]) class AvroFieldTest extends FlatSpec with TestCommon with Assertions { Spec[AvroField] should "do from" in { val types = List( Schema.Type.STRING, // Schema.Type.BYTES, // somehow this avro type is not covered (yet) Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.BOOLEAN ) val simpleSchemas = types map Schema.create val unions = List( Schema.createUnion((Schema.Type.NULL::Schema.Type.INT::Nil) map Schema.create asJava), Schema.createUnion((Schema.Type.INT::Schema.Type.NULL::Nil) map Schema.create asJava) ) val enum = Schema.createEnum("Aliens", "undocumented", "outer", List("Edgar_the_Bug", "Boris_the_Animal", "Laura_Vasquez") asJava) val allSchemas = (enum::unions)++simpleSchemas // NULL does not work val fields = allSchemas.zipWithIndex map { case (s, i) => new Schema.Field("x" + i, s, "Who", null: Object) } val expected = List( AEnum(fields(0), isNullable = false), AInt(fields(1), isNullable = true), AInt(fields(2), isNullable = true), AString(fields(3), isNullable = false), AInt(fields(4), isNullable = false), ALong(fields(5), isNullable = false), AFloat(fields(6), isNullable = false), ADouble(fields(7), isNullable = false), ABoolean(fields(8), isNullable = false) ) an[IllegalArgumentException] should be thrownBy { val nullSchema = Schema.create(Schema.Type.NULL) val nullField = new Schema.Field("xxx", null, "Nobody", null: Object) AvroField from nullField } fields.size shouldBe expected.size for { (field, expected) <- fields zip expected } { val actual = AvroField from field actual shouldBe expected } } }
Example 36
Source File: OpsTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.cli.gen import java.io.File import java.nio.file.Paths import com.salesforce.op.cli.{AvroSchemaFromFile, CliParameters, GeneratorConfig} import com.salesforce.op.test.TestCommon import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec} import scala.io.Source @RunWith(classOf[JUnitRunner]) class OpsTest extends FlatSpec with TestCommon with Assertions { val tempFolder = new File(System.getProperty("java.io.tmpdir")) val projectFolder = new File(tempFolder, "cli_test") projectFolder.deleteOnExit() val testParams = CliParameters( location = tempFolder, projName = "cli_test", inputFile = Some(Paths.get("templates", "simple", "src", "main", "resources", "PassengerData.csv").toFile), response = Some("survived"), idField = Some("passengerId"), schemaSource = Some( AvroSchemaFromFile(Paths.get("..", "utils", "src", "main", "avro", "PassengerCSV.avsc").toFile) ), answersFile = Some(new File("passengers.answers")), overwrite = true).values Spec[Ops] should "generate project files" in { testParams match { case None => fail("Could not create config, I wonder why") case Some(conf: GeneratorConfig) => val ops = Ops(conf) ops.run() val buildFile = new File(projectFolder, "build.gradle") buildFile should exist val buildFileContent = Source.fromFile(buildFile).mkString buildFileContent should include("mainClassName = 'com.salesforce.app.cli_test'") val scalaSourcesFolder = Paths.get(projectFolder.toString, "src", "main", "scala", "com", "salesforce", "app") val featuresFile = Source.fromFile(new File(scalaSourcesFolder.toFile, "Features.scala")).getLines val heightLine = featuresFile.find(_ contains "description") map (_.trim) heightLine shouldBe Some( "val description = FB.Text[PassengerCSV].extract(_.getDescription.toText).asPredictor" ) } } }
Example 37
Source File: UserIOTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.cli.gen import com.salesforce.op.test.TestCommon import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec} @RunWith(classOf[JUnitRunner]) class UserIOTest extends FlatSpec with TestCommon with Assertions { private case class Oracle(answers: String*) extends UserIO { private var i = -1 var question = "---" override def readLine(q: String): Option[String] = { question = q i += 1 if (i < answers.length) Some(answers(i)) else throw new IllegalStateException(s"Out of answers, q=$q") } } Spec[UserIO] should "do qna" in { // @see https://www.urbandictionary.com/define.php?term=aks def aksme(q: String, answers: String*): Option[String] = { Oracle(answers: _*).qna(q, _.length == 1, Map("2*3" -> "6", "3+2" -> "5")) } aksme("2+2", "11", "22", "?") shouldBe Some("?") aksme("2+2", "4", "5", "?") shouldBe Some("4") aksme("2+3", "44", "", "?") shouldBe Some("?") aksme("2*3", "4", "?") shouldBe Some("6") aksme("3+2", "4", "?") shouldBe Some("5") } it should "ask" in { // @see https://www.urbandictionary.com/define.php?term=aks def aksme[Int](q: String, opts: Map[Int, List[String]], answers: String*): (String, Int) = { val console = Oracle(answers: _*) val answer = console.ask(q, opts) getOrElse fail(s"A problem answering question $q") (console.question, answer) } an[IllegalStateException] should be thrownBy aksme("what is your name?", Map(1 -> List("one", "uno")), "11", "1", "?") aksme("what is your name?", Map( 1 -> List("Nessuno", "Nobody"), 2 -> List("Ishmael", "Gantenbein")), "5", "1", "?") shouldBe("what is your name? [0] Nessuno [1] Ishmael: ", 2) } }
Example 38
Source File: OpLDATest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op._ import com.salesforce.op.features.types._ import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.clustering.LDA import org.apache.spark.ml.linalg.{Vector, Vectors} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class OpLDATest extends FlatSpec with TestSparkContext { val inputData = Seq( (0.0, Vectors.sparse(11, Array(0, 1, 2, 4, 5, 6, 7, 10), Array(1.0, 2.0, 6.0, 2.0, 3.0, 1.0, 1.0, 3.0))), (1.0, Vectors.sparse(11, Array(0, 1, 3, 4, 7, 10), Array(1.0, 3.0, 1.0, 3.0, 2.0, 1.0))), (2.0, Vectors.sparse(11, Array(0, 1, 2, 5, 6, 8, 9), Array(1.0, 4.0, 1.0, 4.0, 9.0, 1.0, 2.0))), (3.0, Vectors.sparse(11, Array(0, 1, 3, 6, 8, 9, 10), Array(2.0, 1.0, 3.0, 5.0, 2.0, 3.0, 9.0))), (4.0, Vectors.sparse(11, Array(0, 1, 2, 3, 4, 6, 9, 10), Array(3.0, 1.0, 1.0, 9.0, 3.0, 2.0, 1.0, 3.0))), (5.0, Vectors.sparse(11, Array(0, 1, 3, 4, 5, 6, 7, 8, 9), Array(4.0, 2.0, 3.0, 4.0, 5.0, 1.0, 1.0, 1.0, 4.0))), (6.0, Vectors.sparse(11, Array(0, 1, 3, 6, 8, 9, 10), Array(2.0, 1.0, 3.0, 5.0, 2.0, 2.0, 9.0))), (7.0, Vectors.sparse(11, Array(0, 1, 2, 3, 4, 5, 6, 9, 10), Array(1.0, 1.0, 1.0, 9.0, 2.0, 1.0, 2.0, 1.0, 3.0))), (8.0, Vectors.sparse(11, Array(0, 1, 3, 4, 5, 6, 7), Array(4.0, 4.0, 3.0, 4.0, 2.0, 1.0, 3.0))), (9.0, Vectors.sparse(11, Array(0, 1, 2, 4, 6, 8, 9, 10), Array(2.0, 8.0, 2.0, 3.0, 2.0, 2.0, 7.0, 2.0))), (10.0, Vectors.sparse(11, Array(0, 1, 2, 3, 5, 6, 9, 10), Array(1.0, 1.0, 1.0, 9.0, 2.0, 2.0, 3.0, 3.0))), (11.0, Vectors.sparse(11, Array(0, 1, 4, 5, 6, 7, 9), Array(4.0, 1.0, 4.0, 5.0, 1.0, 3.0, 1.0))) ).map(v => v._1.toReal -> v._2.toOPVector) lazy val (ds, f1, f2) = TestFeatureBuilder(inputData) lazy val inputDS = ds.persist() val seed = 1234567890L val k = 3 val maxIter = 100 lazy val expected = new LDA() .setFeaturesCol(f2.name) .setK(k) .setSeed(seed) .fit(inputDS) .transform(inputDS) .select("topicDistribution") .collect() .toSeq .map(_.getAs[Vector](0)) Spec[OpLDA] should "convert document term vectors into topic vectors" in { val f2Vec = new OpLDA().setInput(f2).setK(k).setSeed(seed).setMaxIter(maxIter) val testTransformedData = f2Vec.fit(inputDS).transform(inputDS) val output = f2Vec.getOutput() val estimate = testTransformedData.collect(output) val mse = computeMeanSqError(estimate, expected) val expectedMse = 0.5 withClue(s"Computed mse $mse (expected $expectedMse)") { mse should be < expectedMse } } it should "convert document term vectors into topic vectors (shortcut version)" in { val output = f2.lda(k = k, seed = seed, maxIter = maxIter) val f2Vec = output.originStage.asInstanceOf[OpLDA] val testTransformedData = f2Vec.fit(inputDS).transform(inputDS) val estimate = testTransformedData.collect(output) val mse = computeMeanSqError(estimate, expected) val expectedMse = 0.5 withClue(s"Computed mse $mse (expected $expectedMse)") { mse should be < expectedMse } } private def computeMeanSqError(estimate: Seq[OPVector], expected: Seq[Vector]): Double = { val n = estimate.length.toDouble estimate.zip(expected).map { case (est, exp) => Vectors.sqdist(est.value, exp) }.sum / n } }
Example 39
Source File: OpWord2VecTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op._ import com.salesforce.op.features.types._ import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class OpWord2VecTest extends FlatSpec with TestSparkContext { val data = Seq( "I I I like like Spark".split(" "), "Hi I heard about Spark".split(" "), "I wish Java could use case classes".split(" "), "Logistic regression models are neat".split(" ") ).map(_.toSeq.toTextList) lazy val (inputData, f1) = TestFeatureBuilder(Seq(data.head)) lazy val (testData, _) = TestFeatureBuilder(data.tail) lazy val expected = data.tail.zip(Seq( Vectors.dense(-0.029884086549282075, -0.055613189935684204, 0.04186216294765473).toOPVector, Vectors.dense(-0.0026281912411962234, -0.016138136386871338, 0.010740748473576136).toOPVector, Vectors.dense(0.0, 0.0, 0.0).toOPVector )).toArray Spec[OpWord2VecTest] should "convert array of strings into a vector" in { val f1Vec = new OpWord2Vec().setInput(f1).setMinCount(0).setVectorSize(3).setSeed(1234567890L) val output = f1Vec.getOutput() val testTransformedData = f1Vec.fit(inputData).transform(testData) testTransformedData.orderBy(f1.name).collect(f1, output) shouldBe expected } it should "convert array of strings into a vector (shortcut version)" in { val output = f1.word2vec(minCount = 0, vectorSize = 3) val f1Vec = output.originStage.asInstanceOf[OpWord2Vec].setSeed(1234567890L) val testTransformedData = f1Vec.fit(inputData).transform(testData) testTransformedData.orderBy(f1.name).collect(f1, output) shouldBe expected } }
Example 40
Source File: IDFTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op._ import com.salesforce.op.features.types._ import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.feature.IDF import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.{Estimator, Transformer} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class IDFTest extends FlatSpec with TestSparkContext { val data = Seq( Vectors.sparse(4, Array(1, 3), Array(1.0, 2.0)), Vectors.dense(0.0, 1.0, 2.0, 3.0), Vectors.sparse(4, Array(1), Array(1.0)) ) lazy val (ds, f1) = TestFeatureBuilder(data.map(_.toOPVector)) Spec[IDF] should "compute inverted document frequency" in { val idf = f1.idf() val model = idf.originStage.asInstanceOf[Estimator[_]].fit(ds) val transformedData = model.asInstanceOf[Transformer].transform(ds) val results = transformedData.select(idf.name).collect(idf) idf.name shouldBe idf.originStage.getOutputFeatureName val expectedIdf = Vectors.dense(Array(0, 3, 1, 2).map { x => math.log((data.length + 1.0) / (x + 1.0)) }) val expected = scaleDataWithIDF(data, expectedIdf) for { (res, exp) <- results.zip(expected) (x, y) <- res.value.toArray.zip(exp.toArray) } assert(math.abs(x - y) <= 1e-5) } it should "compute inverted document frequency when minDocFreq is 1" in { val idf = f1.idf(minDocFreq = 1) val model = idf.originStage.asInstanceOf[Estimator[_]].fit(ds) val transformedData = model.asInstanceOf[Transformer].transform(ds) val results = transformedData.select(idf.name).collect(idf) idf.name shouldBe idf.originStage.getOutputFeatureName val expectedIdf = Vectors.dense(Array(0, 3, 1, 2).map { x => if (x > 0) math.log((data.length + 1.0) / (x + 1.0)) else 0 }) val expected = scaleDataWithIDF(data, expectedIdf) for { (res, exp) <- results.zip(expected) (x, y) <- res.value.toArray.zip(exp.toArray) } assert(math.abs(x - y) <= 1e-5) } private def scaleDataWithIDF(dataSet: Seq[Vector], model: Vector): Seq[Vector] = { dataSet.map { case data: DenseVector => val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y } Vectors.dense(res) case data: SparseVector => val res = data.indices.zip(data.values).map { case (id, value) => (id, value * model(id)) } Vectors.sparse(data.size, res) } } }