org.scalatest.ShouldMatchers Scala Examples

The following examples show how to use org.scalatest.ShouldMatchers. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: SQLRunnerSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package com.sap.spark.cli

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}

import org.apache.spark.SparkContext
import org.apache.spark.sql.{GlobalSapSQLContext, SQLContext}
import org.scalatest.{BeforeAndAfterEach, FunSuite, ShouldMatchers}



    // good call
    val goodOpts =
      SQLRunner.parseOpts(List("a.sql", "b.sql", "-o", "output.csv"))

    goodOpts.sqlFiles should be(List("a.sql", "b.sql"))
    goodOpts.output should be(Some("output.csv"))

    // bad call
    val badOpts = SQLRunner.parseOpts(List())

    badOpts.sqlFiles should be(List())
    badOpts.output should be(None)

    // ugly call
    val uglyOpts =
      SQLRunner.parseOpts(List("a.sql", "-o", "output.csv", "b.sql"))

    uglyOpts.sqlFiles should be(List("a.sql", "b.sql"))
    uglyOpts.output should be(Some("output.csv"))
  }

  def runSQLTest(input: String, expectedOutput: String): Unit = {
    val inputStream: InputStream = new ByteArrayInputStream(input.getBytes())
    val outputStream = new ByteArrayOutputStream()

    SQLRunner.sql(inputStream, outputStream)

    val output = outputStream.toString
    output should be(expectedOutput)
  }

  test("can run dummy query") {
    val input = "SELECT 1;"
    val output = "1\n"

    runSQLTest(input, output)
  }

  test("can run multiple dummy queries") {
    val input = """
        |SELECT 1;SELECT 2;
        |SELECT 3;
      """.stripMargin

    val output = "1\n2\n3\n"

    runSQLTest(input, output)
  }

  test("can run a basic example with tables") {
    val input = """
                  |SELECT * FROM DEMO_TABLE;
                  |SELECT * FROM DEMO_TABLE LIMIT 1;
                  |DROP TABLE DEMO_TABLE;
                """.stripMargin

    val output = "1,a\n2,b\n3,c\n1,a\n"

    runSQLTest(input, output)
  }

  test("can run an example with comments") {
    val input = """
                  |SELECT * FROM DEMO_TABLE; -- this is the first query
                  |SELECT * FROM DEMO_TABLE LIMIT 1;
                  |-- now let's drop a table
                  |DROP TABLE DEMO_TABLE;
                """.stripMargin

    val output = "1,a\n2,b\n3,c\n1,a\n"

    runSQLTest(input, output)
  }
} 
Example 2
Source File: RawStageTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.test.stage

import akka.actor.ActorSystem
import akka.testkit.TestKit
import com.stratio.sparta.driver.stage.{LogError, RawDataStage}
import com.stratio.sparta.sdk.pipeline.autoCalculations.AutoCalculatedField
import com.stratio.sparta.sdk.properties.JsoneyString
import com.stratio.sparta.serving.core.models.policy.writer.{AutoCalculatedFieldModel, WriterModel}
import com.stratio.sparta.serving.core.models.policy.{PolicyModel, RawDataModel}
import org.junit.runner.RunWith
import org.mockito.Mockito.when
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar
import org.scalatest.{FlatSpecLike, ShouldMatchers}

@RunWith(classOf[JUnitRunner])
class RawStageTest
  extends TestKit(ActorSystem("RawStageTest"))
    with FlatSpecLike with ShouldMatchers with MockitoSugar {

  case class TestRawData(policy: PolicyModel) extends RawDataStage with LogError

  def mockPolicy: PolicyModel = {
    val policy = mock[PolicyModel]
    when(policy.id).thenReturn(Some("id"))
    policy
  }

  "rawDataStage" should "Generate a raw data" in {
    val field = "field"
    val timeField = "time"
    val tableName = Some("table")
    val outputs = Seq("output")
    val partitionBy = Some("field")
    val autocalculateFields = Seq(AutoCalculatedFieldModel())
    val configuration = Map.empty[String, JsoneyString]

    val policy = mockPolicy
    val rawData = mock[RawDataModel]
    val writerModel = mock[WriterModel]

    when(policy.rawData).thenReturn(Some(rawData))
    when(rawData.dataField).thenReturn(field)
    when(rawData.timeField).thenReturn(timeField)
    when(rawData.writer).thenReturn(writerModel)
    when(writerModel.tableName).thenReturn(tableName)
    when(writerModel.outputs).thenReturn(outputs)
    when(writerModel.partitionBy).thenReturn(partitionBy)
    when(writerModel.autoCalculatedFields).thenReturn(autocalculateFields)
    when(rawData.configuration).thenReturn(configuration)

    val result = TestRawData(policy).rawDataStage()

    result.timeField should be(timeField)
    result.dataField should be(field)
    result.writerOptions.tableName should be(tableName)
    result.writerOptions.partitionBy should be(partitionBy)
    result.configuration should be(configuration)
    result.writerOptions.outputs should be(outputs)
  }

  "rawDataStage" should "Fail with bad table name" in {
    val field = "field"
    val timeField = "time"
    val tableName = None
    val outputs = Seq("output")
    val partitionBy = Some("field")
    val configuration = Map.empty[String, JsoneyString]

    val policy = mockPolicy
    val rawData = mock[RawDataModel]
    val writerModel = mock[WriterModel]

    when(policy.rawData).thenReturn(Some(rawData))
    when(rawData.dataField).thenReturn(field)
    when(rawData.timeField).thenReturn(timeField)
    when(rawData.writer).thenReturn(writerModel)
    when(writerModel.tableName).thenReturn(tableName)
    when(writerModel.outputs).thenReturn(outputs)
    when(writerModel.partitionBy).thenReturn(partitionBy)
    when(rawData.configuration).thenReturn(configuration)


    the[IllegalArgumentException] thrownBy {
      TestRawData(policy).rawDataStage()
    } should have message "Something gone wrong saving the raw data. Please re-check the policy."
  }

} 
Example 3
Source File: NettyBlockTransferSecuritySuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import java.nio._
import java.util.concurrent.TimeUnit

import scala.concurrent.duration._
import scala.concurrent.{Await, Promise}
import scala.util.{Failure, Success, Try}

import org.apache.commons.io.IOUtils
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.network.{BlockDataManager, BlockTransferService}
import org.apache.spark.storage.{BlockId, ShuffleBlockId}
import org.apache.spark.{SecurityManager, SparkConf}
import org.mockito.Mockito._
import org.scalatest.mock.MockitoSugar
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers}

class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers {
  test("security default off") {
    val conf = new SparkConf()
      .set("spark.app.id", "app-id")
    testConnection(conf, conf) match {
      case Success(_) => // expected
      case Failure(t) => fail(t)
    }
  }

  test("security on same password") {
    val conf = new SparkConf()
      .set("spark.authenticate", "true")
      .set("spark.authenticate.secret", "good")
      .set("spark.app.id", "app-id")
    testConnection(conf, conf) match {
      case Success(_) => // expected
      case Failure(t) => fail(t)
    }
  }

  test("security on mismatch password") {
    val conf0 = new SparkConf()
      .set("spark.authenticate", "true")
      .set("spark.authenticate.secret", "good")
      .set("spark.app.id", "app-id")
    val conf1 = conf0.clone.set("spark.authenticate.secret", "bad")
    testConnection(conf0, conf1) match {
      case Success(_) => fail("Should have failed")
      case Failure(t) => t.getMessage should include ("Mismatched response")
    }
  }

  test("security mismatch auth off on server") {
    val conf0 = new SparkConf()
      .set("spark.authenticate", "true")
      .set("spark.authenticate.secret", "good")
      .set("spark.app.id", "app-id")
    val conf1 = conf0.clone.set("spark.authenticate", "false")
    testConnection(conf0, conf1) match {
      case Success(_) => fail("Should have failed")
      case Failure(t) => // any funny error may occur, sever will interpret SASL token as RPC
    }
  }

  test("security mismatch auth off on client") {
    val conf0 = new SparkConf()
      .set("spark.authenticate", "false")
      .set("spark.authenticate.secret", "good")
      .set("spark.app.id", "app-id")
    val conf1 = conf0.clone.set("spark.authenticate", "true")
    testConnection(conf0, conf1) match {
      case Success(_) => fail("Should have failed")
      case Failure(t) => t.getMessage should include ("Expected SaslMessage")
    }
  }

  
  private def fetchBlock(
      self: BlockTransferService,
      from: BlockTransferService,
      execId: String,
      blockId: BlockId): Try[ManagedBuffer] = {

    val promise = Promise[ManagedBuffer]()

    self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString),
      new BlockFetchingListener {
        override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
          promise.failure(exception)
        }

        override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
          promise.success(data.retain())
        }
      })

    Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS))
    promise.future.value.get
  }
} 
Example 4
Source File: AudienceAnalyticsSpec.scala    From spark-hyperloglog   with MIT License 5 votes vote down vote up
package com.collective.analytics

import com.collective.analytics.schema.ImpressionLog
import org.apache.spark.sql.Row
import org.scalatest.{FlatSpec, ShouldMatchers}


class SparkAudienceAnalyticsSpec extends AudienceAnalyticsSpec with EmbeddedSparkContext {
  def builder: Vector[Row] => AudienceAnalytics =
    log => new SparkAudienceAnalytics(
      new AggregateImpressionLog(sqlContext.createDataFrame(sc.parallelize(log), ImpressionLog.schema))
    )
}

class InMemoryAudienceAnalyticsSpec extends AudienceAnalyticsSpec {
  def builder: Vector[Row] => AudienceAnalytics =
    log => new InMemoryAudienceAnalytics(log.map(ImpressionLog.parse))
}

abstract class AudienceAnalyticsSpec extends FlatSpec with ShouldMatchers with DataGenerator {

  def builder: Vector[Row] => AudienceAnalytics

  private val impressions =
    repeat(100, impressionRow("bmw", "forbes.com", 10L, 1L, Array("income:50000", "education:high-school", "interest:technology"))) ++
    repeat(100, impressionRow("bmw", "forbes.com", 5L, 2L, Array("income:50000", "education:college", "interest:auto"))) ++
    repeat(100, impressionRow("bmw", "auto.com", 7L, 0L, Array("income:100000", "education:high-school", "interest:auto"))) ++
    repeat(100, impressionRow("audi", "cnn.com", 2L, 0L, Array("income:50000", "interest:audi", "education:high-school")))

  //private val impressionLog = impressions.map(ImpressionLog.parse)

  private val analytics = builder(impressions)

  "InMemoryAudienceAnalytics" should "compute audience estimate" in {
    val bmwEstimate = analytics.audienceEstimate(Vector("bmw"))
    assert(bmwEstimate.cookiesHLL.size() == 3 * 100)
    assert(bmwEstimate.impressions == 22 * 100)
    assert(bmwEstimate.clicks == 3 * 100)

    val forbesEstimate = analytics.audienceEstimate(sites = Vector("forbes.com"))
    assert(forbesEstimate.cookiesHLL.size() == 2 * 100)
    assert(forbesEstimate.impressions == 15 * 100)
    assert(forbesEstimate.clicks == 3 * 100)
  }

  it should "compute segment estimate" in {
    val fiftyK = analytics.segmentsEstimate(Vector("income:50000"))
    assert(fiftyK.cookiesHLL.size() == 3 * 100)
    assert(fiftyK.impressions == 17 * 100)
    assert(fiftyK.clicks == 3 * 100)

    val highSchool = analytics.segmentsEstimate(Vector("education:high-school"))
    assert(highSchool.cookiesHLL.size() == 3 * 100)
    assert(highSchool.impressions == 19 * 100)
    assert(highSchool.clicks == 1 * 100)
  }

  it should "compute audience intersection" in {
    val bmwAudience = analytics.audienceEstimate(Vector("bmw"))
    val intersection = analytics.segmentsIntersection(bmwAudience).toMap

    assert(intersection.size == 7)
    assert(intersection("interest:audi") == 0)
    intersection("income:50000") should (be >= 180L and be <= 2020L)
  }

} 
Example 5
Source File: FillSuite.scala    From spark-timeseries   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.sparkts

import scala.Double.NaN

import com.cloudera.sparkts.UnivariateTimeSeries._

import org.scalatest.{FunSuite, ShouldMatchers}

class FillSuite extends FunSuite with ShouldMatchers {
  ignore("nearest") {
    fillNearest(Array(1.0)) should be (Array(1.0))
    fillNearest(Array(1.0, 1.0, 2.0)) should be (Array(1.0, 1.0, 2.0))
    fillNearest(Array(1.0, NaN, NaN, 2.0)) should be (Array(1.0, 1.0, 2.0, 2.0))
    // round down to previous
    fillNearest(Array(1.0, NaN, 2.0)) should be (Array(1.0, 1.0, 2.0))
    fillNearest(Array(1.0, NaN, NaN, NaN, 2.0)) should be (Array(1.0, 1.0, 1.0, 2.0, 2.0))
    fillNearest(Array(1.0, NaN, 3.0, NaN, 2.0)) should be (Array(1.0, 1.0, 3.0, 3.0, 2.0))
  }

  test("previous") {
    fillPrevious(Array(1.0)) should be (Array(1.0))
    fillPrevious(Array(1.0, 1.0, 2.0)) should be (Array(1.0, 1.0, 2.0))
    fillPrevious(Array(1.0, NaN, 2.0)) should be (Array(1.0, 1.0, 2.0))
    fillPrevious(Array(1.0, NaN, NaN, 2.0)) should be (Array(1.0, 1.0, 1.0, 2.0))
    fillPrevious(Array(1.0, NaN, NaN, NaN, 2.0)) should be (Array(1.0, 1.0, 1.0, 1.0, 2.0))
    fillPrevious(Array(1.0, NaN, 3.0, NaN, 2.0)) should be (Array(1.0, 1.0, 3.0, 3.0, 2.0))
  }

  test("next") {
    fillNext(Array(1.0)) should be (Array(1.0))
    fillNext(Array(1.0, 1.0, 2.0)) should be (Array(1.0, 1.0, 2.0))
    fillNext(Array(1.0, NaN, 2.0)) should be (Array(1.0, 2.0, 2.0))
    fillNext(Array(1.0, NaN, NaN, 2.0)) should be (Array(1.0, 2.0, 2.0, 2.0))
    fillNext(Array(1.0, NaN, NaN, NaN, 2.0)) should be (Array(1.0, 2.0, 2.0, 2.0, 2.0))
    fillNext(Array(1.0, NaN, 3.0, NaN, 2.0)) should be (Array(1.0, 3.0, 3.0, 2.0, 2.0))
  }

  test("linear") {
    fillLinear(Array(1.0)) should be (Array(1.0))
    fillLinear(Array(1.0, 1.0, 2.0)) should be (Array(1.0, 1.0, 2.0))
    fillLinear(Array(1.0, NaN, 2.0)) should be (Array(1.0, 1.5, 2.0))
    fillLinear(Array(2.0, NaN, 1.0)) should be (Array(2.0, 1.5, 1.0))
    fillLinear(Array(1.0, NaN, NaN, 4.0)) should be (Array(1.0, 2.0, 3.0, 4.0))
    fillLinear(Array(1.0, NaN, NaN, NaN, 5.0)) should be (Array(1.0, 2.0, 3.0, 4.0, 5.0))
    fillLinear(Array(1.0, NaN, 3.0, NaN, 2.0)) should be (Array(1.0, 2.0, 3.0, 2.5, 2.0))
  }
} 
Example 6
Source File: EWMASuite.scala    From spark-timeseries   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.sparkts.models

import org.apache.spark.mllib.linalg._
import org.scalatest.{FunSuite, ShouldMatchers}

class EWMASuite extends FunSuite with ShouldMatchers {
  test("adding time dependent effects") {
    val orig = new DenseVector((1 to 10).toArray.map(_.toDouble))

    val m1 = new EWMAModel(0.2)
    val smoothed1 = new DenseVector(Array.fill(10)(0.0))
    m1.addTimeDependentEffects(orig, smoothed1)

    smoothed1(0) should be (orig(0))
    smoothed1(1) should be (m1.smoothing * orig(1) + (1 - m1.smoothing) * smoothed1(0))
    round2Dec(smoothed1.toArray.last) should be (6.54)

    val m2 = new EWMAModel(0.6)
    val smoothed2 = new DenseVector(Array.fill(10)(0.0))
    m2.addTimeDependentEffects(orig, smoothed2)

    smoothed2(0) should be (orig(0))
    smoothed2(1) should be (m2.smoothing * orig(1) + (1 - m2.smoothing) * smoothed2(0))
    round2Dec(smoothed2.toArray.last) should be (9.33)
  }

  test("removing time dependent effects") {
    val smoothed = new DenseVector(Array(1.0, 1.2, 1.56, 2.05, 2.64, 3.31, 4.05, 4.84, 5.67, 6.54))

    val m1 = new EWMAModel(0.2)
    val orig1 = new DenseVector(Array.fill(10)(0.0))
    m1.removeTimeDependentEffects(smoothed, orig1)

    round2Dec(orig1(0)) should be (1.0)
    orig1.toArray.last.toInt should be(10)
  }

  test("fitting EWMA model") {
    // We reproduce the example in ch 7.1 from
    // https://www.otexts.org/fpp/7/1
    val oil = Array(446.7, 454.5, 455.7, 423.6, 456.3, 440.6, 425.3, 485.1, 506.0, 526.8,
      514.3, 494.2)
    val model =  EWMA.fitModel(new DenseVector(oil))
    val truncatedSmoothing = (model.smoothing * 100.0).toInt
    truncatedSmoothing should be (89) // approximately 0.89
  }

  private def round2Dec(x: Double): Double = {
    (x * 100).round / 100.00
  }
} 
Example 7
Source File: RegressionARIMASuite.scala    From spark-timeseries   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.sparkts.models

import breeze.linalg
import breeze.linalg.DenseMatrix
import org.scalatest.{FunSuite, ShouldMatchers}

class RegressionARIMASuite extends FunSuite with ShouldMatchers {
  
  test("Cochrane-Orcutt-Stock-Data") {
    val expenditure = Array(214.6, 217.7, 219.6, 227.2, 230.9, 233.3, 234.1, 232.3, 233.7, 236.5,
      238.7, 243.2, 249.4, 254.3, 260.9, 263.3, 265.6, 268.2, 270.4, 275.6)

    val stock = Array(159.3, 161.2, 162.8, 164.6, 165.9, 167.9, 168.3, 169.7, 170.5, 171.6, 173.9,
      176.1, 178.0, 179.1, 180.2, 181.2, 181.6, 182.5, 183.3, 184.3)
    val Y = linalg.DenseVector(expenditure)
    val regressors = new DenseMatrix[Double](stock.length, 1)

    regressors(::, 0) := linalg.DenseVector(stock)
    val regARIMA = RegressionARIMA.fitCochraneOrcutt(Y, regressors, 11)
    val beta = regARIMA.regressionCoeff
    val rho = regARIMA.arimaCoeff(0)
    rho should equal(0.8241 +- 0.001)
    beta(0) should equal(-235.4889 +- 0.1)
    beta(1) should equal(2.75306 +- 0.001)
  }
} 
Example 8
Source File: DateTimeIndexUtilsSuite.scala    From spark-timeseries   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.sparkts

import java.time.{ZonedDateTime, ZoneId}

import com.cloudera.sparkts.DateTimeIndex._
import org.scalatest.{FunSuite, ShouldMatchers}

class DateTimeIndexUtilsSuite extends FunSuite with ShouldMatchers {
  val UTC = ZoneId.of("Z")

  test("non-overlapping sorted") {
    val index1: DateTimeIndex = uniform(dt("2015-04-10"), 5, new DayFrequency(2), UTC)
    val index2: DateTimeIndex = uniform(dt("2015-05-10"), 5, new DayFrequency(2), UTC)
    val index3: DateTimeIndex = irregular(Array(
      dt("2015-06-10"),
      dt("2015-06-13"),
      dt("2015-06-15"),
      dt("2015-06-20"),
      dt("2015-06-25")
    ), UTC)

    DateTimeIndexUtils.union(Array(index1, index2, index3), UTC) should be (
      hybrid(Array(index1, index2, index3)))
  }

  test("non-overlapping non-sorted") {
    val index1: DateTimeIndex = uniform(dt("2015-04-10"), 5, new DayFrequency(2), UTC)
    val index2: DateTimeIndex = uniform(dt("2015-05-10"), 5, new DayFrequency(2), UTC)
    val index3: DateTimeIndex = irregular(Array(
      dt("2015-06-10"),
      dt("2015-06-13"),
      dt("2015-06-15"),
      dt("2015-06-20"),
      dt("2015-06-25")
    ), UTC)

    DateTimeIndexUtils.union(Array(index3, index1, index2), UTC) should be (
      hybrid(Array(index1, index2, index3)))
  }

  test("overlapping uniform and irregular") {
    val index1: DateTimeIndex = uniform(dt("2015-04-10"), 5, new DayFrequency(2), UTC)
    val index2: DateTimeIndex = uniform(dt("2015-05-10"), 5, new DayFrequency(2), UTC)
    val index3: DateTimeIndex = irregular(Array(
      dt("2015-04-09"),
      dt("2015-04-11"),
      dt("2015-05-01"),
      dt("2015-05-10"),
      dt("2015-06-25")
    ), UTC)

    DateTimeIndexUtils.union(Array(index3, index1, index2), UTC) should be (
      hybrid(Array(
        irregular(Array(
          dt("2015-04-09"),
          dt("2015-04-10"),
          dt("2015-04-11")), UTC),
        uniform(dt("2015-04-12"), 4, new DayFrequency(2), UTC),
        irregular(Array(dt("2015-05-01"),
          dt("2015-05-10")), UTC),
        uniform(dt("2015-05-12"), 4, new DayFrequency(2), UTC),
        irregular(Array(dt("2015-06-25")), UTC)
      )))
  }

  def dt(dt: String, zone: ZoneId = UTC): ZonedDateTime = {
    val splits = dt.split("-").map(_.toInt)
    ZonedDateTime.of(splits(0), splits(1), splits(2), 0, 0, 0, 0, zone)
  }
} 
Example 9
Source File: BinnerSpec.scala    From modelmatrix   with Apache License 2.0 5 votes vote down vote up
package com.collective.modelmatrix.transform

import org.scalatest.{ShouldMatchers, FlatSpec}

class BinnerSpec extends FlatSpec with ShouldMatchers with Binner {

  "Binner" should "get from diff to original values" in {
    val diff = Seq(0.1, 0.21, 0.05, 0.5)
    assert(fromDiff(diff) == Seq(0.1, 0.31, 0.36, 0.86))
  }

  it should "get diff from original values" in {
    val values = Seq(0.1, 0.31, 0.37, 0.88)
    assert(toDiff(values) == Seq(0.1, 0.21, 0.06, 0.51))
  }

  it should "calculate perfect split of 10 bins" in {
    val x = (0 to 100).toArray.map(_.toDouble)

    val split = optimalSplit(x, 10, 0, 0)
    split.foreach { s =>
     s.count should be ((x.length / 10) +- 5)
    }
    assert(split.length == 10)
    assert(split.map(_.count).sum == x.length)
  }

  it should "calculate perfect split of 2 bins" in {
    val x = (0 to 100).toArray.map(_.toDouble)

    val split = optimalSplit(x, 2, 0, 0)
    split.foreach { s =>
      s.count should be ((x.length / 2) +- 5)
    }
    assert(split.length == 2)
    assert(split.map(_.count).sum == x.length)
  }

  it should "calculate perfect split of 3 bins" in {
    val x = (0 to 100).toArray.map(_.toDouble)

    val split = optimalSplit(x, 3, 0, 0)
    split.foreach { s =>
      s.count should be ((x.length / 3) +- 5)
    }
    assert(split.length == 3)
    assert(split.map(_.count).sum == x.length)
  }


  it should "calculate perfect split for highly skewed data" in {

    // R: x <- exp(rnorm(1000))

    // Heavy right skewed data
    val g = breeze.stats.distributions.Gaussian(0, 1)
    val skewed = g.sample(1000).map(d => math.exp(d)).toArray

    val split = optimalSplit(skewed, 10, 0, 0)
    split.foreach { s =>
      s.count should be((skewed.length / 10) +- 5)
    }
    assert(split.length == 10)
    assert(split.map(_.count).sum == skewed.length)
  }


} 
Example 10
Source File: PlySuite.scala    From spark-iqmulus   with Apache License 2.0 5 votes vote down vote up
package fr.ign.spark.iqmulus.ply

import org.scalatest.FunSuite
import org.scalatest.ShouldMatchers
import org.apache.spark.sql.types._

class PlySuite extends FunSuite with ShouldMatchers {

  val id = Array("fid" -> IntegerType, "pid" -> LongType)
  val xyz = Array("x" -> FloatType, "y" -> FloatType, "z" -> FloatType)
  val rgb = Array("r" -> ByteType, "g" -> ByteType, "b" -> ByteType)

  val files = Seq(
    ("trepied_xyz.ply", 5995, id ++ xyz) // ,
  //   ("trepied_dim.ply", 5995, id ++ xyz ++ rgb),
  //   ("trepied_dim2.ply", 5995, id ++ xyz ++ rgb),
  //   ("213-232-7.ply", 71651, id ++ xyz ++ rgb)
  )

  val resources = "src/test/resources"

  files foreach {
    case (file, count, fields) =>
      if (new java.io.File(s"$resources/$file").exists) {
        test(s"$file should read the correct header metadata") {
          val header = PlyHeader.read(s"$resources/$file");
          header.section("vertex").count should equal(count)
        }

        test(s"$file should have the correct schema") {
          val header = PlyHeader.read(s"$resources/$file");
          header.section("vertex").schema should equal(StructType(fields map {
            case (name, dataType) => StructField(name, dataType, nullable = false)
          }))
        }
      }
  }
} 
Example 11
Source File: UtilTests.scala    From sparkplug   with MIT License 5 votes vote down vote up
package springnz.sparkplug
import springnz.sparkplug.util.SerializeUtils._

import org.scalatest.{ ShouldMatchers, WordSpec }

case class TestObject(a: String, b: Int, c: Vector[String], d: List[Int])

class UtilTests extends WordSpec with ShouldMatchers {

  "serialise utils" should {
    "serialise and deserialise a local object" in {
      val testObject = TestObject("hello", 42, Vector("test", "array"), List(42, 108))
      val byteArray = serialize(testObject)
      val inflatedObject = deserialize[TestObject](byteArray, this.getClass.getClassLoader)

      inflatedObject should equal(testObject)
    }

    "serialise and deserialise a local object with its class loader" in {
      val testObject = TestObject("hello", 42, Vector("test", "array"), List(42, 108))
      val byteArray = serialize(testObject)
      val inflatedObject = deserialize[TestObject](byteArray, TestObject.getClass.getClassLoader)

      inflatedObject should equal(testObject)
    }

    "serialise and deserialise a local object with default class loader" in {
      val testObject = TestObject("hello", 42, Vector("test", "array"), List(42, 108))
      val byteArray = serialize(testObject)
      val inflatedObject = deserialize[TestObject](byteArray)

      inflatedObject should equal(testObject)
    }
  }
} 
Example 12
Source File: SecondaryPairDCFunctionsTest.scala    From spark-flow   with Apache License 2.0 5 votes vote down vote up
package com.bloomberg.sparkflow.dc

import com.bloomberg.sparkflow._
import com.holdenkarau.spark.testing.SharedSparkContext
import org.scalatest.{ShouldMatchers, FunSuite}


class SecondaryPairDCFunctionsTest extends FunSuite with SharedSparkContext with ShouldMatchers {

  test("testRepartAndSort") {
    val input = parallelize(Seq(
      (("a",3), 0),
      (("b",2), 0),
      (("b",1), 0),
      (("b",3), 0),
      (("a",2), 0),
      (("a",1), 0)))

    val sortAndRepart = input.repartitionAndSecondarySortWithinPartitions(2)

    val result = sortAndRepart.mapPartitions(it => Iterator(it.toList))

    val expected = Seq(
      List(
      (("a",1), 0),
      (("a",2), 0),
      (("a",3), 0)),
      List(
      (("b",1), 0),
      (("b",2), 0),
      (("b",3), 0)))

    expected should contain theSameElementsAs result.getRDD(sc).collect()

  }

}