org.scalatest.ShouldMatchers Scala Examples
The following examples show how to use org.scalatest.ShouldMatchers.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: SQLRunnerSuite.scala From HANAVora-Extensions with Apache License 2.0 | 5 votes |
package com.sap.spark.cli import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} import org.apache.spark.SparkContext import org.apache.spark.sql.{GlobalSapSQLContext, SQLContext} import org.scalatest.{BeforeAndAfterEach, FunSuite, ShouldMatchers} // good call val goodOpts = SQLRunner.parseOpts(List("a.sql", "b.sql", "-o", "output.csv")) goodOpts.sqlFiles should be(List("a.sql", "b.sql")) goodOpts.output should be(Some("output.csv")) // bad call val badOpts = SQLRunner.parseOpts(List()) badOpts.sqlFiles should be(List()) badOpts.output should be(None) // ugly call val uglyOpts = SQLRunner.parseOpts(List("a.sql", "-o", "output.csv", "b.sql")) uglyOpts.sqlFiles should be(List("a.sql", "b.sql")) uglyOpts.output should be(Some("output.csv")) } def runSQLTest(input: String, expectedOutput: String): Unit = { val inputStream: InputStream = new ByteArrayInputStream(input.getBytes()) val outputStream = new ByteArrayOutputStream() SQLRunner.sql(inputStream, outputStream) val output = outputStream.toString output should be(expectedOutput) } test("can run dummy query") { val input = "SELECT 1;" val output = "1\n" runSQLTest(input, output) } test("can run multiple dummy queries") { val input = """ |SELECT 1;SELECT 2; |SELECT 3; """.stripMargin val output = "1\n2\n3\n" runSQLTest(input, output) } test("can run a basic example with tables") { val input = """ |SELECT * FROM DEMO_TABLE; |SELECT * FROM DEMO_TABLE LIMIT 1; |DROP TABLE DEMO_TABLE; """.stripMargin val output = "1,a\n2,b\n3,c\n1,a\n" runSQLTest(input, output) } test("can run an example with comments") { val input = """ |SELECT * FROM DEMO_TABLE; -- this is the first query |SELECT * FROM DEMO_TABLE LIMIT 1; |-- now let's drop a table |DROP TABLE DEMO_TABLE; """.stripMargin val output = "1,a\n2,b\n3,c\n1,a\n" runSQLTest(input, output) } }
Example 2
Source File: RawStageTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.driver.test.stage import akka.actor.ActorSystem import akka.testkit.TestKit import com.stratio.sparta.driver.stage.{LogError, RawDataStage} import com.stratio.sparta.sdk.pipeline.autoCalculations.AutoCalculatedField import com.stratio.sparta.sdk.properties.JsoneyString import com.stratio.sparta.serving.core.models.policy.writer.{AutoCalculatedFieldModel, WriterModel} import com.stratio.sparta.serving.core.models.policy.{PolicyModel, RawDataModel} import org.junit.runner.RunWith import org.mockito.Mockito.when import org.scalatest.junit.JUnitRunner import org.scalatest.mock.MockitoSugar import org.scalatest.{FlatSpecLike, ShouldMatchers} @RunWith(classOf[JUnitRunner]) class RawStageTest extends TestKit(ActorSystem("RawStageTest")) with FlatSpecLike with ShouldMatchers with MockitoSugar { case class TestRawData(policy: PolicyModel) extends RawDataStage with LogError def mockPolicy: PolicyModel = { val policy = mock[PolicyModel] when(policy.id).thenReturn(Some("id")) policy } "rawDataStage" should "Generate a raw data" in { val field = "field" val timeField = "time" val tableName = Some("table") val outputs = Seq("output") val partitionBy = Some("field") val autocalculateFields = Seq(AutoCalculatedFieldModel()) val configuration = Map.empty[String, JsoneyString] val policy = mockPolicy val rawData = mock[RawDataModel] val writerModel = mock[WriterModel] when(policy.rawData).thenReturn(Some(rawData)) when(rawData.dataField).thenReturn(field) when(rawData.timeField).thenReturn(timeField) when(rawData.writer).thenReturn(writerModel) when(writerModel.tableName).thenReturn(tableName) when(writerModel.outputs).thenReturn(outputs) when(writerModel.partitionBy).thenReturn(partitionBy) when(writerModel.autoCalculatedFields).thenReturn(autocalculateFields) when(rawData.configuration).thenReturn(configuration) val result = TestRawData(policy).rawDataStage() result.timeField should be(timeField) result.dataField should be(field) result.writerOptions.tableName should be(tableName) result.writerOptions.partitionBy should be(partitionBy) result.configuration should be(configuration) result.writerOptions.outputs should be(outputs) } "rawDataStage" should "Fail with bad table name" in { val field = "field" val timeField = "time" val tableName = None val outputs = Seq("output") val partitionBy = Some("field") val configuration = Map.empty[String, JsoneyString] val policy = mockPolicy val rawData = mock[RawDataModel] val writerModel = mock[WriterModel] when(policy.rawData).thenReturn(Some(rawData)) when(rawData.dataField).thenReturn(field) when(rawData.timeField).thenReturn(timeField) when(rawData.writer).thenReturn(writerModel) when(writerModel.tableName).thenReturn(tableName) when(writerModel.outputs).thenReturn(outputs) when(writerModel.partitionBy).thenReturn(partitionBy) when(rawData.configuration).thenReturn(configuration) the[IllegalArgumentException] thrownBy { TestRawData(policy).rawDataStage() } should have message "Something gone wrong saving the raw data. Please re-check the policy." } }
Example 3
Source File: NettyBlockTransferSecuritySuite.scala From SparkCore with Apache License 2.0 | 5 votes |
package org.apache.spark.network.netty import java.nio._ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} import scala.util.{Failure, Success, Try} import org.apache.commons.io.IOUtils import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.storage.{BlockId, ShuffleBlockId} import org.apache.spark.{SecurityManager, SparkConf} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers} class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { test("security default off") { val conf = new SparkConf() .set("spark.app.id", "app-id") testConnection(conf, conf) match { case Success(_) => // expected case Failure(t) => fail(t) } } test("security on same password") { val conf = new SparkConf() .set("spark.authenticate", "true") .set("spark.authenticate.secret", "good") .set("spark.app.id", "app-id") testConnection(conf, conf) match { case Success(_) => // expected case Failure(t) => fail(t) } } test("security on mismatch password") { val conf0 = new SparkConf() .set("spark.authenticate", "true") .set("spark.authenticate.secret", "good") .set("spark.app.id", "app-id") val conf1 = conf0.clone.set("spark.authenticate.secret", "bad") testConnection(conf0, conf1) match { case Success(_) => fail("Should have failed") case Failure(t) => t.getMessage should include ("Mismatched response") } } test("security mismatch auth off on server") { val conf0 = new SparkConf() .set("spark.authenticate", "true") .set("spark.authenticate.secret", "good") .set("spark.app.id", "app-id") val conf1 = conf0.clone.set("spark.authenticate", "false") testConnection(conf0, conf1) match { case Success(_) => fail("Should have failed") case Failure(t) => // any funny error may occur, sever will interpret SASL token as RPC } } test("security mismatch auth off on client") { val conf0 = new SparkConf() .set("spark.authenticate", "false") .set("spark.authenticate.secret", "good") .set("spark.app.id", "app-id") val conf1 = conf0.clone.set("spark.authenticate", "true") testConnection(conf0, conf1) match { case Success(_) => fail("Should have failed") case Failure(t) => t.getMessage should include ("Expected SaslMessage") } } private def fetchBlock( self: BlockTransferService, from: BlockTransferService, execId: String, blockId: BlockId): Try[ManagedBuffer] = { val promise = Promise[ManagedBuffer]() self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString), new BlockFetchingListener { override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { promise.failure(exception) } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { promise.success(data.retain()) } }) Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS)) promise.future.value.get } }
Example 4
Source File: AudienceAnalyticsSpec.scala From spark-hyperloglog with MIT License | 5 votes |
package com.collective.analytics import com.collective.analytics.schema.ImpressionLog import org.apache.spark.sql.Row import org.scalatest.{FlatSpec, ShouldMatchers} class SparkAudienceAnalyticsSpec extends AudienceAnalyticsSpec with EmbeddedSparkContext { def builder: Vector[Row] => AudienceAnalytics = log => new SparkAudienceAnalytics( new AggregateImpressionLog(sqlContext.createDataFrame(sc.parallelize(log), ImpressionLog.schema)) ) } class InMemoryAudienceAnalyticsSpec extends AudienceAnalyticsSpec { def builder: Vector[Row] => AudienceAnalytics = log => new InMemoryAudienceAnalytics(log.map(ImpressionLog.parse)) } abstract class AudienceAnalyticsSpec extends FlatSpec with ShouldMatchers with DataGenerator { def builder: Vector[Row] => AudienceAnalytics private val impressions = repeat(100, impressionRow("bmw", "forbes.com", 10L, 1L, Array("income:50000", "education:high-school", "interest:technology"))) ++ repeat(100, impressionRow("bmw", "forbes.com", 5L, 2L, Array("income:50000", "education:college", "interest:auto"))) ++ repeat(100, impressionRow("bmw", "auto.com", 7L, 0L, Array("income:100000", "education:high-school", "interest:auto"))) ++ repeat(100, impressionRow("audi", "cnn.com", 2L, 0L, Array("income:50000", "interest:audi", "education:high-school"))) //private val impressionLog = impressions.map(ImpressionLog.parse) private val analytics = builder(impressions) "InMemoryAudienceAnalytics" should "compute audience estimate" in { val bmwEstimate = analytics.audienceEstimate(Vector("bmw")) assert(bmwEstimate.cookiesHLL.size() == 3 * 100) assert(bmwEstimate.impressions == 22 * 100) assert(bmwEstimate.clicks == 3 * 100) val forbesEstimate = analytics.audienceEstimate(sites = Vector("forbes.com")) assert(forbesEstimate.cookiesHLL.size() == 2 * 100) assert(forbesEstimate.impressions == 15 * 100) assert(forbesEstimate.clicks == 3 * 100) } it should "compute segment estimate" in { val fiftyK = analytics.segmentsEstimate(Vector("income:50000")) assert(fiftyK.cookiesHLL.size() == 3 * 100) assert(fiftyK.impressions == 17 * 100) assert(fiftyK.clicks == 3 * 100) val highSchool = analytics.segmentsEstimate(Vector("education:high-school")) assert(highSchool.cookiesHLL.size() == 3 * 100) assert(highSchool.impressions == 19 * 100) assert(highSchool.clicks == 1 * 100) } it should "compute audience intersection" in { val bmwAudience = analytics.audienceEstimate(Vector("bmw")) val intersection = analytics.segmentsIntersection(bmwAudience).toMap assert(intersection.size == 7) assert(intersection("interest:audi") == 0) intersection("income:50000") should (be >= 180L and be <= 2020L) } }
Example 5
Source File: FillSuite.scala From spark-timeseries with Apache License 2.0 | 5 votes |
package com.cloudera.sparkts import scala.Double.NaN import com.cloudera.sparkts.UnivariateTimeSeries._ import org.scalatest.{FunSuite, ShouldMatchers} class FillSuite extends FunSuite with ShouldMatchers { ignore("nearest") { fillNearest(Array(1.0)) should be (Array(1.0)) fillNearest(Array(1.0, 1.0, 2.0)) should be (Array(1.0, 1.0, 2.0)) fillNearest(Array(1.0, NaN, NaN, 2.0)) should be (Array(1.0, 1.0, 2.0, 2.0)) // round down to previous fillNearest(Array(1.0, NaN, 2.0)) should be (Array(1.0, 1.0, 2.0)) fillNearest(Array(1.0, NaN, NaN, NaN, 2.0)) should be (Array(1.0, 1.0, 1.0, 2.0, 2.0)) fillNearest(Array(1.0, NaN, 3.0, NaN, 2.0)) should be (Array(1.0, 1.0, 3.0, 3.0, 2.0)) } test("previous") { fillPrevious(Array(1.0)) should be (Array(1.0)) fillPrevious(Array(1.0, 1.0, 2.0)) should be (Array(1.0, 1.0, 2.0)) fillPrevious(Array(1.0, NaN, 2.0)) should be (Array(1.0, 1.0, 2.0)) fillPrevious(Array(1.0, NaN, NaN, 2.0)) should be (Array(1.0, 1.0, 1.0, 2.0)) fillPrevious(Array(1.0, NaN, NaN, NaN, 2.0)) should be (Array(1.0, 1.0, 1.0, 1.0, 2.0)) fillPrevious(Array(1.0, NaN, 3.0, NaN, 2.0)) should be (Array(1.0, 1.0, 3.0, 3.0, 2.0)) } test("next") { fillNext(Array(1.0)) should be (Array(1.0)) fillNext(Array(1.0, 1.0, 2.0)) should be (Array(1.0, 1.0, 2.0)) fillNext(Array(1.0, NaN, 2.0)) should be (Array(1.0, 2.0, 2.0)) fillNext(Array(1.0, NaN, NaN, 2.0)) should be (Array(1.0, 2.0, 2.0, 2.0)) fillNext(Array(1.0, NaN, NaN, NaN, 2.0)) should be (Array(1.0, 2.0, 2.0, 2.0, 2.0)) fillNext(Array(1.0, NaN, 3.0, NaN, 2.0)) should be (Array(1.0, 3.0, 3.0, 2.0, 2.0)) } test("linear") { fillLinear(Array(1.0)) should be (Array(1.0)) fillLinear(Array(1.0, 1.0, 2.0)) should be (Array(1.0, 1.0, 2.0)) fillLinear(Array(1.0, NaN, 2.0)) should be (Array(1.0, 1.5, 2.0)) fillLinear(Array(2.0, NaN, 1.0)) should be (Array(2.0, 1.5, 1.0)) fillLinear(Array(1.0, NaN, NaN, 4.0)) should be (Array(1.0, 2.0, 3.0, 4.0)) fillLinear(Array(1.0, NaN, NaN, NaN, 5.0)) should be (Array(1.0, 2.0, 3.0, 4.0, 5.0)) fillLinear(Array(1.0, NaN, 3.0, NaN, 2.0)) should be (Array(1.0, 2.0, 3.0, 2.5, 2.0)) } }
Example 6
Source File: EWMASuite.scala From spark-timeseries with Apache License 2.0 | 5 votes |
package com.cloudera.sparkts.models import org.apache.spark.mllib.linalg._ import org.scalatest.{FunSuite, ShouldMatchers} class EWMASuite extends FunSuite with ShouldMatchers { test("adding time dependent effects") { val orig = new DenseVector((1 to 10).toArray.map(_.toDouble)) val m1 = new EWMAModel(0.2) val smoothed1 = new DenseVector(Array.fill(10)(0.0)) m1.addTimeDependentEffects(orig, smoothed1) smoothed1(0) should be (orig(0)) smoothed1(1) should be (m1.smoothing * orig(1) + (1 - m1.smoothing) * smoothed1(0)) round2Dec(smoothed1.toArray.last) should be (6.54) val m2 = new EWMAModel(0.6) val smoothed2 = new DenseVector(Array.fill(10)(0.0)) m2.addTimeDependentEffects(orig, smoothed2) smoothed2(0) should be (orig(0)) smoothed2(1) should be (m2.smoothing * orig(1) + (1 - m2.smoothing) * smoothed2(0)) round2Dec(smoothed2.toArray.last) should be (9.33) } test("removing time dependent effects") { val smoothed = new DenseVector(Array(1.0, 1.2, 1.56, 2.05, 2.64, 3.31, 4.05, 4.84, 5.67, 6.54)) val m1 = new EWMAModel(0.2) val orig1 = new DenseVector(Array.fill(10)(0.0)) m1.removeTimeDependentEffects(smoothed, orig1) round2Dec(orig1(0)) should be (1.0) orig1.toArray.last.toInt should be(10) } test("fitting EWMA model") { // We reproduce the example in ch 7.1 from // https://www.otexts.org/fpp/7/1 val oil = Array(446.7, 454.5, 455.7, 423.6, 456.3, 440.6, 425.3, 485.1, 506.0, 526.8, 514.3, 494.2) val model = EWMA.fitModel(new DenseVector(oil)) val truncatedSmoothing = (model.smoothing * 100.0).toInt truncatedSmoothing should be (89) // approximately 0.89 } private def round2Dec(x: Double): Double = { (x * 100).round / 100.00 } }
Example 7
Source File: RegressionARIMASuite.scala From spark-timeseries with Apache License 2.0 | 5 votes |
package com.cloudera.sparkts.models import breeze.linalg import breeze.linalg.DenseMatrix import org.scalatest.{FunSuite, ShouldMatchers} class RegressionARIMASuite extends FunSuite with ShouldMatchers { test("Cochrane-Orcutt-Stock-Data") { val expenditure = Array(214.6, 217.7, 219.6, 227.2, 230.9, 233.3, 234.1, 232.3, 233.7, 236.5, 238.7, 243.2, 249.4, 254.3, 260.9, 263.3, 265.6, 268.2, 270.4, 275.6) val stock = Array(159.3, 161.2, 162.8, 164.6, 165.9, 167.9, 168.3, 169.7, 170.5, 171.6, 173.9, 176.1, 178.0, 179.1, 180.2, 181.2, 181.6, 182.5, 183.3, 184.3) val Y = linalg.DenseVector(expenditure) val regressors = new DenseMatrix[Double](stock.length, 1) regressors(::, 0) := linalg.DenseVector(stock) val regARIMA = RegressionARIMA.fitCochraneOrcutt(Y, regressors, 11) val beta = regARIMA.regressionCoeff val rho = regARIMA.arimaCoeff(0) rho should equal(0.8241 +- 0.001) beta(0) should equal(-235.4889 +- 0.1) beta(1) should equal(2.75306 +- 0.001) } }
Example 8
Source File: DateTimeIndexUtilsSuite.scala From spark-timeseries with Apache License 2.0 | 5 votes |
package com.cloudera.sparkts import java.time.{ZonedDateTime, ZoneId} import com.cloudera.sparkts.DateTimeIndex._ import org.scalatest.{FunSuite, ShouldMatchers} class DateTimeIndexUtilsSuite extends FunSuite with ShouldMatchers { val UTC = ZoneId.of("Z") test("non-overlapping sorted") { val index1: DateTimeIndex = uniform(dt("2015-04-10"), 5, new DayFrequency(2), UTC) val index2: DateTimeIndex = uniform(dt("2015-05-10"), 5, new DayFrequency(2), UTC) val index3: DateTimeIndex = irregular(Array( dt("2015-06-10"), dt("2015-06-13"), dt("2015-06-15"), dt("2015-06-20"), dt("2015-06-25") ), UTC) DateTimeIndexUtils.union(Array(index1, index2, index3), UTC) should be ( hybrid(Array(index1, index2, index3))) } test("non-overlapping non-sorted") { val index1: DateTimeIndex = uniform(dt("2015-04-10"), 5, new DayFrequency(2), UTC) val index2: DateTimeIndex = uniform(dt("2015-05-10"), 5, new DayFrequency(2), UTC) val index3: DateTimeIndex = irregular(Array( dt("2015-06-10"), dt("2015-06-13"), dt("2015-06-15"), dt("2015-06-20"), dt("2015-06-25") ), UTC) DateTimeIndexUtils.union(Array(index3, index1, index2), UTC) should be ( hybrid(Array(index1, index2, index3))) } test("overlapping uniform and irregular") { val index1: DateTimeIndex = uniform(dt("2015-04-10"), 5, new DayFrequency(2), UTC) val index2: DateTimeIndex = uniform(dt("2015-05-10"), 5, new DayFrequency(2), UTC) val index3: DateTimeIndex = irregular(Array( dt("2015-04-09"), dt("2015-04-11"), dt("2015-05-01"), dt("2015-05-10"), dt("2015-06-25") ), UTC) DateTimeIndexUtils.union(Array(index3, index1, index2), UTC) should be ( hybrid(Array( irregular(Array( dt("2015-04-09"), dt("2015-04-10"), dt("2015-04-11")), UTC), uniform(dt("2015-04-12"), 4, new DayFrequency(2), UTC), irregular(Array(dt("2015-05-01"), dt("2015-05-10")), UTC), uniform(dt("2015-05-12"), 4, new DayFrequency(2), UTC), irregular(Array(dt("2015-06-25")), UTC) ))) } def dt(dt: String, zone: ZoneId = UTC): ZonedDateTime = { val splits = dt.split("-").map(_.toInt) ZonedDateTime.of(splits(0), splits(1), splits(2), 0, 0, 0, 0, zone) } }
Example 9
Source File: BinnerSpec.scala From modelmatrix with Apache License 2.0 | 5 votes |
package com.collective.modelmatrix.transform import org.scalatest.{ShouldMatchers, FlatSpec} class BinnerSpec extends FlatSpec with ShouldMatchers with Binner { "Binner" should "get from diff to original values" in { val diff = Seq(0.1, 0.21, 0.05, 0.5) assert(fromDiff(diff) == Seq(0.1, 0.31, 0.36, 0.86)) } it should "get diff from original values" in { val values = Seq(0.1, 0.31, 0.37, 0.88) assert(toDiff(values) == Seq(0.1, 0.21, 0.06, 0.51)) } it should "calculate perfect split of 10 bins" in { val x = (0 to 100).toArray.map(_.toDouble) val split = optimalSplit(x, 10, 0, 0) split.foreach { s => s.count should be ((x.length / 10) +- 5) } assert(split.length == 10) assert(split.map(_.count).sum == x.length) } it should "calculate perfect split of 2 bins" in { val x = (0 to 100).toArray.map(_.toDouble) val split = optimalSplit(x, 2, 0, 0) split.foreach { s => s.count should be ((x.length / 2) +- 5) } assert(split.length == 2) assert(split.map(_.count).sum == x.length) } it should "calculate perfect split of 3 bins" in { val x = (0 to 100).toArray.map(_.toDouble) val split = optimalSplit(x, 3, 0, 0) split.foreach { s => s.count should be ((x.length / 3) +- 5) } assert(split.length == 3) assert(split.map(_.count).sum == x.length) } it should "calculate perfect split for highly skewed data" in { // R: x <- exp(rnorm(1000)) // Heavy right skewed data val g = breeze.stats.distributions.Gaussian(0, 1) val skewed = g.sample(1000).map(d => math.exp(d)).toArray val split = optimalSplit(skewed, 10, 0, 0) split.foreach { s => s.count should be((skewed.length / 10) +- 5) } assert(split.length == 10) assert(split.map(_.count).sum == skewed.length) } }
Example 10
Source File: PlySuite.scala From spark-iqmulus with Apache License 2.0 | 5 votes |
package fr.ign.spark.iqmulus.ply import org.scalatest.FunSuite import org.scalatest.ShouldMatchers import org.apache.spark.sql.types._ class PlySuite extends FunSuite with ShouldMatchers { val id = Array("fid" -> IntegerType, "pid" -> LongType) val xyz = Array("x" -> FloatType, "y" -> FloatType, "z" -> FloatType) val rgb = Array("r" -> ByteType, "g" -> ByteType, "b" -> ByteType) val files = Seq( ("trepied_xyz.ply", 5995, id ++ xyz) // , // ("trepied_dim.ply", 5995, id ++ xyz ++ rgb), // ("trepied_dim2.ply", 5995, id ++ xyz ++ rgb), // ("213-232-7.ply", 71651, id ++ xyz ++ rgb) ) val resources = "src/test/resources" files foreach { case (file, count, fields) => if (new java.io.File(s"$resources/$file").exists) { test(s"$file should read the correct header metadata") { val header = PlyHeader.read(s"$resources/$file"); header.section("vertex").count should equal(count) } test(s"$file should have the correct schema") { val header = PlyHeader.read(s"$resources/$file"); header.section("vertex").schema should equal(StructType(fields map { case (name, dataType) => StructField(name, dataType, nullable = false) })) } } } }
Example 11
Source File: UtilTests.scala From sparkplug with MIT License | 5 votes |
package springnz.sparkplug import springnz.sparkplug.util.SerializeUtils._ import org.scalatest.{ ShouldMatchers, WordSpec } case class TestObject(a: String, b: Int, c: Vector[String], d: List[Int]) class UtilTests extends WordSpec with ShouldMatchers { "serialise utils" should { "serialise and deserialise a local object" in { val testObject = TestObject("hello", 42, Vector("test", "array"), List(42, 108)) val byteArray = serialize(testObject) val inflatedObject = deserialize[TestObject](byteArray, this.getClass.getClassLoader) inflatedObject should equal(testObject) } "serialise and deserialise a local object with its class loader" in { val testObject = TestObject("hello", 42, Vector("test", "array"), List(42, 108)) val byteArray = serialize(testObject) val inflatedObject = deserialize[TestObject](byteArray, TestObject.getClass.getClassLoader) inflatedObject should equal(testObject) } "serialise and deserialise a local object with default class loader" in { val testObject = TestObject("hello", 42, Vector("test", "array"), List(42, 108)) val byteArray = serialize(testObject) val inflatedObject = deserialize[TestObject](byteArray) inflatedObject should equal(testObject) } } }
Example 12
Source File: SecondaryPairDCFunctionsTest.scala From spark-flow with Apache License 2.0 | 5 votes |
package com.bloomberg.sparkflow.dc import com.bloomberg.sparkflow._ import com.holdenkarau.spark.testing.SharedSparkContext import org.scalatest.{ShouldMatchers, FunSuite} class SecondaryPairDCFunctionsTest extends FunSuite with SharedSparkContext with ShouldMatchers { test("testRepartAndSort") { val input = parallelize(Seq( (("a",3), 0), (("b",2), 0), (("b",1), 0), (("b",3), 0), (("a",2), 0), (("a",1), 0))) val sortAndRepart = input.repartitionAndSecondarySortWithinPartitions(2) val result = sortAndRepart.mapPartitions(it => Iterator(it.toList)) val expected = Seq( List( (("a",1), 0), (("a",2), 0), (("a",3), 0)), List( (("b",1), 0), (("b",2), 0), (("b",3), 0))) expected should contain theSameElementsAs result.getRDD(sc).collect() } }