org.scalatest.junit.JUnitRunner Scala Examples

The following examples show how to use org.scalatest.junit.JUnitRunner. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: RegressITCase.scala    From flink-tensorflow   with Apache License 2.0 6 votes vote down vote up
package org.apache.flink.contrib.tensorflow.ml

import com.twitter.bijection.Conversion._
import org.apache.flink.api.common.functions.RichFlatMapFunction
import org.apache.flink.api.scala._
import org.apache.flink.configuration.Configuration
import org.apache.flink.contrib.tensorflow.ml.signatures.RegressionMethod._
import org.apache.flink.contrib.tensorflow.types.TensorInjections.{message2Tensor, messages2Tensor}
import org.apache.flink.contrib.tensorflow.util.TestData._
import org.apache.flink.contrib.tensorflow.util.{FlinkTestBase, RegistrationUtils}
import org.apache.flink.core.fs.Path
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.util.Collector
import org.apache.flink.util.Preconditions.checkState
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}
import org.tensorflow.Tensor
import org.tensorflow.contrib.scala.Arrays._
import org.tensorflow.contrib.scala.Rank._
import org.tensorflow.contrib.scala._
import org.tensorflow.example.Example
import resource._

@RunWith(classOf[JUnitRunner])
class RegressITCase extends WordSpecLike
  with Matchers
  with FlinkTestBase {

  override val parallelism = 1

  type LabeledExample = (Example, Float)

  def examples(): Seq[LabeledExample] = {
    for (v <- Seq(0.0f -> 2.0f, 1.0f -> 2.5f, 2.0f -> 3.0f, 3.0f -> 3.5f))
      yield (example("x" -> feature(v._1)), v._2)
  }

  "A RegressFunction" should {
    "process elements" in {
      val env = StreamExecutionEnvironment.getExecutionEnvironment
      RegistrationUtils.registerTypes(env.getConfig)

      val model = new HalfPlusTwo(new Path("../models/half_plus_two"))

      val outputs = env
        .fromCollection(examples())
        .flatMap(new RichFlatMapFunction[LabeledExample, Float] {
          override def open(parameters: Configuration): Unit = model.open()
          override def close(): Unit = model.close()

          override def flatMap(value: (Example, Float), out: Collector[Float]): Unit = {
            for {
              x <- managed(Seq(value._1).toList.as[Tensor].taggedAs[ExampleTensor])
              y <- model.regress_x_to_y(x)
            } {
              // cast as a 1D tensor to use the available conversion
              val o = y.taggedAs[TypedTensor[`1D`,Float]].as[Array[Float]]
              val actual = o(0)
              checkState(actual == value._2)
              out.collect(actual)
            }
          }
        })
        .print()

      env.execute()
    }
  }
} 
Example 2
Source File: DriverActorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.actor

import java.nio.file.{Files, Path}

import akka.actor.{ActorSystem, Props}
import akka.testkit.{DefaultTimeout, ImplicitSender, TestKit}
import akka.util.Timeout
import com.stratio.sparta.serving.api.actor.DriverActor.UploadDrivers
import com.stratio.sparta.serving.core.config.{SpartaConfig, SpartaConfigFactory}
import com.stratio.sparta.serving.core.models.SpartaSerializer
import com.stratio.sparta.serving.core.models.files.{SpartaFile, SpartaFilesResponse}
import com.typesafe.config.{Config, ConfigFactory}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar
import spray.http.BodyPart

import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.{Failure, Success}

@RunWith(classOf[JUnitRunner])
class DriverActorTest extends TestKit(ActorSystem("PluginActorSpec"))
  with DefaultTimeout
  with ImplicitSender
  with WordSpecLike
  with Matchers
  with BeforeAndAfterAll
  with BeforeAndAfterEach
  with MockitoSugar with SpartaSerializer {

  val tempDir: Path = Files.createTempDirectory("test")
  tempDir.toFile.deleteOnExit()

  val localConfig: Config = ConfigFactory.parseString(
    s"""
       |sparta{
       |   api {
       |     host = local
       |     port= 7777
       |   }
       |}
       |
       |sparta.config.driverPackageLocation = "$tempDir"
    """.stripMargin)

  val fileList = Seq(BodyPart("reference.conf", "file"))

  override def beforeEach(): Unit = {
    SpartaConfig.initMainConfig(Option(localConfig), SpartaConfigFactory(localConfig))
    SpartaConfig.initApiConfig()
  }

  override def afterAll: Unit = {
    shutdown()
  }

  override implicit val timeout: Timeout = Timeout(15 seconds)

  "DriverActor " must {

    "Not save files with wrong extension" in {
      val driverActor = system.actorOf(Props(new DriverActor()))
      driverActor ! UploadDrivers(fileList)
      expectMsgPF() {
        case SpartaFilesResponse(Success(f: Seq[SpartaFile])) => f.isEmpty shouldBe true
      }
    }
    "Not upload empty files" in {
      val driverActor = system.actorOf(Props(new DriverActor()))
      driverActor ! UploadDrivers(Seq.empty)
      expectMsgPF() {
        case SpartaFilesResponse(Failure(f)) => f.getMessage shouldBe "At least one file is expected"
      }
    }
    "Save a file" in {
      val driverActor = system.actorOf(Props(new DriverActor()))
      driverActor ! UploadDrivers(Seq(BodyPart("reference.conf", "file.jar")))
      expectMsgPF() {
        case SpartaFilesResponse(Success(f: Seq[SpartaFile])) => f.head.fileName.endsWith("file.jar") shouldBe true
      }
    }
  }
} 
Example 3
Source File: PluginActorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.actor

import java.nio.file.{Files, Path}

import akka.actor.{ActorSystem, Props}
import akka.testkit.{DefaultTimeout, ImplicitSender, TestKit}
import akka.util.Timeout
import com.stratio.sparta.serving.api.actor.PluginActor.{PluginResponse, UploadPlugins}
import com.stratio.sparta.serving.api.constants.HttpConstant
import com.stratio.sparta.serving.core.config.{SpartaConfig, SpartaConfigFactory}
import com.stratio.sparta.serving.core.models.SpartaSerializer
import com.stratio.sparta.serving.core.models.files.{SpartaFile, SpartaFilesResponse}
import com.typesafe.config.{Config, ConfigFactory}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar
import spray.http.BodyPart

import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.{Failure, Success}

@RunWith(classOf[JUnitRunner])
class PluginActorTest extends TestKit(ActorSystem("PluginActorSpec"))
  with DefaultTimeout
  with ImplicitSender
  with WordSpecLike
  with Matchers
  with BeforeAndAfterAll
  with BeforeAndAfterEach
  with MockitoSugar with SpartaSerializer {

  val tempDir: Path = Files.createTempDirectory("test")
  tempDir.toFile.deleteOnExit()

  val localConfig: Config = ConfigFactory.parseString(
    s"""
       |sparta{
       |   api {
       |     host = local
       |     port= 7777
       |   }
       |}
       |
       |sparta.config.pluginPackageLocation = "$tempDir"
    """.stripMargin)


  val fileList = Seq(BodyPart("reference.conf", "file"))

  override def beforeEach(): Unit = {
    SpartaConfig.initMainConfig(Option(localConfig), SpartaConfigFactory(localConfig))
    SpartaConfig.initApiConfig()
  }

  override def afterAll: Unit = {
    shutdown()
  }

  override implicit val timeout: Timeout = Timeout(15 seconds)

  "PluginActor " must {

    "Not save files with wrong extension" in {
      val pluginActor = system.actorOf(Props(new PluginActor()))
      pluginActor ! UploadPlugins(fileList)
      expectMsgPF() {
        case SpartaFilesResponse(Success(f: Seq[SpartaFile])) => f.isEmpty shouldBe true
      }
    }
    "Not upload empty files" in {
      val pluginActor = system.actorOf(Props(new PluginActor()))
      pluginActor ! UploadPlugins(Seq.empty)
      expectMsgPF() {
        case SpartaFilesResponse(Failure(f)) => f.getMessage shouldBe "At least one file is expected"
      }
    }
    "Save a file" in {
      val pluginActor = system.actorOf(Props(new PluginActor()))
      pluginActor ! UploadPlugins(Seq(BodyPart("reference.conf", "file.jar")))
      expectMsgPF() {
        case SpartaFilesResponse(Success(f: Seq[SpartaFile])) => f.head.fileName.endsWith("file.jar") shouldBe true
      }
    }
  }

} 
Example 4
Source File: CustomExceptionHandlerTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.service.handler

import akka.actor.ActorSystem
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}
import spray.http.StatusCodes
import spray.httpx.Json4sJacksonSupport
import spray.routing.{Directives, HttpService, StandardRoute}
import spray.testkit.ScalatestRouteTest
import com.stratio.sparta.sdk.exception.MockException
import com.stratio.sparta.serving.api.service.handler.CustomExceptionHandler._
import com.stratio.sparta.serving.core.exception.ServingCoreException
import com.stratio.sparta.serving.core.models.{ErrorModel, SpartaSerializer}

@RunWith(classOf[JUnitRunner])
class CustomExceptionHandlerTest extends WordSpec
with Directives with ScalatestRouteTest with Matchers
with Json4sJacksonSupport with HttpService with SpartaSerializer {

  def actorRefFactory: ActorSystem = system

  trait MyTestRoute {

    val exception: Throwable
    val route: StandardRoute = complete(throw exception)
  }

  def route(throwable: Throwable): StandardRoute = complete(throw throwable)

  "CustomExceptionHandler" should {
    "encapsulate a unknow error in an error model and response with a 500 code" in new MyTestRoute {
      val exception = new MockException
      Get() ~> sealRoute(route) ~> check {
        status should be(StatusCodes.InternalServerError)
        response.entity.asString should be(ErrorModel.toString(new ErrorModel("666", "unknown")))
      }
    }
    "encapsulate a serving api error in an error model and response with a 400 code" in new MyTestRoute {
      val exception = ServingCoreException.create(ErrorModel.toString(new ErrorModel("333", "testing exception")))
      Get() ~> sealRoute(route) ~> check {
        status should be(StatusCodes.NotFound)
        response.entity.asString should be(ErrorModel.toString(new ErrorModel("333", "testing exception")))
      }
    }
  }
} 
Example 5
Source File: ConfigHttpServiceTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.service.http

import akka.actor.ActorRef
import akka.testkit.TestProbe
import com.stratio.sparta.serving.api.actor.ConfigActor
import com.stratio.sparta.serving.api.actor.ConfigActor._
import com.stratio.sparta.serving.api.constants.HttpConstant
import com.stratio.sparta.serving.core.config.{SpartaConfig, SpartaConfigFactory}
import com.stratio.sparta.serving.core.constants.{AkkaConstant, AppConstant}
import com.stratio.sparta.serving.core.models.dto.LoggedUserConstant
import com.stratio.sparta.serving.core.models.frontend.FrontendConfiguration
import org.junit.runner.RunWith
import org.scalatest.WordSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ConfigHttpServiceTest extends WordSpec
  with ConfigHttpService
  with HttpServiceBaseTest{

  val configActorTestProbe = TestProbe()

  val dummyUser = Some(LoggedUserConstant.AnonymousUser)

  override implicit val actors: Map[String, ActorRef] = Map(
    AkkaConstant.ConfigActorName -> configActorTestProbe.ref
  )

  override val supervisor: ActorRef = testProbe.ref

  override def beforeEach(): Unit = {
    SpartaConfig.initMainConfig(Option(localConfig), SpartaConfigFactory(localConfig))
  }

  protected def retrieveStringConfig(): FrontendConfiguration =
    FrontendConfiguration(AppConstant.DefaultFrontEndTimeout, Option(AppConstant.DefaultOauth2CookieName))

  "ConfigHttpService.FindAll" should {
    "retrieve a FrontendConfiguration item" in {
      startAutopilot(ConfigResponse(retrieveStringConfig()))
      Get(s"/${HttpConstant.ConfigPath}") ~> routes(dummyUser) ~> check {
        testProbe.expectMsgType[ConfigActor.FindAll.type]
        responseAs[FrontendConfiguration] should equal(retrieveStringConfig())
      }
    }
  }

} 
Example 6
Source File: AppStatusHttpServiceTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.service.http

import akka.actor.ActorRef
import com.stratio.sparta.serving.api.constants.HttpConstant
import org.apache.curator.framework.CuratorFramework
import org.junit.runner.RunWith
import org.scalamock.scalatest.MockFactory
import org.scalatest.WordSpec
import org.scalatest.junit.JUnitRunner
import spray.http.StatusCodes

@RunWith(classOf[JUnitRunner])
class AppStatusHttpServiceTest extends WordSpec
                              with AppStatusHttpService
                              with HttpServiceBaseTest
with MockFactory {

  override implicit val actors: Map[String, ActorRef] = Map()
  override val supervisor: ActorRef = testProbe.ref
  override val curatorInstance = mock[CuratorFramework]

  "AppStatusHttpService" should {
    "check the status of the server" in {
      Get(s"/${HttpConstant.AppStatus}") ~> routes() ~> check {
        status should be (StatusCodes.InternalServerError)
      }
    }
  }
} 
Example 7
Source File: PluginsHttpServiceTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.service.http

import akka.actor.ActorRef
import akka.testkit.TestProbe
import com.stratio.sparta.serving.api.actor.PluginActor.{PluginResponse, UploadPlugins}
import com.stratio.sparta.serving.api.constants.HttpConstant
import com.stratio.sparta.serving.core.config.{SpartaConfig, SpartaConfigFactory}
import com.stratio.sparta.serving.core.models.dto.LoggedUserConstant
import com.stratio.sparta.serving.core.models.files.{SpartaFile, SpartaFilesResponse}
import org.junit.runner.RunWith
import org.scalatest.WordSpec
import org.scalatest.junit.JUnitRunner
import spray.http._

import scala.util.{Failure, Success}

@RunWith(classOf[JUnitRunner])
class PluginsHttpServiceTest extends WordSpec
  with PluginsHttpService
  with HttpServiceBaseTest {

  override val supervisor: ActorRef = testProbe.ref

  val pluginTestProbe = TestProbe()

  val dummyUser = Some(LoggedUserConstant.AnonymousUser)

  override implicit val actors: Map[String, ActorRef] = Map.empty

  override def beforeEach(): Unit = {
    SpartaConfig.initMainConfig(Option(localConfig), SpartaConfigFactory(localConfig))
  }

  "PluginsHttpService.upload" should {
    "Upload a file" in {
      val response = SpartaFilesResponse(Success(Seq(SpartaFile("", "", "", ""))))
      startAutopilot(response)
      Put(s"/${HttpConstant.PluginsPath}") ~> routes(dummyUser) ~> check {
        testProbe.expectMsgType[UploadPlugins]
        status should be(StatusCodes.OK)
      }
    }
    "Fail when service is not available" in {
      val response = SpartaFilesResponse(Failure(new IllegalArgumentException("Error")))
      startAutopilot(response)
      Put(s"/${HttpConstant.PluginsPath}") ~> routes(dummyUser) ~> check {
        testProbe.expectMsgType[UploadPlugins]
        status should be(StatusCodes.InternalServerError)
      }
    }
  }
} 
Example 8
Source File: FileSystemOutputIT.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.filesystem

import java.io.File

import com.stratio.sparta.plugin.TemporalSparkContext
import com.stratio.sparta.plugin.output.fileSystem.FileSystemOutput
import com.stratio.sparta.sdk.pipeline.output.{Output, OutputFormatEnum, SaveModeEnum}
import org.apache.commons.io.FileUtils
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.junit.runner.RunWith
import org.scalatest.Matchers
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class FileSystemOutputIT extends TemporalSparkContext with Matchers {

  val directory = getClass().getResource("/origin.txt")
  val parentFile = new File(directory.getPath).getParent
  val properties = Map(("path", parentFile + "/testRow"), ("outputFormat", "row"))
  val fields = StructType(StructField("name", StringType, false) ::
    StructField("age", IntegerType, false) ::
    StructField("year", IntegerType, true) :: Nil)
  val fsm = new FileSystemOutput("key", properties)


  "An object of type FileSystemOutput " should "have the same values as the properties Map" in {
    fsm.outputFormat should be(OutputFormatEnum.ROW)
  }

  
  private def dfGen(): DataFrame = {
    val sqlCtx = SparkSession.builder().config(sc.getConf).getOrCreate()
    val dataRDD = sc.parallelize(List(("user1", 23, 1993), ("user2", 26, 1990), ("user3", 21, 1995)))
      .map { case (name, age, year) => Row(name, age, year) }

    sqlCtx.createDataFrame(dataRDD, fields)
  }

  def fileExists(path: String): Boolean = new File(path).exists()

  "Given a DataFrame, a directory" should "be created with the data written inside" in {
    fsm.save(dfGen(), SaveModeEnum.Append, Map(Output.TableNameKey -> "test"))
    fileExists(fsm.path.get) should equal(true)
  }

  it should "exist with the given path and be deleted" in {
    if (fileExists(fsm.path.get))
      FileUtils.deleteDirectory(new File(fsm.path.get))
    fileExists(fsm.path.get) should equal(false)
  }

  val fsm2 = new FileSystemOutput("key", properties.updated("outputFormat", "json")
    .updated("path", parentFile + "/testJson"))

  "Given another DataFrame, a directory" should "be created with the data inside in JSON format" in {
    fsm2.outputFormat should be(OutputFormatEnum.JSON)
    fsm2.save(dfGen(), SaveModeEnum.Append, Map(Output.TableNameKey -> "test"))
    fileExists(fsm2.path.get) should equal(true)
  }

  it should "exist with the given path and be deleted" in {
    if (fileExists(s"${fsm2.path.get}/test"))
      FileUtils.deleteDirectory(new File(s"${fsm2.path.get}/test"))
    fileExists(s"${fsm2.path.get}/test") should equal(false)
  }
} 
Example 9
Source File: AvroOutputIT.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.avro

import java.sql.Timestamp
import java.time.Instant

import com.databricks.spark.avro._
import com.stratio.sparta.plugin.TemporalSparkContext
import com.stratio.sparta.sdk.pipeline.output.{Output, SaveModeEnum}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

import scala.reflect.io.File
import scala.util.Random


@RunWith(classOf[JUnitRunner])
class AvroOutputIT extends TemporalSparkContext with Matchers {

  trait CommonValues {
    val tmpPath: String = File.makeTemp().name
    val sparkSession = SparkSession.builder().config(sc.getConf).getOrCreate()
    val schema = StructType(Seq(
      StructField("name", StringType),
      StructField("age", IntegerType),
      StructField("minute", LongType)
    ))

    val data =
      sparkSession.createDataFrame(sc.parallelize(Seq(
        Row("Kevin", Random.nextInt, Timestamp.from(Instant.now).getTime),
        Row("Kira", Random.nextInt, Timestamp.from(Instant.now).getTime),
        Row("Ariadne", Random.nextInt, Timestamp.from(Instant.now).getTime)
      )), schema)
  }

  trait WithEventData extends CommonValues {
    val properties = Map("path" -> tmpPath)
    val output = new AvroOutput("avro-test", properties)
  }


  "AvroOutput" should "throw an exception when path is not present" in {
    an[Exception] should be thrownBy new AvroOutput("avro-test", Map.empty)
  }

  it should "throw an exception when empty path " in {
    an[Exception] should be thrownBy new AvroOutput("avro-test", Map("path" -> "    "))
  }

  it should "save a dataframe " in new WithEventData {
    output.save(data, SaveModeEnum.Append, Map(Output.TableNameKey -> "person"))
    val read = sparkSession.read.avro(s"$tmpPath/person")
    read.count should be(3)
    read should be eq data
    File(tmpPath).deleteRecursively
    File("spark-warehouse").deleteRecursively
  }

} 
Example 10
Source File: HttpOutputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.http

import com.stratio.sparta.plugin.TemporalSparkContext
import com.stratio.sparta.sdk.pipeline.output.OutputFormatEnum
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.junit.runner.RunWith
import org.scalatest.Matchers
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class HttpOutputTest extends TemporalSparkContext with Matchers {

  val properties = Map(
    "url" -> "https://httpbin.org/post",
    "delimiter" -> ",",
    "parameterName" -> "thisIsAKeyName",
    "readTimeOut" -> "5000",
    "outputFormat" -> "ROW",
    "postType" -> "body",
    "connTimeout" -> "6000"
  )

  val fields = StructType(StructField("name", StringType, false) ::
    StructField("age", IntegerType, false) ::
    StructField("year", IntegerType, true) :: Nil)
  val OkHTTPResponse = 200

  "An object of type RestOutput " should "have the same values as the properties Map" in {
    val rest = new HttpOutput("key", properties)

    rest.outputFormat should be(OutputFormatEnum.ROW)
    rest.readTimeout should be(5000)
  }
  it should "throw a NoSuchElementException" in {
    val properties2 = properties.updated("postType", "vooooooody")
    a[NoSuchElementException] should be thrownBy {
      new HttpOutput("keyName", properties2)
    }
  }

  
  private def dfGen(): DataFrame = {
    val sqlCtx = SparkSession.builder().config(sc.getConf).getOrCreate()
    val dataRDD = sc.parallelize(List(("user1", 23, 1993), ("user2", 26, 1990))).map { case (name, age, year) =>
      Row(name, age, year)
    }
    sqlCtx.createDataFrame(dataRDD, fields)
  }

  val restMock1 = new HttpOutput("key", properties)
  "Given a DataFrame it" should "be parsed and send through a Raw data POST request" in {

    dfGen().collect().foreach(row => {
      assertResult(OkHTTPResponse)(restMock1.sendData(row.mkString(restMock1.delimiter)).code)
    })
  }

  it should "return the same amount of responses as rows in the DataFrame" in {
    val size = dfGen().collect().map(row => restMock1.sendData(row.mkString(restMock1.delimiter)).code).size
    assertResult(dfGen().count())(size)
  }

  val restMock2 = new HttpOutput("key", properties.updated("postType", "parameter"))
  it should "be parsed and send as a POST request along with a parameter stated by properties.parameterKey " in {
    dfGen().collect().foreach(row => {
      assertResult(OkHTTPResponse)(restMock2.sendData(row.mkString(restMock2.delimiter)).code)
    })
  }

  val restMock3 = new HttpOutput("key", properties.updated("outputFormat", "JSON"))
  "Given a DataFrame it" should "be sent as JSON through a Raw data POST request" in {

    dfGen().toJSON.collect().foreach(row => {
      assertResult(OkHTTPResponse)(restMock3.sendData(row).code)
    })
  }

  val restMock4 = new HttpOutput("key", properties.updated("postType", "parameter").updated("format", "JSON"))
  it should "sent as a POST request along with a parameter stated by properties.parameterKey " in {

    dfGen().toJSON.collect().foreach(row => {
      assertResult(OkHTTPResponse)(restMock4.sendData(row).code)
    })
  }
} 
Example 11
Source File: CassandraOutputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.cassandra

import java.io.{Serializable => JSerializable}

import com.datastax.spark.connector.cql.CassandraConnector
import com.stratio.sparta.sdk._
import com.stratio.sparta.sdk.properties.JsoneyString
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class CassandraOutputTest extends FlatSpec with Matchers with MockitoSugar with AnswerSugar {

  val s = "sum"
  val properties = Map(("connectionHost", "127.0.0.1"), ("connectionPort", "9042"))

  "getSparkConfiguration" should "return a Seq with the configuration" in {
    val configuration = Map(("connectionHost", "127.0.0.1"), ("connectionPort", "9042"))
    val cass = CassandraOutput.getSparkConfiguration(configuration)

    cass should be(List(("spark.cassandra.connection.host", "127.0.0.1"), ("spark.cassandra.connection.port", "9042")))
  }

  "getSparkConfiguration" should "return all cassandra-spark config" in {
    val config: Map[String, JSerializable] = Map(
      ("sparkProperties" -> JsoneyString(
        "[{\"sparkPropertyKey\":\"spark.cassandra.input.fetch.size_in_rows\",\"sparkPropertyValue\":\"2000\"}," +
          "{\"sparkPropertyKey\":\"spark.cassandra.input.split.size_in_mb\",\"sparkPropertyValue\":\"64\"}]")),
      ("anotherProperty" -> "true")
    )

    val sparkConfig = CassandraOutput.getSparkConfiguration(config)

    sparkConfig.exists(_ == ("spark.cassandra.input.fetch.size_in_rows" -> "2000")) should be(true)
    sparkConfig.exists(_ == ("spark.cassandra.input.split.size_in_mb" -> "64")) should be(true)
    sparkConfig.exists(_ == ("anotherProperty" -> "true")) should be(false)
  }

  "getSparkConfiguration" should "not return cassandra-spark config" in {
    val config: Map[String, JSerializable] = Map(
      ("hadoopProperties" -> JsoneyString(
        "[{\"sparkPropertyKey\":\"spark.cassandra.input.fetch.size_in_rows\",\"sparkPropertyValue\":\"2000\"}," +
          "{\"sparkPropertyKey\":\"spark.cassandra.input.split.size_in_mb\",\"sparkPropertyValue\":\"64\"}]")),
      ("anotherProperty" -> "true")
    )

    val sparkConfig = CassandraOutput.getSparkConfiguration(config)

    sparkConfig.exists(_ == ("spark.cassandra.input.fetch.size_in_rows" -> "2000")) should be(false)
    sparkConfig.exists(_ == ("spark.cassandra.input.split.size_in_mb" -> "64")) should be(false)
    sparkConfig.exists(_ == ("anotherProperty" -> "true")) should be(false)
  }
} 
Example 12
Source File: ElasticSearchOutputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.elasticsearch

import com.stratio.sparta.sdk.properties.JsoneyString
import org.apache.spark.sql.types._
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ElasticSearchOutputTest extends FlatSpec with ShouldMatchers {

  trait BaseValues {

    final val localPort = 9200
    final val remotePort = 9300
    val output = getInstance()
    val outputMultipleNodes = new ElasticSearchOutput("ES-out",
      Map("nodes" ->
        new JsoneyString(
          s"""[{"node":"host-a","tcpPort":"$remotePort","httpPort":"$localPort"},{"node":"host-b",
              |"tcpPort":"9301","httpPort":"9201"}]""".stripMargin),
        "dateType" -> "long"))

    def getInstance(host: String = "localhost", httpPort: Int = localPort, tcpPort: Int = remotePort)
    : ElasticSearchOutput =
      new ElasticSearchOutput("ES-out",
        Map("nodes" -> new JsoneyString( s"""[{"node":"$host","httpPort":"$httpPort","tcpPort":"$tcpPort"}]"""),
          "clusterName" -> "elasticsearch"))
  }

  trait NodeValues extends BaseValues {

    val ipOutput = getInstance("127.0.0.1", localPort, remotePort)
    val ipv6Output = getInstance("0:0:0:0:0:0:0:1", localPort, remotePort)
    val remoteOutput = getInstance("dummy", localPort, remotePort)
  }

  trait TestingValues extends BaseValues {

    val indexNameType = "spartatable/sparta"
    val tableName = "spartaTable"
    val baseFields = Seq(StructField("string", StringType), StructField("int", IntegerType))
    val schema = StructType(baseFields)
    val extraFields = Seq(StructField("id", StringType, false), StructField("timestamp", LongType, false))
    val properties = Map("nodes" -> new JsoneyString(
      """[{"node":"localhost","httpPort":"9200","tcpPort":"9300"}]""".stripMargin),
      "dateType" -> "long",
      "clusterName" -> "elasticsearch")
    override val output = new ElasticSearchOutput("ES-out", properties)
    val dateField = StructField("timestamp", TimestampType, false)
    val expectedDateField = StructField("timestamp", LongType, false)
    val stringField = StructField("string", StringType)
    val expectedStringField = StructField("string", StringType)
  }

  trait SchemaValues extends BaseValues {

    val fields = Seq(
      StructField("long", LongType),
      StructField("double", DoubleType),
      StructField("decimal", DecimalType(10, 0)),
      StructField("int", IntegerType),
      StructField("boolean", BooleanType),
      StructField("date", DateType),
      StructField("timestamp", TimestampType),
      StructField("array", ArrayType(StringType)),
      StructField("map", MapType(StringType, IntegerType)),
      StructField("string", StringType),
      StructField("binary", BinaryType))
    val completeSchema = StructType(fields)
  }

  "ElasticSearchOutput" should "format properties" in new NodeValues with SchemaValues {
    output.httpNodes should be(Seq(("localhost", 9200)))
    outputMultipleNodes.httpNodes should be(Seq(("host-a", 9200), ("host-b", 9201)))
    output.clusterName should be("elasticsearch")
  }

  it should "parse correct index name type" in new TestingValues {
    output.indexNameType(tableName) should be(indexNameType)
  }

  it should "return a Seq of tuples (host,port) format" in new NodeValues {

    output.getHostPortConfs("nodes", "localhost", "9200", "node", "httpPort") should be(List(("localhost", 9200)))
    output.getHostPortConfs("nodes", "localhost", "9300", "node", "tcpPort") should be(List(("localhost", 9300)))
    outputMultipleNodes.getHostPortConfs("nodes", "localhost", "9200", "node", "httpPort") should be(List(
      ("host-a", 9200), ("host-b", 9201)))
    outputMultipleNodes.getHostPortConfs("nodes", "localhost", "9300", "node", "tcpPort") should be(List(
      ("host-a", 9300), ("host-b", 9301)))
  }
} 
Example 13
Source File: CsvOutputIT.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.csv

import java.sql.Timestamp
import java.time.Instant

import com.databricks.spark.avro._
import com.stratio.sparta.plugin.TemporalSparkContext
import com.stratio.sparta.sdk.pipeline.output.{Output, SaveModeEnum}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

import scala.reflect.io.File
import scala.util.Random


@RunWith(classOf[JUnitRunner])
class CsvOutputIT extends TemporalSparkContext with Matchers {

  trait CommonValues {
    val tmpPath: String = File.makeTemp().name
    val sparkSession = SparkSession.builder().config(sc.getConf).getOrCreate()
    val schema = StructType(Seq(
      StructField("name", StringType),
      StructField("age", IntegerType),
      StructField("minute", LongType)
    ))

    val data =
      sparkSession.createDataFrame(sc.parallelize(Seq(
        Row("Kevin", Random.nextInt, Timestamp.from(Instant.now).getTime),
        Row("Kira", Random.nextInt, Timestamp.from(Instant.now).getTime),
        Row("Ariadne", Random.nextInt, Timestamp.from(Instant.now).getTime)
      )), schema)
  }

  trait WithEventData extends CommonValues {
    val properties = Map("path" -> tmpPath)
    val output = new CsvOutput("csv-test", properties)
  }


  "CsvOutput" should "throw an exception when path is not present" in {
    an[Exception] should be thrownBy new CsvOutput("csv-test", Map.empty)
  }

  it should "throw an exception when empty path " in {
    an[Exception] should be thrownBy new CsvOutput("csv-test", Map("path" -> "    "))
  }

  it should "save a dataframe " in new WithEventData {
    output.save(data, SaveModeEnum.Append, Map(Output.TableNameKey -> "person"))
    val read = sparkSession.read.csv(s"$tmpPath/person.csv")
    read.count should be(3)
    read should be eq data
    File(tmpPath).deleteRecursively
    File("spark-warehouse").deleteRecursively
  }

} 
Example 14
Source File: LastValueOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.lastValue

import java.util.Date

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class LastValueOperatorTest extends WordSpec with Matchers {

  "LastValue operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new LastValueOperator("lastValue", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new LastValueOperator("lastValue", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new LastValueOperator("lastValue", initSchema, Map())
      inputFields.processReduce(Seq()) should be(None)

      val inputFields2 = new LastValueOperator("lastValue", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2))) should be(Some(2))

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some("b"))
    }

    "associative process must be " in {
      val inputFields = new LastValueOperator("lastValue", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(1L))

      val inputFields2 = new LastValueOperator("lastValue", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)))
      inputFields2.associativity(resultInput2) should be(Some(1))

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(1)),
        (Operator.NewValuesKey, Some(2)))
      inputFields3.associativity(resultInput3) should be(Some(2))

      val inputFields4 = new LastValueOperator("lastValue", initSchema, Map())
      val resultInput4 = Seq()
      inputFields4.associativity(resultInput4) should be(None)

      val inputFields5 = new LastValueOperator("lastValue", initSchema, Map())
      val date = new Date()
      val resultInput5 = Seq((Operator.NewValuesKey, Some(date)))
      inputFields5.associativity(resultInput5) should be(Some(date))
    }
  }
} 
Example 15
Source File: StddevOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.stddev

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class StddevOperatorTest extends WordSpec with Matchers {

  "Std dev operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new StddevOperator("stdev", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new StddevOperator("stdev", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new StddevOperator("stdev", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new StddevOperator("stdev", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be
      (Some(2.8284271247461903))

      val inputFields3 = new StddevOperator("stdev", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be
      (Some(2.850438562747845))

      val inputFields4 = new StddevOperator("stdev", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new StddevOperator("stdev", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(
        Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some("2.850438562747845"))
    }

    "processReduce distinct must be " in {
      val inputFields = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be
      (Some(2.8284271247461903))

      val inputFields3 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be
      (Some(2.850438562747845))

      val inputFields4 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new StddevOperator("stdev", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(
        Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some("2.850438562747845"))
    }
  }
} 
Example 16
Source File: MedianOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.median

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MedianOperatorTest extends WordSpec with Matchers {

  "Median operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new MedianOperator("median", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MedianOperator("median", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new MedianOperator("median", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MedianOperator("median", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(3d))

      val inputFields3 = new MedianOperator("median", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(3))

      val inputFields4 = new MedianOperator("median", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MedianOperator("median", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("3.0"))
    }

    "processReduce distinct must be " in {
      val inputFields = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(2.5))

      val inputFields3 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(3))

      val inputFields4 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MedianOperator("median", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("2.5"))
    }
  }
} 
Example 17
Source File: ModeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mode

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class ModeOperatorTest extends WordSpec with Matchers {

  "Mode operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new ModeOperator("mode", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new ModeOperator("mode", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new ModeOperator("mode", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new ModeOperator("mode", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new ModeOperator("mode", initSchema, Map())
      inputFields2.processReduce(Seq(Some("hey"), Some("hey"), Some("hi"))) should be(Some(List("hey")))

      val inputFields3 = new ModeOperator("mode", initSchema, Map())
      inputFields3.processReduce(Seq(Some("1"), Some("1"), Some("4"))) should be(Some(List("1")))

      val inputFields4 = new ModeOperator("mode", initSchema, Map())
      inputFields4.processReduce(Seq(
        Some("1"), Some("1"), Some("4"), Some("4"), Some("4"), Some("4"))) should be(Some(List("4")))

      val inputFields5 = new ModeOperator("mode", initSchema, Map())
      inputFields5.processReduce(Seq(
        Some("1"), Some("1"), Some("2"), Some("2"), Some("4"), Some("4"))) should be(Some(List("1", "2", "4")))

      val inputFields6 = new ModeOperator("mode", initSchema, Map())
      inputFields6.processReduce(Seq(
        Some("1"), Some("1"), Some("2"), Some("2"), Some("4"), Some("4"), Some("5"))
      ) should be(Some(List("1", "2", "4")))
    }
  }
} 
Example 18
Source File: RangeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.range

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class RangeOperatorTest extends WordSpec with Matchers {

  "Range operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new RangeOperator("range", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new RangeOperator("range", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new RangeOperator("range", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new RangeOperator("range", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(0))

      val inputFields3 = new RangeOperator("range", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(4))) should be(Some(3))

      val inputFields4 = new RangeOperator("range", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new RangeOperator("range", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("6.0"))
    }

    "processReduce distinct must be " in {
      val inputFields = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(0))

      val inputFields3 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(4))) should be(Some(3))

      val inputFields4 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new RangeOperator("range", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("6.0"))
    }
  }
} 
Example 19
Source File: AccumulatorOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.accumulator

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class AccumulatorOperatorTest extends WordSpec with Matchers {

  "Accumulator operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new AccumulatorOperator("accumulator", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new AccumulatorOperator("accumulator", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":2}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":2}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(Seq()))

      val inputFields2 = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(Seq("1", "1")))

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some(Seq("a", "b")))
    }

    "associative process must be " in {
      val inputFields = new AccumulatorOperator("accumulator", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(Seq(1L))),
        (Operator.NewValuesKey, Some(Seq(2L))),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Seq("1", "2")))

      val inputFields2 = new AccumulatorOperator("accumulator", initSchema, Map("typeOp" -> "arraydouble"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Seq(1))),
        (Operator.NewValuesKey, Some(Seq(3))))
      inputFields2.associativity(resultInput2) should be(Some(Seq(1d, 3d)))

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(Seq(1))),
        (Operator.NewValuesKey, Some(Seq(1))))
      inputFields3.associativity(resultInput3) should be(Some(Seq("1", "1")))
    }
  }
} 
Example 20
Source File: FirstValueOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.firstValue

import java.util.Date

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class FirstValueOperatorTest extends WordSpec with Matchers {

  "FirstValue operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new FirstValueOperator("firstValue", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new FirstValueOperator("firstValue", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields.processReduce(Seq()) should be(None)

      val inputFields2 = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2))) should be(Some(1))

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some("a"))
    }

    "associative process must be " in {
      val inputFields = new FirstValueOperator("firstValue", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(1L))

      val inputFields2 = new FirstValueOperator("firstValue", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)))
      inputFields2.associativity(resultInput2) should be(Some(1))

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(1)),
        (Operator.NewValuesKey, Some(1)),
        (Operator.NewValuesKey, None))
      inputFields3.associativity(resultInput3) should be(Some(1))

      val inputFields4 = new FirstValueOperator("firstValue", initSchema, Map())
      val resultInput4 = Seq()
      inputFields4.associativity(resultInput4) should be(None)

      val inputFields5 = new FirstValueOperator("firstValue", initSchema, Map())
      val date = new Date()
      val resultInput5 = Seq((Operator.NewValuesKey, Some(date)))
      inputFields5.associativity(resultInput5) should be(Some(date))
    }
  }
} 
Example 21
Source File: MeanAssociativeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mean

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MeanAssociativeOperatorTest extends WordSpec with Matchers {

  "Mean operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new MeanAssociativeOperator("avg", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MeanAssociativeOperator("avg", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be (Some(List(1.0, 1.0)))

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), None)) should be(Some(List(1.0, 2.0, 3.0)))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(List()))
    }

    "processReduce distinct must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(List(1.0)))

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(3), Some(1), None)) should be(Some(List(1.0, 3.0)))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(List()))
    }

    "associative process must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map())
      val resultInput =
        Seq((Operator.OldValuesKey, Some(Map("count" -> 1d, "sum" -> 2d, "mean" -> 2d))), (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Map("count" -> 1.0, "sum" -> 2.0, "mean" -> 2.0)))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map())
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Map("count" -> 1d, "sum" -> 2d, "mean" -> 2d))),
        (Operator.NewValuesKey, Some(Seq(1d))))
      inputFields2.associativity(resultInput2) should be(Some(Map("sum" -> 3.0, "count" -> 2.0, "mean" -> 1.5)))
    }
  }
} 
Example 22
Source File: MeanOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mean

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MeanOperatorTest extends WordSpec with Matchers {

  "Mean operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new MeanOperator("avg", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MeanOperator("avg", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new MeanOperator("avg", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MeanOperator("avg", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(1))

      val inputFields3 = new MeanOperator("avg", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), None)) should be(Some(2))

      val inputFields4 = new MeanOperator("avg", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MeanOperator("avg", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(1))) should be(Some("1.0"))
    }

    "processReduce distinct must be " in {
      val inputFields = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(1))

      val inputFields3 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(3), Some(1), None)) should be(Some(2))

      val inputFields4 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MeanOperator("avg", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(1))) should be(Some("1.0"))
    }
  }
} 
Example 23
Source File: OperatorEntityCountTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.entityCount

import java.io.{Serializable => JSerializable}

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class OperatorEntityCountTest extends WordSpec with Matchers {

  "EntityCount" should {
    val props = Map(
      "inputField" -> "inputField".asInstanceOf[JSerializable],
      "split" -> ",".asInstanceOf[JSerializable])
    val schema = StructType(Seq(StructField("inputField", StringType)))
    val entityCount = new OperatorEntityCountMock("op1", schema, props)
    val inputFields = Row("hello,bye")

    "Return the associated precision name" in {
      val expected = Option(Seq("hello", "bye"))
      val result = entityCount.processMap(inputFields)
      result should be(expected)
    }

    "Return empty list" in {
      val expected = None
      val result = entityCount.processMap(Row())
      result should be(expected)
    }
  }
} 
Example 24
Source File: EntityCountOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.entityCount

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class EntityCountOperatorTest extends WordSpec with Matchers {

  "Entity Count Operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new EntityCountOperator("entityCount", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new EntityCountOperator("entityCount", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields4 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields4.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields5 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> "-"))
      inputFields5.processMap(Row("hola-holo", 2)) should be(Some(Seq("hola", "holo")))

      val inputFields6 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields6.processMap(Row("hola,holo adios", 2)) should be(Some(Seq("hola", "holo " + "adios")))

      val inputFields7 = new EntityCountOperator("entityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]"))
      inputFields7.processMap(Row("hola", 2)) should be(None)

      val inputFields8 = new EntityCountOperator("entityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]",
          "split" -> " "))
      inputFields8.processMap(Row("hola holo", 2)) should be(Some(Seq("hola", "holo")))
    }

    "processReduce must be " in {
      val inputFields = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(Seq()))

      val inputFields2 = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields2.processReduce(Seq(Some(Seq("hola", "holo")))) should be(Some(Seq("hola", "holo")))

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields3.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(Seq("hola", "holo", "hola")))
    }

    "associative process must be " in {
      val inputFields = new EntityCountOperator("entityCount", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Map("hola" -> 1L, "holo" -> 1L)))

      val inputFields2 = new EntityCountOperator("entityCount", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))),
        (Operator.NewValuesKey, Some(Seq("hola"))))
      inputFields2.associativity(resultInput2) should be(Some(Map()))

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))))
      inputFields3.associativity(resultInput3) should be(Some(Map("hola" -> 1L, "holo" -> 1L)))
    }
  }
} 
Example 25
Source File: SumOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.sum

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class SumOperatorTest extends WordSpec with Matchers {

  "Sum operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new SumOperator("sum", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new SumOperator("sum", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new SumOperator("sum", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new SumOperator("sum", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new SumOperator("sum", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new SumOperator("sum", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new SumOperator("sum", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(20d))

      val inputFields3 = new SumOperator("sum", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(20d))

      val inputFields4 = new SumOperator("sum", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))
    }

    "processReduce distinct must be " in {
      val inputFields = new SumOperator("sum", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new SumOperator("sum", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(1))) should be(Some(3d))
    }

    "associative process must be " in {
      val inputFields = new SumOperator("count", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(2d))

      val inputFields2 = new SumOperator("count", initSchema, Map("typeOp" -> "string"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields2.associativity(resultInput2) should be(Some("2.0"))
    }
  }
} 
Example 26
Source File: FullTextOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.fullText

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class FullTextOperatorTest extends WordSpec with Matchers {

  "FullText operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new FullTextOperator("fullText", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new FullTextOperator("fullText", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new FullTextOperator("fullText", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new FullTextOperator("fullText", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new FullTextOperator("fullText", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new FullTextOperator("fullText", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new FullTextOperator("fullText", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(""))

      val inputFields2 = new FullTextOperator("fullText", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(s"1${Operator.SpaceSeparator}1"))

      val inputFields3 = new FullTextOperator("fullText", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some(s"a${Operator.SpaceSeparator}b"))
    }

    "associative process must be " in {
      val inputFields = new FullTextOperator("fullText", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(2)), (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some("2"))

      val inputFields2 = new FullTextOperator("fullText", initSchema, Map("typeOp" -> "arraystring"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(2)),
        (Operator.NewValuesKey, Some(1)))
      inputFields2.associativity(resultInput2) should be(Some(Seq(s"2${Operator.SpaceSeparator}1")))

      val inputFields3 = new FullTextOperator("fullText", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(2)), (Operator.OldValuesKey, Some(3)))
      inputFields3.associativity(resultInput3) should be(Some(s"2${Operator.SpaceSeparator}3"))
    }
  }
} 
Example 27
Source File: TotalEntityCountOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.totalEntityCount

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class TotalEntityCountOperatorTest extends WordSpec with Matchers {

  "Entity Count Operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields4 =
        new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields4.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields5 =
        new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> "-"))
      inputFields5.processMap(Row("hola-holo", 2)) should be(Some(Seq("hola", "holo")))

      val inputFields6 =
        new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields6.processMap(Row("hola,holo adios", 2)) should be(Some(Seq("hola", "holo " + "adios")))

      val inputFields7 = new TotalEntityCountOperator("totalEntityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]"))
      inputFields7.processMap(Row("hola", 2)) should be(None)

      val inputFields8 = new TotalEntityCountOperator("totalEntityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]",
          "split" -> " "))
      inputFields8.processMap(Row("hola holo", 2)) should be
      (Some(Seq("hola", "holo")))
    }

    "processReduce must be " in {
      val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0L))

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields2.processReduce(Seq(Some(Seq("hola", "holo")))) should be(Some(2L))

      val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields3.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(3L))

      val inputFields4 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0L))
    }

    "processReduce distinct must be " in {
      val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0L))

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(2L))
    }

    "associative process must be " in {
      val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(2)), (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(2))

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(2)),
        (Operator.NewValuesKey, Some(1)))
      inputFields2.associativity(resultInput2) should be(Some(3))

      val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(2)))
      inputFields3.associativity(resultInput3) should be(Some(2))
    }
  }
} 
Example 28
Source File: HierarchyFieldTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.field.hierarchy

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.prop.TableDrivenPropertyChecks
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Matchers, WordSpecLike}

@RunWith(classOf[JUnitRunner])
class HierarchyFieldTest extends WordSpecLike
with Matchers
with BeforeAndAfter
with BeforeAndAfterAll
with TableDrivenPropertyChecks {

  var hbs: Option[HierarchyField] = _

  before {
    hbs = Some(new HierarchyField())
  }

  after {
    hbs = None
  }

  "A HierarchyDimension" should {
    "In default implementation, get 4 precisions for all precision sizes" in {
      val precisionLeftToRight = hbs.get.precisionValue(HierarchyField.LeftToRightName, "")
      val precisionRightToLeft = hbs.get.precisionValue(HierarchyField.RightToLeftName, "")
      val precisionLeftToRightWithWildCard = hbs.get.precisionValue(HierarchyField.LeftToRightWithWildCardName, "")
      val precisionRightToLeftWithWildCard = hbs.get.precisionValue(HierarchyField.RightToLeftWithWildCardName, "")

      precisionLeftToRight._1.id should be(HierarchyField.LeftToRightName)
      precisionRightToLeft._1.id should be(HierarchyField.RightToLeftName)
      precisionLeftToRightWithWildCard._1.id should be(HierarchyField.LeftToRightWithWildCardName)
      precisionRightToLeftWithWildCard._1.id should be(HierarchyField.RightToLeftWithWildCardName)
    }

    "In default implementation, every proposed combination should be ok" in {
      val data = Table(
        ("i", "o"),
        ("google.com", Seq("google.com", "*.com", "*"))
      )

      forAll(data) { (i: String, o: Seq[String]) =>
        val result = hbs.get.precisionValue(HierarchyField.LeftToRightWithWildCardName, i)
        assertResult(o)(result._2)
      }
    }
    "In reverse implementation, every proposed combination should be ok" in {
      hbs = Some(new HierarchyField())
      val data = Table(
        ("i", "o"),
        ("com.stratio.sparta", Seq("com.stratio.sparta", "com.stratio.*", "com.*", "*"))
      )

      forAll(data) { (i: String, o: Seq[String]) =>
        val result = hbs.get.precisionValue(HierarchyField.RightToLeftWithWildCardName, i.asInstanceOf[Any])
        assertResult(o)(result._2)
      }
    }
    "In reverse implementation without wildcards, every proposed combination should be ok" in {
      hbs = Some(new HierarchyField())
      val data = Table(
        ("i", "o"),
        ("com.stratio.sparta", Seq("com.stratio.sparta", "com.stratio", "com", "*"))
      )

      forAll(data) { (i: String, o: Seq[String]) =>
        val result = hbs.get.precisionValue(HierarchyField.RightToLeftName, i.asInstanceOf[Any])
        assertResult(o)(result._2)
      }
    }
    "In non-reverse implementation without wildcards, every proposed combination should be ok" in {
      hbs = Some(new HierarchyField())
      val data = Table(
        ("i", "o"),
        ("google.com", Seq("google.com", "com", "*"))
      )

      forAll(data) { (i: String, o: Seq[String]) =>
        val result = hbs.get.precisionValue(HierarchyField.LeftToRightName, i.asInstanceOf[Any])
        assertResult(o)(result._2)
      }
    }
  }
} 
Example 29
Source File: DateTimeFieldTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.field.datetime

import java.io.{Serializable => JSerializable}
import java.util.Date

import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}

@RunWith(classOf[JUnitRunner])
class DateTimeFieldTest extends WordSpecLike with Matchers {

  val dateTimeDimension = new DateTimeField(Map("second" -> "long", "minute" -> "date", "typeOp" -> "datetime"))

  "A DateTimeDimension" should {
    "In default implementation, get 6 dimensions for a specific time" in {
      val newDate = new Date()
      val precision5s =
        dateTimeDimension.precisionValue("5s", newDate.asInstanceOf[JSerializable])
      val precision10s =
        dateTimeDimension.precisionValue("10s", newDate.asInstanceOf[JSerializable])
      val precision15s =
        dateTimeDimension.precisionValue("15s", newDate.asInstanceOf[JSerializable])
      val precisionSecond =
        dateTimeDimension.precisionValue("second", newDate.asInstanceOf[JSerializable])
      val precisionMinute =
        dateTimeDimension.precisionValue("minute", newDate.asInstanceOf[JSerializable])
      val precisionHour =
        dateTimeDimension.precisionValue("hour", newDate.asInstanceOf[JSerializable])
      val precisionDay =
        dateTimeDimension.precisionValue("day", newDate.asInstanceOf[JSerializable])
      val precisionMonth =
        dateTimeDimension.precisionValue("month", newDate.asInstanceOf[JSerializable])
      val precisionYear =
        dateTimeDimension.precisionValue("year", newDate.asInstanceOf[JSerializable])

      precision5s._1.id should be("5s")
      precision10s._1.id should be("10s")
      precision15s._1.id should be("15s")
      precisionSecond._1.id should be("second")
      precisionMinute._1.id should be("minute")
      precisionHour._1.id should be("hour")
      precisionDay._1.id should be("day")
      precisionMonth._1.id should be("month")
      precisionYear._1.id should be("year")
    }

    "Each precision dimension have their output type, second must be long, minute must be date, others datetime" in {
      dateTimeDimension.precision("5s").typeOp should be(TypeOp.DateTime)
      dateTimeDimension.precision("10s").typeOp should be(TypeOp.DateTime)
      dateTimeDimension.precision("15s").typeOp should be(TypeOp.DateTime)
      dateTimeDimension.precision("second").typeOp should be(TypeOp.Long)
      dateTimeDimension.precision("minute").typeOp should be(TypeOp.Date)
      dateTimeDimension.precision("day").typeOp should be(TypeOp.DateTime)
      dateTimeDimension.precision("month").typeOp should be(TypeOp.DateTime)
      dateTimeDimension.precision("year").typeOp should be(TypeOp.DateTime)
      dateTimeDimension.precision(DateTimeField.TimestampPrecision.id).typeOp should be(TypeOp.Timestamp)
    }
  }
} 
Example 30
Source File: DefaultFieldTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.field.defaultField

import com.stratio.sparta.plugin.default.DefaultField
import com.stratio.sparta.sdk.pipeline.aggregation.cube.{DimensionType, Precision}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}

@RunWith(classOf[JUnitRunner])
class DefaultFieldTest extends WordSpecLike with Matchers {

  val defaultDimension: DefaultField = new DefaultField(Map("typeOp" -> "int"))

  "A DefaultDimension" should {
    "In default implementation, get one precisions for a specific time" in {
      val precision: (Precision, Any) = defaultDimension.precisionValue("", "1".asInstanceOf[Any])

      precision._2 should be(1)

      precision._1.id should be(DimensionType.IdentityName)
    }

    "The precision must be int" in {
      defaultDimension.precision(DimensionType.IdentityName).typeOp should be(TypeOp.Int)
    }
  }
} 
Example 31
Source File: SocketInputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.input.socket

import java.io.{Serializable => JSerializable}

import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class SocketInputTest extends WordSpec {

  "A SocketInput" should {
    "instantiate successfully with parameters" in {
      new SocketInput(Map("hostname" -> "localhost", "port" -> 9999).mapValues(_.asInstanceOf[JSerializable]))
    }
    "fail without parameters" in {
      intercept[IllegalStateException] {
        new SocketInput(Map())
      }
    }
    "fail with bad port argument" in {
      intercept[IllegalStateException] {
        new SocketInput(Map("hostname" -> "localhost", "port" -> "BADPORT").mapValues(_.asInstanceOf[JSerializable]))
      }
    }
  }
} 
Example 32
Source File: TwitterJsonInputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.input.twitter

import java.io.{Serializable => JSerializable}

import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TwitterJsonInputTest extends WordSpec {

  "A TwitterInput" should {

    "fail without parameters" in {
      intercept[IllegalStateException] {
        new TwitterJsonInput(Map())
      }
    }
    "fail with bad arguments argument" in {
      intercept[IllegalStateException] {
        new TwitterJsonInput(Map("hostname" -> "localhost", "port" -> "BADPORT")
          .mapValues(_.asInstanceOf[JSerializable]))
      }
    }
  }
} 
Example 33
Source File: RabbitMQInputIT.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.input.rabbitmq

import java.util.UUID

import akka.pattern.{ask, gracefulStop}
import com.github.sstone.amqp.Amqp._
import com.github.sstone.amqp.{Amqp, ChannelOwner, ConnectionOwner, Consumer}
import com.rabbitmq.client.ConnectionFactory
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


import scala.concurrent.Await

@RunWith(classOf[JUnitRunner])
class RabbitMQInputIT extends RabbitIntegrationSpec {

  val queueName = s"$configQueueName-${this.getClass.getName}-${UUID.randomUUID().toString}"

  def initRabbitMQ(): Unit = {
    val connFactory = new ConnectionFactory()
    connFactory.setUri(RabbitConnectionURI)
    val conn = system.actorOf(ConnectionOwner.props(connFactory, RabbitTimeOut))
    val producer = ConnectionOwner.createChildActor(
      conn,
      ChannelOwner.props(),
      timeout = RabbitTimeOut,
      name = Some("RabbitMQ.producer")
    )

    val queue = QueueParameters(
      name = queueName,
      passive = false,
      exclusive = false,
      durable = true,
      autodelete = false
    )

    Amqp.waitForConnection(system, conn, producer).await()

    val deleteQueueResult = producer ? DeleteQueue(queueName)
    Await.result(deleteQueueResult, RabbitTimeOut)
    val createQueueResult = producer ? DeclareQueue(queue)
    Await.result(createQueueResult, RabbitTimeOut)

    //Send some messages to the queue
    val results = for (register <- 1 to totalRegisters)
      yield producer ? Publish(
        exchange = "",
        key = queueName,
        body = register.toString.getBytes
      )
    results.map(result => Await.result(result, RabbitTimeOut))
    
    conn ! Close()
    Await.result(gracefulStop(conn, RabbitTimeOut), RabbitTimeOut * 2)
    Await.result(gracefulStop(consumer, RabbitTimeOut), RabbitTimeOut * 2)
  }


  "RabbitMQInput " should {

    "Read all the records" in {
      val props = Map(
        "hosts" -> hosts,
        "queueName" -> queueName)

      val input = new RabbitMQInput(props)
      val distributedStream = input.initStream(ssc.get, DefaultStorageLevel)
      val totalEvents = ssc.get.sparkContext.accumulator(0L, "Number of events received")

      // Fires each time the configured window has passed.
      distributedStream.foreachRDD(rdd => {
        if (!rdd.isEmpty()) {
          val count = rdd.count()
          // Do something with this message
          log.info(s"EVENTS COUNT : $count")
          totalEvents.add(count)
        } else log.info("RDD is empty")
        log.info(s"TOTAL EVENTS : $totalEvents")
      })

      ssc.get.start() // Start the computation
      ssc.get.awaitTerminationOrTimeout(SparkTimeOut) // Wait for the computation to terminate

      totalEvents.value should ===(totalRegisters.toLong)
    }
  }

} 
Example 34
Source File: MessageHandlerTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.input.rabbitmq.handler

import com.rabbitmq.client.QueueingConsumer.Delivery
import com.stratio.sparta.plugin.input.rabbitmq.handler.MessageHandler.{ByteArrayMessageHandler, StringMessageHandler}
import org.junit.runner.RunWith
import org.mockito.Mockito._
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MessageHandlerTest extends WordSpec with Matchers with MockitoSugar {

  val message = "This is the message for testing"
  "RabbitMQ MessageHandler Factory " should {

    "Get correct handler for string with a map " in {
      MessageHandler(Map(MessageHandler.KeyDeserializer -> "")) should matchPattern { case StringMessageHandler => }
      MessageHandler(Map.empty[String, String]) should matchPattern { case StringMessageHandler => }
      MessageHandler(Map(MessageHandler.KeyDeserializer -> "badInput")) should
        matchPattern { case StringMessageHandler => }
      MessageHandler(Map(MessageHandler.KeyDeserializer -> "arraybyte")) should
        matchPattern { case ByteArrayMessageHandler => }
    }

    "Get correct handler for string " in {
      val result = MessageHandler("string")
      result should matchPattern { case StringMessageHandler => }
    }

    "Get correct handler for arraybyte " in {
      val result = MessageHandler("arraybyte")
      result should matchPattern { case ByteArrayMessageHandler => }
    }

    "Get correct handler for empty input " in {
      val result = MessageHandler("")
      result should matchPattern { case StringMessageHandler => }
    }

    "Get correct handler for bad input " in {
      val result = MessageHandler("badInput")
      result should matchPattern { case StringMessageHandler => }
    }
  }

  "StringMessageHandler " should {
    "Handle strings" in {
      val delivery = mock[Delivery]
      when(delivery.getBody).thenReturn(message.getBytes)
      val result = StringMessageHandler.handler(delivery)
      verify(delivery, times(1)).getBody
      result.getString(0) shouldBe message
    }
  }

  "ByteArrayMessageHandler " should {
    "Handle bytes" in {
      val delivery = mock[Delivery]
      when(delivery.getBody).thenReturn(message.getBytes)
      val result = ByteArrayMessageHandler.handler(delivery)
      verify(delivery, times(1)).getBody
      result.getAs[Array[Byte]](0) shouldBe message.getBytes
    }

  }

} 
Example 35
Source File: HostPortZkTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.input.kafka


import java.io.Serializable

import com.stratio.sparta.sdk.properties.JsoneyString
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class HostPortZkTest extends WordSpec with Matchers {

  class KafkaTestInput(val properties: Map[String, Serializable]) extends KafkaBase
  
  "getHostPortZk" should {

    "return a chain (zookeper:conection , host:port)" in {
      val conn = """[{"host": "localhost", "port": "2181"}]"""
      val props = Map("zookeeper.connect" -> JsoneyString(conn), "zookeeper.path" -> "")
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181"))
    }

    "return a chain (zookeper:conection , host:port, zookeeper.path:path)" in {
      val conn = """[{"host": "localhost", "port": "2181"}]"""
      val props = Map("zookeeper.connect" -> JsoneyString(conn), "zookeeper.path" -> "/test")
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181/test"))
    }

    "return a chain (zookeper:conection , host:port,host:port,host:port)" in {
      val conn =
        """[{"host": "localhost", "port": "2181"},{"host": "localhost", "port": "2181"},
          |{"host": "localhost", "port": "2181"}]""".stripMargin
      val props = Map("zookeeper.connect" -> JsoneyString(conn))
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181,localhost:2181,localhost:2181"))
    }

    "return a chain (zookeper:conection , host:port,host:port,host:port, zookeeper.path:path)" in {
      val conn =
        """[{"host": "localhost", "port": "2181"},{"host": "localhost", "port": "2181"},
          |{"host": "localhost", "port": "2181"}]""".stripMargin
      val props = Map("zookeeper.connect" -> JsoneyString(conn), "zookeeper.path" -> "/test")
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181,localhost:2181,localhost:2181/test"))
    }

    "return a chain with default port (zookeper:conection , host: defaultport)" in {

      val props = Map("foo" -> "var")
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181"))
    }

    "return a chain with default port (zookeper:conection , host: defaultport, zookeeper.path:path)" in {
      val props = Map("zookeeper.path" -> "/test")
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181/test"))
    }

    "return a chain with default host and default porty (zookeeper.connect: ," +
      "defaultHost: defaultport," +
      "zookeeper.path:path)" in {
      val props = Map("foo" -> "var")
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181"))
    }
  }
} 
Example 36
Source File: MorphlinesParserTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.morphline

import java.io.Serializable

import com.stratio.sparta.sdk.pipeline.input.Input
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Matchers, WordSpecLike}


@RunWith(classOf[JUnitRunner])
class MorphlinesParserTest extends WordSpecLike with Matchers with BeforeAndAfter with BeforeAndAfterAll {

  val morphlineConfig = """
          id : test1
          importCommands : ["org.kitesdk.**"]
          commands: [
          {
              readJson {},
          }
          {
              extractJsonPaths {
                  paths : {
                      col1 : /col1
                      col2 : /col2
                  }
              }
          }
          {
            java {
              code : "return child.process(record);"
            }
          }
          {
              removeFields {
                  blacklist:["literal:_attachment_body"]
              }
          }
          ]
                        """
  val inputField = Some(Input.RawDataKey)
  val outputsFields = Seq("col1", "col2")
  val props: Map[String, Serializable] = Map("morphline" -> morphlineConfig)

  val schema = StructType(Seq(StructField("col1", StringType), StructField("col2", StringType)))

  val parser = new MorphlinesParser(1, inputField, outputsFields, schema, props)

  "A MorphlinesParser" should {

    "parse a simple json" in {
      val simpleJson =
        """{
            "col1":"hello",
            "col2":"word"
            }
        """
      val input = Row(simpleJson)
      val result = parser.parse(input)

      val expected = Seq(Row(simpleJson, "hello", "world"))

      result should be eq(expected)
    }

    "parse a simple json removing raw" in {
      val simpleJson =
        """{
            "col1":"hello",
            "col2":"word"
            }
        """
      val input = Row(simpleJson)
      val result = parser.parse(input)

      val expected = Seq(Row("hello", "world"))

      result should be eq(expected)
    }

    "exclude not configured fields" in {
      val simpleJson =
        """{
            "col1":"hello",
            "col2":"word",
            "col3":"!"
            }
        """
      val input = Row(simpleJson)
      val result = parser.parse(input)

      val expected = Seq(Row(simpleJson, "hello", "world"))

      result should be eq(expected)
    }
  }
} 
Example 37
Source File: DateTimeParserTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.datetime

import com.stratio.sparta.sdk.properties.JsoneyString
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}

@RunWith(classOf[JUnitRunner])
class DateTimeParserTest extends WordSpecLike with Matchers {

  val inputField = Some("ts")
  val outputsFields = Seq("ts")

  //scalastyle:off
  "A DateTimeParser" should {
    "parse unixMillis to string" in {
      val input = Row(1416330788000L)
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unixMillis"))
          .parse(input)

      val expected = Seq(Row(1416330788000L, "1416330788000"))

      assertResult(result)(expected)
    }

    "parse unix to string" in {
      val input = Row(1416330788)
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unix"))
          .parse(input)

      val expected = Seq(Row(1416330788, "1416330788000"))

      assertResult(result)(expected)
    }

    "parse unix to string removing raw" in {
      val input = Row(1416330788)
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unix",
          "removeInputField" -> JsoneyString.apply("true")))
          .parse(input)

      val expected = Seq(Row("1416330788000"))

      assertResult(result)(expected)
    }

    "not parse anything if the field does not match" in {
      val input = Row("1212")
      val schema = StructType(Seq(StructField("otherField", StringType)))

      an[IllegalStateException] should be thrownBy new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unixMillis")).parse(input)
    }

    "not parse anything and generate a new Date" in {
      val input = Row("anything")
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "autoGenerated"))
          .parse(input)

      assertResult(result.head.size)(2)
    }

    "Auto generated if inputFormat does not exist" in {
      val input = Row("1416330788")
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map()).parse(input)

      assertResult(result.head.size)(2)
    }

    "parse dateTime in hive format" in {
      val input = Row("2015-11-08 15:58:58")
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "hive"))
          .parse(input)

      val expected = Seq(Row("2015-11-08 15:58:58", "1446998338000"))

      assertResult(result)(expected)
    }
  }
} 
Example 38
Source File: LongInputTests.scala    From boson   with Apache License 2.0 5 votes vote down vote up
package io.zink.boson

import bsonLib.BsonObject
import io.netty.util.ResourceLeakDetector
import io.vertx.core.json.JsonObject
import io.zink.boson.bson.bsonImpl.BosonImpl
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner
import org.junit.Assert._

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.io.Source


@RunWith(classOf[JUnitRunner])
class LongInputTests extends FunSuite {
  ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.ADVANCED)

  val bufferedSource: Source = Source.fromURL(getClass.getResource("/jsonOutput.txt"))
  val finale: String = bufferedSource.getLines.toSeq.head
  bufferedSource.close

  val json: JsonObject = new JsonObject(finale)
  val bson: BsonObject = new BsonObject(json)

  test("extract top field") {
    val expression: String = ".Epoch"
    val boson: Boson = Boson.extractor(expression, (out: Int) => {
      assertTrue(3 == out)
    })
    val res = boson.go(bson.encode.getBytes)
    Await.result(res, Duration.Inf)
  }

  test("extract bottom field") {
    val expression: String = "SSLNLastName"
    val expected: String = "de Huanuco"
    val boson: Boson = Boson.extractor(expression, (out: String) => {
      assertTrue(expected.zip(out).forall(e => e._1.equals(e._2)))
    })
    val res = boson.go(bson.encode.getBytes)
    Await.result(res, Duration.Inf)
  }

  test("extract positions of an Array") {
    val expression: String = "Markets[3 to 5]"
    val mutableBuffer: ArrayBuffer[Array[Byte]] = ArrayBuffer()
    val boson: Boson = Boson.extractor(expression, (out: Array[Byte]) => {
      mutableBuffer += out
    })
    val res = boson.go(bson.encode.getBytes)
    Await.result(res, Duration.Inf)
    assertEquals(3, mutableBuffer.size)
  }

  test("extract further positions of an Array") {
    val expression: String = "Markets[50 to 55]"
    val mutableBuffer: ArrayBuffer[Array[Byte]] = ArrayBuffer()
    val boson: Boson = Boson.extractor(expression, (out: Array[Byte]) => {
      mutableBuffer += out
    })
    val res = boson.go(bson.encode.getBytes)
    Await.result(res, Duration.Inf)
    assertEquals(6, mutableBuffer.size)
  }

  test("size of all occurrences of Key") {
    val expression: String = "Price"
    val mutableBuffer: ArrayBuffer[Float] = ArrayBuffer()
    val boson: Boson = Boson.extractor(expression, (out: Float) => {
      mutableBuffer += out
    })
    val res = boson.go(bson.encode.getBytes)
    Await.result(res, Duration.Inf)
    assertEquals(195, mutableBuffer.size)
  }

} 
Example 39
Source File: StorageTest.scala    From mqttd   with MIT License 5 votes vote down vote up
package plantae.citrus.mqtt.actors.session

import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class StorageTest extends FunSuite {
  test("persist test") {
    val storage = Storage("persist-test")
    Range(1, 2000).foreach(count => {
      storage.persist((count + " persist").getBytes, (count % 3).toShort, true, "topic" + count)
    })

    assert(
      !Range(1, 2000).exists(count => {
        storage.nextMessage match {
          case Some(message) =>
            storage.complete(message.packetId match {
              case Some(x) => Some(x)
              case None => None
            })
            println(new String(message.payload.toArray))
            count + " persist" != new String(message.payload.toArray)

          case None => true
        }
      })
    )
  }
} 
Example 40
Source File: LogisticRegressionTest.scala    From spark-cp   with Apache License 2.0 5 votes vote down vote up
package se.uu.farmbio.cp.alg

import org.apache.spark.SharedSparkContext
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import se.uu.farmbio.cp.ICP
import se.uu.farmbio.cp.TestUtils
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class LogisticRegressionTest extends FunSuite with SharedSparkContext {

  test("test performance") {
    val trainData = TestUtils.generateBinaryData(100, 11)
    val testData = TestUtils.generateBinaryData(30, 22)
    val (calibration, properTrain) = ICP.calibrationSplit(sc.parallelize(trainData), 16)  
    val lr = new LogisticRegression(properTrain, 30)
    val model = ICP.trainClassifier(lr, numClasses=2, calibration)
    assert(TestUtils.testPerformance(model, sc.parallelize(testData)))
  }

} 
Example 41
Source File: SVMTest.scala    From spark-cp   with Apache License 2.0 5 votes vote down vote up
package se.uu.farmbio.cp.alg

import org.apache.spark.SharedSparkContext
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import se.uu.farmbio.cp.ICP
import se.uu.farmbio.cp.TestUtils
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class SVMTest extends FunSuite with SharedSparkContext {

  test("test performance") {
    val trainData = TestUtils.generateBinaryData(100, 11)
    val testData = TestUtils.generateBinaryData(30, 22)
    val (calibration, properTrain) = ICP.calibrationSplit(sc.parallelize(trainData), 16)  
    val svm = new SVM(properTrain, 30)
    val model = ICP.trainClassifier(svm, numClasses=2, calibration)
    assert(TestUtils.testPerformance(model, sc.parallelize(testData)))
  }

} 
Example 42
Source File: GBTTest.scala    From spark-cp   with Apache License 2.0 5 votes vote down vote up
package se.uu.farmbio.cp.alg

import scala.util.Random
import org.apache.spark.SharedSparkContext
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import se.uu.farmbio.cp.ICP
import se.uu.farmbio.cp.TestUtils
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class GBTTest extends FunSuite with SharedSparkContext {
  
  Random.setSeed(11)

  test("test performance") {
    val trainData = TestUtils.generateBinaryData(100, 11)
    val testData = TestUtils.generateBinaryData(30, 22)
    val (calibration, properTrain) = ICP.calibrationSplit(sc.parallelize(trainData), 16)  
    val gbt = new GBT(properTrain, 30)
    val model = ICP.trainClassifier(gbt, numClasses=2, calibration)
    assert(TestUtils.testPerformance(model, sc.parallelize(testData)))
  }

} 
Example 43
Source File: KafkaTestUtilsTest.scala    From spark-testing-base   with Apache License 2.0 5 votes vote down vote up
package com.holdenkarau.spark.testing.kafka

import java.util.Properties

import scala.collection.JavaConversions._

import kafka.consumer.ConsumerConfig
import org.apache.spark.streaming.kafka.KafkaTestUtils
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, FunSuite}

@RunWith(classOf[JUnitRunner])
class KafkaTestUtilsTest extends FunSuite with BeforeAndAfterAll {

  private var kafkaTestUtils: KafkaTestUtils = _

  override def beforeAll(): Unit = {
    kafkaTestUtils = new KafkaTestUtils
    kafkaTestUtils.setup()
  }

  override def afterAll(): Unit = if (kafkaTestUtils != null) {
    kafkaTestUtils.teardown()
    kafkaTestUtils = null
  }

  test("Kafka send and receive message") {
    val topic = "test-topic"
    val message = "HelloWorld!"
    kafkaTestUtils.createTopic(topic)
    kafkaTestUtils.sendMessages(topic, message.getBytes)

    val consumerProps = new Properties()
    consumerProps.put("zookeeper.connect", kafkaTestUtils.zkAddress)
    consumerProps.put("group.id", "test-group")
    consumerProps.put("flow-topic", topic)
    consumerProps.put("auto.offset.reset", "smallest")
    consumerProps.put("zookeeper.session.timeout.ms", "2000")
    consumerProps.put("zookeeper.connection.timeout.ms", "6000")
    consumerProps.put("zookeeper.sync.time.ms", "2000")
    consumerProps.put("auto.commit.interval.ms", "2000")

    val consumer = kafka.consumer.Consumer.createJavaConsumerConnector(new ConsumerConfig(consumerProps))

    try {
      val topicCountMap = Map(topic -> new Integer(1))
      val consumerMap = consumer.createMessageStreams(topicCountMap)
      val stream = consumerMap.get(topic).get(0)
      val it = stream.iterator()
      val mess = it.next
      assert(new String(mess.message().map(_.toChar)) === message)
    } finally {
      consumer.shutdown()
    }
  }

} 
Example 44
Source File: TestBase.scala    From open-korean-text   with Apache License 2.0 5 votes vote down vote up
package org.openkoreantext.processor

import java.util.logging.{Level, Logger}

import org.junit.runner.RunWith
import org.openkoreantext.processor.util.KoreanDictionaryProvider._
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

object TestBase {

  case class ParseTime(time: Long, chunk: String)

  def time[R](block: => R): Long = {
    val t0 = System.currentTimeMillis()
    block
    val t1 = System.currentTimeMillis()
    t1 - t0
  }

  def assertExamples(exampleFiles: String, log: Logger, f: (String => String)) {
    assert({
      val input = readFileByLineFromResources(exampleFiles)

      val (parseTimes, hasErrors) = input.foldLeft((List[ParseTime](), true)) {
        case ((l: List[ParseTime], output: Boolean), line: String) =>
          val s = line.split("\t")
          val (chunk, parse) = (s(0), if (s.length == 2) s(1) else "")

          val oldTokens = parse
          val t0 = System.currentTimeMillis()
          val newTokens = f(chunk)
          val t1 = System.currentTimeMillis()

          val oldParseMatches = oldTokens == newTokens

          if (!oldParseMatches) {
            System.err.println("Example set match error: %s \n - EXPECTED: %s\n - ACTUAL  : %s".format(
              chunk, oldTokens, newTokens))
          }

          (ParseTime(t1 - t0, chunk) :: l, output && oldParseMatches)
      }

      val averageTime = parseTimes.map(_.time).sum.toDouble / parseTimes.size
      val maxItem = parseTimes.maxBy(_.time)

      log.log(Level.INFO, ("Parsed %d chunks. \n" +
          "       Total time: %d ms \n" +
          "       Average time: %.2f ms \n" +
          "       Max time: %d ms, %s").format(
            parseTimes.size,
            parseTimes.map(_.time).sum,
            averageTime,
            maxItem.time,
            maxItem.chunk
          ))
      hasErrors
    }, "Some parses did not match the example set.")
  }
}

@RunWith(classOf[JUnitRunner])
abstract class TestBase extends FunSuite 
Example 45
Source File: ITSelectorSuite.scala    From spark-infotheoretic-feature-selection   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.sql.{DataFrame, SQLContext}
import org.junit.runner.RunWith
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.junit.JUnitRunner
import TestHelper._



  test("Run ITFS on nci data (nPart = 10, nfeat = 10)") {

    val df = readCSVData(sqlContext, "test_nci9_s3.csv")
    val cols = df.columns
    val pad = 2
    val allVectorsDense = true
    val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head, 
        10, 10, allVectorsDense, pad)

    assertResult("443, 755, 1369, 1699, 3483, 5641, 6290, 7674, 9399, 9576") {
      model.selectedFeatures.mkString(", ")
    }
  }
} 
Example 46
Source File: Downloader$Test.scala    From mystem-scala   with MIT License 5 votes vote down vote up
package ru.stachek66.tools

import java.io.File
import java.net.URL

import org.junit.runner.RunWith
import org.scalatest.{Ignore, FunSuite}
import org.scalatest.junit.JUnitRunner


@Ignore
class Downloader$Test extends FunSuite {

  test("downloading-something") {

    val hello = new File("hello-test.html")
    val mystem = new File("atmta.binary")

    Downloader.downloadBinaryFile(new URL("http://www.stachek66.ru/"), hello)

    Downloader.downloadBinaryFile(
      new URL("http://download.cdn.yandex.net/mystem/mystem-3.0-linux3.1-64bit.tar.gz"),
      mystem
    )

    Downloader.downloadBinaryFile(
      new URL("http://download.cdn.yandex.net/mystem/mystem-3.1-win-64bit.zip"),
      mystem
    )

    hello.delete
    mystem.delete
  }

  test("download-and-unpack") {
    val bin = new File("atmta.binary.tar.gz")
    val bin2 = new File("executable")

    Decompressor.select.unpack(
      Downloader.downloadBinaryFile(
        new URL("http://download.cdn.yandex.net/mystem/mystem-3.0-linux3.1-64bit.tar.gz"),
        bin),
      bin2
    )

    bin.delete
    bin2.delete
  }
} 
Example 47
Source File: Zip$Test.scala    From mystem-scala   with MIT License 5 votes vote down vote up
package ru.stachek66.tools

import java.io.{File, FileInputStream}

import org.apache.commons.io.IOUtils
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

import org.junit.runner.RunWith


class Zip$Test extends FunSuite {

  test("zip-test") {
    val src = new File("src/test/resources/test.txt")
    Zip.unpack(
      new File("src/test/resources/test.zip"),
      new File("src/test/resources/res.txt")) match {
      case f =>
        val content0 = IOUtils.toString(new FileInputStream(f))
        val content1 = IOUtils.toString(new FileInputStream(src))
        print(content0.trim + " vs " + content1.trim)
        assert(content0 === content1)
    }
  }

} 
Example 48
Source File: TarGz$Test.scala    From mystem-scala   with MIT License 5 votes vote down vote up
package ru.stachek66.tools

import java.io.{File, FileInputStream}

import org.apache.commons.io.IOUtils
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner


class TarGz$Test extends FunSuite {

  test("tgz-test") {
    val src = new File("src/test/resources/test.txt")
    TarGz.unpack(
      new File("src/test/resources/test.tar.gz"),
      new File("src/test/resources/res.txt")) match {
      case f =>
        val content0 = IOUtils.toString(new FileInputStream(f))
        val content1 = IOUtils.toString(new FileInputStream(src))
        print(content0.trim + " vs " + content1.trim)
        assert(content0 === content1)
    }
  }
} 
Example 49
Source File: DaoServiceTest.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.scheduler.dao

import com.ivan.nikolov.scheduler.TestEnvironment
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class DaoServiceTest extends FlatSpec with Matchers with BeforeAndAfter with TestEnvironment {

  override val databaseService = new H2DatabaseService
  override val migrationService = new MigrationService
  override val daoService = new DaoServiceImpl
  
  before {
    // we run this here. Generally migrations will only
    // be dealing with data layout and we will be able to have
    // test classes that insert test data.
    migrationService.runMigrations()
  }
  
  after {
    migrationService.cleanupDatabase()
  }
  
  "readResultSet" should "properly iterate over a result set and apply a function to it." in {
    val connection = daoService.getConnection()
    try {
      val result = daoService.executeSelect(
        connection.prepareStatement(
          "SELECT name FROM people"
        )
      ) {
        case rs =>
          daoService.readResultSet(rs) {
            case row =>
              row.getString("name")
          }
      }
      result should have size(3)
      result should contain("Ivan")
      result should contain("Maria")
      result should contain("John")
    } finally {
      connection.close()
      
    }
  }
} 
Example 50
package com.ivan.nikolov.scheduler.services

import com.ivan.nikolov.scheduler.TestEnvironment
import com.ivan.nikolov.scheduler.config.job.{Console, Daily, JobConfig, TimeOptions}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class JobConfigReaderServiceTest extends FlatSpec with Matchers with TestEnvironment {

  override val ioService: IOService = new IOService
  override val jobConfigReaderService: JobConfigReaderService = new JobConfigReaderService

  "readJobConfigs" should "read and parse configurations successfully." in {
    val result = jobConfigReaderService.readJobConfigs()
    result should have size(1)
    result should contain(
      JobConfig(
        "Test Command",
        "ping google.com -c 10",
        Console,
        Daily,
        TimeOptions(12, 10)
      )
    )
  }
} 
Example 51
Source File: TimeOptionsTest.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.scheduler.config.job

import java.time.LocalDateTime

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class TimeOptionsTest extends FlatSpec with Matchers {

  "getInitialDelay" should "get the right initial delay for hourly less than an hour after now." in {
    val now = LocalDateTime.of(2018, 3, 20, 12, 43, 10)
    val later = now.plusMinutes(20)
    val timeOptions = TimeOptions(later.getHour, later.getMinute)
    val result = timeOptions.getInitialDelay(now, Hourly)
    result.toMinutes should equal(20)
  }
  
  it should "get the right initial delay for hourly more than an hour after now." in {
    val now = LocalDateTime.of(2018, 3, 20, 18, 51, 17)
    val later = now.plusHours(3)
    val timeOptions = TimeOptions(later.getHour, later.getMinute)
    val result = timeOptions.getInitialDelay(now, Hourly)
    result.toHours should equal(3)
  }
  
  it should "get the right initial delay for hourly less than an hour before now." in {
    val now = LocalDateTime.of(2018, 3, 20, 11, 18, 55)
    val earlier = now.minusMinutes(25)
    // because of the logic and it will fail otherwise.
    if (earlier.getDayOfWeek == now.getDayOfWeek) {
      val timeOptions = TimeOptions(earlier.getHour, earlier.getMinute)
      val result = timeOptions.getInitialDelay(now, Hourly)
      result.toMinutes should equal(35)
    }
  }
  
  it should "get the right initial delay for hourly more than an hour before now." in {
    val now = LocalDateTime.of(2018, 3, 20, 12, 43, 59)
    val earlier = now.minusHours(1).minusMinutes(25)
    // because of the logic and it will fail otherwise.
    if (earlier.getDayOfWeek == now.getDayOfWeek) {
      val timeOptions = TimeOptions(earlier.getHour, earlier.getMinute)
      val result = timeOptions.getInitialDelay(now, Hourly)
      result.toMinutes should equal(35)
    }
  }
  
  it should "get the right initial delay for daily before now." in {
    val now = LocalDateTime.of(2018, 3, 20, 14, 43, 10)
    val earlier = now.minusMinutes(25)
    // because of the logic and it will fail otherwise.
    if (earlier.getDayOfWeek == now.getDayOfWeek) {
      val timeOptions = TimeOptions(earlier.getHour, earlier.getMinute)
      val result = timeOptions.getInitialDelay(now, Daily)
      result.toMinutes should equal(24 * 60 - 25)
    }
  }
  
  it should "get the right initial delay for daily after now." in {
    val now = LocalDateTime.of(2018, 3, 20, 16, 21, 6)
    val later = now.plusMinutes(20)
    val timeOptions = TimeOptions(later.getHour, later.getMinute)
    val result = timeOptions.getInitialDelay(now, Daily)
    result.toMinutes should equal(20)
  }
} 
Example 52
Source File: TraitATest.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.composition

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class TraitATest extends FlatSpec with Matchers with A {

  "hello" should "greet properly." in {
    hello() should equal("Hello, I am trait A!")
  }
  
  "pass" should "return the right string with the number." in {
    pass(10) should equal("Trait A said: 'You passed 10.'")
  }
  
  it should "be correct also for negative values." in {
    pass(-10) should equal("Trait A said: 'You passed -10.'")
  }
} 
Example 53
Source File: TraitACaseScopeTest.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.composition

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class TraitACaseScopeTest extends FlatSpec with Matchers {
  "hello" should "greet properly." in new A {
    hello() should equal("Hello, I am trait A!")
  }

  "pass" should "return the right string with the number." in new A {
    pass(10) should equal("Trait A said: 'You passed 10.'")
  }

  it should "be correct also for negative values." in new A {
    pass(-10) should equal("Trait A said: 'You passed -10.'")
  }
} 
Example 54
Source File: UserComponentTest.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.cake

import com.ivan.nikolov.cake.model.Person
import org.junit.runner.RunWith
import org.mockito.Mockito._
import org.scalatest.junit.JUnitRunner
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class UserComponentTest extends FlatSpec with Matchers with MockitoSugar with TestEnvironment {
  val className = "A"
  val emptyClassName = "B"
  val people = List(
    Person(1, "a", 10),
    Person(2, "b", 15),
    Person(3, "c", 20)
  )
  
  override val userService = new UserService
  
  when(dao.getPeopleInClass(className)).thenReturn(people)
  when(dao.getPeopleInClass(emptyClassName)).thenReturn(List())
  
  "getAverageAgeOfUsersInClass" should "properly calculate the average of all ages." in {
    userService.getAverageAgeOfUsersInClass(className) should equal(15.0)
  }
  
  it should "properly handle an empty result." in {
    userService.getAverageAgeOfUsersInClass(emptyClassName) should equal(0.0)
  }
} 
Example 55
Source File: MetricsStatsReceiverTest.scala    From finagle-metrics   with MIT License 5 votes vote down vote up
package com.twitter.finagle.metrics

import com.twitter.finagle.metrics.MetricsStatsReceiver._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.FunSuite

@RunWith(classOf[JUnitRunner])
class MetricsStatsReceiverTest extends FunSuite {

  private[this] val receiver = new MetricsStatsReceiver()

  private[this] def readGauge(name: String): Option[Number] =
    Option(metrics.getGauges.get(name)) match {
      case Some(gauge) => Some(gauge.getValue.asInstanceOf[Float])
      case _ => None
    }

  private[this] def readCounter(name: String): Option[Number] =
    Option(metrics.getMeters.get(name)) match {
      case Some(counter) => Some(counter.getCount)
      case _ => None
    }

  private[this] def readStat(name: String): Option[Number] =
    Option(metrics.getHistograms.get(name)) match {
      case Some(stat) => Some(stat.getSnapshot.getValues.toSeq.sum)
      case _ => None
    }

  test("MetricsStatsReceiver should store and read gauge into the Codahale Metrics library") {
    val x = 1.5f
    receiver.addGauge("my_gauge")(x)

    assert(readGauge("my_gauge") === Some(x))
  }

  test("MetricsStatsReceiver should always assume the latest value of an already created gauge") {
    val gaugeName = "my_gauge2"
    val expectedValue = 8.8f

    receiver.addGauge(gaugeName)(2.2f)
    receiver.addGauge(gaugeName)(9.9f)
    receiver.addGauge(gaugeName)(expectedValue)

    assert(readGauge(gaugeName) === Some(expectedValue))
  }

  test("MetricsStatsReceiver should store and remove gauge into the Codahale Metrics Library") {
    val gaugeName = "temp-gauge"
    val expectedValue = 2.8f

    val tempGauge = receiver.addGauge(gaugeName)(expectedValue)
    assert(readGauge(gaugeName) === Some(expectedValue))

    tempGauge.remove()

    assert(readGauge(gaugeName) === None)
  }

  test("MetricsStatsReceiver should store and read stat into the Codahale Metrics library") {
    val x = 1
    val y = 3
    val z = 5

    val s = receiver.stat("my_stat")
    s.add(x)
    s.add(y)
    s.add(z)

    assert(readStat("my_stat") === Some(x + y + z))
  }

  test("MetricsStatsReceiver should store and read counter into the Codahale Metrics library") {
    val x = 2
    val y = 5
    val z = 8

    val c = receiver.counter("my_counter")
    c.incr(x)
    c.incr(y)
    c.incr(z)

    assert(readCounter("my_counter") === Some(x + y + z))
  }

} 
Example 56
Source File: ReceiverWithoutOffsetIT.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource

import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.datasource.models.{InputSentences, StopConditions}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ReceiverWithoutOffsetIT extends TemporalDataSuite {

  test("DataSource Receiver should read all the records on each batch without offset conditions") {
    sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    val rdd = sc.parallelize(registers)
    sqlContext.createDataFrame(rdd, schema).registerTempTable(tableName)
    ssc = new StreamingContext(sc, Seconds(1))
    val totalEvents = ssc.sparkContext.accumulator(0L, "Number of events received")
    val inputSentences = InputSentences(
      s"select * from $tableName",
      StopConditions(stopWhenEmpty = true, finishContextWhenEmpty = true),
      initialStatements = Seq.empty[String]
    )
    val distributedStream = DatasourceUtils.createStream(ssc, inputSentences, datasourceParams)

    distributedStream.start()
    distributedStream.foreachRDD(rdd => {
      val streamingEvents = rdd.count()
      log.info(s" EVENTS COUNT : \t $streamingEvents")
      totalEvents += streamingEvents
      log.info(s" TOTAL EVENTS : \t $totalEvents")
      if (!rdd.isEmpty())
        assert(streamingEvents === totalRegisters.toLong)
    })
    ssc.start()
    ssc.awaitTerminationOrTimeout(10000L)

    assert(totalEvents.value === totalRegisters.toLong * 10)
  }
} 
Example 57
Source File: ReceiverNotStopContextIT.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource

import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.datasource.models.{InputSentences, OffsetConditions, OffsetField}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ReceiverNotStopContextIT extends TemporalDataSuite {

  test("DataSource Receiver should read all the records in one streaming batch") {
    sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    val rdd = sc.parallelize(registers)
    sqlContext.createDataFrame(rdd, schema).registerTempTable(tableName)
    ssc = new StreamingContext(sc, Seconds(1))
    val totalEvents = ssc.sparkContext.accumulator(0L, "Number of events received")
    val inputSentences = InputSentences(
      s"select * from $tableName",
      OffsetConditions(OffsetField("idInt")),
      initialStatements = Seq.empty[String]
    )
    val distributedStream = DatasourceUtils.createStream(ssc, inputSentences, datasourceParams)

    distributedStream.start()
    distributedStream.foreachRDD(rdd => {
      totalEvents += rdd.count()
    })
    ssc.start()
    ssc.awaitTerminationOrTimeout(15000L)

    assert(totalEvents.value === totalRegisters.toLong)
  }
} 
Example 58
Source File: ReceiverLimitedIT.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource

import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.datasource.models.{InputSentences, OffsetConditions, OffsetField, StopConditions}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ReceiverLimitedIT extends TemporalDataSuite {

  test("DataSource Receiver should read the records limited on each batch") {
    sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    val rdd = sc.parallelize(registers)
    sqlContext.createDataFrame(rdd, schema).registerTempTable(tableName)

    ssc = new StreamingContext(sc, Seconds(1))
    val totalEvents = ssc.sparkContext.accumulator(0L, "Number of events received")
    val inputSentences = InputSentences(
      s"select * from $tableName",
      OffsetConditions(OffsetField("idInt"), limitRecords = 1000),
      StopConditions(stopWhenEmpty = true, finishContextWhenEmpty = true),
      initialStatements = Seq.empty[String]
    )
    val distributedStream = DatasourceUtils.createStream(ssc, inputSentences, datasourceParams)

    // Start up the receiver.
    distributedStream.start()

    // Fires each time the configured window has passed.
    distributedStream.foreachRDD(rdd => {
      totalEvents += rdd.count()
    })

    ssc.start() // Start the computation
    ssc.awaitTerminationOrTimeout(15000L) // Wait for the computation to terminate

    assert(totalEvents.value === totalRegisters.toLong)
  }
} 
Example 59
Source File: ReceiverBasicIT.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource

import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.datasource.models.{InputSentences, OffsetConditions, OffsetField, StopConditions}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class ReceiverBasicIT extends TemporalDataSuite {

  test ("DataSource Receiver should read all the records in one streaming batch") {
    sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    val rdd = sc.parallelize(registers)
    sqlContext.createDataFrame(rdd, schema).registerTempTable(tableName)
    ssc = new StreamingContext(sc, Seconds(1))
    val totalEvents = ssc.sparkContext.accumulator(0L, "Number of events received")
    val inputSentences = InputSentences(
      s"select * from $tableName",
      OffsetConditions(OffsetField("idInt")),
      StopConditions(stopWhenEmpty = true, finishContextWhenEmpty = true),
      initialStatements = Seq.empty[String]
    )
    val distributedStream = DatasourceUtils.createStream(ssc, inputSentences, datasourceParams)

    distributedStream.start()
    distributedStream.foreachRDD(rdd => {
      val streamingEvents = rdd.count()
      log.info(s" EVENTS COUNT : \t $streamingEvents")
      totalEvents += streamingEvents
      log.info(s" TOTAL EVENTS : \t $totalEvents")
      val streamingRegisters = rdd.collect()
      if (!rdd.isEmpty())
        assert(streamingRegisters === registers.reverse)
    })
    ssc.start()
    ssc.awaitTerminationOrTimeout(15000L)

    assert(totalEvents.value === totalRegisters.toLong)
  }
} 
Example 60
Source File: DefaultSaverITCase.scala    From flink-tensorflow   with Apache License 2.0 5 votes vote down vote up
package org.apache.flink.contrib.tensorflow.io

import org.apache.flink.contrib.tensorflow.models.savedmodel.DefaultSavedModelLoader
import org.apache.flink.contrib.tensorflow.util.{FlinkTestBase, RegistrationUtils}
import org.apache.flink.core.fs.Path
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}
import org.tensorflow.{Session, Tensor}

import scala.collection.JavaConverters._

@RunWith(classOf[JUnitRunner])
class DefaultSaverITCase extends WordSpecLike
  with Matchers
  with FlinkTestBase {

  override val parallelism = 1

  "A DefaultSaver" should {
    "run the save op" in {
      val env = StreamExecutionEnvironment.getExecutionEnvironment
      RegistrationUtils.registerTypes(env.getConfig)

      val loader = new DefaultSavedModelLoader(new Path("../models/half_plus_two"), "serve")
      val bundle = loader.load()
      val saverDef = loader.metagraph.getSaverDef
      val saver = new DefaultSaver(saverDef)

      def getA = getVariable(bundle.session(), "a").floatValue()
      def setA(value: Float) = setVariable(bundle.session(), "a", Tensor.create(value))

      val initialA = getA
      println("Initial value: " + initialA)

      setA(1.0f)
      val savePath = tempFolder.newFolder("model-0").getAbsolutePath
      val path = saver.save(bundle.session(), savePath)
      val savedA = getA
      savedA shouldBe (1.0f)
      println("Saved value: " + getA)

      setA(2.0f)
      val updatedA = getA
      updatedA shouldBe (2.0f)
      println("Updated value: " + updatedA)

      saver.restore(bundle.session(), path)
      val restoredA = getA
      restoredA shouldBe (savedA)
      println("Restored value: " + restoredA)
    }

    def getVariable(sess: Session, name: String): Tensor = {
      val result = sess.runner().fetch(name).run().asScala
      result.head
    }

    def setVariable(sess: Session, name: String, value: Tensor): Unit = {
      sess.runner()
        .addTarget(s"$name/Assign")
        .feed(s"$name/initial_value", value)
        .run()
    }
  }
} 
Example 61
Source File: ArraysTest.scala    From flink-tensorflow   with Apache License 2.0 5 votes vote down vote up
package org.tensorflow.contrib.scala

import com.twitter.bijection.Conversion._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}
import org.tensorflow.contrib.scala.Arrays._
import org.tensorflow.contrib.scala.Rank._
import resource._

@RunWith(classOf[JUnitRunner])
class ArraysTest extends WordSpecLike
  with Matchers {

  "Arrays" when {
    "Array[Float]" should {
      "convert to Tensor[`1D`,Float]" in {
        val expected = Array(1f,2f,3f)
        managed(expected.as[TypedTensor[`1D`,Float]]).foreach { t =>
          t.shape shouldEqual Array(expected.length)
          val actual = t.as[Array[Float]]
          actual shouldEqual expected
        }
      }
    }
  }
} 
Example 62
Source File: TestRenaming.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.lir

import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners
import at.forsyte.apalache.tla.lir.transformations.standard.Renaming
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterEach, FunSuite}


@RunWith(classOf[JUnitRunner])
class TestRenaming extends FunSuite with BeforeAndAfterEach with TestingPredefs {
  import at.forsyte.apalache.tla.lir.Builder._

  private var renaming = new Renaming(TrackerWithListeners())

  override protected def beforeEach(): Unit = {
    renaming = new Renaming(TrackerWithListeners())
  }

  test("test renaming exists/forall") {
    val original =
        and(
          exists(n_x, n_S, gt(n_x, int(1))),
          forall(n_x, n_T, lt(n_x, int(42))))
    ///
    val expected =
      and(
        exists(name("x_1"), n_S, gt(name("x_1"), int(1))),
        forall(name("x_2"), n_T, lt(name("x_2"), int(42))))
    val renamed = renaming.renameBindingsUnique(original)
    assert(expected == renamed)
  }

  test("test renaming filter") {
    val original =
        cup(
          filter(name("x"), name("S"), eql(name("x"), int(1))),
          filter(name("x"), name("S"), eql(name("x"), int(2)))
        )
    val expected =
      cup(
        filter(name("x_1"), name("S"), eql(name("x_1"), int(1))),
        filter(name("x_2"), name("S"), eql(name("x_2"), int(2))))
    val renamed = renaming.renameBindingsUnique(original)
    assert(expected == renamed)
  }

  test( "Test renaming LET-IN" ) {
    // LET p(t) == \A x \in S . R(t,x) IN \E x \in S . p(x)
    val original =
      letIn(
        exists( n_x, n_S, appOp( name( "p" ), n_x ) ),
        declOp( "p", forall( n_x, n_S, appOp( name( "R" ), name( "t" ), n_x ) ), "t" )
      )

    val expected =
      letIn(
        exists( name( "x_2" ), n_S, appOp( name( "p_1" ), name( "x_2" ) ) ),
        declOp( "p_1", forall( name( "x_1" ), n_S, appOp( name( "R" ), name( "t_1" ), name( "x_1" ) ) ), "t_1" )
      )

    val actual = renaming( original )

    assert(expected == actual)
  }

  test( "Test renaming multiple LET-IN" ) {
    // LET X == TRUE IN X /\ LET X == FALSE IN X
    val original =
      and(
        letIn(
          appOp( name( "X" ) ),
          declOp( "X", trueEx )
        ),
        letIn(
          appOp( name( "X" ) ),
          declOp( "X", falseEx )
        )
      )

    val expected =
      and(
      letIn(
        appOp( name( "X_1" ) ),
        declOp( "X_1", trueEx )
      ),
      letIn(
        appOp( name( "X_2" ) ),
        declOp( "X_2", falseEx )
      )
    )

    val actual = renaming( original )

    assert(expected == actual)
  }

} 
Example 63
Source File: TestLirValues.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.lir

import at.forsyte.apalache.tla.lir.values._
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class TestLirValues extends FunSuite {
  test("create booleans") {
    val b = TlaBool(false)
    assert(!b.value)
  }

  test("create int") {
    val i = TlaInt(1)
    assert(i.value == BigInt(1))
    assert(i == TlaInt(1))
    assert(i.isNatural)
    assert(TlaInt(0).isNatural)
    assert(!TlaInt(-1).isNatural)
  }

  test("create a string") {
    val s = TlaStr("hello")
    assert(s.value == "hello")
  }


  test("create a constant") {
    val c = new TlaConstDecl("x")
    assert("x" == c.name)
  }

  test("create a variable") {
    val c = new TlaVarDecl("x")
    assert("x" == c.name)
  }
} 
Example 64
Source File: TestAux.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.lir

import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

@RunWith( classOf[JUnitRunner] )
class TestAux extends FunSuite with TestingPredefs {

  test( "Test aux::collectSegments" ){

    val ar0Decl1 = TlaOperDecl( "X", List.empty, n_x )
    val ar0Decl2 = TlaOperDecl( "Y", List.empty, n_y )
    val ar0Decl3 = TlaOperDecl( "Z", List.empty, n_z )

    val arGe0Decl1 = TlaOperDecl( "A", List( SimpleFormalParam( "t" ) ), n_a )
    val arGe0Decl2 = TlaOperDecl( "B", List( SimpleFormalParam( "t" ) ), n_b )
    val arGe0Decl3 = TlaOperDecl( "C", List( SimpleFormalParam( "t" ) ), n_c )

    val pa1 =
      List( ar0Decl1 ) ->
        List( List( ar0Decl1 ) )
    val pa2 =
      List( ar0Decl1, ar0Decl2 ) ->
        List( List( ar0Decl1, ar0Decl2 ) )
    val pa3 =
      List( arGe0Decl1, ar0Decl1 ) ->
        List( List( arGe0Decl1 ), List( ar0Decl1 ) )
    val pa4 =
      List( arGe0Decl1, arGe0Decl2 ) ->
        List( List( arGe0Decl1, arGe0Decl2 ) )
    val pa5 =
      List( arGe0Decl1, arGe0Decl2, ar0Decl1, ar0Decl2, arGe0Decl3 ) ->
        List( List( arGe0Decl1, arGe0Decl2 ), List( ar0Decl1, ar0Decl2 ), List( arGe0Decl3 ) )

    val expected = Seq(
      pa1, pa2, pa3, pa4, pa5
    )
    val cmp = expected map { case (k, v) =>
      (v, aux.collectSegments( k ))
    }
    cmp foreach { case (ex, act) =>
      assert( ex == act )
    }
  }
} 
Example 65
Source File: TestTypeReduction.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.types

import at.forsyte.apalache.tla.lir.TestingPredefs
import org.junit.runner.RunWith
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.junit.JUnitRunner

@RunWith( classOf[JUnitRunner] )
class TestTypeReduction extends FunSuite with TestingPredefs with BeforeAndAfter {

  var gen = new SmtVarGenerator
  var tr  = new TypeReduction( gen )

  before {
    gen = new SmtVarGenerator
    tr = new TypeReduction( gen )
  }

  test( "Test nesting" ) {
    val tau = FunT( IntT, SetT( IntT ) )
    val m = Map.empty[TypeVar, SmtTypeVariable]
    val rr = tr( tau, m )
    assert( rr.t == fun( int, set( int ) ) )
  }

  test("Test tuples"){
    val tau = SetT( FunT( TupT( IntT, StrT ), SetT( IntT ) ) )
    val m = Map.empty[TypeVar, SmtTypeVariable]
    val rr = tr(tau, m)
    val idx = SmtIntVariable( 0 )
    assert( rr.t == set( fun( tup( idx ), set( int ) ) ) )
    assert( rr.phi.contains( hasIndex( idx, 0, int ) ) )
    assert( rr.phi.contains( hasIndex( idx, 1, str ) ) )
  }
} 
Example 66
Source File: TestSymbStateRewriterStr.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.lir.convenience.tla
import at.forsyte.apalache.tla.lir.values.TlaStr
import at.forsyte.apalache.tla.lir.{NameEx, ValEx}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestSymbStateRewriterStr extends RewriterBase {
  test("SE-STR-CTOR: \"red\" -> $C$k") {
    val state = new SymbState(ValEx(TlaStr("red")),
      CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextStateRed = rewriter.rewriteUntilDone(state)
    nextStateRed.ex match {
      case predEx@NameEx(name) =>
        assert(CellTheory().hasConst(name))
        assert(CellTheory() == state.theory)
        assert(solverContext.sat())
        val redEqBlue = tla.eql(tla.str("blue"), tla.str("red"))
        val nextStateEq = rewriter.rewriteUntilDone(nextStateRed.setRex(redEqBlue))
        rewriter.push()
        solverContext.assertGroundExpr(nextStateEq.ex)
        assert(!solverContext.sat())
        rewriter.pop()
        solverContext.assertGroundExpr(tla.not(nextStateEq.ex))
        assert(solverContext.sat())


      case _ =>
        fail("Unexpected rewriting result")
    }
  }
} 
Example 67
Source File: TestVCGenerator.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.imp.SanyImporter
import at.forsyte.apalache.tla.imp.src.SourceStore
import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker
import at.forsyte.apalache.tla.lir.{TlaModule, TlaOperDecl}
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

import scala.io.Source

@RunWith(classOf[JUnitRunner])
class TestVCGenerator extends FunSuite {
  private def mkVCGen(): VCGenerator = {
    new VCGenerator(new IdleTracker)
  }

  test("simple invariant") {
    val text =
      """---- MODULE inv ----
        |EXTENDS Integers
        |VARIABLE x
        |Inv == x > 0
        |====================
      """.stripMargin

    val mod = loadFromText("inv", text)
    val newMod = mkVCGen().gen(mod, "Inv")
    assertDecl(newMod, "VCInv$0", "x > 0")
    assertDecl(newMod, "VCNotInv$0", "¬(x > 0)")
  }

  test("conjunctive invariant") {
    val text =
      """---- MODULE inv ----
        |EXTENDS Integers
        |VARIABLE x
        |Inv == x > 0 /\ x < 10
        |====================
      """.stripMargin

    val mod = loadFromText("inv", text)
    val newMod = mkVCGen().gen(mod, "Inv")
    assertDecl(newMod, "VCInv$0", "x > 0")
    assertDecl(newMod, "VCInv$1", "x < 10")
    assertDecl(newMod, "VCNotInv$0", "¬(x > 0)")
    assertDecl(newMod, "VCNotInv$1", "¬(x < 10)")
  }

  test("conjunction under universals") {
    val text =
      """---- MODULE inv ----
        |EXTENDS Integers
        |VARIABLE x, S
        |Inv == \A z \in S: \A y \in S: y > 0 /\ y < 10
        |====================
      """.stripMargin

    val mod = loadFromText("inv", text)
    val newMod = mkVCGen().gen(mod, "Inv")
    assertDecl(newMod, "VCInv$0", """∀z ∈ S: (∀y ∈ S: (y > 0))""")
    assertDecl(newMod, "VCInv$1", """∀z ∈ S: (∀y ∈ S: (y < 10))""")
    assertDecl(newMod, "VCNotInv$0", """¬(∀z ∈ S: (∀y ∈ S: (y > 0)))""")
    assertDecl(newMod, "VCNotInv$1", """¬(∀z ∈ S: (∀y ∈ S: (y < 10)))""")
  }

  private def assertDecl(mod: TlaModule, name: String, expectedBodyText: String): Unit = {
    val vc = mod.declarations.find(_.name == name)
    assert(vc.nonEmpty, s"(VC $name not found)")
    assert(vc.get.isInstanceOf[TlaOperDecl])
    assert(vc.get.asInstanceOf[TlaOperDecl].body.toString == expectedBodyText)
  }

  private def loadFromText(moduleName: String, text: String): TlaModule = {
    val locationStore = new SourceStore
    val (rootName, modules) = new SanyImporter(locationStore)
      .loadFromSource(moduleName, Source.fromString(text))
    modules(moduleName)
  }
} 
Example 68
Source File: TestTypeInference.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.bmcmt.types.{Signatures, TypeInference}
import at.forsyte.apalache.tla.lir.TestingPredefs
import at.forsyte.apalache.tla.lir.convenience._
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

// TODO: remove?
@RunWith( classOf[JUnitRunner] )
class TestTypeInference extends FunSuite with TestingPredefs {

  ignore( "Signatures" ) {
    val exs = List(
      tla.and( n_x, n_y ),
      tla.choose( n_x, n_S, n_p ),
      tla.enumSet( seq( 10 ) : _* ),
      tla.in( n_x, n_S ),
      tla.map( n_e, n_x, n_S )
    )

    val sigs = exs map Signatures.get

    exs zip sigs foreach { case (x, y) => println( s"${x}  ...  ${y}" ) }

    val funDef = tla.funDef( tla.plus( n_x, n_y ), n_x, n_S, n_y, n_T )

    val sig = Signatures.get( funDef )

    printsep()
    println( sig )
    printsep()
  }

  ignore( "TypeInference" ) {
    val ex = tla.and( tla.primeEq( n_a, tla.choose( n_x, n_S, n_p ) ), tla.in( 2, n_S ) )

    val r = TypeInference.theta( ex )

    println( r )

  }

  ignore( "Application" ) {

    val ex = tla.eql( tla.plus(  tla.appFun( n_f, n_x ) , 2), 4 )
    val ex2 =
      tla.and(
        tla.in( n_x, n_S ),
        tla.le(
          tla.plus(
            tla.mult( 2, n_x ),
            5
          ),
          10
        ),
        tla.primeEq( n_x,
          tla.appFun(
            n_f,
            n_x
          )
        )
      )

    val r = TypeInference( ex )
  }
} 
Example 69
Source File: TestSymbStateRewriterChoose.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.bmcmt.types.{AnnotationParser, FinSetT, IntT}
import at.forsyte.apalache.tla.lir.TestingPredefs
import at.forsyte.apalache.tla.lir.convenience.tla
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestSymbStateRewriterChoose extends RewriterBase with TestingPredefs {
  test("""CHOOSE x \in {1, 2, 3}: x > 1""") {
    val ex = tla.choose(tla.name("x"),
      tla.enumSet(tla.int(1), tla.int(2), tla.int(3)),
      tla.gt(tla.name("x"), tla.int(1)))
    val state = new SymbState(ex, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    assert(solverContext.sat())
    def assertEq(i: Int): SymbState = {
      val ns = rewriter.rewriteUntilDone(nextState.setRex(tla.eql(nextState.ex, tla.int(i))))
      solverContext.assertGroundExpr(ns.ex)
      ns
    }

    rewriter.push()
    assertEq(3)
    assert(solverContext.sat())
    rewriter.pop()
    rewriter.push()
    assertEq(2)
    assert(solverContext.sat())
    rewriter.pop()
    rewriter.push()
    val ns = assertEq(1)
    assertUnsatOrExplain(rewriter, ns)
  }

  test("""CHOOSE x \in {1}: x > 1""") {
    val ex = tla.choose(tla.name("x"),
      tla.enumSet(tla.int(1)),
      tla.gt(tla.name("x"), tla.int(1)))
    val state = new SymbState(ex, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    // the buggy implementation of choose fails on a dynamically empty set
    assert(solverContext.sat())
    // The semantics of choose does not restrict the outcome on the empty sets,
    // so we do not test for anything here. Our previous implementation of CHOOSE produced default values in this case,
    // but this happened to be error-prone and sometimes conflicting with other rules. So, no default values.
  }

  test("""CHOOSE x \in {}: x > 1""") {
    val ex = tla.choose(tla.name("x"),
      tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT()))),
      tla.gt(tla.name("x"), tla.int(1)))
    val state = new SymbState(ex, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    // the buggy implementation of choose fails on a dynamically empty set
    assert(solverContext.sat())
    def assertEq(i: Int): SymbState = {
      val ns = rewriter.rewriteUntilDone(nextState.setRex(tla.eql(nextState.ex, tla.int(i))))
      solverContext.assertGroundExpr(ns.ex)
      ns
    }

    // Actually, semantics of choose does not restrict the outcome on the empty sets.
    // But we know that our implementation would always return 0 in this case.
    val ns = assertEq(1)
    assertUnsatOrExplain(rewriter, ns)
  }
} 
Example 70
Source File: TestSymbStateRewriterFiniteSets.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.bmcmt.types._
import at.forsyte.apalache.tla.lir.{NameEx, TlaEx}
import at.forsyte.apalache.tla.lir.convenience.tla
import at.forsyte.apalache.tla.lir.oper.TlaFunOper
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestSymbStateRewriterFiniteSets extends RewriterBase {
  test("""Cardinality({1, 2, 3}) = 3""") {
    val set = tla.enumSet(1.to(3).map(tla.int) :_*)
    val card = tla.card(set)
    val state = new SymbState(card, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    assert(solverContext.sat())
    assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(3), nextState.ex)))
  }

  test("""Cardinality({1, 2, 2, 2, 3, 3}) = 3""") {
    val set = tla.enumSet(Seq(1, 2, 2, 2, 3, 3).map(tla.int) :_*)
    val card = tla.card(set)
    val state = new SymbState(card, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    assert(solverContext.sat())
    assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(3), nextState.ex)))
  }

  test("""Cardinality({1, 2, 3} \ {2}) = 2""") {
    def setminus(set: TlaEx, intVal: Int): TlaEx = {
      tla.filter(tla.name("t"),
        set,
        tla.not(tla.eql(tla.name("t"), tla.int(intVal))))
    }

    val set = setminus(tla.enumSet(1.to(3).map(tla.int) :_*), 2)
    val card = tla.card(set)
    val state = new SymbState(card, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    assert(solverContext.sat())
    assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(2), nextState.ex)))
  }

  test("""IsFiniteSet({1, 2, 3}) = TRUE""") {
    val set = tla.enumSet(1.to(3).map(tla.int) :_*)
    val card = tla.isFin(set)
    val state = new SymbState(card, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    assert(solverContext.sat())
    assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.bool(true), nextState.ex)))
  }

} 
Example 71
Source File: TestUninterpretedConstOracle.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt.rules.aux

import at.forsyte.apalache.tla.bmcmt.types.BoolT
import at.forsyte.apalache.tla.bmcmt.{Binding, CellTheory, RewriterBase, SymbState}
import at.forsyte.apalache.tla.lir.TestingPredefs
import at.forsyte.apalache.tla.lir.convenience.tla
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestUninterpretedConstOracle extends RewriterBase with TestingPredefs {
  test("""Oracle.create""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    // introduce an oracle
    val (nextState, oracle) = UninterpretedConstOracle.create(rewriter, state, 6)
    assert(solverContext.sat())
  }

  test("""Oracle.whenEqualTo""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    // introduce an oracle
    val (nextState, oracle) = UninterpretedConstOracle.create(rewriter, state, 6)
    assert(solverContext.sat())
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 3))
    assert(solverContext.sat())
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 4))
    assert(!solverContext.sat())
  }

  test("""Oracle.evalPosition""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    // introduce an oracle
    val (nextState, oracle) = UninterpretedConstOracle.create(rewriter, state, 6)
    assert(solverContext.sat())
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 3))
    assert(solverContext.sat())
    val position = oracle.evalPosition(rewriter.solverContext, nextState)
    assert(3 == position)
  }

  test("""Oracle.caseAssertions""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    state = state.updateArena(_.appendCell(BoolT()))
    val flag = state.arena.topCell
    // introduce an oracle
    val (nextState, oracle) = UninterpretedConstOracle.create(rewriter, state, 2)
    // assert flag == true iff oracle = 0
    rewriter.solverContext.assertGroundExpr(oracle.caseAssertions(nextState, Seq(flag.toNameEx, tla.not(flag.toNameEx))))
    // assert oracle = 1
    rewriter.push()
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 1))
    assert(solverContext.sat())
    assert(solverContext.evalGroundExpr(flag.toNameEx) == tla.bool(false))
    rewriter.pop()
    // assert oracle = 0
    rewriter.push()
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 0))
    assert(solverContext.sat())
    assert(solverContext.evalGroundExpr(flag.toNameEx) == tla.bool(true))
    rewriter.pop()
  }
} 
Example 72
Source File: TestPropositionalOracle.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt.rules.aux

import at.forsyte.apalache.tla.bmcmt.types.BoolT
import at.forsyte.apalache.tla.bmcmt.{Binding, CellTheory, RewriterBase, SymbState}
import at.forsyte.apalache.tla.lir.TestingPredefs
import at.forsyte.apalache.tla.lir.convenience.tla
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestPropositionalOracle extends RewriterBase with TestingPredefs {
  test("""Oracle.create""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    // introduce an oracle
    val (nextState, oracle) = PropositionalOracle.create(rewriter, state, 6)
    assert(solverContext.sat())
  }

  test("""Oracle.whenEqualTo""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    // introduce an oracle
    val (nextState, oracle) = PropositionalOracle.create(rewriter, state, 6)
    assert(solverContext.sat())
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 3))
    assert(solverContext.sat())
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 4))
    assert(!solverContext.sat())
  }

  test("""Oracle.evalPosition""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    // introduce an oracle
    val (nextState, oracle) = PropositionalOracle.create(rewriter, state, 6)
    assert(solverContext.sat())
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 3))
    assert(solverContext.sat())
    val position = oracle.evalPosition(rewriter.solverContext, nextState)
    assert(3 == position)
  }

  test("""Oracle.caseAssertions""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    state = state.updateArena(_.appendCell(BoolT()))
    val flag = state.arena.topCell
    // introduce an oracle
    val (nextState, oracle) = PropositionalOracle.create(rewriter, state, 2)
    // assert flag == true iff oracle = 0
    rewriter.solverContext.assertGroundExpr(oracle.caseAssertions(nextState, Seq(flag.toNameEx, tla.not(flag.toNameEx))))
    // assert oracle = 1
    rewriter.push()
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 1))
    assert(solverContext.sat())
    assert(solverContext.evalGroundExpr(flag.toNameEx) == tla.bool(false))
    rewriter.pop()
    // assert oracle = 0
    rewriter.push()
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 0))
    assert(solverContext.sat())
    assert(solverContext.evalGroundExpr(flag.toNameEx) == tla.bool(true))
    rewriter.pop()
  }
} 
Example 73
Source File: TestSymbStateRewriterAction.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.bmcmt.SymbStateRewriter.Continue
import at.forsyte.apalache.tla.bmcmt.types.IntT
import at.forsyte.apalache.tla.lir.NameEx
import at.forsyte.apalache.tla.lir.convenience._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestSymbStateRewriterAction extends RewriterBase {
  test("""SE-PRIME: x' ~~> NameEx(x')""") {
    val rewriter = create()
    arena.appendCell(IntT()) // the type finder is strict about unassigned types, so let's create a cell for x'
    val state = new SymbState(tla.prime(NameEx("x")), CellTheory(), arena, Binding("x'" -> arena.topCell))
    rewriter.rewriteOnce(state) match {
      case Continue(next) =>
        assert(next.ex == NameEx("x'"))

      case _ =>
        fail("Expected x to be renamed to x'")
    }
  }
} 
Example 74
Source File: TestArena.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.bmcmt.types.{BoolT, FinSetT, UnknownT}
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestArena extends FunSuite {
  test("create cells") {
    val solverContext = new Z3SolverContext()
    val emptyArena = Arena.create(solverContext)
    val arena = emptyArena.appendCell(UnknownT())
    assert(emptyArena.cellCount + 1 == arena.cellCount)
    assert(UnknownT() == arena.topCell.cellType)
    val arena2 = arena.appendCell(BoolT())
    assert(emptyArena.cellCount + 2 == arena2.cellCount)
    assert(BoolT() == arena2.topCell.cellType)
  }

  test("add 'has' edges") {
    val solverContext = new Z3SolverContext()
    val arena = Arena.create(solverContext).appendCell(FinSetT(UnknownT()))
    val set = arena.topCell
    val arena2 = arena.appendCell(BoolT())
    val elem = arena2.topCell
    val arena3 = arena2.appendHas(set, elem)
    assert(List(elem) == arena3.getHas(set))
  }

  test("BOOLEAN has FALSE and TRUE") {
    val solverContext = new Z3SolverContext()
    val arena = Arena.create(solverContext)
    val boolean = arena.cellBooleanSet()
    assert(List(arena.cellFalse(), arena.cellTrue()) == arena.getHas(arena.cellBooleanSet()))
  }
} 
Example 75
Source File: TestSymbStateRewriterExpand.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.bmcmt.types._
import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx}
import at.forsyte.apalache.tla.lir.convenience.tla
import at.forsyte.apalache.tla.lir.oper.BmcOper
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestSymbStateRewriterExpand extends RewriterBase {
  test("""Expand(SUBSET {1, 2})""") {
    val baseset = tla.enumSet(tla.int(1), tla.int(2))
    val expandPowset = OperEx(BmcOper.expand, tla.powSet(baseset))
    val state = new SymbState(expandPowset, CellTheory(), arena, new Binding)
    val rewriter = create()
    var nextState = rewriter.rewriteUntilDone(state)
    val powCell = nextState.asCell
    // check equality
    val eq = tla.eql(nextState.ex,
      tla.enumSet(tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT()))),
        tla.enumSet(tla.int(1)),
        tla.enumSet(tla.int(2)),
          tla.enumSet(tla.int(1), tla.int(2))))
    assertTlaExAndRestore(rewriter, nextState.setRex(eq))
  }

  test("""Expand([{1, 2, 3} -> {FALSE, TRUE}]) should fail""") {
    val domain = tla.enumSet(tla.int(1), tla.int(2), tla.int(3))
    val codomain = tla.enumSet(tla.bool(false), tla.bool(true))
    val funSet = OperEx(BmcOper.expand, tla.funSet(domain, codomain))
    val state = new SymbState(funSet, CellTheory(), arena, new Binding)
    val rewriter = create()
    assertThrows[RewriterException](rewriter.rewriteUntilDone(state))
  }

  // Constructing an explicit set of functions is, of course, expensive. But it should work for small values.
  // Left for the future...
  ignore("""Expand([{1, 2} -> {FALSE, TRUE}]) should work""") {
    val domain = tla.enumSet(tla.int(1), tla.int(2))
    val codomain = tla.enumSet(tla.bool(false), tla.bool(true))
    val funSet = OperEx(BmcOper.expand, tla.funSet(domain, codomain))
    val state = new SymbState(funSet, CellTheory(), arena, new Binding)
    val rewriter = create()
    var nextState = rewriter.rewriteUntilDone(state)
    val funSetCell = nextState.asCell
    def mkFun(v1: Boolean, v2: Boolean): TlaEx = {
      val mapEx = tla.ite(tla.eql(NameEx("x"), tla.int(1)), tla.bool(v1), tla.bool(v2))
      tla.funDef(mapEx, tla.name("x"), domain)
    }
    val expected = tla.enumSet(mkFun(false, false), mkFun(false, true),
                               mkFun(true, false), mkFun(true, true))
    assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(expected, funSetCell.toNameEx)))
  }

} 
Example 76
Source File: TestSourceStore.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.imp.src

import at.forsyte.apalache.tla.lir.convenience.tla
import at.forsyte.apalache.tla.lir.src.SourceRegion
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class TestSourceStore extends FunSuite {
  test("basic add and find") {
    val store = new SourceStore()
    val ex = tla.int(1)
    val loc = SourceLocation("root", SourceRegion(1, 2, 3, 4))
    store.addRec(ex, loc)
    val foundLoc = store.find(ex.ID)
    assert(loc == foundLoc.get)
  }

  test("recursive add and find") {
    val store = new SourceStore()
    val int1 = tla.int(1)
    val set = tla.enumSet(int1)
    val loc = SourceLocation("root", SourceRegion(1, 2, 3, 4))
    store.addRec(set, loc)
    val foundLoc = store.find(set.ID)
    assert(loc == foundLoc.get)
    val foundLoc2 = store.find(int1.ID)
    assert(loc == foundLoc2.get)
  }

  test("locations are not overwritten") {
    val store = new SourceStore()
    val int1 = tla.int(1)
    val set = tla.enumSet(int1)
    val set2 = tla.enumSet(set)
    val loc1 = SourceLocation("tada", SourceRegion(100, 200, 300, 400))
    store.addRec(int1, loc1)
    val loc2 = SourceLocation("root", SourceRegion(1, 2, 3, 4))
    store.addRec(set2, loc2)
    assert(loc2 == store.find(set2.ID).get)
    assert(loc2 == store.find(set.ID).get)
    assert(loc1 == store.find(int1.ID).get)
  }
} 
Example 77
Source File: TestRegionTree.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.imp.src

import at.forsyte.apalache.tla.lir.src.{RegionTree, SourcePosition, SourceRegion}
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestRegionTree extends FunSuite {
  test("add") {
    val tree = new RegionTree()
    val region = SourceRegion(SourcePosition(1, 20), SourcePosition(3, 10))
    tree.add(region)
  }

  test("add a subregion, then size") {
    val tree = new RegionTree()
    val reg1 = SourceRegion(SourcePosition(1, 20), SourcePosition(3, 10))
    tree.add(reg1)
    assert(tree.size == 1)
    val reg2 = SourceRegion(SourcePosition(1, 20), SourcePosition(2, 5))
    tree.add(reg2)
    assert(tree.size == 2)
    val reg3 = SourceRegion(SourcePosition(2, 10), SourcePosition(3, 10))
    tree.add(reg3)
    assert(tree.size == 3)
  }

  test("add an overlapping subregion") {
    val tree = new RegionTree()
    val reg1 = SourceRegion(SourcePosition(1, 10), SourcePosition(3, 10))
    tree.add(reg1)
    val reg2 = SourceRegion(SourcePosition(1, 20), SourcePosition(5, 20))
    assertThrows[IllegalArgumentException] {
      tree.add(reg2)
    }
  }

  test("add a small region, then a larger region") {
    val tree = new RegionTree()
    val reg1 = SourceRegion(SourcePosition(2, 10), SourcePosition(3, 10))
    tree.add(reg1)
    val reg2 = SourceRegion(SourcePosition(1, 1), SourcePosition(4, 1))
    tree.add(reg2)
  }

  test("add a region twice") {
    val tree = new RegionTree()
    val reg1 = SourceRegion(SourcePosition(2, 10), SourcePosition(3, 10))
    tree.add(reg1)
    val reg2 = SourceRegion(SourcePosition(2, 10), SourcePosition(3, 10))
    tree.add(reg2)
  }

  test("add and find") {
    val tree = new RegionTree()
    val region = SourceRegion(SourcePosition(1, 20), SourcePosition(3, 10))
    val idx = tree.add(region)
    val found = tree(idx)
    assert(found == region)
  }

  test("find non-existing index") {
    val tree = new RegionTree()
    val region = SourceRegion(SourcePosition(1, 20), SourcePosition(3, 10))
    val idx = tree.add(region)
    assertThrows[IndexOutOfBoundsException] {
      tree(999)
    }
  }
} 
Example 78
Source File: TestConstAndDefRewriter.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.pp

import at.forsyte.apalache.tla.imp.SanyImporter
import at.forsyte.apalache.tla.imp.src.SourceStore
import at.forsyte.apalache.tla.lir.{SimpleFormalParam, TlaOperDecl}
import at.forsyte.apalache.tla.lir.convenience._
import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterEach, FunSuite}

import scala.io.Source

@RunWith(classOf[JUnitRunner])
class TestConstAndDefRewriter extends FunSuite with BeforeAndAfterEach {
  test("override a constant") {
    val text =
      """---- MODULE const ----
        |CONSTANT n
        |OVERRIDE_n == 10
        |A == {n}
        |================================
      """.stripMargin

    val (rootName, modules) = new SanyImporter(new SourceStore)
      .loadFromSource("const", Source.fromString(text))
    val root = modules(rootName)
    val rewritten = new ConstAndDefRewriter(new IdleTracker())(root)
    assert(rewritten.constDeclarations.isEmpty) // no constants anymore
    assert(rewritten.operDeclarations.size == 2)
    val expected_n = TlaOperDecl("n", List(), tla.int(10))
    assert(expected_n == rewritten.operDeclarations.head)
    val expected_A = TlaOperDecl("A", List(), tla.enumSet(tla.appOp(tla.name("n"))))
    assert(expected_A == rewritten.operDeclarations(1))
  }

  // In TLA+, constants may be operators with multiple arguments.
  // We do not support that yet.
  test("override a constant with a unary operator") {
    val text =
      """---- MODULE const ----
        |CONSTANT n
        |OVERRIDE_n(x) == x
        |A == {n}
        |================================
      """.stripMargin

    val (rootName, modules) = new SanyImporter(new SourceStore)
      .loadFromSource("const", Source.fromString(text))
    val root = modules(rootName)
    assertThrows[OverridingError](new ConstAndDefRewriter(new IdleTracker())(root))
  }

  test("overriding a variable with an operator => error") {
    val text =
      """---- MODULE const ----
        |VARIABLE n, m
        |OVERRIDE_n == m
        |A == {n}
        |================================
      """.stripMargin

    val (rootName, modules) = new SanyImporter(new SourceStore)
      .loadFromSource("const", Source.fromString(text))
    val root = modules(rootName)
    assertThrows[OverridingError](new ConstAndDefRewriter(new IdleTracker())(root))
  }

  test("override an operator") {
    val text =
      """---- MODULE op ----
        |BoolMin(S) == CHOOSE x \in S: \A y \in S: x => y
        |OVERRIDE_BoolMin(S) == CHOOSE x \in S: TRUE
        |================================
      """.stripMargin

    val (rootName, modules) = new SanyImporter(new SourceStore)
      .loadFromSource("op", Source.fromString(text))
    val root = modules(rootName)
    val rewritten = new ConstAndDefRewriter(new IdleTracker())(root)
    assert(rewritten.constDeclarations.isEmpty)
    assert(rewritten.operDeclarations.size == 1)
    val expected = TlaOperDecl("BoolMin", List(SimpleFormalParam("S")),
      tla.choose(tla.name("x"), tla.name("S"), tla.bool(true)))
    assert(expected == rewritten.operDeclarations.head)
  }

  test("override a unary operator with a binary operator") {
    val text =
      """---- MODULE op ----
        |BoolMin(S) == CHOOSE x \in S: \A y \in S: x => y
        |OVERRIDE_BoolMin(S, T) == CHOOSE x \in S: x \in T
        |================================
      """.stripMargin

    val (rootName, modules) = new SanyImporter(new SourceStore)
      .loadFromSource("op", Source.fromString(text))
    val root = modules(rootName)
    assertThrows[OverridingError](new ConstAndDefRewriter(new IdleTracker())(root))
  }
} 
Example 79
Source File: TestUniqueNameGenerator.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.pp

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterEach, FunSuite}

@RunWith(classOf[JUnitRunner])
class TestUniqueNameGenerator extends FunSuite with BeforeAndAfterEach {
  test("first three") {
    val gen = new UniqueNameGenerator
    assert("t_1" == gen.newName())
    assert("t_2" == gen.newName())
    assert("t_3" == gen.newName())
  }

  test("after 10000") {
    val gen = new UniqueNameGenerator
    for (i <- 1.to(10000)) {
      gen.newName()
    }
    assert("t_7pt" == gen.newName())
  }
} 
Example 80
Source File: TestPassChainExecutor.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.infra.passes

import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.easymock.EasyMockSugar
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestPassChainExecutor extends FunSuite with EasyMockSugar {
  test("""2 passes, OK""") {
    val pass1 = mock[Pass]
    val pass2 = mock[Pass]
    expecting {
      pass1.name.andReturn("pass1").anyTimes()
      pass1.execute().andReturn(true)
      pass1.next().andReturn(Some(pass2))
      pass2.name.andReturn("pass2").anyTimes()
      pass2.execute().andReturn(true)
      pass2.next().andReturn(None)
    }
    // run the chain
    whenExecuting(pass1, pass2) {
      val options = new WriteablePassOptions()
      val executor = new PassChainExecutor(options, pass1)
      val result = executor.run()
      assert(result.isDefined)
      assert(result.contains(pass2))
    }
  }

  test("""2 passes, first fails""") {
    val pass1 = mock[Pass]
    val pass2 = mock[Pass]
    expecting {
      pass1.name.andReturn("pass1").anyTimes()
      pass1.execute().andReturn(false)
    }
    // run the chain
    whenExecuting(pass1, pass2) {
      val options = new WriteablePassOptions()
      val executor = new PassChainExecutor(options, pass1)
      val result = executor.run()
      assert(result.isEmpty)
    }
  }
} 
Example 81
Source File: DSLSpec.scala    From nd4s   with Apache License 2.0 5 votes vote down vote up
package org.nd4s

import org.junit.runner.RunWith
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}
import org.nd4s.Implicits._

@RunWith(classOf[JUnitRunner])
class DSLSpec extends FlatSpec with Matchers {

  "DSL" should "wrap and extend an INDArray" in {

    // This test just verifies that an INDArray gets wrapped with an implicit conversion

    val nd = Nd4j.create(Array[Float](1, 2), Array(2, 1))
    val nd1 = nd + 10L // + creates new array, += modifies in place

    nd.get(0) should equal(1)
    nd1.get(0) should equal(11)

    val nd2 = nd += 100
    nd2 should equal(nd)
    nd2.get(0) should equal(101)

    // Verify that we are working with regular old INDArray objects
    nd2 match {
      case i: INDArray => // do nothing
      case _ => fail("Expect our object to be an INDArray")
    }

  }

  "DSL" should "not prevent Map[Int,T] creation" in {
    Map(0->"hello") shouldBe a [Map[_,_]]
  }
} 
Example 82
Source File: ThresholdFinderSuite.scala    From spark-MDLP-discretization   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.mllib.feature.{BucketInfo, FeatureUtils, ThresholdFinder}
import org.apache.spark.sql.SQLContext
import org.junit.runner.RunWith
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.junit.JUnitRunner



@RunWith(classOf[JUnitRunner])
class ThresholdFinderSuite extends FunSuite {


  test("Test calcCriterion with even split hence low criterion value (and high entropy)") {

    val bucketInfo = new BucketInfo(Array(100L, 200L, 300L))
    val leftFreqs = Array(50L, 100L, 150L)
    val rightFreqs = Array(50L, 100L, 150L)

    assertResult((-0.030412853556075408, 1.4591479170272448, 300, 300)) {
      ThresholdFinder.calcCriterionValue(bucketInfo, leftFreqs, rightFreqs)
    }
  }

  test("Test calcCriterion with even split (and some at split) hence low criterion value (and high entropy)") {

    val bucketInfo = new BucketInfo(Array(100L, 200L, 300L))
    val leftFreqs = Array(40L, 100L, 140L)
    val rightFreqs = Array(50L, 90L, 150L)

    assertResult((0.05852316831964029,1.370380206618117,280,290)) {
      ThresholdFinder.calcCriterionValue(bucketInfo, leftFreqs, rightFreqs)
    }
  }

  test("Test calcCriterion with uneven split hence high criterion value (and low entropy)") {

    val bucketInfo = new BucketInfo(Array(100L, 200L, 300L))
    val leftFreqs = Array(100L, 10L, 250L)
    val rightFreqs = Array(0L, 190L, 50L)

    assertResult((0.5270800719912969, 0.9086741857687387, 360, 240)) {
      ThresholdFinder.calcCriterionValue(bucketInfo, leftFreqs, rightFreqs)
    }
  }

  test("Test calcCriterion with uneven split hence very high criterion value (and very low entropy)") {

    val bucketInfo = new BucketInfo(Array(100L, 200L, 300L))
    val leftFreqs = Array(100L, 200L, 0L)
    val rightFreqs = Array(0L, 0L, 300L)

    assertResult((0.9811176395006821, 0.45914791702724483, 300, 300)) {
      ThresholdFinder.calcCriterionValue(bucketInfo, leftFreqs, rightFreqs)
    }
  }

  test("Test calcCriterion with all data on one side (hence low criterion value)") {

    val bucketInfo = new BucketInfo(Array(100L, 200L, 300L))
    val leftFreqs = Array(0L, 0L, 0L)
    val rightFreqs = Array(100L, 200L, 300L)

    assertResult((-0.02311711397093918, 1.4591479170272448, 0, 600)) {
      ThresholdFinder.calcCriterionValue(bucketInfo, leftFreqs, rightFreqs)
    }
  }

  test("Test calcCriterion with most data on one side (hence low criterion value)") {

    val bucketInfo = new BucketInfo(Array(100L, 200L, 300L))
    val leftFreqs = Array(0L, 10L, 0L)
    val rightFreqs = Array(100L, 190L, 300L)

    assertResult((0.003721577231942788,1.4323219723298557,10,590)) {
      ThresholdFinder.calcCriterionValue(bucketInfo, leftFreqs, rightFreqs)
    }
  }
} 
Example 83
Source File: MDLPDiscretizerHugeSuite.scala    From spark-MDLP-discretization   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.ml.feature.TestHelper._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, FunSuite}




  test("Run MDLPD on all columns in serverX data (label = target2, maxBins = 50, maxByPart = 10000)") {
    val dataDf = readServerXData(sqlContext)
    val model = getDiscretizerModel(dataDf,
      Array("CPU1_TJ", "CPU2_TJ", "total_cfm", "rpm1"),
      "target4", maxBins = 50, maxByPart = 10000, stoppingCriterion = 0, minBinPercentage = 1)

    assertResult(
      """-Infinity, 337.55365, 363.06793, Infinity;
        |-Infinity, 329.35974, 330.47424, 331.16617, 331.54724, 332.8419, 333.82208, 334.7564, 335.65106, 336.6503, 337.26328, 337.8406, 339.16763, 339.81476, 341.1809, 341.81186, 343.64825, 355.91144, 357.8602, 361.57806, Infinity;
        |-Infinity, 0.0041902177, 0.0066683707, 0.00841628, 0.009734755, 0.011627266, 0.012141651, 0.012740928, 0.013055362, 0.013293093, 0.014488807, 0.014869433, 0.015116488, 0.015383363, 0.015662778, 0.015978532, 0.016246023, 0.016492717, 0.01686273, 0.017246526, 0.017485093, 0.017720722, 0.017845878, 0.018008012, 0.018357705, 0.018629191, 0.018964633, 0.019226547, 0.019445801, 0.01960973, 0.019857172, 0.020095222, 0.020373512, 0.020728927, 0.020977266, 0.02137091, 0.021543117, 0.02188059, 0.022238541, 0.02265025, 0.023091711, 0.023352059, 0.023588676, 0.023957964, 0.024230447, 0.024448851, 0.024822969, 0.025079254, 0.026178652, 0.027195029, Infinity;
        |-Infinity, 1500.0, 4500.0, 7500.0, Infinity""".stripMargin.replaceAll(System.lineSeparator(), "")) {
      model.splits.map(a => a.mkString(", ")).mkString(";")
    }
  }

} 
Example 84
Source File: MLeapModelConverterTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.local

import com.salesforce.op.test.TestCommon
import ml.combust.mleap.core.feature._
import ml.combust.mleap.core.types.ScalarShape
import org.apache.spark.ml.linalg.{DenseMatrix, Vectors}
import org.junit.runner.RunWith
import org.scalatest.PropSpec
import org.scalatest.junit.JUnitRunner
import org.scalatest.prop.PropertyChecks

@RunWith(classOf[JUnitRunner])
class MLeapModelConverterTest extends PropSpec with PropertyChecks with TestCommon {

  val mleapModels = Table("mleapModels",
    BinarizerModel(0.0, ScalarShape()),
    BucketedRandomProjectionLSHModel(Seq(), 0.0, 0),
    BucketizerModel(Array.empty),
    ChiSqSelectorModel(Seq(), 0),
    CoalesceModel(Seq()),
    CountVectorizerModel(Array.empty, false, 0.0),
    DCTModel(false, 0),
    ElementwiseProductModel(Vectors.zeros(0)),
    FeatureHasherModel(0, Seq(), Seq(), Seq()),
    HashingTermFrequencyModel(),
    IDFModel(Vectors.zeros(0)),
    ImputerModel(0.0, 0.0, ""),
    InteractionModel(Array(), Seq()),
    MathBinaryModel(BinaryOperation.Add),
    MathUnaryModel(UnaryOperation.Log),
    MaxAbsScalerModel(Vectors.zeros(0)),
    MinHashLSHModel(Seq(), 0),
    MinMaxScalerModel(Vectors.zeros(0), Vectors.zeros(0)),
    NGramModel(0),
    NormalizerModel(0.0, 0),
    OneHotEncoderModel(Array()),
    PcaModel(DenseMatrix.zeros(0, 0)),
    PolynomialExpansionModel(0, 0),
    RegexIndexerModel(Seq(), None),
    RegexTokenizerModel(".*".r),
    ReverseStringIndexerModel(Seq()),
    StandardScalerModel(Some(Vectors.dense(Array(1.0))), Some(Vectors.dense(Array(1.0)))),
    StopWordsRemoverModel(Seq(), false),
    StringIndexerModel(Seq()),
    StringMapModel(Map()),
    TokenizerModel(),
    VectorAssemblerModel(Seq()),
    VectorIndexerModel(0, Map()),
    VectorSlicerModel(Array(), Array(), 0),
    WordLengthFilterModel(),
    WordToVectorModel(Map("a" -> 1), Array(1))
  )

  property("convert mleap models to functions") {
    forAll(mleapModels) { m =>
      val fn = MLeapModelConverter.modelToFunction(m)
      fn shouldBe a[Function[_, _]]
    }
  }

  property("error on unsupported models") {
    the[RuntimeException] thrownBy MLeapModelConverter.modelToFunction(model = "not at model") should have message
      "Unsupported MLeap model: java.lang.String"
  }

} 
Example 85
Source File: SparkStageParamTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package org.apache.spark.ml

import com.salesforce.op.stages.SparkStageParam
import com.salesforce.op.test.TestSparkContext
import org.apache.spark.ml.feature.StandardScaler
import org.joda.time.DateTime
import org.json4s.JsonDSL._
import org.json4s._
import org.json4s.jackson.JsonMethods.{parse, _}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterEach, FlatSpec}


@RunWith(classOf[JUnitRunner])
class SparkStageParamTest extends FlatSpec with TestSparkContext with BeforeAndAfterEach {
  import SparkStageParam._

  var savePath: String = _
  var param: SparkStageParam[StandardScaler] = _
  var stage: StandardScaler = _

  override def beforeEach(): Unit = {
    super.beforeEach()
    savePath = tempDir + "/op-stage-param-test-" + DateTime.now().getMillis
    param = new SparkStageParam[StandardScaler](parent = "test" , name = "test", doc = "none")
    // by setting both to be the same, we guarantee that at least one isn't the default value
    stage = new StandardScaler().setWithMean(true).setWithStd(false)
  }

  // easier if test both at the same time
  Spec[SparkStageParam[_]] should "encode and decode properly when is set" in {
    param.savePath = Option(savePath)
    val jsonOut = param.jsonEncode(Option(stage))
    val parsed = parse(jsonOut).asInstanceOf[JObject]
    val updated = parsed ~ ("path" -> savePath) // inject path for decoding

    updated shouldBe JObject(
      "className" -> JString(stage.getClass.getName),
      "uid" -> JString(stage.uid),
      "path" -> JString(savePath)
    )
    val updatedJson = compact(updated)

    param.jsonDecode(updatedJson) match {
      case None => fail("Failed to recover the stage")
      case Some(stageRecovered) =>
        stageRecovered shouldBe a[StandardScaler]
        stageRecovered.uid shouldBe stage.uid
        stageRecovered.getWithMean shouldBe stage.getWithMean
        stageRecovered.getWithStd shouldBe stage.getWithStd
    }
  }

  it should "except out when path is empty" in {
    intercept[RuntimeException](param.jsonEncode(Option(stage))).getMessage shouldBe
      s"Path must be set before Spark stage '${stage.uid}' can be saved"
  }

  it should "have empty path if stage is empty" in {
    param.savePath = Option(savePath)
    val jsonOut = param.jsonEncode(None)
    val parsed = parse(jsonOut)

    parsed shouldBe JObject("className" -> JString(NoClass), "uid" -> JString(NoUID))
    param.jsonDecode(jsonOut) shouldBe None
  }
} 
Example 86
Source File: UnaryEstimatorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.base.unary

import com.salesforce.op.UID
import com.salesforce.op.features.Feature
import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.{DoubleType, MetadataBuilder, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class UnaryEstimatorTest extends OpEstimatorSpec[Real, UnaryModel[Real, Real], UnaryEstimator[Real, Real]] {

  
  val expectedResult = Seq(0.0, 0.8, 0.4, 0.2, 1.0).map(_.toReal)

}

class MinMaxNormEstimator(uid: String = UID[MinMaxNormEstimator])
  extends UnaryEstimator[Real, Real](operationName = "minMaxNorm", uid = uid) {

  def fitFn(dataset: Dataset[Real#Value]): UnaryModel[Real, Real] = {
    val grouped = dataset.groupBy()
    val maxVal = grouped.max().first().getDouble(0)
    val minVal = grouped.min().first().getDouble(0)
    new MinMaxNormEstimatorModel(min = minVal, max = maxVal, operationName = operationName, uid = uid)
  }
}

final class MinMaxNormEstimatorModel private[op](val min: Double, val max: Double, operationName: String, uid: String)
  extends UnaryModel[Real, Real](operationName = operationName, uid = uid) {
  def transformFn: Real => Real = _.v.map(v => (v - min) / (max - min)).toReal
} 
Example 87
Source File: QuaternaryEstimatorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.base.quaternary

import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.sql.Dataset
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class QuaternaryEstimatorTest
  extends OpEstimatorSpec[Real,
    QuaternaryModel[Real, TextMap, BinaryMap, MultiPickList, Real],
    QuaternaryEstimator[Real, TextMap, BinaryMap, MultiPickList, Real]] {

  val (inputData, reals, textMap, booleanMap, binary) = TestFeatureBuilder(
    Seq(
      (Real.empty, TextMap(Map("a" -> "keen")), BinaryMap(Map("a" -> true)), MultiPickList(Set("a"))),
      (Real(15.0), TextMap(Map("b" -> "bok")), BinaryMap(Map("b" -> true)), MultiPickList(Set("b"))),
      (Real(23.0), TextMap(Map("c" -> "bar")), BinaryMap(Map("c" -> true)), MultiPickList(Set("c"))),
      (Real(40.0), TextMap(Map.empty), BinaryMap(Map("d" -> true)), MultiPickList(Set("d"))),
      (Real(65.0), TextMap(Map("e" -> "B")), BinaryMap(Map("e" -> true)), MultiPickList(Set("e")))
    )
  )

  val estimator = new FantasticFourEstimator().setInput(reals, textMap, booleanMap, binary)

  val expectedResult = Seq(Real.empty, Real(-31.6), Real(-23.6), Real.empty, Real(18.4))
}

class FantasticFourEstimator(uid: String = UID[FantasticFourEstimator])
  extends QuaternaryEstimator[Real, TextMap, BinaryMap, MultiPickList, Real](operationName = "fantasticFour", uid = uid)
    with FantasticFour  {

  // scalastyle:off line.size.limit
  def fitFn(dataset: Dataset[(Real#Value, TextMap#Value, BinaryMap#Value, MultiPickList#Value)]): QuaternaryModel[Real, TextMap, BinaryMap, MultiPickList, Real] = {
    import dataset.sparkSession.implicits._
    val topAge = dataset.map(_._1.getOrElse(0.0)).groupBy().max().first().getDouble(0)
    val mean = dataset.map { case (age, strMp, binMp, gndr) =>
      if (filterFN(age, strMp, binMp, gndr)) age.getOrElse(topAge) else topAge
    }.groupBy().mean().first().getDouble(0)

    new FantasticFourModel(mean = mean, operationName = operationName, uid = uid)
  }
  // scalastyle:on

}

final class FantasticFourModel private[op](val mean: Double, operationName: String, uid: String)
  extends QuaternaryModel[Real, TextMap, BinaryMap, MultiPickList, Real](operationName = operationName, uid = uid)
    with FantasticFour {

  def transformFn: (Real, TextMap, BinaryMap, MultiPickList) => Real = (age, strMp, binMp, gndr) => new Real(
    if (filterFN(age.v, strMp.v, binMp.v, gndr.v)) Some(age.v.get - mean) else None
  )

}

sealed trait FantasticFour {
  def filterFN(a: Real#Value, sm: TextMap#Value, bm: BinaryMap#Value, g: MultiPickList#Value): Boolean =
    a.nonEmpty && g.nonEmpty && sm.contains(g.head) && bm.contains(g.head)
} 
Example 88
Source File: TernaryEstimatorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.base.ternary

import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.sql.Dataset
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class TernaryEstimatorTest
  extends OpEstimatorSpec[Real,
    TernaryModel[MultiPickList, Binary, RealMap, Real],
    TernaryEstimator[MultiPickList, Binary, RealMap, Real]] {

  val (inputData, gender, numericMap, survived) = TestFeatureBuilder("gender", "numericMap", "survived",
    Seq(
      (MultiPickList.empty, RealMap(Map("teen" -> 1.0)), Binary(true)),
      (MultiPickList(Set("teen")), RealMap(Map("teen" -> 2.0)), Binary(false)),
      (MultiPickList(Set("teen")), RealMap(Map("teen" -> 3.0)), Binary(false)),
      (MultiPickList(Set("adult")), RealMap(Map("adult" -> 1.0)), Binary(false)),
      (MultiPickList(Set("senior")), RealMap(Map("senior" -> 1.0, "adult" -> 2.0)), Binary(false))
    )
  )

  val estimator = new TripleInteractionsEstimator().setInput(gender, survived, numericMap)

  val expectedResult = Seq(Real.empty, Real(0.25), Real(1.25), Real(-0.75), Real(-0.75))
}

class TripleInteractionsEstimator(uid: String = UID[TripleInteractionsEstimator])
  extends TernaryEstimator[MultiPickList, Binary, RealMap, Real](operationName = "tripleInteractions", uid = uid)
    with TripleInteractions {

  // scalastyle:off line.size.limit
  def fitFn(dataset: Dataset[(MultiPickList#Value, Binary#Value, RealMap#Value)]): TernaryModel[MultiPickList, Binary, RealMap, Real] = {
    import dataset.sparkSession.implicits._
    val mean = {
      dataset.map { case (gndr, srvvd, nmrcMp) =>
        if (survivedAndMatches(gndr, srvvd, nmrcMp)) nmrcMp(gndr.head) else 0.0
      }.filter(_ != 0.0).groupBy().mean().first().getDouble(0)
    }
    new TripleInteractionsModel(mean = mean, operationName = operationName, uid = uid)
  }
  // scalastyle:on

}

final class TripleInteractionsModel private[op](val mean: Double, operationName: String, uid: String)
  extends TernaryModel[MultiPickList, Binary, RealMap, Real](operationName = operationName, uid = uid)
    with TripleInteractions {

  def transformFn: (MultiPickList, Binary, RealMap) => Real = (g: MultiPickList, s: Binary, nm: RealMap) => new Real(
    if (!survivedAndMatches(g.value, s.value, nm.value)) None
    else Some(nm.value(g.value.head) - mean)
  )

}

sealed trait TripleInteractions {
  def survivedAndMatches(g: MultiPickList#Value, s: Binary#Value, nm: RealMap#Value): Boolean =
    !s.getOrElse(false) && g.nonEmpty && nm.contains(g.head)
} 
Example 89
Source File: SequenceEstimatorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.base.sequence

import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import com.salesforce.op.utils.spark.SequenceAggregators
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.Dataset
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class SequenceEstimatorTest
  extends OpEstimatorSpec[OPVector, SequenceModel[DateList, OPVector], SequenceEstimator[DateList, OPVector]] {

  val sample = Seq[(DateList, DateList, DateList)](
    (new DateList(1476726419000L, 1476726019000L),
      new DateList(1476726919000L),
      new DateList(1476726519000L)),
    (new DateList(1476725419000L, 1476726019000L),
      new DateList(1476726319000L, 1476726919000L),
      new DateList(1476726419000L)),
    (new DateList(1476727419000L),
      new DateList(1476728919000L),
      new DateList(1476726619000L, 1476726949000L))
  )
  val (inputData, clicks, opens, purchases) = TestFeatureBuilder("clicks", "opens", "purchases", sample)

  val estimator = new FractionOfResponsesEstimator().setInput(clicks, opens, purchases)

  val expectedResult = Seq(
    Vectors.dense(0.4, 0.25, 0.25).toOPVector,
    Vectors.dense(0.4, 0.5, 0.25).toOPVector,
    Vectors.dense(0.2, 0.25, 0.5).toOPVector
  )
}


class FractionOfResponsesEstimator(uid: String = UID[FractionOfResponsesEstimator])
  extends SequenceEstimator[DateList, OPVector](operationName = "fractionOfResponses", uid = uid) {
  def fitFn(dataset: Dataset[Seq[Seq[Long]]]): SequenceModel[DateList, OPVector] = {
    import dataset.sparkSession.implicits._
    val sizes = dataset.map(_.map(_.size))
    val size = getInputFeatures().length
    val counts = sizes.select(SequenceAggregators.SumNumSeq[Int](size = size).toColumn).first().map(_.toDouble)
    new FractionOfResponsesModel(counts = counts, operationName = operationName, uid = uid)
  }
}

final class FractionOfResponsesModel private[op]
(
  val counts: Seq[Double],
  operationName: String,
  uid: String
) extends SequenceModel[DateList, OPVector](operationName = operationName, uid = uid) {
  def transformFn: Seq[DateList] => OPVector = row => {
    val fractions = row.zip(counts).map { case (feature, count) => feature.value.size.toDouble / count }
    Vectors.dense(fractions.toArray).toOPVector
  }
} 
Example 90
Source File: BinarySequenceTransformerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.base.sequence

import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class BinarySequenceTransformerTest
  extends OpTransformerSpec[MultiPickList, BinarySequenceTransformer[Real, Text, MultiPickList]] {

  val sample = Seq(
    (1.toReal, "one".toText, "two".toText),
    ((-1).toReal, "three".toText, "four".toText),
    (15.toReal, "five".toText, "six".toText),
    (1.111.toReal, "seven".toText, "".toText)
  )

  val (inputData, f1, f2, f3) = TestFeatureBuilder(sample)

  val transformer = new BinarySequenceLambdaTransformer[Real, Text, MultiPickList](
    operationName = "realToMultiPicklist", transformFn = new BinarySequenceTransformerTest.Fun
  ).setInput(f1, f2, f3)

  val expectedResult = Seq(
    Set("1.0", "one", "two"),
    Set("-1.0", "three", "four"),
    Set("15.0", "five", "six"),
    Set("1.111", "seven", "")
  ).map(_.toMultiPickList)
}

object BinarySequenceTransformerTest {

  class Fun extends Function2[Real, Seq[Text], MultiPickList] with Serializable {
    def apply(r: Real, texts: Seq[Text]): MultiPickList =
      MultiPickList(texts.map(_.value.get).toSet + r.value.get.toString)
  }

} 
Example 91
Source File: SequenceTransformerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.base.sequence

import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class SequenceTransformerTest extends OpTransformerSpec[MultiPickList, SequenceTransformer[Real, MultiPickList]] {

  val sample = Seq(
    1.toReal -> 1.toReal,
    (-1).toReal -> 1.toReal,
    15.toReal -> Real.empty,
    1.111.toReal -> 2.222.toReal
  )
  val (inputData, f1, f2) = TestFeatureBuilder(sample)

  val transformer = new SequenceLambdaTransformer[Real, MultiPickList](
    operationName = "realToMultiPicklist", transformFn = new SequenceTransformerTest.Fun
  ).setInput(f1, f2)

  val expectedResult = Seq(
    Set("1.0").toMultiPickList,
    Set("-1.0", "1.0").toMultiPickList,
    Set("15.0").toMultiPickList,
    Set("1.111", "2.222").toMultiPickList
  )

}

object SequenceTransformerTest {

  class Fun extends Function1[Seq[Real], MultiPickList] with Serializable {
    def apply(value: Seq[Real]): MultiPickList = MultiPickList(value.flatMap(_.v.map(_.toString)).toSet)
  }
} 
Example 92
Source File: BinarySequenceEstimatorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.base.sequence

import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import com.salesforce.op.utils.spark.SequenceAggregators
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.Dataset
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class BinarySequenceEstimatorTest
  extends OpEstimatorSpec[OPVector,
    BinarySequenceModel[Real, DateList, OPVector],
    BinarySequenceEstimator[Real, DateList, OPVector]] {

  val sample = Seq[(DateList, DateList, DateList, Real)](
    (new DateList(1476726419000L, 1476726019000L),
      new DateList(1476726919000L),
      new DateList(1476726519000L),
      Real(1.0)),
    (new DateList(1476725419000L, 1476726019000L),
      new DateList(1476726319000L, 1476726919000L),
      new DateList(1476726419000L),
      Real(0.5)),
    (new DateList(1476727419000L),
      new DateList(1476728919000L),
      new DateList(1476726619000L, 1476726949000L),
      Real(0.0))
  )
  val (inputData, clicks, opens, purchases, weights) =
    TestFeatureBuilder("clicks", "opens", "purchases", "weights", sample)

  val estimator = new WeightedFractionOfResponsesEstimator().setInput(weights, clicks, opens, purchases)

  val expectedResult = Seq(
    Vectors.dense(0.4, 0.5, Double.PositiveInfinity),
    Vectors.dense(0.4, 1.0, Double.PositiveInfinity),
    Vectors.dense(0.2, 0.5, Double.PositiveInfinity)
  ).map(_.toOPVector)
}


class WeightedFractionOfResponsesEstimator(uid: String = UID[WeightedFractionOfResponsesEstimator])
  extends BinarySequenceEstimator[Real, DateList, OPVector](operationName = "fractionOfResponses", uid = uid) {
  def fitFn(dataset: Dataset[(Real#Value, Seq[Seq[Long]])]): BinarySequenceModel[Real, DateList, OPVector] = {
    import dataset.sparkSession.implicits._
    val sizes = dataset.map(_._2.map(_.size))
    val weights = dataset.map(_._1.get).rdd.collect()
    val size = getInputFeatures().length
    val counts = sizes.select(SequenceAggregators.SumNumSeq[Int](size = size).toColumn).first()
    val weightedCounts = counts.zip(weights).map {
      case (c, w) => c.toDouble * w
    }
    new WeightedFractionOfResponsesModel(counts = weightedCounts, operationName = operationName, uid = uid)
  }
}

final class WeightedFractionOfResponsesModel private[op]
(
  val counts: Seq[Double],
  operationName: String,
  uid: String
) extends BinarySequenceModel[Real, DateList, OPVector](operationName = operationName, uid = uid) {
  def transformFn: (Real, Seq[DateList]) => OPVector = (w, dates) => {
    val fractions = dates.zip(counts).map { case (feature, count) => feature.value.size.toDouble / count }
    Vectors.dense(fractions.toArray).toOPVector
  }
} 
Example 93
Source File: BinaryEstimatorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.base.binary

import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.Dataset
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class BinaryEstimatorTest
  extends OpEstimatorSpec[OPVector, BinaryModel[Text, Text, OPVector], BinaryEstimator[Text, Text, OPVector]] {

  val (inputData, city, country) = TestFeatureBuilder("city", "country",
    Seq(
      (Text("San Francisco"), Text("USA")),
      (Text("Paris"), Text("France")),
      (Text("Austin"), Text("USA")),
      (Text("San Francisco"), Text("USA")),
      (Text("Paris"), Text("USA")),
      (Text("Puerto Arenas"), Text("Chile")),
      (Text("Iquitos"), Text(None))
    )
  )

  val estimator = new TestPivotEstimator().setInput(city, country)

  val expectedResult = Seq(
    Vectors.dense(1.0, 0.0),
    Vectors.dense(0.0, 1.0),
    Vectors.dense(0.0, 1.0),
    Vectors.dense(1.0, 0.0),
    Vectors.dense(0.0, 1.0),
    Vectors.dense(0.0, 1.0),
    Vectors.dense(0.0, 1.0)
  ).map(_.toOPVector)

}


class TestPivotEstimator(uid: String = UID[TestPivotEstimator])
  extends BinaryEstimator[Text, Text, OPVector](operationName = "pivot", uid = uid) {

  def fitFn(data: Dataset[(Text#Value, Text#Value)]): BinaryModel[Text, Text, OPVector] = {
    import data.sparkSession.implicits._
    val counts =
      data.map { case (cty, cntry) => Seq(cty, cntry).flatten.mkString(" ") -> 1 }
        .groupByKey(_._1).reduceGroups((a, b) => (a._1, a._2 + b._2)).map(_._2)

    val topValue = counts.collect().minBy(-_._2)._1
    new TestPivotModel(topValue = topValue, operationName = operationName, uid = uid)
  }
}
final class TestPivotModel private[op](val topValue: String, operationName: String, uid: String)
  extends BinaryModel[Text, Text, OPVector](operationName = operationName, uid = uid) {

  def transformFn: (Text, Text) => OPVector = (city: Text, country: Text) => {
    val cityCountry = Seq(city.value, country.value).flatten.mkString(" ")
    val vector = if (topValue == cityCountry) Vectors.dense(1, 0) else Vectors.dense(0, 1)
    vector.toOPVector
  }

} 
Example 94
Source File: PredictionTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features.types

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class PredictionTest extends FlatSpec with TestCommon {
  import Prediction.Keys._

  Spec[Prediction] should "extend FeatureType" in {
    Prediction(1.0) shouldBe a[FeatureType]
    Prediction(1.0) shouldBe a[OPMap[_]]
    Prediction(1.0) shouldBe a[NumericMap]
    Prediction(1.0) shouldBe a[RealMap]
  }
  it should "error if prediction is missing" in {
    intercept[NonNullableEmptyException](new Prediction(null))
    intercept[NonNullableEmptyException](new Prediction(Map.empty))
    intercept[NonNullableEmptyException](Map.empty[String, Double].toPrediction)
    intercept[NonNullableEmptyException]((null: Map[String, Double]).toPrediction)
    assertPredictionError(new Prediction(Map("a" -> 1.0)))
    assertPredictionError(Map("a" -> 1.0, "b" -> 2.0).toPrediction)
    assertInvalidKeysError(new Prediction(Map(PredictionName -> 2.0, "a" -> 1.0)))
  }
  it should "compare values correctly" in {
    Prediction(1.0).equals(Prediction(1.0)) shouldBe true
    Prediction(1.0).equals(Prediction(0.0)) shouldBe false
    Prediction(1.0, Array(1.0), Array.empty[Double]).equals(Prediction(1.0)) shouldBe false
    Prediction(1.0, Array(1.0), Array(2.0, 3.0)).equals(Prediction(1.0, Array(1.0), Array(2.0, 3.0))) shouldBe true

    Map(PredictionName -> 5.0).toPrediction shouldBe a[Prediction]
  }
  it should "return prediction" in {
    Prediction(2.0).prediction shouldBe 2.0
  }
  it should "return raw prediction" in {
    Prediction(2.0).rawPrediction shouldBe Array()
    Prediction(1.0, Array(1.0, 2.0), Array.empty[Double]).rawPrediction shouldBe Array(1.0, 2.0)
    Prediction(1.0, (1 until 200).map(_.toDouble).toArray, Array.empty[Double]).rawPrediction shouldBe
      (1 until 200).map(_.toDouble).toArray
  }
  it should "return probability" in {
    Prediction(3.0).probability shouldBe Array()
    Prediction(1.0, Array.empty[Double], Array(1.0, 2.0)).probability shouldBe Array(1.0, 2.0)
    Prediction(1.0, Array.empty[Double], (1 until 200).map(_.toDouble).toArray).probability shouldBe
      (1 until 200).map(_.toDouble).toArray
  }
  it should "return score" in {
    Prediction(4.0).score shouldBe Array(4.0)
    Prediction(1.0, Array(2.0, 3.0), Array.empty[Double]).score shouldBe Array(1.0)
    Prediction(1.0, Array.empty[Double], Array(2.0, 3.0)).score shouldBe Array(2.0, 3.0)
  }
  it should "have a nice .toString method implementation" in {
    Prediction(4.0).toString shouldBe
      "Prediction(prediction = 4.0, rawPrediction = Array(), probability = Array())"
    Prediction(1.0, Array(2.0, 3.0), Array.empty[Double]).toString shouldBe
      "Prediction(prediction = 1.0, rawPrediction = Array(2.0, 3.0), probability = Array())"
    Prediction(1.0, Array.empty[Double], Array(2.0, 3.0)).toString shouldBe
      "Prediction(prediction = 1.0, rawPrediction = Array(), probability = Array(2.0, 3.0))"
  }

  private def assertPredictionError(f: => Unit) =
    intercept[NonNullableEmptyException](f).getMessage shouldBe
      s"Prediction cannot be empty: value map must contain '$PredictionName' key"

  private def assertInvalidKeysError(f: => Unit) =
    intercept[IllegalArgumentException](f).getMessage shouldBe
      s"requirement failed: value map must only contain valid keys: '$PredictionName' or " +
        s"starting with '$RawPredictionName' or '$ProbabilityName'"

} 
Example 95
Source File: JavaConversionTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features.types

import java.util

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class JavaConversionTest extends FlatSpec with TestCommon {

  Spec[JavaConversionTest] should "convert java Map to TextMap" in {
    type T = util.HashMap[String, String]
    null.asInstanceOf[T].toTextMap shouldEqual TextMap(Map())
    val j = new T()
    j.toTextMap shouldEqual TextMap(Map())
    j.put("A", "a")
    j.put("B", null)
    j.toTextMap shouldEqual TextMap(Map("A" -> "a", "B" -> null))
  }

  it should "convert java Map to MultiPickListMap" in {
    type T = util.HashMap[String, java.util.HashSet[String]]
    null.asInstanceOf[T].toMultiPickListMap shouldEqual MultiPickListMap(Map())
    val j = new T()
    j.toMultiPickListMap shouldEqual MultiPickListMap(Map())
    val h = new util.HashSet[String]()
    h.add("X")
    h.add("Y")
    j.put("test", h)
    j.put("test2", null)
    j.toMultiPickListMap shouldEqual MultiPickListMap(Map("test" -> Set("X", "Y"), "test2" -> Set()))
  }

  it should "convert java Map to IntegralMap" in {
    type T = util.HashMap[String, java.lang.Long]
    null.asInstanceOf[T].toIntegralMap shouldEqual IntegralMap(Map())
    val j = new T()
    j.toIntegralMap shouldEqual IntegralMap(Map())
    j.put("test", java.lang.Long.valueOf(17))
    j.put("test2", null)
    j.toIntegralMap.v("test") shouldEqual 17L
    j.toIntegralMap.v("test2") shouldEqual (null: java.lang.Long)
  }

  it should "convert java Map to RealMap" in {
    type T = util.HashMap[String, java.lang.Double]
    null.asInstanceOf[T].toRealMap shouldEqual RealMap(Map())
    val j = new T()
    j.toRealMap shouldEqual RealMap(Map())
    j.put("test", java.lang.Double.valueOf(17.5))
    j.put("test2", null)
    j.toRealMap.v("test") shouldEqual 17.5
    j.toRealMap.v("test2") shouldEqual (null: java.lang.Double)
  }

  it should "convert java Map to BinaryMap" in {
    type T = util.HashMap[String, java.lang.Boolean]
    null.asInstanceOf[T].toBinaryMap shouldEqual RealMap(Map())
    val j = new T()
    j.toBinaryMap shouldEqual RealMap(Map())
    j.put("test1", java.lang.Boolean.TRUE)
    j.put("test0", java.lang.Boolean.FALSE)
    j.put("test2", null)
    j.toBinaryMap.v("test1") shouldEqual true
    j.toBinaryMap.v("test0") shouldEqual false
    j.toBinaryMap.v("test2") shouldEqual (null: java.lang.Boolean)
  }

} 
Example 96
Source File: GeolocationTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features.types

import com.salesforce.op.test.TestCommon
import org.apache.lucene.spatial3d.geom.{GeoPoint, PlanetModel}
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class GeolocationTest extends FlatSpec with TestCommon {
  val PaloAlto: (Double, Double) = (37.4419, -122.1430)

  Spec[Geolocation] should "extend OPList[Double]" in {
    val myGeolocation = new Geolocation(List.empty[Double])
    myGeolocation shouldBe a[FeatureType]
    myGeolocation shouldBe a[OPCollection]
    myGeolocation shouldBe a[OPList[_]]
  }

  it should "behave on missing data" in {
    val sut = new Geolocation(List.empty[Double])
    sut.lat.isNaN shouldBe true
    sut.lon.isNaN shouldBe true
    sut.accuracy shouldBe GeolocationAccuracy.Unknown
  }

  it should "not accept missing value" in {
    assertThrows[IllegalArgumentException](new Geolocation(List(PaloAlto._1)))
    assertThrows[IllegalArgumentException](new Geolocation(List(PaloAlto._1, PaloAlto._2)))
    assertThrows[IllegalArgumentException](new Geolocation((PaloAlto._1, PaloAlto._2, 123456.0)))
  }

  it should "compare values correctly" in {
    new Geolocation(List(32.399, 154.213, 6.0)).equals(new Geolocation(List(32.399, 154.213, 6.0))) shouldBe true
    new Geolocation(List(12.031, -23.44, 6.0)).equals(new Geolocation(List(32.399, 154.213, 6.0))) shouldBe false
    FeatureTypeDefaults.Geolocation.equals(new Geolocation(List(32.399, 154.213, 6.0))) shouldBe false
    FeatureTypeDefaults.Geolocation.equals(FeatureTypeDefaults.Geolocation) shouldBe true
    FeatureTypeDefaults.Geolocation.equals(Geolocation(List.empty[Double])) shouldBe true

    (35.123, -94.094, 5.0).toGeolocation shouldBe a[Geolocation]
  }

  it should "correctly generate a Lucene GeoPoint object" in {
    val myGeo = new Geolocation(List(32.399, 154.213, 6.0))
    myGeo.toGeoPoint shouldBe new GeoPoint(PlanetModel.WGS84, math.toRadians(myGeo.lat), math.toRadians(myGeo.lon))
  }

} 
Example 97
Source File: ListTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features.types

import com.salesforce.op.test.TestCommon
import org.apache.lucene.spatial3d.geom.{GeoPoint, PlanetModel}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class ListTest extends FlatSpec with TestCommon {

  
  Spec[DateTimeList] should "extend OPList[Long]" in {
    val myDateTimeList = new DateTimeList(List.empty[Long])
    myDateTimeList shouldBe a[FeatureType]
    myDateTimeList shouldBe a[OPCollection]
    myDateTimeList shouldBe a[OPList[_]]
    myDateTimeList shouldBe a[DateList]
  }
  it should "compare values correctly" in {
    new DateTimeList(List(456L, 13L)) shouldBe new DateTimeList(List(456L, 13L))
    new DateTimeList(List(13L, 456L)) should not be new DateTimeList(List(456L, 13L))
    FeatureTypeDefaults.DateTimeList should not be new DateTimeList(List(456L, 13L))
    FeatureTypeDefaults.DateTimeList shouldBe DateTimeList(List.empty[Long])

    List(12237834L, 4890489839L).toDateTimeList shouldBe a[DateTimeList]
  }


} 
Example 98
Source File: Base64Test.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features.types

import java.nio.charset.Charset

import com.salesforce.op.test.TestCommon
import org.apache.commons.io.IOUtils
import org.junit.runner.RunWith
import org.scalatest.PropSpec
import org.scalatest.junit.JUnitRunner
import org.scalatest.prop.PropertyChecks

@RunWith(classOf[JUnitRunner])
class Base64Test extends PropSpec with PropertyChecks with TestCommon {

  property("handle empty") {
    forAll(None) {
      (v: Option[String]) =>
        Base64(v).asBytes shouldBe None
        Base64(v).asString shouldBe None
        Base64(v).asInputStream shouldBe None
    }
  }

  property("can show byte contents") {
    forAll {
      (b: Array[Byte]) =>
        val b64 = toBase64(b)
        (Base64(b64).asBytes map (_.toList)) shouldBe Some(b.toList)
    }
  }

  property("can show string contents") {
    forAll {
      (s: String) =>
        val b64 = toBase64(s.getBytes)
        Base64(b64).asString shouldBe Some(s)
    }
  }

  property("produce a stream") {
    forAll {
      (s: String) =>
        val b64 = toBase64(s.getBytes)
        Base64(b64).asInputStream.map(IOUtils.toString(_, Charset.defaultCharset())) shouldBe Some(s)
    }
  }

  property("produce a stream and map over it") {
    forAll {
      (s: String) =>
        val b64 = toBase64(s.getBytes)
        Base64(b64).mapInputStream(IOUtils.toString(_, Charset.defaultCharset())) shouldBe Some(s)
    }
  }

  def toBase64(b: Array[Byte]): String = new String(java.util.Base64.getEncoder.encode(b))
} 
Example 99
Source File: OPVectorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features.types

import com.salesforce.op.test.TestCommon
import com.salesforce.op.utils.spark.RichVector._
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OPVectorTest extends FlatSpec with TestCommon {

  val vectors = Seq(
    Vectors.sparse(4, Array(0, 3), Array(1.0, 1.0)).toOPVector,
    Vectors.dense(Array(2.0, 3.0, 4.0)).toOPVector,
    // Purposely added a very large sparse vector to verify the efficiency
    Vectors.sparse(100000000, Array(1), Array(777.0)).toOPVector
  )

  Spec[OPVector] should "be empty" in {
    val zero = Vectors.zeros(0)
    new OPVector(zero).isEmpty shouldBe true
    new OPVector(zero).nonEmpty shouldBe false
    zero.toOPVector shouldBe a[OPVector]
  }

  it should "error on size mismatch" in {
    val ones = Array.fill(vectors.size)(Vectors.sparse(1, Array(0), Array(1.0)).toOPVector)
    for {
      (v1, v2) <- vectors.zip(ones)
      res <- Seq(() => v1 + v2, () => v1 - v2, () => v1 dot v2)
    } intercept[IllegalArgumentException](res()).getMessage should {
      startWith("requirement failed: Vectors must") and include("same length")
    }
  }

  it should "compare values" in {
    val zero = Vectors.zeros(0)
    new OPVector(zero) shouldBe new OPVector(zero)
    new OPVector(zero).value shouldBe zero

    Vectors.dense(Array(1.0, 2.0)).toOPVector shouldBe Vectors.dense(Array(1.0, 2.0)).toOPVector
    Vectors.sparse(5, Array(3, 4), Array(1.0, 2.0)).toOPVector shouldBe
      Vectors.sparse(5, Array(3, 4), Array(1.0, 2.0)).toOPVector
    Vectors.dense(Array(1.0, 2.0)).toOPVector should not be Vectors.dense(Array(2.0, 2.0)).toOPVector
    new OPVector(Vectors.dense(Array(1.0, 2.0))) should not be Vectors.dense(Array(2.0, 2.0)).toOPVector
    OPVector.empty shouldBe new OPVector(zero)
  }

  it should "'+' add" in {
    for {(v1, v2) <- vectors.zip(vectors)} {
      (v1 + v2) shouldBe (v1.value + v2.value).toOPVector
    }
  }

  it should "'-' subtract" in {
    for {(v1, v2) <- vectors.zip(vectors)} {
      (v1 - v2) shouldBe (v1.value - v2.value).toOPVector
    }
  }

  it should "compute dot product" in {
    for {(v1, v2) <- vectors.zip(vectors)} {
      (v1 dot v2) shouldBe (v1.value dot v2.value)
    }
  }

  it should "combine" in {
    for {(v1, v2) <- vectors.zip(vectors)} {
      v1.combine(v2) shouldBe v1.value.combine(v2.value).toOPVector
      v1.combine(v2, v2, v1) shouldBe v1.value.combine(v2.value, v2.value, v1.value).toOPVector
    }
  }

} 
Example 100
Source File: URLTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features.types

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.PropSpec
import org.scalatest.junit.JUnitRunner
import org.scalatest.prop.PropertyChecks

@RunWith(classOf[JUnitRunner])
class URLTest extends PropSpec with PropertyChecks with TestCommon {

  val badOnes = Table("bad ones",
    None,
    Some(""),
    Some("protocol://domain.codomain"),
    Some("httpd://domain.codomain"),
    Some("http://domain."),
    Some("ftp://.codomain"),
    Some("https://.codomain"),
    Some("//domain.nambia"),
    Some("http://\u00ff\u0080\u007f\u0000.com") // scalastyle:off
  )

  val goodOnes = Table("good ones",
    "https://nothinghere.com?Eli=%E6%B8%87%40",
    "http://nothingthere.com?Chr=%E5%95%A9%E7%B1%85&Raj=%E7%B5%89%EC%AE%A1&Hir=%E5%B3%8F%E0%B4%A3",
    "ftp://my.red.book.com/amorcito.mio",
    "http://secret.gov?Cla=%E9%99%B9%E4%8A%93&Cha=%E3%95%98%EA%A3%A7&Eve=%EC%91%90%E8%87%B1",
    "ftp://nukes.mil?Lea=%E2%BC%84%EB%91%A3&Mur=%E2%83%BD%E1%92%83"
  )

  property("validate urls") {
    forAll(badOnes) {
      sample => URL(sample).isValid shouldBe false
    }
    forAll(goodOnes) {
      sample => URL(sample).isValid shouldBe true
    }
    forAll(goodOnes) {
      sample => URL(sample).isValid(protocols = Array("http")) shouldBe sample.startsWith("http:")
    }
  }

  property("extract domain") {
    val samples = Table("samples",
      "https://nothinghere.com?Eli=%E6%B8%87%40" -> "nothinghere.com",
      "http://nothingthere.com?Chr=%E5%85&Raj=%E7%B5%AE%A1&Hir=%8F%E0%B4%A3" -> "nothingthere.com",
      "ftp://my.red.book.com/amorcito.mio" -> "my.red.book.com",
      "http://secret.gov?Cla=%E9%99%B9%E4%8A%93&Cha=%E3&Eve=%EC%91%90%E8%87%B1" -> "secret.gov",
      "ftp://nukes.mil?Lea=%E2%BC%84%EB%91%A3&Mur=%E2%83%BD%E1%92%83" -> "nukes.mil"
    )

    URL(None).domain shouldBe None

    forAll(samples) {
      case (sample, expected) =>
        val url = URL(sample)
        val domain = url.domain
        domain shouldBe Some(expected)
    }
  }

  property("extract protocol") {
    val samples = Table("samples",
      "https://nothinghere.com?Eli=%E6%B8%87%40" -> "https",
      "http://nothingthere.com?Chr=%E5%85&Raj=%E7%B5%AE%A1&Hir=%8F%E0%B4%A3" -> "http",
      "ftp://my.red.book.com/amorcito.mio" -> "ftp"
    )

    URL(None).protocol shouldBe None

    forAll(samples) {
      case (sample, expected) =>
        val url = URL(sample)
        val domain = url.protocol
        domain shouldBe Some(expected)
    }
  }
} 
Example 101
Source File: FeatureSparkTypeTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features

import com.salesforce.op.features.types.FeatureType
import com.salesforce.op.test.{TestCommon, TestSparkContext}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.sql.types._
import org.junit.runner.RunWith
import org.scalatest.{Assertion, FlatSpec}
import org.scalatest.junit.JUnitRunner

import scala.reflect.runtime.universe._

@RunWith(classOf[JUnitRunner])
class FeatureSparkTypeTest extends FlatSpec with TestCommon {
  val primitiveTypes = Seq(
    (DoubleType, weakTypeTag[types.Real], DoubleType),
    (FloatType, weakTypeTag[types.Real], DoubleType),
    (LongType, weakTypeTag[types.Integral], LongType),
    (IntegerType, weakTypeTag[types.Integral], LongType),
    (ShortType, weakTypeTag[types.Integral], LongType),
    (ByteType, weakTypeTag[types.Integral], LongType),
    (DateType, weakTypeTag[types.Date], LongType),
    (TimestampType, weakTypeTag[types.DateTime], LongType),
    (StringType, weakTypeTag[types.Text], StringType),
    (BooleanType, weakTypeTag[types.Binary], BooleanType),
    (VectorType, weakTypeTag[types.OPVector], VectorType)
  )

  val nonNullable = Seq(
    (DoubleType, weakTypeTag[types.RealNN], DoubleType),
    (FloatType, weakTypeTag[types.RealNN], DoubleType)
  )

  private def mapType(v: DataType) = MapType(StringType, v, valueContainsNull = true)
  private def arrType(v: DataType) = ArrayType(v, containsNull = true)

  val collectionTypes = Seq(
    (arrType(LongType), weakTypeTag[types.DateList], arrType(LongType)),
    (arrType(DoubleType), weakTypeTag[types.Geolocation], arrType(DoubleType)),
    (arrType(StringType), weakTypeTag[types.TextList], arrType(StringType)),
    (mapType(StringType), weakTypeTag[types.TextMap], mapType(StringType)),
    (mapType(DoubleType), weakTypeTag[types.RealMap], mapType(DoubleType)),
    (mapType(LongType), weakTypeTag[types.IntegralMap], mapType(LongType)),
    (mapType(BooleanType), weakTypeTag[types.BinaryMap], mapType(BooleanType)),
    (mapType(arrType(StringType)), weakTypeTag[types.MultiPickListMap], mapType(arrType(StringType))),
    (mapType(arrType(DoubleType)), weakTypeTag[types.GeolocationMap], mapType(arrType(DoubleType)))
  )

  Spec(FeatureSparkTypes.getClass) should "assign appropriate feature type tags for valid types and versa" in {
    primitiveTypes.map(scala.Function.tupled(assertTypes()))
  }

  it should "assign appropriate feature type tags for valid non-nullable types and versa" in {
    nonNullable.map(scala.Function.tupled(assertTypes(isNullable = false)))
  }

  it should "assign appropriate feature type tags for collection types and versa" in {
    collectionTypes.map(scala.Function.tupled(assertTypes()))
  }

  it should "error for unsupported types" in {
    val error = intercept[IllegalArgumentException](FeatureSparkTypes.featureTypeTagOf(BinaryType, isNullable = false))
    error.getMessage shouldBe "Spark BinaryType is currently not supported"
  }

  it should "error for unknown types" in {
    val unknownType = NullType
    val error = intercept[IllegalArgumentException](FeatureSparkTypes.featureTypeTagOf(unknownType, isNullable = false))
    error.getMessage shouldBe s"No feature type tag mapping for Spark type $unknownType"
  }

  def assertTypes(
    isNullable: Boolean = true
  )(
    sparkType: DataType,
    featureType: WeakTypeTag[_ <: FeatureType],
    expectedSparkType: DataType
  ): Assertion = {
    FeatureSparkTypes.featureTypeTagOf(sparkType, isNullable) shouldBe featureType
    FeatureSparkTypes.sparkTypeOf(featureType) shouldBe expectedSparkType
  }

} 
Example 102
Source File: RichStructTypeTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.spark

import com.salesforce.op.test.TestSparkContext
import org.apache.spark.sql.functions._
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class RichStructTypeTest extends FlatSpec with TestSparkContext {

  import com.salesforce.op.utils.spark.RichStructType._

  case class Human
  (
    name: String,
    age: Double,
    height: Double,
    heightIsNull: Double,
    isBlueEyed: Double,
    gender: Double,
    testFeatNegCor: Double
  )

  // scalastyle:off
  val humans = Seq(
    Human("alex",     32,  5.0,  0,  1,  1,  0),
    Human("alice",    32,  4.0,  1,  0,  0,  1),
    Human("bob",      32,  6.0,  1,  1,  1,  0),
    Human("charles",  32,  5.5,  0,  1,  1,  0),
    Human("diana",    32,  5.4,  1,  0,  0,  1),
    Human("max",      32,  5.4,  1,  0,  0,  1)
  )
  // scalastyle:on

  val humansDF = spark.createDataFrame(humans).select(col("*"), col("name").as("(name)_blarg_123"))
  val schema = humansDF.schema

  Spec[RichStructType] should "find schema fields by name (case insensitive)" in {
    schema.findFields("name").map(_.name) shouldBe Seq("name", "(name)_blarg_123")
    schema.findFields("blArg").map(_.name) shouldBe Seq("(name)_blarg_123")
  }

  it should "find schema fields by name (case sensitive)" in {
    schema.findFields("Name", ignoreCase = false) shouldBe Seq.empty
    schema.findFields("aGe", ignoreCase = false) shouldBe Seq.empty
    schema.findFields("age", ignoreCase = false).map(_.name) shouldBe Seq("age")
  }

  it should "fail on duplication" in {
    the[IllegalArgumentException] thrownBy schema.findField("a")
  }

  it should "throw an error if no such name" in {
    the[IllegalArgumentException] thrownBy schema.findField("???")
  }

} 
Example 103
Source File: TimeBasedAggregatorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.aggregators

import com.salesforce.op.features.FeatureBuilder
import com.salesforce.op.features.types._
import com.salesforce.op.stages.FeatureGeneratorStage
import com.salesforce.op.test.TestCommon
import org.joda.time.Duration
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TimeBasedAggregatorTest extends FlatSpec with TestCommon {

  private val data = Seq(TimeBasedTest(100L, 1.0, "a", Map("a" -> "a")),
    TimeBasedTest(200L, 2.0, "b", Map("b" -> "b")),
    TimeBasedTest(300L, 3.0, "c", Map("c" -> "c")),
    TimeBasedTest(400L, 4.0, "d", Map("d" -> "d")),
    TimeBasedTest(500L, 5.0, "e", Map("e" -> "e")),
    TimeBasedTest(600L, 6.0, "f", Map("f" -> "f"))
  )

  private val timeExt = Option((d: TimeBasedTest) => d.time)

  Spec[LastAggregator[_]] should "return the most recent event" in {
    val feature = FeatureBuilder.Real[TimeBasedTest].extract(_.real.toRealNN)
      .aggregate(LastReal).asPredictor
    val aggregator = feature.originStage.asInstanceOf[FeatureGeneratorStage[TimeBasedTest, _]].featureAggregator
    val extracted = aggregator.extract(data, timeExt, CutOffTime.NoCutoff())
    extracted shouldBe Real(Some(6.0))
  }

  it should "return the most recent event within the time window" in {
    val feature = FeatureBuilder.Text[TimeBasedTest].extract(_.string.toText)
      .aggregate(LastText).asResponse
    val aggregator = feature.originStage.asInstanceOf[FeatureGeneratorStage[TimeBasedTest, _]].featureAggregator
    val extracted = aggregator.extract(data, timeExt, CutOffTime.UnixEpoch(300L),
      responseWindow = Option(new Duration(201L)))
    extracted shouldBe Text(Some("e"))
  }

  it should "return the feature type empty value when no events are passed in" in {
    val feature = FeatureBuilder.TextMap[TimeBasedTest].extract(_.map.toTextMap)
      .aggregate(LastTextMap).asPredictor
    val aggregator = feature.originStage.asInstanceOf[FeatureGeneratorStage[TimeBasedTest, _]].featureAggregator
    val extracted = aggregator.extract(Seq(), timeExt, CutOffTime.NoCutoff())
    extracted shouldBe TextMap.empty
  }

  Spec[FirstAggregator[_]] should "return the first event" in {
    val feature = FeatureBuilder.TextAreaMap[TimeBasedTest].extract(_.map.toTextAreaMap)
      .aggregate(FirstTextAreaMap).asResponse
    val aggregator = feature.originStage.asInstanceOf[FeatureGeneratorStage[TimeBasedTest, _]].featureAggregator
    val extracted = aggregator.extract(data, timeExt, CutOffTime.UnixEpoch(301L))
    extracted shouldBe TextAreaMap(Map("d" -> "d"))
  }

  it should "return the first event within the time window" in {
    val feature = FeatureBuilder.Currency[TimeBasedTest].extract(_.real.toCurrency)
      .aggregate(FirstCurrency).asPredictor
    val aggregator = feature.originStage.asInstanceOf[FeatureGeneratorStage[TimeBasedTest, _]].featureAggregator
    val extracted = aggregator.extract(data, timeExt, CutOffTime.UnixEpoch(400L),
      predictorWindow = Option(new Duration(201L)))
    extracted shouldBe Currency(Some(2.0))
  }

  it should "return the feature type empty value when no events are passed in" in {
    val feature = FeatureBuilder.State[TimeBasedTest].extract(_.string.toState)
      .aggregate(FirstState).asPredictor
    val aggregator = feature.originStage.asInstanceOf[FeatureGeneratorStage[TimeBasedTest, _]].featureAggregator
    val extracted = aggregator.extract(Seq(), timeExt, CutOffTime.NoCutoff())
    extracted shouldBe State.empty
  }
}

case class TimeBasedTest(time: Long, real: Double, string: String, map: Map[String, String]) 
Example 104
Source File: ExtendedMultisetTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.aggregators

import com.salesforce.op.aggregators.{ExtendedMultiset => SUT}
import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ExtendedMultisetTest extends FlatSpec with TestCommon {

  Spec[ExtendedMultiset] should "add" in {
    val sut1 = Map[String, Long]("a" -> 1, "b" -> 0, "c" -> 42)
    val sut2 = Map[String, Long]("d" -> 7, "b" -> 0, "c" -> 2)

    SUT.plus(sut1, sut2) shouldBe Map[String, Long]("a" -> 1, "c" -> 44, "d" -> 7)
    SUT.plus(SUT.zero, sut2) shouldBe sut2
    SUT.plus(sut1, SUT.zero) shouldBe sut1
  }

  it should "subtract" in {
    val sut1 = Map[String, Long]("a" -> 1, "b" -> 0, "c" -> 42)
    val sut2 = Map[String, Long]("d" -> 7, "b" -> 0, "c" -> 2)

    SUT.minus(sut1, sut2) shouldBe Map[String, Long]("a" -> 1, "c" -> 40, "d" -> -7)
    SUT.minus(sut1, SUT.zero) shouldBe Map[String, Long]("a" -> 1, "c" -> 42)
    SUT.minus(SUT.zero, sut2) shouldBe Map[String, Long]("d" -> -7, "c" -> -2)
  }
} 
Example 105
Source File: EventTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.aggregators

import com.salesforce.op.features.types._
import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class EventTest extends FlatSpec with TestCommon {

  Spec[Event[_]] should "compare" in {
    val sut1 = Event[Integral](123, Integral(42), isResponse = false)
    val sut2 = Event[Integral](321, Integral(666))
    (sut1 compare sut1) shouldBe 0
    (sut1 compare sut2) shouldBe -1
    (sut2 compare sut1) shouldBe 1
    (sut2 compare sut2) shouldBe 0
    sut2.isResponse shouldBe false
  }

} 
Example 106
Source File: SequentialMulticastServiceSpec.scala    From diffy   with GNU Affero General Public License v3.0 5 votes vote down vote up
package ai.diffy.proxy

import ai.diffy.ParentSpec
import com.twitter.finagle.Service
import com.twitter.util._
import org.junit.runner.RunWith
import org.mockito.Mockito._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class SequentialMulticastServiceSpec extends ParentSpec {

  describe("SequentialMulticastService"){
    val first, second = mock[Service[String, String]]
    val multicastHandler = new SequentialMulticastService(Seq(first, second))

    it("must not access second until first is done"){
      val firstResponse, secondResponse = new Promise[String]
      when(first("anyString")) thenReturn firstResponse
      when(second("anyString")) thenReturn secondResponse
      val result = multicastHandler("anyString")
      verify(first)("anyString")
      verifyZeroInteractions(second)
      firstResponse.setValue("first")
      verify(second)("anyString")
      secondResponse.setValue("second")
      Await.result(result) must be(Seq(Try("first"), Try("second")))
    }

    it("should call all services") {
      val request = "anyString"
      val services = Seq.fill(100)(mock[Service[String, Int]])
      val responses = Seq.fill(100)(new Promise[Int])
      val svcResp = services zip responses
      svcResp foreach { case (service, response) =>
        when(service(request)) thenReturn response
      }
      val sequentialMulticast = new SequentialMulticastService(services)
      val result = sequentialMulticast("anyString")
      def verifySequentialInteraction(s: Seq[((Service[String,Int], Promise[Int]), Int)]): Unit = s match {
        case Nil =>
        case Seq(((svc, resp), index), tail@_*)  => {
          verify(svc)(request)
          tail foreach { case ((subsequent, _), _) =>
            verifyZeroInteractions(subsequent)
          }
          resp.setValue(index)
          verifySequentialInteraction(tail)
        }
      }
      verifySequentialInteraction(svcResp.zipWithIndex)
      Await.result(result) must be((0 until 100).toSeq map {i => Try(i)})
    }
  }
} 
Example 107
Source File: ConditionDefinitionParserSpec.scala    From vamp   with Apache License 2.0 5 votes vote down vote up
package io.vamp.model.parser

import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ConditionDefinitionParserSpec extends FlatSpec with Matchers with ConditionDefinitionParser {

  "ConditionDefinition" should "parse" in {
    parse("user-agent == a and user-agent != b") shouldBe {
      And(UserAgent("a"), Negation(UserAgent("b")))
    }

    parse("user-agent == a and not user-agent == b") shouldBe {
      And(UserAgent("a"), Negation(UserAgent("b")))
    }
  }

  it should "resolve direct value" in {
    parse("<water ⇒ ice>") shouldBe Value("water ⇒ ice")
    parse("< water ⇒ ice>") shouldBe Value("water ⇒ ice")
    parse("< water ⇒ ice >") shouldBe Value("water ⇒ ice")
    parse(" <water ⇒ ice>") shouldBe Value("water ⇒ ice")
    parse("<water ⇒ ice> ") shouldBe Value("water ⇒ ice")
    parse(" <water ⇒ ice> ") shouldBe Value("water ⇒ ice")
    parse("  <  water  ⇒  ice  > ") shouldBe Value("water  ⇒  ice")
  }

  it should "resolve user agent" in {
    parse("User-Agent==Firefox") shouldBe UserAgent("Firefox")
    parse("User-Agent== Firefox ") shouldBe UserAgent("Firefox")
    parse("User-Agent == Firefox ") shouldBe UserAgent("Firefox")
    parse(" User-Agent == Firefox ") shouldBe UserAgent("Firefox")
    parse(" User-Agent ! Firefox ") shouldBe Negation(UserAgent("Firefox"))
    parse(" User-Agent != Firefox ") shouldBe Negation(UserAgent("Firefox"))
    parse("user.agent != Firefox") shouldBe Negation(UserAgent("Firefox"))
    parse("User-Agent is Firefox") shouldBe UserAgent("Firefox")
    parse("User-Agent not Firefox") shouldBe Negation(UserAgent("Firefox"))
    parse(" ( User-Agent == Firefox ) ") shouldBe UserAgent("Firefox")
  }

  it should "resolve host" in {
    parse("host == localhost") shouldBe Host("localhost")
    parse("host != localhost") shouldBe Negation(Host("localhost"))

    parse("host is localhost") shouldBe Host("localhost")
    parse("host not localhost") shouldBe Negation(Host("localhost"))
    parse("host misses localhost") shouldBe Negation(Host("localhost"))

    parse("host has localhost") shouldBe Host("localhost")
    parse("host contains localhost") shouldBe Host("localhost")

    parse("! host is localhost") shouldBe Negation(Host("localhost"))
    parse("not host is localhost") shouldBe Negation(Host("localhost"))
  }

  it should "resolve cookie" in {
    parse("has cookie vamp") shouldBe Cookie("vamp")
    parse("misses cookie vamp") shouldBe Negation(Cookie("vamp"))
    parse("contains cookie vamp") shouldBe Cookie("vamp")
  }

  it should "resolve header" in {
    parse("has header vamp") shouldBe Header("vamp")
    parse("misses header vamp") shouldBe Negation(Header("vamp"))
    parse("contains header vamp") shouldBe Header("vamp")
  }

  it should "resolve cookie contains" in {
    parse("cookie vamp has 12345") shouldBe CookieContains("vamp", "12345")
    parse("cookie vamp misses 12345") shouldBe Negation(CookieContains("vamp", "12345"))
    parse("cookie vamp contains 12345") shouldBe CookieContains("vamp", "12345")
  }

  it should "resolve header contains" in {
    parse("header vamp has 12345") shouldBe HeaderContains("vamp", "12345")
    parse("header vamp misses 12345") shouldBe Negation(HeaderContains("vamp", "12345"))
    parse("header vamp contains 12345") shouldBe HeaderContains("vamp", "12345")
  }
} 
Example 108
Source File: BooleanFlatterSpec.scala    From vamp   with Apache License 2.0 5 votes vote down vote up
package io.vamp.model.parser

import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class BooleanFlatterSpec extends FlatSpec with Matchers with BooleanFlatter with BooleanParser {

  "BooleanMapper" should "map 'and'" in {
    transform(parse("a and b")) shouldBe parse("a and b")
    transform(parse("!a and b")) shouldBe parse("!a and b")
    transform(parse("!a and b and !c")) shouldBe parse("!a and b and !c")

    transform(parse("true and true")) shouldBe parse("true")
    transform(parse("1 and 0")) shouldBe parse("false")
    transform(parse("F and F")) shouldBe parse("false")

    transform(parse("((a) and (b))")) shouldBe parse("a and b")
  }

  it should "map combined" in {
    transform(parse("a or b")) shouldBe parse("(a and !b) or (!a and b) or (a and b)")
    transform(parse("(a or b) and c")) shouldBe parse("(a and !b and c) or (!a and b and c) or (a and b and c)")
    transform(parse("(a or !b) and c")) shouldBe parse("(!a and !b and c) or (a and !b and c) or (a and b and c)")
  }

  it should "reduce" in {
    flatten(parse("a or a")) shouldBe parse("a")
    flatten(parse("a or !a")) shouldBe parse("true")
    flatten(parse("a and a")) shouldBe parse("a")
    flatten(parse("a and !a")) shouldBe parse("false")

    flatten(parse("a or b")) shouldBe parse("a or b")
    flatten(parse("a and b")) shouldBe parse("a and b")
    flatten(parse("a or !b")) shouldBe parse("a or !b")
    flatten(parse("(a or b) and c")) shouldBe parse("(a and c) or (b and c)")
    flatten(parse("(a or !b) and c")) shouldBe parse("(!b and c) or (a and c)")

    flatten(parse("true or true")) shouldBe parse("1")
    flatten(parse("1 or 0")) shouldBe parse("T")
    flatten(parse("F or F")) shouldBe parse("false")

    flatten(parse("a or true")) shouldBe parse("true")
    flatten(parse("a or false")) shouldBe parse("a")

    flatten(parse("a and true")) shouldBe parse("a")
    flatten(parse("a and false")) shouldBe parse("0")

    flatten(parse("a or true and b")) shouldBe parse("a or b")
    flatten(parse("a or false and b")) shouldBe parse("a")

    flatten(parse("a and true or b")) shouldBe parse("a or b")
    flatten(parse("a and false or b")) shouldBe parse("b")
  }

  private def transform(node: AstNode): AstNode = {
    val terms = map(node)
    if (terms.nonEmpty) {
      terms map {
        _.terms.reduce {
          (op1, op2) ⇒ And(op1, op2)
        }
      } reduce {
        (op1, op2) ⇒ Or(op1, op2)
      }
    }
    else False
  }
} 
Example 109
Source File: BooleanParserSpec.scala    From vamp   with Apache License 2.0 5 votes vote down vote up
package io.vamp.model.parser

import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class BooleanParserSpec extends FlatSpec with Matchers with BooleanParser {

  "BooleanParser" should "parse operands" in {
    parse("a") shouldBe Value("a")
    parse(" a") shouldBe Value("a")
    parse("a ") shouldBe Value("a")
    parse(" a ") shouldBe Value("a")
    parse("((a))") shouldBe Value("a")

    parse("TRUE") shouldBe True
    parse("t ") shouldBe True
    parse(" 1 ") shouldBe True

    parse("False") shouldBe False
    parse(" F") shouldBe False
    parse(" 0 ") shouldBe False
    parse("(false)") shouldBe False
    parse("((1))") shouldBe True
  }

  it should "parse 'not' expression" in {
    parse("not a") shouldBe Negation(Value("a"))
    parse("! a") shouldBe Negation(Value("a"))
    parse(" not a ") shouldBe Negation(Value("a"))
  }

  it should "parse 'and' expression" in {
    parse("a and b") shouldBe And(Value("a"), Value("b"))
    parse("a && b") shouldBe And(Value("a"), Value("b"))
    parse("a & b") shouldBe And(Value("a"), Value("b"))
    parse(" a & b ") shouldBe And(Value("a"), Value("b"))
    parse("a AND b And c") shouldBe And(And(Value("a"), Value("b")), Value("c"))

    parse("true and true") shouldBe And(True, True)
    parse("1 and 0") shouldBe And(True, False)
    parse("F && F") shouldBe And(False, False)

    parse("a and true") shouldBe And(Value("a"), True)
    parse("a & false") shouldBe And(Value("a"), False)
  }

  it should "parse 'or' expression" in {
    parse("a or b") shouldBe Or(Value("a"), Value("b"))
    parse("a || b") shouldBe Or(Value("a"), Value("b"))
    parse("a | b") shouldBe Or(Value("a"), Value("b"))
    parse(" a or  b ") shouldBe Or(Value("a"), Value("b"))
    parse("a Or b | c") shouldBe Or(Or(Value("a"), Value("b")), Value("c"))

    parse("true or true") shouldBe Or(True, True)
    parse("1 || 0") shouldBe Or(True, False)
    parse("F | F") shouldBe Or(False, False)

    parse("a or true") shouldBe Or(Value("a"), True)
    parse("a or false") shouldBe Or(Value("a"), False)
  }

  it should "parse parenthesis expression" in {
    parse("((a))") shouldBe Value("a")
    parse("(a and b)") shouldBe And(Value("a"), Value("b"))
    parse("(a and b) and c") shouldBe And(And(Value("a"), Value("b")), Value("c"))
    parse("a and (b and c)") shouldBe And(Value("a"), And(Value("b"), Value("c")))
  }

  it should "parse combined expression" in {
    parse("a and b or c") shouldBe Or(And(Value("a"), Value("b")), Value("c"))
    parse("a or b and c") shouldBe Or(Value("a"), And(Value("b"), Value("c")))
    parse("!(a and b) and c") shouldBe And(Negation(And(Value("a"), Value("b"))), Value("c"))
    parse("(a or b) and c") shouldBe And(Or(Value("a"), Value("b")), Value("c"))
    parse("a or (b and c)") shouldBe Or(Value("a"), And(Value("b"), Value("c")))
  }
} 
Example 110
Source File: PaginationSupportSpec.scala    From vamp   with Apache License 2.0 5 votes vote down vote up
package io.vamp.persistence

import io.vamp.common.akka.ExecutionContextProvider
import io.vamp.common.http.OffsetResponseEnvelope
import org.junit.runner.RunWith
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.junit.JUnitRunner
import org.scalatest.time.{ Millis, Seconds, Span }
import org.scalatest.{ FlatSpec, Matchers }

import scala.concurrent.{ ExecutionContext, Future }

@RunWith(classOf[JUnitRunner])
class PaginationSupportSpec extends FlatSpec with Matchers with PaginationSupport with ScalaFutures with ExecutionContextProvider {

  case class ResponseEnvelope(response: List[Int], total: Long, page: Int, perPage: Int) extends OffsetResponseEnvelope[Int]

  implicit def executionContext = ExecutionContext.global

  implicit val defaultPatience = PatienceConfig(timeout = Span(3, Seconds), interval = Span(100, Millis))

  "PaginationSupport" should "collect all from single page" in {
    val list = (1 to 5).toList

    val source = (page: Int, perPage: Int) ⇒ Future {
      ResponseEnvelope(list.take(perPage), list.size, page, perPage)
    }

    whenReady(consume(allPages(source, 5)))(_ shouldBe list)
  }

  it should "collect all from multiple pages" in {
    val list = (1 to 15).toList

    val source = (page: Int, perPage: Int) ⇒ Future {
      ResponseEnvelope(list.slice((page - 1) * perPage, page * perPage), list.size, page, perPage)
    }

    whenReady(consume(allPages(source, 5)))(_ shouldBe list)
  }

  it should "collect all from multiple pages without round total / per page" in {
    val list = (1 to 17).toList

    val source = (page: Int, perPage: Int) ⇒ Future {
      ResponseEnvelope(list.slice((page - 1) * perPage, page * perPage), list.size, page, perPage)
    }

    whenReady(consume(allPages(source, 5)))(_ shouldBe list)
  }
} 
Example 111
Source File: EmailSenderSpec.scala    From diffy   with GNU Affero General Public License v3.0 5 votes vote down vote up
package ai.diffy.util

import java.util.Date

import ai.diffy.ParentSpec
import com.twitter.logging.Logger
import com.twitter.util.Await
import org.junit.runner.RunWith
import org.mockito.Mockito._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class EmailSenderSpec extends ParentSpec {

  val log = mock[Logger]
  val sender = new EmailSender(log, _ => ())
  describe("EmailSender") {
    it("should not encouter any errors while trying to compose emails") {
      Await.result(
        sender(
          SimpleMessage(
            from = "Diffy <[email protected]>",
            to = "[email protected]",
            bcc = "[email protected]",
            subject = "Diffy Report at " + new Date,
            body = "just testing emails from mesos!"
          )
        )
      )
      verifyZeroInteractions(log)
    }
  }
} 
Example 112
Source File: ResourceMatcherSpec.scala    From diffy   with GNU Affero General Public License v3.0 5 votes vote down vote up
package ai.diffy.util

import ai.diffy.ParentSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ResourceMatcherSpec extends ParentSpec {

  describe("The HTTP PathMatcher") {

    it("should support parameter placeholders") {
      new ResourceMatcher(List("/path1/:param1/path2/:param2" -> "p1"))
        .resourceName("/path1/param1/path2/param2") mustBe Some("p1")
    }

    it("should support wildcards, matching everything after the wildcard") {
      new ResourceMatcher(List("/path1path/param2" -> "p3"))
        .resourceName("/path1/param1/path/param2") mustBe Some("p3")
    }
  }
} 
Example 113
Source File: JsonLifterSpec.scala    From diffy   with GNU Affero General Public License v3.0 5 votes vote down vote up
package ai.diffy.lifter

import ai.diffy.ParentSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class JsonLifterSpec extends ParentSpec {
  describe("JsonLifter"){
    it("should correctly lift maps when keys are invalid identifier prefixes") {
      JsonLifter.lift(JsonLifter.decode("""{"1":1}""")) mustBe a [Map[_, _]]
    }

    it("should correctly lift objects when keys are valid identifier prefixes") {
      JsonLifter.lift(JsonLifter.decode("""{"a":1}""")) mustBe a [FieldMap[_]]
    }
  }
} 
Example 114
Source File: StringLifterSpec.scala    From diffy   with GNU Affero General Public License v3.0 5 votes vote down vote up
package ai.diffy.lifter

import ai.diffy.ParentSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class StringLifterSpec extends ParentSpec {
  describe("String") {
    val htmlString = "<html><head>as</head><body><p>it's an html!</p></body></html>"
    val jsonString = """{"a": "it's a json!" }"""
    val regularString = "hello world!"

    it("should be true") {
      StringLifter.htmlRegexPattern.findFirstIn(htmlString).isDefined must be (true)
    }

    it("must return a FieldMap when lifted (html)") {
      StringLifter.lift(htmlString) mustBe a [FieldMap[_]]
    }

    it("must return a FieldMap when lifted (json)") {
      StringLifter.lift(jsonString) mustBe a [FieldMap[_]]
    }

    it("must return the original string when lifted") {
      StringLifter.lift(regularString) must be ("hello world!")
    }
  }
} 
Example 115
Source File: HtmlLifterSpec.scala    From diffy   with GNU Affero General Public License v3.0 5 votes vote down vote up
package ai.diffy.lifter

import ai.diffy.ParentSpec
import ai.diffy.compare.{Difference, PrimitiveDifference}
import org.jsoup.Jsoup
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class HtmlLifterSpec extends ParentSpec {
  describe("HtmlLifter"){
    val simpleActualHtml = """<html><head><title>Sample HTML</title></head><body><div class="header"><h1 class="box">Hello World</h1></div><p>Lorem ipsum dolor sit amet.</p></body></html>"""
    val simpleExpectedHtml = """<html><head><title>Sample HTML</title></head><body><div class="header"><h1 class="round">Hello World</h1></div><p>Lorem ipsum dolor sit amet.</p></body></html>"""

    val simpleActualDoc = Jsoup.parse(simpleActualHtml)
    val simpleExpectedDoc = Jsoup.parse(simpleExpectedHtml)

    it("should return a FieldMap") {
      HtmlLifter.lift(simpleActualDoc) mustBe a [FieldMap[_]]
    }

    it("should return a Primitive Difference") {
      Difference(HtmlLifter.lift(simpleActualDoc), HtmlLifter.lift(simpleExpectedDoc)).flattened must be (FieldMap(Map("body.children.children.attributes.class.PrimitiveDifference" -> PrimitiveDifference("box","round"))))
    }
  }
} 
Example 116
Source File: DifferenceStatsMonitorSpec.scala    From diffy   with GNU Affero General Public License v3.0 5 votes vote down vote up
package ai.diffy.workflow

import ai.diffy.ParentSpec
import ai.diffy.analysis.{DifferenceCounter, EndpointMetadata, RawDifferenceCounter}
import com.twitter.finagle.stats.InMemoryStatsReceiver
import com.twitter.util.{Duration, Future, MockTimer, Time}
import org.junit.runner.RunWith
import org.mockito.Mockito._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class DifferenceStatsMonitorSpec extends ParentSpec {
  describe("DifferenceStatsMonitor"){
    val diffCounter = mock[DifferenceCounter]
    val metadata =
      new EndpointMetadata {
        override val differences = 0
        override val total = 0
      }

    val endpoints = Map("endpointName" -> metadata)
    when(diffCounter.endpoints) thenReturn Future.value(endpoints)

    val stats = new InMemoryStatsReceiver
    val timer = new MockTimer
    val monitor = new DifferenceStatsMonitor(RawDifferenceCounter(diffCounter), stats, timer)

    it("must add gauges after waiting a minute"){
      Time.withCurrentTimeFrozen { tc =>
        monitor.schedule()
        timer.tasks.size must be(1)
        stats.gauges.size must be(0)
        tc.advance(Duration.fromMinutes(1))
        timer.tick()
        timer.tasks.size must be(1)
        stats.gauges.size must be(2)
        stats.gauges.keySet map { _.takeRight(2) } must be(Set(Seq("endpointName", "total"), Seq("endpointName", "differences")))
      }
    }
  }
} 
Example 117
Source File: EventReaderSpec.scala    From vamp   with Apache License 2.0 5 votes vote down vote up
package io.vamp.model.reader

import java.time.OffsetDateTime

import io.vamp.model.event.{ Aggregator, TimeRange }
import io.vamp.model.notification._
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class EventReaderSpec extends FlatSpec with Matchers with ReaderSpec {

  "EventReader" should "read the event" in {
    EventReader.read(res("event/event1.yml")) should have(
      'tags(Set("server", "service")),
      'timestamp(OffsetDateTime.parse("2015-06-05T15:12:38.000Z")),
      'value(0),
      'type("metrics")
    )
  }

  it should "expand tags" in {
    EventReader.read(res("event/event2.yml")) should have(
      'tags(Set("server")),
      'value(Map("response" → Map("time" → 50))),
      'type("metrics")
    )
  }

  it should "fail on no tag" in {
    expectedError[MissingPathValueError]({
      EventReader.read(res("event/event3.yml"))
    })
  }

  it should "fail on empty tags" in {
    expectedError[NoTagEventError.type]({
      EventReader.read(res("event/event4.yml"))
    })
  }

  it should "fail on invalid timestamp" in {
    expectedError[EventTimestampError]({
      EventReader.read(res("event/event5.yml"))
    })
  }

  it should "parse no value" in {
    EventReader.read(res("event/event6.yml")) should have(
      'tags(Set("server")),
      'value(None)
    )
  }

  it should "fail on unsupported type" in {
    expectedError[EventTypeError]({
      EventReader.read(res("event/event7.yml"))
    })
  }

  "EventQueryReader" should "read the query" in {
    EventQueryReader.read(res("event/query1.yml")) should have(
      'tags(Set("server", "service")),
      'type(None),
      'timestamp(Some(TimeRange(None, None, Some("now() - 10m"), None))),
      'aggregator(Some(Aggregator(Aggregator.average, Some("response.time"))))
    )
  }

  it should "expand tags" in {
    EventQueryReader.read(res("event/query2.yml")) should have(
      'tags(Set("server")),
      'timestamp(None),
      'aggregator(None)
    )
  }

  it should "fail on invalid time range" in {
    expectedError[EventQueryTimeError.type]({
      EventQueryReader.read(res("event/query3.yml"))
    })
  }

  it should "fail on unsupported aggregator" in {
    expectedError[UnsupportedAggregatorError]({
      EventQueryReader.read(res("event/query4.yml"))
    })
  }

  it should "read the query type" in {
    EventQueryReader.read(res("event/query5.yml")) should have(
      'tags(Set("server", "service")),
      'type(Option("router")),
      'timestamp(None),
      'aggregator(None)
    )
  }
} 
Example 118
Source File: RabbitMQDistributedConsumerIT.scala    From spark-rabbitmq   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.rabbitmq

import java.util.UUID

import com.rabbitmq.client.QueueingConsumer.Delivery
import org.apache.spark.streaming.rabbitmq.distributed.RabbitMQDistributedKey
import org.apache.spark.streaming.rabbitmq.models.ExchangeAndRouting
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class RabbitMQDistributedConsumerIT extends TemporalDataSuite {

  override val queueName = s"$configQueueName-${this.getClass().getName()}-${UUID.randomUUID().toString}"

  override val exchangeName = s"$configExchangeName-${this.getClass().getName()}-${UUID.randomUUID().toString}"

    test("RabbitMQ Receiver should read all the records") {
      
      val rabbitMQParams = Map.empty[String, String]

      val rabbitMQConnection = Map(
        "hosts" -> hosts,
        "queueName" -> queueName,
        "exchangeName" -> exchangeName,
        "vHost" -> vHost,
        "userName" -> userName,
        "password" -> password
      )

      val distributedKey = Seq(
        RabbitMQDistributedKey(
          queueName,
          new ExchangeAndRouting(exchangeName, routingKey),
          rabbitMQConnection
        )
      )

      //Delivery is not Serializable by Spark, is possible use Map, Seq or native Classes
      import scala.collection.JavaConverters._
      val distributedStream = RabbitMQUtils.createDistributedStream[Map[String, Any]](
        ssc,
        distributedKey,
        rabbitMQParams,
        (rawMessage: Delivery) =>
          Map(
            "body" -> new Predef.String(rawMessage.getBody),
            "exchange" -> rawMessage.getEnvelope.getExchange,
            "routingKey" -> rawMessage.getEnvelope.getRoutingKey,
            "deliveryTag" -> rawMessage.getEnvelope.getDeliveryTag
          ) ++ {
            //Avoid null pointer Exception
            Option(rawMessage.getProperties.getHeaders) match {
              case Some(headers) => Map("headers" -> headers.asScala)
              case None => Map.empty[String, Any]
            }
          }
      )

      val totalEvents = ssc.sparkContext.longAccumulator("Number of events received")

      // Start up the receiver.
      distributedStream.start()

      // Fires each time the configured window has passed.
      distributedStream.foreachRDD(rdd => {
        if (!rdd.isEmpty()) {
          val count = rdd.count()
          // Do something with this message
          println(s"EVENTS COUNT : \t $count")
          totalEvents.add(count)
          //rdd.collect().foreach(event => print(s"${event.toString}, "))
        } else println("RDD is empty")
        println(s"TOTAL EVENTS : \t $totalEvents")
      })

      ssc.start() // Start the computation
      ssc.awaitTerminationOrTimeout(10000L) // Wait for the computation to terminate

      assert(totalEvents.value === totalRegisters.toLong)
    }
} 
Example 119
Source File: RabbitMQConsumerIT.scala    From spark-rabbitmq   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.rabbitmq

import java.util.UUID

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class RabbitMQConsumerIT extends TemporalDataSuite {

  override val queueName = s"$configQueueName-${this.getClass().getName()}-${UUID.randomUUID().toString}"

  override val exchangeName = s"$configExchangeName-${this.getClass().getName()}-${UUID.randomUUID().toString}"

  test("RabbitMQ Receiver should read all the records") {

    val receiverStream = RabbitMQUtils.createStream(ssc, Map(
      "hosts" -> hosts,
      "queueName" -> queueName,
      "exchangeName" -> exchangeName,
      "exchangeType" -> exchangeType,
      "vHost" -> vHost,
      "userName" -> userName,
      "password" -> password
    ))
    val totalEvents = ssc.sparkContext.longAccumulator("My Accumulator")

    // Start up the receiver.
    receiverStream.start()

    // Fires each time the configured window has passed.
    receiverStream.foreachRDD(rdd => {
      if (!rdd.isEmpty()) {
        val count = rdd.count()
        // Do something with this message
        println(s"EVENTS COUNT : \t $count")
        totalEvents.add(count)
        //rdd.collect().sortBy(event => event.toInt).foreach(event => print(s"$event, "))
      } else println("RDD is empty")
      println(s"TOTAL EVENTS : \t $totalEvents")
    })

    ssc.start() // Start the computation
    ssc.awaitTerminationOrTimeout(10000L) // Wait for the computation to terminate

    assert(totalEvents.value === totalRegisters.toLong)
  }
} 
Example 120
Source File: MongodbSchemaIT.scala    From Spark-MongoDB   with Apache License 2.0 5 votes vote down vote up
package com.stratio.datasource.mongodb.schema

import java.text.SimpleDateFormat
import java.util.Locale

import com.stratio.datasource.MongodbTestConstants
import com.stratio.datasource.mongodb.config.{MongodbConfig, MongodbConfigBuilder}
import com.stratio.datasource.mongodb.partitioner.MongodbPartitioner
import com.stratio.datasource.mongodb.rdd.MongodbRDD
import com.stratio.datasource.mongodb._
import org.apache.spark.sql.mongodb.{TemporaryTestSQLContext, TestSQLContext}
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, TimestampType}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class MongodbSchemaIT extends FlatSpec
with Matchers
with MongoEmbedDatabase
with TestBsonData
with MongodbTestConstants {

  private val host: String = "localhost"
  private val collection: String = "testCol"
  private val readPreference = "secondaryPreferred"

  val testConfig = MongodbConfigBuilder()
    .set(MongodbConfig.Host,List(host + ":" + mongoPort))
    .set(MongodbConfig.Database,db)
    .set(MongodbConfig.Collection,collection)
    .set(MongodbConfig.SamplingRatio,1.0)
    .set(MongodbConfig.ReadPreference, readPreference)
    .build()

  val mongodbPartitioner = new MongodbPartitioner(testConfig)

  val mongodbRDD = new MongodbRDD(TemporaryTestSQLContext, testConfig, mongodbPartitioner)

  behavior of "A schema"

  it should "be inferred from rdd with primitives" + scalaBinaryVersion in {
    withEmbedMongoFixture(primitiveFieldAndType) { mongodProc =>
      val schema = MongodbSchema(mongodbRDD, 1.0).schema()

      schema.fields should have size 7
      schema.fieldNames should contain allOf("string", "integer", "long", "double", "boolean", "null")

      schema.printTreeString()
    }
  }

  it should "be inferred from rdd with complex fields" + scalaBinaryVersion in {
    withEmbedMongoFixture(complexFieldAndType1) { mongodProc =>
      val schema = MongodbSchema(mongodbRDD, 1.0).schema()

      schema.fields should have size 13

      schema.fields filter {
        case StructField(name, ArrayType(StringType, _), _, _) => Set("arrayOfNull", "arrayEmpty") contains name
        case _ => false
      } should have size 2

      schema.printTreeString()
    }
  }

  it should "resolve type conflicts between fields" + scalaBinaryVersion in {
    withEmbedMongoFixture(primitiveFieldValueTypeConflict) { mongodProc =>
      val schema = MongodbSchema(mongodbRDD, 1.0).schema()

      schema.fields should have size 7

      schema.printTreeString()
    }
  }

  it should "be inferred from rdd with more complex fields" + scalaBinaryVersion in {
    withEmbedMongoFixture(complexFieldAndType2) { mongodProc =>
      val schema = MongodbSchema(mongodbRDD, 1.0).schema()

      schema.fields should have size 5

      schema.printTreeString()
    }
  }

  it should "read java.util.Date fields as timestamptype" + scalaBinaryVersion in {
    val dfunc = (s: String) => new SimpleDateFormat("EEE MMM dd HH:mm:ss Z yyyy", Locale.ENGLISH).parse(s)
    import com.mongodb.casbah.Imports.DBObject
    val stringAndDate = List(DBObject("string" -> "this is a simple string.", "date" -> dfunc("Mon Aug 10 07:52:49 EDT 2015")))
    withEmbedMongoFixture(stringAndDate) { mongodProc =>
      val schema = MongodbSchema(mongodbRDD, 1.0).schema()

      schema.fields should have size 3
      schema.fields.filter(_.name == "date").head.dataType should equal(TimestampType)
      schema.printTreeString()
    }
  }
} 
Example 121
Source File: ArgumentTrackingOptionsTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.cmdline

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ArgumentTrackingOptionsTest extends CommonWordSpec {
  "ArgumentTrackingOptions" should {
    "correctly convert arguments list" in {
      val sut = new ArgumentTrackingOptions
      val item = new CommandLineArgument("a", 1, false)
      sut.addArgument(item)
      sut.getArguments should contain(item)
    }
  }

} 
Example 122
Source File: CommandLineValidatorsTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.cmdline

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class CommandLineValidatorsTest extends CommonWordSpec {
  "CommandLineValidators" should {
    "validate integers" in {
      CommandLineValidators.validInteger("1") should be(None)
      CommandLineValidators.validInteger("-1") should be(None)
      CommandLineValidators.validInteger("0") should be(None)
      CommandLineValidators.validInteger("asdf") should be('defined)
    }

    "validate positive integers" in {
      CommandLineValidators.validPositiveInteger("1") should be(None)
      CommandLineValidators.validPositiveInteger("-1") should be('defined)
      CommandLineValidators.validPositiveInteger("0") should be('defined)
      CommandLineValidators.validPositiveInteger("asdf") should be('defined)
    }

    "validate non-negative integers" in {
      CommandLineValidators.validNonNegativeInteger("1") should be(None)
      CommandLineValidators.validNonNegativeInteger("-1") should be('defined)
      CommandLineValidators.validNonNegativeInteger("0") should be(None)
      CommandLineValidators.validNonNegativeInteger("asdf") should be('defined)
    }
  }

} 
Example 123
Source File: RichCommandLineTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.cmdline

import com.sumologic.shellbase.CommonWordSpec
import com.sumologic.shellbase.cmdline.RichCommandLine._
import org.apache.commons.cli.Options
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class RichCommandLineTest extends CommonWordSpec {

  "Command line options" should {
    "not accept options with the same short name" in {
      val options = new Options
      options += new CommandLineOption("s", "one", true, "same shit")
      the[IllegalArgumentException] thrownBy {
        options += new CommandLineOption("s", "two", true, "different day")
      }
    }

    "not accept options with the same long name" in {
      val options = new Options
      options += new CommandLineOption("x", "same", true, "same shit")
      the[IllegalArgumentException] thrownBy {
        options += new CommandLineOption("y", "same", true, "different day")
      }
    }
  }

  "RichCommandLine.get" should {
    "return a default value if no value was provided on the command line" in {

      val defaultValue = "blargh"
      val sut = new CommandLineOption("s", "test", true, "halp", Some(defaultValue))

      val options = new Options
      options += sut

      val cmdLine = Array[String]().parseCommandLine(options)
      cmdLine.get.get(sut).get should equal(defaultValue)
    }

    "return a provided value if one was provided despite a default" in {

      val defaultValue = "blargh"
      val sut = new CommandLineOption("s", "test", true, "halp", Some(defaultValue))

      val options = new Options
      options += sut

      val providedValue = "wtf?"
      val cmdLine = Array[String]("-s", providedValue).parseCommandLine(options)
      cmdLine.get.get(sut).get should equal(providedValue)
    }
  }

  "RichCommandLine.apply" should {
    "return a provided value" in {
      val sut = new CommandLineOption("a", "animal", false, "halp")

      val options = new Options
      options += sut

      val providedValue = "wombat"
      val cmdLine = Array[String]("-a", providedValue).parseCommandLine(options).get
      cmdLine(sut) should equal(providedValue)
    }

    "throw a NoSuchElementException for a missing command line parameter" in {
      val sut = new CommandLineOption("a", "animal", false, "halp")
      val anotherOption = new CommandLineOption("ml", "my-love", true, "here I am!")

      val options = new Options
      options += sut
      options += anotherOption

      val cmdLine = Array[String]("--my-love", "wombat").parseCommandLine(options).get
      a[NoSuchElementException] should be thrownBy {
        cmdLine(sut)
      }
    }
  }

} 
Example 124
Source File: RichScalaOptionTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.cmdline

import com.sumologic.shellbase.{CommonWordSpec, ExitShellCommandException}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class RichScalaOptionTest extends CommonWordSpec {
  "RichScalaOption" should {
    "get the value when defined" in {
      new RichScalaOption(Some("hi")).
        getOrExitWithMessage("test message") should be("hi")
    }

    "throw a ExitShellCommandException when the value isn't defined" in {
      intercept[ExitShellCommandException] {
        new RichScalaOption[String](None).getOrExitWithMessage("test message")
      }
    }
  }
} 
Example 125
Source File: ShellHighlightsTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ShellHighlightsTest extends CommonWordSpec {
  "ShellHighlights" should {
    "not error for each color and contain original text" in {
      ShellHighlights.black("test") should include("test")
      ShellHighlights.blue("test") should include("test")
      ShellHighlights.cyan("test") should include("test")
      ShellHighlights.green("test") should include("test")
      ShellHighlights.magenta("test") should include("test")
      ShellHighlights.red("test") should include("test")
      ShellHighlights.white("test") should include("test")
      ShellHighlights.yellow("test") should include("test")
    }
  }
} 
Example 126
Source File: ShellFormattingTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ShellFormattingTest extends CommonWordSpec {
  "ShellFormatting" should {
    "not error for each format and contain original text" in {
      ShellFormatting.blink("test") should include("test")
      ShellFormatting.bold("test") should include("test")
      ShellFormatting.invisible("test") should include("test")
      ShellFormatting.reversed("test") should include("test")
      ShellFormatting.underlined("test") should include("test")
    }
  }
} 
Example 127
Source File: CouchbaseDataFrameSpec.scala    From couchbase-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.couchbase.spark.sql

import com.couchbase.spark.connection.CouchbaseConnection
import org.apache.avro.generic.GenericData.StringType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode, SparkSession}
import org.apache.spark.sql.sources.EqualTo
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.{SparkConf, SparkContext}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class CouchbaseDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfterAll {

  private val master = "local[2]"
  private val appName = "cb-int-specs1"

  private var spark: SparkSession = null


  override def beforeAll(): Unit = {
    val conf = new SparkConf()
      .setMaster(master)
      .setAppName(appName)
      .set("spark.couchbase.nodes", "127.0.0.1")
      .set("com.couchbase.username", "Administrator")
      .set("com.couchbase.password", "password")
      .set("com.couchbase.bucket.default", "")
      .set("com.couchbase.bucket.travel-sample", "")
    spark = SparkSession.builder().config(conf).getOrCreate()

    loadData()
  }

  override def afterAll(): Unit = {
    CouchbaseConnection().stop()
    spark.stop()
  }

  def loadData(): Unit = {

  }

  "If two buckets are used and the bucket is specified the API" should
    "not fail" in {
    val ssc = spark.sqlContext
    ssc.read.couchbase(EqualTo("type", "airline"), Map("bucket" -> "travel-sample"))
  }

  "The DataFrame API" should "infer the schemas" in {
    val ssc = spark.sqlContext
    import com.couchbase.spark.sql._

    val airline = ssc.read.couchbase(EqualTo("type", "airline"), Map("bucket" -> "travel-sample"))
    val airport = ssc.read.couchbase(EqualTo("type", "airport"), Map("bucket" -> "travel-sample"))
    val route = ssc.read.couchbase(EqualTo("type", "route"), Map("bucket" -> "travel-sample"))
    val landmark = ssc.read.couchbase(EqualTo("type", "landmark"), Map("bucket" -> "travel-sample"))


    airline
      .limit(10)
      .write
      .mode(SaveMode.Overwrite)
      .couchbase(Map("bucket" -> "default"))

    // TODO: validate schemas which are inferred on a field and type basis

  }

  it should "write and ignore" in {
    val ssc = spark.sqlContext
    import com.couchbase.spark.sql._

    // create df, write it twice
    val data = ("Michael", 28, true)
    val df = ssc.createDataFrame(spark.sparkContext.parallelize(Seq(data)))

    df.write
      .mode(SaveMode.Ignore)
      .couchbase(options = Map("idField" -> "_1", "bucket" -> "default"))

    df.write
      .mode(SaveMode.Ignore)
      .couchbase(options = Map("idField" -> "_1", "bucket" -> "default"))
  }

  it should "filter based on a function" in {
    val ssc = spark.sqlContext
    import com.couchbase.spark.sql._

    val airlineBySubstrCountry: DataFrame = ssc.read.couchbase(
      EqualTo("'substr(country, 0, 6)'", "United"), Map("bucket" -> "travel-sample"))

    airlineBySubstrCountry.count() should equal(6797)
  }

} 
Example 128
Source File: BotPluginTestKit.scala    From sumobot   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.sumobot.test.annotated

import akka.actor.ActorSystem
import akka.testkit.{TestKit, TestProbe}
import com.sumologic.sumobot.core.model.{IncomingMessage, InstantMessageChannel, OutgoingMessage, UserSender}
import org.junit.runner.RunWith
import org.scalatest.concurrent.Eventually
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike}
import slack.models.User

import scala.concurrent.duration.{FiniteDuration, _}

@RunWith(classOf[JUnitRunner])
abstract class BotPluginTestKit(actorSystem: ActorSystem)
  extends TestKit(actorSystem)
    with WordSpecLike with Eventually with Matchers
    with BeforeAndAfterAll {

  protected val outgoingMessageProbe = TestProbe()
  system.eventStream.subscribe(outgoingMessageProbe.ref, classOf[OutgoingMessage])

  protected def confirmOutgoingMessage(test: OutgoingMessage => Unit, timeout: FiniteDuration = 1.second): Unit = {
    outgoingMessageProbe.expectMsgClass(timeout, classOf[OutgoingMessage]) match {
      case msg: OutgoingMessage =>
        test(msg)
    }
  }

  protected def instantMessage(text: String, user: User = mockUser("123", "jshmoe")): IncomingMessage = {
    IncomingMessage(text, true, InstantMessageChannel("125", user), "1527239216000090", sentBy = UserSender(user))
  }

  protected def mockUser(id: String, name: String): User = {
    User(id, name, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
  }

  protected def send(message: IncomingMessage): Unit = {
    system.eventStream.publish(message)
  }

  override protected def afterAll(): Unit = {
    TestKit.shutdownActorSystem(system)
  }
} 
Example 129
Source File: PathUtilSpec.scala    From sonar-scala   with GNU Lesser General Public License v3.0 5 votes vote down vote up
package com.buransky.plugins.scoverage.util

import org.scalatest.{FlatSpec, Matchers}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class PathUtilSpec extends FlatSpec with Matchers {
  
  val osName = System.getProperty("os.name")
  val separator = System.getProperty("file.separator")
  
  behavior of s"splitPath for $osName"
  
  it should "ignore the empty path" in {
    PathUtil.splitPath("") should equal(List.empty[String])
  }

  it should "ignore a separator at the beginning" in {
    PathUtil.splitPath(s"${separator}a") should equal(List("a"))
  }

  it should "work with separator in the middle" in {
    PathUtil.splitPath(s"a${separator}b") should equal(List("a", "b"))
  }
  
  it should "work with an OS dependent absolute path" in {
    if (osName.startsWith("Windows")) {
      PathUtil.splitPath("C:\\test\\2") should equal(List("test", "2"))
    } else {
      PathUtil.splitPath("/test/2") should equal(List("test", "2"))
    }
  }
} 
Example 130
Source File: XmlScoverageReportParserSpec.scala    From sonar-scala   with GNU Lesser General Public License v3.0 5 votes vote down vote up
package com.buransky.plugins.scoverage.xml

import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
import com.buransky.plugins.scoverage.ScoverageException

@RunWith(classOf[JUnitRunner])
class XmlScoverageReportParserSpec extends FlatSpec with Matchers {
  behavior of "parse file path"

  it must "fail for null path" in {
    the[IllegalArgumentException] thrownBy XmlScoverageReportParser().parse(null.asInstanceOf[String], null)
  }

  it must "fail for empty path" in {
    the[IllegalArgumentException] thrownBy XmlScoverageReportParser().parse("", null)
  }

  it must "fail for not existing path" in {
    the[ScoverageException] thrownBy XmlScoverageReportParser().parse("/x/a/b/c/1/2/3/4.xml", null)
  }
} 
Example 131
Source File: RuntimeStatisticsSpec.scala    From coral   with Apache License 2.0 5 votes vote down vote up
package io.coral.api

import io.coral.TestHelper
import org.junit.runner.RunWith
import org.scalatest.WordSpecLike
import org.scalatest.junit.JUnitRunner
import org.json4s._
import org.json4s.jackson.JsonMethods._

@RunWith(classOf[JUnitRunner])
class RuntimeStatisticsSpec
	extends WordSpecLike {
	"A RuntimeStatistics class" should {
		"Properly sum multiple statistics objects together" in {
			val counters1 = Map(
				(("actor1", "stat1") -> 100L),
				(("actor1", "stat2") -> 20L),
				(("actor1", "stat3") -> 15L))
			val counters2 = Map(
				(("actor2", "stat1") -> 20L),
				(("actor2", "stat2") -> 30L),
				(("actor2", "stat3") -> 40L))
			val counters3 = Map(
				(("actor2", "stat1") -> 20L),
				(("actor2", "stat2") -> 30L),
				(("actor2", "stat3") -> 40L),
				(("actor2", "stat4") -> 12L))
			val stats1 = RuntimeStatistics(1, 2, 3, counters1)
			val stats2 = RuntimeStatistics(2, 3, 4, counters2)
			val stats3 = RuntimeStatistics(4, 5, 6, counters3)

			val actual = RuntimeStatistics.merge(List(stats1, stats2, stats3))

			val expected = RuntimeStatistics(7, 10, 13,
				Map(("actor1", "stat1") -> 100,
					("actor1", "stat2") -> 20,
					("actor1", "stat3") -> 15,
					("actor2", "stat1") -> 20,
					("actor2", "stat2") -> 30,
					("actor2", "stat3") -> 40,
					("actor2", "stat4") -> 12))

			assert(actual == expected)
		}

		"Create a JSON object from a RuntimeStatistics object" in {
			val input = RuntimeStatistics(1, 2, 3,
				Map((("actor1", "stat1") -> 10L),
					(("actor1", "stat2") -> 20L)))

			val expected = parse(
				s"""{
				   |  "totalActors": 1,
				   |  "totalMessages": 2,
				   |  "totalExceptions": 3,
				   |  "counters": {
				   |    "total": {
				   |      "stat1": 10,
				   |      "stat2": 20
				   |    }, "actor1": {
				   |      "stat1": 10,
				   |      "stat2": 20
				   |    }
				   |  }
				   |}
				 """.stripMargin).asInstanceOf[JObject]

			val actual = RuntimeStatistics.toJson(input)
			assert(actual == expected)
		}

		"Create a RuntimeStatistics object from a JSON object" in {
			val input = parse(
				s"""{
				   |  "totalActors": 1,
				   |  "totalMessages": 2,
				   |  "totalExceptions": 3,
				   |  "counters": {
				   |    "total": {
				   |      "stat1": 10,
				   |      "stat2": 20
				   |    }, "actor1": {
				   |      "stat1": 10,
				   |      "stat2": 20
				   |    }
				   |  }
				   |}
				 """.stripMargin).asInstanceOf[JObject]

			val actual = RuntimeStatistics.fromJson(input)

			val expected = RuntimeStatistics(1, 2, 3,
				Map((("actor1", "stat1") -> 10L),
					(("actor1", "stat2") -> 20L)))

			assert(actual == expected)
		}
	}
} 
Example 132
Source File: BootConfigSpec.scala    From coral   with Apache License 2.0 5 votes vote down vote up
package io.coral.api

import org.junit.runner.RunWith
import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, WordSpecLike}
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class BootConfigSpec
	extends WordSpecLike
	with BeforeAndAfterAll
	with BeforeAndAfterEach {
	"A Boot program actor" should {
		"Properly process given command line arguments for api and akka ports" in {
			val commandLine = CommandLineConfig(apiPort = Some(1234), akkaPort = Some(5345))
			val actual: CoralConfig = io.coral.api.Boot.getFinalConfig(commandLine)
			assert(actual.akka.remote.nettyTcpPort == 5345)
			assert(actual.coral.api.port == 1234)
		}

		"Properly process a given configuration file through the command line" in {
			val configPath = getClass().getResource("bootconfigspec.conf").getFile()
			val commandLine = CommandLineConfig(config = Some(configPath), apiPort = Some(4321))
			val actual: CoralConfig = io.coral.api.Boot.getFinalConfig(commandLine)
			// Overriden in bootconfigspec.conf
			assert(actual.akka.remote.nettyTcpPort == 6347)
			// Overridden in command line parameter
			assert(actual.coral.api.port == 4321)
			// Not overriden in command line or bootconfigspec.conf
			assert(actual.coral.cassandra.port == 9042)
		}
	}
} 
Example 133
Source File: ThresholdActorSpec.scala    From coral   with Apache License 2.0 5 votes vote down vote up
package io.coral.actors.transform

import io.coral.actors.CoralActorFactory
import io.coral.api.DefaultModule
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import scala.concurrent.duration._
import akka.actor.ActorSystem
import akka.testkit._
import akka.util.Timeout
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike}

@RunWith(classOf[JUnitRunner])
class ThresholdActorSpec(_system: ActorSystem) extends TestKit(_system)
	with ImplicitSender
	with WordSpecLike
	with Matchers
	with BeforeAndAfterAll {
	implicit val timeout = Timeout(100.millis)
	def this() = this(ActorSystem("ThresholdActorSpec"))

	override def afterAll() {
		TestKit.shutdownActorSystem(system)
	}

	"A ThresholdActor" must {
		val createJson = parse(
			"""{ "type": "threshold", "params": { "key": "key1", "threshold": 10.5 }}"""
				.stripMargin).asInstanceOf[JObject]

		implicit val injector = new DefaultModule(system.settings.config)

		// test invalid definition json as well !!!
		val props = CoralActorFactory.getProps(createJson).get
		val threshold = TestActorRef[ThresholdActor](props)

		// subscribe the testprobe for emitting
		val probe = TestProbe()
		threshold.underlyingActor.emitTargets += probe.ref

		"Emit when equal to the threshold" in {
			val json = parse( """{"key1": 10.5}""").asInstanceOf[JObject]
			threshold ! json
			probe.expectMsg(parse( """{ "key1": 10.5 }"""))
		}

		"Emit when higher than the threshold" in {
			val json = parse( """{"key1": 10.7}""").asInstanceOf[JObject]
			threshold ! json
			probe.expectMsg(parse( """{"key1": 10.7 }"""))
		}

		"Not emit when lower than the threshold" in {
			val json = parse( """{"key1": 10.4 }""").asInstanceOf[JObject]
			threshold ! json
			probe.expectNoMsg()
		}

		"Not emit when key is not present in triggering json" in {
			val json = parse( """{"key2": 10.7 }""").asInstanceOf[JObject]
			threshold ! json
			probe.expectNoMsg()
		}
	}
} 
Example 134
Source File: MinMaxActorSpec.scala    From coral   with Apache License 2.0 5 votes vote down vote up
package io.coral.actors.transform

import akka.actor.{Actor, ActorSystem, Props}
import akka.testkit.{TestProbe, ImplicitSender, TestActorRef, TestKit}
import akka.util.Timeout
import io.coral.actors.CoralActorFactory
import io.coral.api.DefaultModule
import org.json4s.JsonDSL._
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike}
import scala.concurrent.duration._

@RunWith(classOf[JUnitRunner])
class MinMaxActorSpec(_system: ActorSystem)
	extends TestKit(_system)
	with ImplicitSender
	with WordSpecLike
	with Matchers
	with BeforeAndAfterAll {
	implicit val timeout = Timeout(100.millis)
	implicit val formats = org.json4s.DefaultFormats
	implicit val injector = new DefaultModule(system.settings.config)
	def this() = this(ActorSystem("ZscoreActorSpec"))

	override def afterAll() {
		TestKit.shutdownActorSystem(system)
	}

	"A MinMaxActor" must {
		val createJson = parse(
			"""{ "type": "minmax", "params": { "field": "field1", "min": 10.0, "max": 13.5 }}"""
				.stripMargin).asInstanceOf[JObject]

		implicit val injector = new DefaultModule(system.settings.config)

		val props = CoralActorFactory.getProps(createJson).get
		val threshold = TestActorRef[MinMaxActor](props)

		// subscribe the testprobe for emitting
		val probe = TestProbe()
		threshold.underlyingActor.emitTargets += probe.ref

		"Emit the minimum when lower than the min" in {
			val json = parse( """{"field1": 7 }""").asInstanceOf[JObject]
			threshold ! json
			probe.expectMsg(parse( """{ "field1": 10.0 }"""))
		}

		"Emit the maximum when higher than the max" in {
			val json = parse( """{"field1": 15.3 }""").asInstanceOf[JObject]
			threshold ! json
			probe.expectMsg(parse( """{"field1": 13.5 }"""))
		}

		"Emit the value itself when between the min and the max" in {
			val json = parse( """{"field1": 11.7 }""").asInstanceOf[JObject]
			threshold ! json
			probe.expectMsg(parse( """{"field1": 11.7 }"""))
		}

		"Emit object unchanged when key is not present in triggering json" in {
			val json = parse( """{"otherfield": 15.3 }""").asInstanceOf[JObject]
			threshold ! json
			probe.expectMsg(parse( """{"otherfield": 15.3 }"""))
		}
	}
} 
Example 135
Source File: PluginTest.scala    From marathon-vault-plugin   with MIT License 5 votes vote down vote up
package com.avast.marathon.plugin.vault

import java.util.concurrent.TimeUnit

import com.bettercloud.vault.{Vault, VaultConfig}
import org.junit.runner.RunWith
import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.junit.JUnitRunner

import scala.collection.JavaConverters._
import scala.concurrent.Await
import scala.concurrent.duration.Duration

@RunWith(classOf[JUnitRunner])
class PluginTest extends FlatSpec with Matchers {

  private lazy val marathonUrl = s"http://${System.getProperty("marathon.host")}:${System.getProperty("marathon.tcp.8080")}"
  private lazy val mesosSlaveUrl = s"http://${System.getProperty("mesos-slave.host")}:${System.getProperty("mesos-slave.tcp.5051")}"
  private lazy val vaultUrl = s"http://${System.getProperty("vault.host")}:${System.getProperty("vault.tcp.8200")}"

  it should "read existing shared secret" in {
    check("SECRETVAR", env => deployWithSecret("testappjson", env, "/test@testKey")) { envVarValue =>
      envVarValue shouldBe "testValue"
    }
  }

  it should "read existing private secret" in {
    check("SECRETVAR", env => deployWithSecret("testappjson", env, "test@testKey")) { envVarValue =>
      envVarValue shouldBe "privateTestValue"
    }
  }

  it should "read existing private secret from application in folder" in {
    check("SECRETVAR", env => deployWithSecret("folder/testappjson", env, "test@testKey")) { envVarValue =>
      envVarValue shouldBe "privateTestFolderValue"
    }
  }

  it should "fail when using .. in secret" in {
    intercept[RuntimeException] {
      check("SECRETVAR", env => deployWithSecret("folder/testappjson", env, "test/../test@testKey"), java.time.Duration.ofSeconds(1)) { envVarValue =>
        envVarValue shouldNot be("privateTestFolderValue")
      }
    }
  }

  private def deployWithSecret(appId: String, envVarName: String, secret: String): String = {
    val json = s"""{ "id": "$appId","cmd": "${EnvAppCmd.create(envVarName)}","env": {"$envVarName": {"secret": "pwd"}},"secrets": {"pwd": {"source": "$secret"}}}"""

    val marathonResponse = new MarathonClient(marathonUrl).put(appId, json)
    appId
  }

  private def check(envVarName: String, deployApp: String => String, timeout: java.time.Duration = java.time.Duration.ofSeconds(30))(verifier: String => Unit): Unit = {
    val client = new MarathonClient(marathonUrl)
    val eventStream = new MarathonEventStream(marathonUrl)

    val vaultConfig = new VaultConfig().address(vaultUrl).token("testroottoken").build()
    val vault = new Vault(vaultConfig)
    vault.logical().write("secret/shared/test", Map[String, AnyRef]("testKey" -> "testValue").asJava)
    vault.logical().write("secret/private/testappjson/test", Map[String, AnyRef]("testKey" -> "privateTestValue").asJava)
    vault.logical().write("secret/private/folder/testappjson/test", Map[String, AnyRef]("testKey" -> "privateTestFolderValue").asJava)

    val appId = deployApp(envVarName)
    val appCreatedFuture = eventStream.when(_.eventType.contains("deployment_success"))
    Await.result(appCreatedFuture, Duration.create(20, TimeUnit.SECONDS))

    val agentClient = MesosAgentClient(mesosSlaveUrl)
    val state = agentClient.fetchState()

    try {
      val envVarValue = agentClient.waitForStdOutContentsMatch(envVarName, state.frameworks(0).executors(0),
        o => EnvAppCmd.extractEnvValue(envVarName, o),
        timeout)
      verifier(envVarValue)
    } finally {
      client.delete(appId)
      val appRemovedFuture = eventStream.when(_.eventType.contains("deployment_success"))
      Await.result(appRemovedFuture, Duration.create(20, TimeUnit.SECONDS))
      eventStream.close()
    }
  }
} 
Example 136
Source File: CronScheduleSpec.scala    From sundial   with MIT License 5 votes vote down vote up
package model

import java.text.ParseException
import java.util.GregorianCalendar

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatestplus.play.PlaySpec

@RunWith(classOf[JUnitRunner])
class CronScheduleSpec extends PlaySpec {

  "Cron scheduler" should {

    "successfully parse cron entry for 10pm every day" in {
      val cronSchedule = CronSchedule("0", "22", "*", "*", "?")
      val date = new GregorianCalendar(2015, 10, 5, 21, 0).getTime
      val expectedNextDate = new GregorianCalendar(2015, 10, 5, 22, 0).getTime
      val nextDate = cronSchedule.nextRunAfter(date)
      nextDate must be(expectedNextDate)
    }

    "Throw exception on creation if cron schedlue is invalid" in {
      intercept[ParseException] {
        CronSchedule("0", "22", "*", "*", "*")
      }
    }
  }

} 
Example 137
Source File: PostgresJsonMarshallerTest.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres.marshalling

import com.fasterxml.jackson.databind.PropertyNamingStrategy.SNAKE_CASE
import com.hbc.svc.sundial.v2.models.NotificationOptions
import model.{EmailNotification, PagerdutyNotification, Team}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatestplus.play.PlaySpec
import util.Json

@RunWith(classOf[JUnitRunner])
class PostgresJsonMarshallerTest extends PlaySpec {

  private val postgresJsonMarshaller = new PostgresJsonMarshaller()

  private val objectMapper = Json.mapper()
  objectMapper.setPropertyNamingStrategy(SNAKE_CASE)
  //  objectMapper.setVisibility(PropertyAccessor.FIELD,Visibility.ANY)

  "PostgresJsonMarshaller" should {

    "correctly deserialize a json string into Seq[Team]" in {
      val json =
        """
          | [{
          |   "name" : "teamName",
          |   "email" : "teamEmail",
          |   "notify_action": "on_state_change_and_failures"
          | }]
        """.stripMargin

      val expectedTeams: Seq[Team] =
        Vector(Team("teamName", "teamEmail", "on_state_change_and_failures"))
      val actualTeams = postgresJsonMarshaller.toTeams(json)
      actualTeams must be(expectedTeams)
    }

    "correctly serialise a Seq[Team] in a json string" in {
      val expectedJson =
        """
          | [{
          |   "name" : "teamName",
          |   "email" : "teamEmail",
          |   "notify_action": "on_state_change_and_failures"
          | }]
        """.stripMargin
      val expectedTeams: Seq[Team] =
        Vector(Team("teamName", "teamEmail", "on_state_change_and_failures"))
      val actualJson = postgresJsonMarshaller.toJson(expectedTeams)
      objectMapper.readTree(actualJson) must be(
        objectMapper.readTree(expectedJson))
    }

    "correctly deserialize a json string into Seq[Notification]" in {

      val json =
        """
          |[{"name":"name","email":"email","notify_action":"on_state_change_and_failures", "type": "email"},{"service_key":"service-key","api_url":"http://google.com", "type": "pagerduty","num_consecutive_failures":1}]
        """.stripMargin

      val notifications = Vector(
        EmailNotification(
          "name",
          "email",
          NotificationOptions.OnStateChangeAndFailures.toString),
        PagerdutyNotification("service-key", "http://google.com", 1)
      )

      val actualNotifications = postgresJsonMarshaller.toNotifications(json)

      actualNotifications must be(notifications)

    }

    "correctly serialise a Seq[Notification] in a json string" in {

      val json =
        """
          |[{"name":"name","email":"email","notify_action":"on_state_change_and_failures", "type": "email"},{"service_key":"service-key","api_url":"http://google.com", "type": "pagerduty","num_consecutive_failures":1}]
        """.stripMargin

      val notifications = Vector(
        EmailNotification(
          "name",
          "email",
          NotificationOptions.OnStateChangeAndFailures.toString),
        PagerdutyNotification("service-key", "http://google.com", 1)
      )

      println(s"bla1: ${postgresJsonMarshaller.toJson(notifications)}")
      println(s"bla2: ${objectMapper.writeValueAsString(notifications)}")

      objectMapper.readTree(json) must be(
        objectMapper.readTree(postgresJsonMarshaller.toJson(notifications)))

    }

  }

} 
Example 138
Source File: HTMLTableTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.table

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class HTMLTableTest extends CommonWordSpec {
  "HTMLTable" should {
    "create a table without triggering an exception" in {
      val table = new HTMLTable[(Int, Int, String)]()

      table.addColumn("Test 1", 6, _._1.toString)
      table.addColumn("Test 2", 6, _._2.toString)
      table.addColumn("Test 3", 7, _._3, rightAligned = true)

      val data = Seq[(Int, Int, String)](
        (1, 2, "hi"),
        (10, 20, "dude")
      )

      table.renderLines(data).foreach(println)

      table.bare = false

      table.renderLines(data).foreach(println)
    }
  }

} 
Example 139
Source File: SubdocMutationAccessorSpec.scala    From couchbase-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.couchbase.spark.connection

import com.couchbase.client.java.document.json.JsonObject
import com.couchbase.client.java.document.{Document, JsonDocument}
import org.apache.spark.SparkConf
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class SubdocMutationAccessorSpec extends FlatSpec with Matchers {

  "A SubdocMutationAccessor" should "upsert a path into a doc" in {
    val sparkCfg = new SparkConf()
    sparkCfg.set("com.couchbase.username", "Administrator")
    sparkCfg.set("com.couchbase.password", "password")
    val cfg = CouchbaseConfig(sparkCfg)

    val bucket = CouchbaseConnection().bucket(cfg, "default")

    bucket.upsert(JsonDocument.create("doc", JsonObject.create()))
    bucket.upsert(JsonDocument.create("doc2", JsonObject.create()))

    val accessor = new SubdocMutationAccessor(cfg, Seq(
      SubdocUpsert("doc", "element", "value"),
      SubdocUpsert("doc2", "_", 5678),
      SubdocUpsert("doc2", "element2", 1234)
    ), null, None)

    accessor.compute().foreach(println)
  }

} 
Example 140
Source File: LazyIteratorSpec.scala    From couchbase-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.couchbase.spark.internal

import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class LazyIteratorSpec extends FlatSpec with Matchers {

  "A LazyIterator" should "not create the delegated Iterator in the constructor" in {
    var created = false
    val iter = LazyIterator {
      created = true
      Iterator(1, 2, 3)
    }

    created should equal (false)
    iter.toList should equal (1 :: 2 :: 3 :: Nil)
    created should equal (true)
  }

} 
Example 141
Source File: RDDFunctionsSpec.scala    From couchbase-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.couchbase.spark.integration

import com.couchbase.client.java.Bucket
import com.couchbase.client.java.document.JsonDocument
import com.couchbase.client.java.document.json.JsonObject
import com.couchbase.client.java.query.N1qlQuery
import com.couchbase.client.java.query.core.N1qlQueryExecutor
import com.couchbase.client.java.view.{SpatialViewQuery, ViewQuery}
import com.couchbase.spark.connection.{CouchbaseConfig, CouchbaseConnection}
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}
import com.couchbase.spark._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class RDDFunctionsSpec extends FlatSpec with Matchers with BeforeAndAfterAll {
  private val master = "local[2]"
  private val appName = "cb-int-specs2"
  private val bucketName = "default"

  private var sparkContext: SparkContext = null
  private var bucket: Bucket = null

  override def beforeAll() {
    val conf = new SparkConf()
      .setMaster(master)
      .setAppName(appName)
      .set("spark.couchbase.nodes", "127.0.0.1")
      .set("spark.couchbase.username", "Administrator")
      .set("spark.couchbase.password", "password")
      .set("com.couchbase.bucket." + bucketName, "")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    sparkContext = spark.sparkContext
    bucket = CouchbaseConnection().bucket(CouchbaseConfig(conf), bucketName)
  }

  override def afterAll(): Unit = {
    CouchbaseConnection().stop()
    sparkContext.stop()
  }

  "A RDD" should "be created as a transformation" in {
    bucket.upsert(JsonDocument.create("doc1", JsonObject.create().put("val", "doc1")))
    bucket.upsert(JsonDocument.create("doc2", JsonObject.create().put("val", "doc2")))
    bucket.upsert(JsonDocument.create("doc3", JsonObject.create().put("val", "doc3")))


    val result = sparkContext
      .parallelize(Seq("doc1", "doc2", "doc3"))
      .couchbaseGet[JsonDocument]()
      .collect()

    result should have size 3
    result.foreach { doc =>
      doc.content().getString("val") should equal (doc.id())
    }
  }

} 
Example 142
Source File: DeploymentValueResolverSpec.scala    From vamp   with Apache License 2.0 5 votes vote down vote up
package io.vamp.model.resolver

import io.vamp.common.{ Namespace, NamespaceProvider }
import io.vamp.model.artifact._
import io.vamp.model.notification.ModelNotificationProvider
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class DeploymentValueResolverSpec extends FlatSpec with Matchers with DeploymentValueResolver with ModelNotificationProvider with NamespaceProvider {

  implicit val namespace: Namespace = Namespace("default")

  "DeploymentTraitResolver" should "pass through environment variables for an empty cluster list" in {

    val environmentVariables = EnvironmentVariable("backend1.environment_variables.port", None, Some("5555")) ::
      EnvironmentVariable("backend1.environment_variables.timeout", None, Some(s"$$backend1.host")) :: Nil

    val filter = environmentVariables.map(_.name).toSet

    resolveEnvironmentVariables(deployment(environmentVariables), Nil).filter(ev ⇒ filter.contains(ev.name)) should equal(
      environmentVariables
    )
  }

  it should "pass through environment variables for non relevant clusters" in {

    val environmentVariables = EnvironmentVariable("backend1.environment_variables.port", None, Some("5555")) ::
      EnvironmentVariable("backend1.environment_variables.timeout", None, Some(s"$$backend1.host")) :: Nil

    val filter = environmentVariables.map(_.name).toSet

    resolveEnvironmentVariables(deployment(environmentVariables), DeploymentCluster("backend", Map(), Nil, Nil, None, None, None) :: Nil).filter(ev ⇒ filter.contains(ev.name)) should equal(
      environmentVariables
    )
  }

  it should "interpolate simple reference" in {

    val environmentVariables = EnvironmentVariable("backend.environment_variables.port", None, Some(s"$$frontend.constants.const1")) ::
      EnvironmentVariable("backend.environment_variables.timeout", None, Some(s"$${backend1.constants.const2}")) ::
      EnvironmentVariable("backend1.environment_variables.timeout", None, Some(s"$${frontend.constants.const1}")) :: Nil

    val filter = environmentVariables.map(_.name).toSet

    resolveEnvironmentVariables(deployment(environmentVariables), DeploymentCluster("backend", Map(), Nil, Nil, None, None, None) :: Nil).filter(ev ⇒ filter.contains(ev.name)) should equal(
      EnvironmentVariable("backend.environment_variables.port", None, Some(s"$$frontend.constants.const1"), interpolated = Some("9050")) ::
        EnvironmentVariable("backend.environment_variables.timeout", None, Some(s"$${backend1.constants.const2}"), interpolated = Some(s"$$backend1.host")) ::
        EnvironmentVariable("backend1.environment_variables.timeout", None, Some(s"$${frontend.constants.const1}"), interpolated = None) :: Nil
    )
  }

  it should "interpolate complex value" in {

    val environmentVariables = EnvironmentVariable("backend.environment_variables.url", None, Some("http://$backend1.host:$frontend.constants.const1/api/$$/$backend1.environment_variables.timeout")) ::
      EnvironmentVariable("backend1.environment_variables.timeout", None, Some("4000")) :: Nil

    val filter = environmentVariables.map(_.name).toSet

    resolveEnvironmentVariables(deployment(environmentVariables), DeploymentCluster("backend", Map(), Nil, Nil, None, None, None) :: Nil).filter(ev ⇒ filter.contains(ev.name)) should equal(
      EnvironmentVariable("backend.environment_variables.url", None, Some("http://$backend1.host:$frontend.constants.const1/api/$$/$backend1.environment_variables.timeout"), interpolated = Some("http://vamp.io:9050/api/$/4000")) ::
        EnvironmentVariable("backend1.environment_variables.timeout", None, Some("4000"), interpolated = None) :: Nil
    )
  }

  def deployment(environmentVariables: List[EnvironmentVariable]) = {
    val clusters = DeploymentCluster("backend1", Map(), Nil, Nil, None, None, None) :: DeploymentCluster("backend2", Map(), Nil, Nil, None, None, None) :: Nil
    val addition = EnvironmentVariable("frontend.constants.const1", None, Some("9050")) :: EnvironmentVariable("backend1.constants.const2", None, Some(s"$$backend1.host")) :: Nil
    val hosts = Host("backend1.hosts.host", Some("vamp.io")) :: Nil
    Deployment("", Map(), clusters, Nil, Nil, environmentVariables ++ addition, hosts)
  }
} 
Example 143
Source File: TimeScheduleSpec.scala    From vamp   with Apache License 2.0 5 votes vote down vote up
package io.vamp.model.workflow

import java.time.{ Duration, Period }

import io.vamp.model.artifact.TimeSchedule
import io.vamp.model.reader.ReaderSpec
import io.vamp.model.artifact.TimeSchedule.{ RepeatForever, RepeatPeriod }
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TimeScheduleSpec extends FlatSpec with Matchers with ReaderSpec {

  "TimeSchedule" should "read an empty period" in {
    TimeSchedule("") should have(
      'period(RepeatPeriod(None, None)),
      'repeat(RepeatForever),
      'start(None)
    )
  }

  it should "read days" in {
    TimeSchedule("P1Y2M3D") should have(
      'period(RepeatPeriod(Some(Period.parse("P1Y2M3D")), None)),
      'repeat(RepeatForever),
      'start(None)
    )
  }

  it should "read time" in {
    TimeSchedule("PT1H2M3S") should have(
      'period(RepeatPeriod(None, Some(Duration.parse("PT1H2M3S")))),
      'repeat(RepeatForever),
      'start(None)
    )
  }

  it should "read days and time" in {
    TimeSchedule("P1Y2M3DT1H2M3S") should have(
      'period(RepeatPeriod(Some(Period.parse("P1Y2M3D")), Some(Duration.parse("PT1H2M3S")))),
      'repeat(RepeatForever),
      'start(None)
    )
  }
} 
Example 144
Source File: ReaderSpec.scala    From vamp   with Apache License 2.0 5 votes vote down vote up
package io.vamp.model.reader

import io.vamp.model.notification.{ UnexpectedInnerElementError, UnexpectedTypeError, YamlParsingError }
import io.vamp.common.notification.NotificationErrorException
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{ FlatSpec, Matchers }

import scala.io.Source
import scala.reflect._
import YamlSourceReader._

trait ReaderSpec extends FlatSpec with Matchers {
  protected def res(path: String): String = Source.fromURL(getClass.getResource(path)).mkString

  protected def expectedError[A <: Any: ClassTag](f: ⇒ Any): A = {
    the[NotificationErrorException] thrownBy f match {
      case NotificationErrorException(error: A, _) ⇒ error
      case unexpected                              ⇒ throw new IllegalArgumentException(s"Expected ${classTag[A].runtimeClass}, actual ${unexpected.notification.getClass}", unexpected)
    }
  }
}

@RunWith(classOf[JUnitRunner])
class YamlReaderSpec extends ReaderSpec {

  "YamlReader" should "fail on invalid YAML" in {
    expectedError[YamlParsingError]({
      new YamlReader[Any] {
        override protected def parse(implicit source: YamlSourceReader): Any = None
      }.read(res("invalid1.yml"))
    }).message should startWith("Can't construct a resource for !ios")
  }

  it should "fail on invalid type" in {
    expectedError[UnexpectedTypeError]({
      new YamlReader[Any] {
        override protected def parse(implicit source: YamlSourceReader): Any = <<![Int]("integer")
      }.read(res("invalid2.yml"))
    }) should have(
      'path("integer"),
      'expected(classOf[Int]),
      'actual(classOf[String])
    )
  }

  it should "fail on unexpected inner element type" in {
    expectedError[UnexpectedInnerElementError]({
      new YamlReader[Any] {
        override protected def parse(implicit source: YamlSourceReader): Any = <<![String]("root" :: "nested" :: "next")
      }.read(res("invalid3.yml"))
    }) should have(
      'path("nested"),
      'found(classOf[String])
    )
  }
} 
Example 145
Source File: UnitValueSpec.scala    From vamp   with Apache License 2.0 5 votes vote down vote up
package io.vamp.model.reader

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{ FlatSpec, Matchers }

import scala.util.{ Failure, Success }

@RunWith(classOf[JUnitRunner])
class UnitValueSpec extends FlatSpec with Matchers {

  "Percentage" should "parse" in {

    UnitValue.of[Percentage]("0 %") shouldBe Success(Percentage(0))
    UnitValue.of[Percentage](" 50 %") shouldBe Success(Percentage(50))
    UnitValue.of[Percentage]("100%") shouldBe Success(Percentage(100))

    UnitValue.of[Percentage]("x") shouldBe a[Failure[_]]
    UnitValue.of[Percentage]("-50") shouldBe a[Failure[_]]
    UnitValue.of[Percentage]("1.5") shouldBe a[Failure[_]]
    UnitValue.of[Percentage]("100") shouldBe a[Failure[_]]
  }

  "MegaByte" should "parse" in {

    UnitValue.of[MegaByte]("128Ki ").map(_.normalized) shouldBe Success(MegaByte(128.0 * 1024 / 1000000).normalized)
    UnitValue.of[MegaByte]("128KB ").map(_.normalized) shouldBe Success(MegaByte(128.0 * 1000 / 1000000).normalized)

    UnitValue.of[MegaByte]("128mb") shouldBe Success(MegaByte(128))
    UnitValue.of[MegaByte](" 128mb ") shouldBe Success(MegaByte(128))
    UnitValue.of[MegaByte](" 128 mb ") shouldBe Success(MegaByte(128))
    UnitValue.of[MegaByte](" 128 Mi ") shouldBe Success(MegaByte(128 * 1.024))
    UnitValue.of[MegaByte](" 128 mi ") shouldBe Success(MegaByte(128 * 1.024))
    UnitValue.of[MegaByte](".1m") shouldBe Success(MegaByte(0.1))
    UnitValue.of[MegaByte]("10.1Mb") shouldBe Success(MegaByte(10.1))
    UnitValue.of[MegaByte]("64.MB") shouldBe Success(MegaByte(64))
    UnitValue.of[MegaByte](".1gb") shouldBe Success(MegaByte(100))
    UnitValue.of[MegaByte]("1GB") shouldBe Success(MegaByte(1000))
    UnitValue.of[MegaByte]("1.5G") shouldBe Success(MegaByte(1500))
    UnitValue.of[MegaByte](".1gB") shouldBe Success(MegaByte(100))

    UnitValue.of[MegaByte]("1") shouldBe a[Failure[_]]
    UnitValue.of[MegaByte]("-1") shouldBe a[Failure[_]]
    UnitValue.of[MegaByte]("1Tb") shouldBe a[Failure[_]]
    UnitValue.of[MegaByte](".") shouldBe a[Failure[_]]
  }

  "Quantity" should "parse" in {

    UnitValue.of[Quantity]("128") shouldBe Success(Quantity(128.0))
    UnitValue.of[Quantity]("-128.5") shouldBe Success(Quantity(-128.5))
    UnitValue.of[Quantity](" 1m ") shouldBe Success(Quantity(0.001))
    UnitValue.of[Quantity](" 0.1 ") shouldBe Success(Quantity(0.1))
    UnitValue.of[Quantity](".1") shouldBe Success(Quantity(0.1))
    UnitValue.of[Quantity]("-0.1 ") shouldBe Success(Quantity(-0.1))
    UnitValue.of[Quantity]("-.1 ") shouldBe Success(Quantity(-.1))
  }

  "Time" should "parse" in {
    // Test for second values
    UnitValue.of[Time]("1sec") shouldBe Success(Time(1))
    UnitValue.of[Time]("20s") shouldBe Success(Time(20))
    UnitValue.of[Time]("1second") shouldBe Success(Time(1))
    UnitValue.of[Time]("2seconds") shouldBe Success(Time(2))
    // Test for minute values
    UnitValue.of[Time]("20m") shouldBe Success(Time(20 * 60))
    UnitValue.of[Time]("1min") shouldBe Success(Time(60))
    UnitValue.of[Time]("1minute") shouldBe Success(Time(60))
    UnitValue.of[Time]("5minutes") shouldBe Success(Time(5 * 60))
    // Test for hourly values
    UnitValue.of[Time]("1h") shouldBe Success(Time(3600))
    UnitValue.of[Time]("2hrs") shouldBe Success(Time(2 * 3600))
    UnitValue.of[Time]("1hour") shouldBe Success(Time(3600))
    UnitValue.of[Time]("4hours") shouldBe Success(Time(4 * 3600))
  }
} 
Example 146
Source File: SlaReaderSpec.scala    From vamp   with Apache License 2.0 5 votes vote down vote up
package io.vamp.model.reader

import io.vamp.model.artifact._
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

import scala.concurrent.duration._
import scala.language.postfixOps

@RunWith(classOf[JUnitRunner])
class SlaReaderSpec extends FlatSpec with Matchers with ReaderSpec {

  "SlaReader" should "read the generic SLA" in {
    SlaReader.read(res("sla/sla1.yml")) should have(
      'name("red"),
      'type("response_time"),
      'parameters(Map("window" → Map("cooldown" → 600, "interval" → 600), "threshold" → Map("lower" → 100, "upper" → 1000))),
      'escalations(List(GenericEscalation("", Map(), "scale_nothing", Map("scale_by" → 1, "minimum" → 1, "maximum" → 4))))
    )
  }

  it should "read the response time sliding window SLA with generic escalations" in {
    SlaReader.read(res("sla/sla2.yml")) should have(
      'name("red"),
      'interval(600 seconds),
      'cooldown(600 seconds),
      'upper(1000 milliseconds),
      'lower(100 milliseconds),
      'escalations(List(GenericEscalation("", Map(), "scale_nothing", Map("scale_by" → 1, "minimum" → 1, "maximum" → 4))))
    )
  }

  it should "read the response time sliding window SLA with scale escalations" in {
    SlaReader.read(res("sla/sla3.yml")) should have(
      'name("red"),
      'interval(600 seconds),
      'cooldown(600 seconds),
      'upper(1000 milliseconds),
      'lower(100 milliseconds),
      'escalations(List(ToAllEscalation("", Map(), List(ScaleInstancesEscalation("", Map(), 1, 4, 1, None), ScaleCpuEscalation("", Map(), 1.0, 4.0, 1.0, None), ScaleMemoryEscalation("", Map(), 1024.0, 2048.5, 512.1, None)))))
    )
  }

  it should "read the SLA with a group escalation" in {
    SlaReader.read(res("sla/sla4.yml")) should have(
      'name("red"),
      'interval(600 seconds),
      'cooldown(600 seconds),
      'upper(1000 milliseconds),
      'lower(100 milliseconds),
      'escalations(List(ToAllEscalation("", Map(), List(EscalationReference("notify"), ToOneEscalation("", Map(), List(ScaleInstancesEscalation("", Map(), 1, 4, 1, None), ScaleCpuEscalation("", Map(), 1.0, 4.0, 1.0, None)))))))
    )
  }

  it should "read the SLA with a group escalation with expansion" in {
    SlaReader.read(res("sla/sla5.yml")) should have(
      'name("red"),
      'interval(600 seconds),
      'cooldown(600 seconds),
      'upper(1000 milliseconds),
      'lower(100 milliseconds),
      'escalations(List(ToAllEscalation("", Map(), List(EscalationReference("notify"), ToOneEscalation("", Map(), List(ScaleInstancesEscalation("", Map(), 1, 4, 1, None), ScaleCpuEscalation("", Map(), 1.0, 4.0, 1.0, None)))))))
    )
  }

  it should "read the SLA with a nested group escalation" in {
    SlaReader.read(res("sla/sla6.yml")) should have(
      'name("red"),
      'interval(600 seconds),
      'cooldown(600 seconds),
      'upper(1000 milliseconds),
      'lower(100 milliseconds),
      'escalations(List(ToAllEscalation("", Map(), List(EscalationReference("notify"), ToOneEscalation("", Map(), List(ScaleInstancesEscalation("", Map(), 1, 4, 1, None), ToAllEscalation("", Map(), List(ScaleCpuEscalation("", Map(), 1.0, 4.0, 1.0, None), EscalationReference("email")))))))))
    )
  }
} 
Example 147
Source File: ImportReaderSpec.scala    From vamp   with Apache License 2.0 5 votes vote down vote up
package io.vamp.model.reader

import io.vamp.model.notification.{ EmptyImportError, ImportDefinitionError }
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ImportReaderSpec extends ReaderSpec {

  "ImportReader" should "read the single import" in {
    ImportReader.read(res("template/template1.yml")) shouldBe
      Import(Map("name" → "template1"), List(ImportReference("templateA", "templates")))
  }

  it should "read the multiple imports" in {
    ImportReader.read(res("template/template2.yml")) shouldBe
      Import(Map("name" → "template2"), List(ImportReference("templateA", "templates"), ImportReference("templateB", "templates")))
  }

  it should "throw an error when reference is not a string" in {
    expectedError[ImportDefinitionError.type] {
      ImportReader.read(res("template/template3.yml"))
    }
  }

  it should "read the multiple imports with kind" in {
    ImportReader.read(res("template/template4.yml")) shouldBe
      Import(Map("name" → "template4"), List(ImportReference("templateA", "templates"), ImportReference("sava", "breeds")))
  }

  it should "throw an error when reference is empty" in {
    expectedError[EmptyImportError.type] {
      ImportReader.read(res("template/template5.yml"))
    }
  }

  it should "throw an error when reference is not correct (kind/name)" in {
    expectedError[ImportDefinitionError.type] {
      ImportReader.read(res("template/template6.yml"))
    }
  }
} 
Example 148
Source File: OpTransformerReaderWriterTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages

import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.UnaryLambdaTransformer
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpTransformerReaderWriterTest extends OpPipelineStageReaderWriterTest {

  override val hasOutputName = false

  val stage =
    new UnaryLambdaTransformer[Real, Real](
      operationName = "test",
      transformFn = new Lambdas.FncUnary,
      uid = "uid_1234"
    ).setInput(weight).setMetadata(meta)

  val expected = Array(21.2248.toReal, 8.2678.toReal, Real.empty, 9.6252.toReal, 11.8464.toReal, 8.2678.toReal)
} 
Example 149
Source File: InfiniteStreamTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class InfiniteStreamTest extends FlatSpec with TestCommon {

  Spec[InfiniteStream[_]] should "map" in {
    var i = 0
    val src = new InfiniteStream[Int] {
      override def next: Int = {
        i += 1;
        i
      }
    }

    val sut = src map (5 +)

    while (i < 10) sut.next shouldBe (i + 5)
  }

} 
Example 150
Source File: RandomRealTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import com.salesforce.op.features.types.{Currency, Percent, Real, RealNN}
import com.salesforce.op.test.TestCommon
import com.salesforce.op.testkit.RandomReal._
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class RandomRealTest extends FlatSpec with TestCommon {
  val numTries = 1000000

  
  // ignore should "cast to default data type" in {
  //  check(uniform(1.0, 2.0), probabilityOfEmpty = 0.5, range = (1.0, 2.0))
  // }

  Spec[RandomReal[Real]]  should "Give Normal distribution with mean 1 sigma 0.1, 10% nulls" in {
    val normalReals = normal[Real](1.0, 0.2)
    check(normalReals, probabilityOfEmpty = 0.1, range = (-2.0, 4.0))
  }

  Spec[RandomReal[Real]] should "Give Uniform distribution on 1..2, half nulls" in {
    check(uniform[Real](1.0, 2.0), probabilityOfEmpty = 0.5, range = (1.0, 2.0))
  }

  it should "Give Poisson distribution with mean 4, 20% nulls" in {
    check(poisson[Real](4.0), probabilityOfEmpty = 0.2, range = (0.0, 15.0))
  }

  it should "Give Exponential distribution with mean 1, 1% nulls" in {
    check(exponential[Real](1.0), probabilityOfEmpty = 0.01, range = (0.0, 15.0))
  }

  it should "Give Gamma distribution with mean 5, 0% nulls" in {
    check(gamma[Real](5.0), probabilityOfEmpty = 0.0, range = (0.0, 25.0))
  }

  it should "Give LogNormal distribution with mean 0.25, 20% nulls" in {
    check(logNormal[Real](0.25, 0.001), probabilityOfEmpty = 0.7, range = (0.1, 15.0))
  }

  it should "Weibull distribution (4.0, 5.0), 20% nulls" in {
    check(weibull[Real](4.0, 5.0), probabilityOfEmpty = 0.2, range = (0.0, 15.0))
  }

  Spec[RandomReal[RealNN]] should "give no nulls" in {
    check(normal[RealNN](1.0, 0.2), probabilityOfEmpty = 0.0, range = (-2.0, 4.0))
  }

  Spec[RandomReal[Currency]] should "distribute money normally" in {
    check(normal[Currency](1.0, 0.2), probabilityOfEmpty = 0.5, range = (-2.0, 4.0))
  }

  Spec[RandomReal[Percent]] should "distribute percentage evenly" in {
    check(uniform[Percent](1.0, 2.0), probabilityOfEmpty = 0.5, range = (0.0, 2.0))
  }

  private val rngSeed = 7688721

  private def check[T <: Real](
    src: RandomReal[T],
    probabilityOfEmpty: Double,
    range: (Double, Double)) = {
    val sut = src withProbabilityOfEmpty probabilityOfEmpty
    sut reset rngSeed

    val found = sut.next
    sut reset rngSeed
    val foundAfterReseed = sut.next
    if (foundAfterReseed != found) {
      sut.reset(rngSeed)
    }
    withClue(s"generator reset did not work for $sut") {
      foundAfterReseed shouldBe found
    }
    sut reset rngSeed

    val numberOfNulls = sut limit numTries count (_.isEmpty)

    val expectedNumberOfNulls = probabilityOfEmpty * numTries
    withClue(s"numNulls = $numberOfNulls, expected $expectedNumberOfNulls") {
      math.abs(numberOfNulls - expectedNumberOfNulls) < numTries / 100 shouldBe true
    }

    val numberOfOutliers = sut limit numTries count (xOpt => xOpt.value.exists(x => x < range._1 || x > range._2))

    numberOfOutliers should be < (numTries / 1000)

  }
} 
Example 151
Source File: RandomListTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import java.text.SimpleDateFormat

import com.salesforce.op.features.types._
import com.salesforce.op.test.TestCommon
import com.salesforce.op.testkit.RandomList.{NormalGeolocation, UniformGeolocation}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}

import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class RandomListTest extends FlatSpec with TestCommon with Assertions {
  private val numTries = 10000
  private val rngSeed = 314159214142136L

  private def check[D, T <: OPList[D]](
    g: RandomList[D, T],
    minLen: Int, maxLen: Int,
    predicate: (D => Boolean) = (_: D) => true
  ) = {
    g reset rngSeed

    def segment = g limit numTries

    segment count (_.value.length < minLen) shouldBe 0
    segment count (_.value.length > maxLen) shouldBe 0
    segment foreach (list => list.value foreach { x =>
      predicate(x) shouldBe true
    })
  }

  private val df = new SimpleDateFormat("dd/MM/yy")

  Spec[Text, RandomList[String, TextList]] should "generate lists of strings" in {
    val sut = RandomList.ofTexts(RandomText.countries, 0, 4)
    check[String, TextList](sut, 0, 4, _.length > 0)

    (sut limit 7 map (_.value.toList)) shouldBe
      List(
        List("Madagascar", "Gondal", "Zephyria"),
        List("Holy Alliance"),
        List("North American Union"),
        List("Guatemala", "Estonia", "Kolechia"),
        List(),
        List("Myanmar", "Bhutan"),
        List("Equatorial Guinea")
      )
  }

  Spec[Date, RandomList[Long, DateList]] should "generate lists of dates" in {
    val dates = RandomIntegral.dates(df.parse("01/01/2017"), 1000, 1000000)
    val sut = RandomList.ofDates(dates, 11, 22)
    var d0 = 0L
    check[Long, DateList](sut, 11, 22, d => {
      val d1 = d0
      d0 = d
      d > d1
    })
  }

  Spec[DateTimeList, RandomList[Long, DateTimeList]] should "generate lists of datetimes" in {
    val datetimes = RandomIntegral.datetimes(df.parse("01/01/2017"), 1000, 1000000)
    val sut = RandomList.ofDateTimes(datetimes, 11, 22)
    var d0 = 0L
    check[Long, DateTimeList](sut, 11, 22, d => {
      val d1 = d0
      d0 = d
      d > d1
    })
  }

  Spec[UniformGeolocation] should "generate uniformly distributed geolocations" in {
    val sut = RandomList.ofGeolocations
    val segment = sut limit numTries
    segment foreach (_.value.length shouldBe 3)
  }

  Spec[NormalGeolocation] should "generate geolocations around given point" in {
    for {accuracy <- GeolocationAccuracy.values} {
      val geolocation = RandomList.ofGeolocationsNear(37.444136, 122.163160, accuracy)
      val segment = geolocation limit numTries
      segment foreach (_.value.length shouldBe 3)
    }
  }
} 
Example 152
Source File: ScalaStyleValidationTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class ScalaStyleValidationTest extends FlatSpec with Matchers with Assertions {
  import scala.Throwable
  private def +(x: Int, y: Int) = x + y
  private def -(x: Int, y: Int) = x - y
  private def *(x: Int, y: Int) = x * y
  private def /(x: Int, y: Int) = x / y
  private def +-(x: Int, y: Int) = x + (-y)
  private def xx_=(y: Int) = println(s"setting xx to $y")

  "bad names" should "never happen" in {
    "def _=abc = ???" shouldNot compile
    true shouldBe true
  }

  "non-ascii" should "not be allowed" in {
//    "def ⇒ = ???" shouldNot compile // it does not even compile as a string
  }

} 
Example 153
Source File: OpRegressionEvaluatorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.evaluators

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.classification.OpLogisticRegression
import com.salesforce.op.stages.impl.regression.{OpLinearRegression, RegressionModelSelector}
import com.salesforce.op.stages.impl.selector.ModelSelectorNames.EstimatorType
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpRegressionEvaluatorTest extends FlatSpec with TestSparkContext {

  val (ds, rawLabel, features) = TestFeatureBuilder[RealNN, OPVector](
    Seq(
      (10.0, Vectors.dense(1.0, 4.3, 1.3)),
      (20.0, Vectors.dense(2.0, 0.3, 0.1)),
      (30.0, Vectors.dense(3.0, 3.9, 4.3)),
      (40.0, Vectors.dense(4.0, 1.3, 0.9)),
      (50.0, Vectors.dense(5.0, 4.7, 1.3)),
      (10.0, Vectors.dense(1.0, 4.3, 1.3)),
      (20.0, Vectors.dense(2.0, 0.3, 0.1)),
      (30.0, Vectors.dense(3.0, 3.9, 4.3)),
      (40.0, Vectors.dense(4.0, 1.3, 0.9)),
      (50.0, Vectors.dense(5.0, 4.7, 1.3))
    ).map(v => v._1.toRealNN -> v._2.toOPVector)
  )

  val label = rawLabel.copy(isResponse = true)

  val lr = new OpLogisticRegression()
  val lrParams = new ParamGridBuilder().addGrid(lr.regParam, Array(0.0)).build()

  val testEstimator = RegressionModelSelector.withTrainValidationSplit(dataSplitter = None, trainRatio = 0.5,
    modelsAndParameters = Seq(lr -> lrParams))
    .setInput(label, features)

  val prediction = testEstimator.getOutput()
  val testEvaluator = new OpRegressionEvaluator().setLabelCol(label).setPredictionCol(prediction)

  val testEstimator2 = new OpLinearRegression().setInput(label, features)

  val prediction2 = testEstimator2.getOutput()
  val testEvaluator2 = new OpRegressionEvaluator().setLabelCol(label).setPredictionCol(prediction2)


  Spec[OpRegressionEvaluator] should "copy" in {
    val testEvaluatorCopy = testEvaluator.copy(ParamMap())
    testEvaluatorCopy.uid shouldBe testEvaluator.uid
  }

  it should "evaluate the metrics from a model selector" in {
    val model = testEstimator.fit(ds)
    val transformedData = model.setInput(label, features).transform(ds)
    val metrics = testEvaluator.evaluateAll(transformedData).toMetadata()

    assert(metrics.getDouble(RegressionEvalMetrics.RootMeanSquaredError.toString) <= 1E-12, "rmse should be close to 0")
    assert(metrics.getDouble(RegressionEvalMetrics.MeanSquaredError.toString) <= 1E-24, "mse should be close to 0")
    assert(metrics.getDouble(RegressionEvalMetrics.R2.toString) == 1.0, "R2 should equal 1.0")
    assert(metrics.getDouble(RegressionEvalMetrics.MeanAbsoluteError.toString) <= 1E-12, "mae should be close to 0")
  }

  it should "evaluate the metrics from a single model" in {
    val model = testEstimator2.fit(ds)
    val transformedData = model.setInput(label, features).transform(ds)
    val metrics = testEvaluator2.evaluateAll(transformedData).toMetadata()

    assert(metrics.getDouble(RegressionEvalMetrics.RootMeanSquaredError.toString) <= 1E-12, "rmse should be close to 0")
    assert(metrics.getDouble(RegressionEvalMetrics.MeanSquaredError.toString) <= 1E-24, "mse should be close to 0")
    assert(metrics.getDouble(RegressionEvalMetrics.R2.toString) == 1.0, "R2 should equal 1.0")
    assert(metrics.getDouble(RegressionEvalMetrics.MeanAbsoluteError.toString) <= 1E-12, "mae should be close to 0")
  }
} 
Example 154
Source File: SummaryTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.filters

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class SummaryTest extends FlatSpec with TestCommon {
  Spec[Summary] should "be correctly created from a sequence of features" in {
    val f1 = Left(Seq("a", "b", "c"))
    val f2 = Right(Seq(0.5, 1.0))
    val f1s = Summary(f1)
    val f2s = Summary(f2)
    f1s.min shouldBe 3
    f1s.max shouldBe 3
    f1s.sum shouldBe 3
    f1s.count shouldBe 1
    f2s.min shouldBe 0.5
    f2s.max shouldBe 1.0
    f2s.sum shouldBe 1.5
    f2s.count shouldBe 2
  }
} 
Example 155
Source File: OpenNLPNameEntityTaggerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.text

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.feature.NameEntityRecognizer
import com.salesforce.op.test.TestCommon
import com.salesforce.op.utils.text.NameEntityType._
import opennlp.tools.util.Span
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpenNLPNameEntityTaggerTest extends FlatSpec with TestCommon {

  val nerTagger = new OpenNLPNameEntityTagger()

  Spec[OpenNLPNameEntityTagger] should "return the consistent results as expected" in {
    val input = Seq(
      "Pierre Vinken, 61 years old, will join the board as a nonexecutive director Nov. 29.",
      "Rudolph Agnew, 55 years old and former chairman of Consolidated Gold Fields PLC, was named a director of this" +
        "a director of this British industrial conglomerate."
    )
    val tokens: Seq[TextList] = input.map(x => NameEntityRecognizer.Analyzer.analyze(x, Language.English).toTextList)
    val expectedOutputs = Seq(
      Map("Vinken" -> Set(Person), "Pierre" -> Set(Person)),
      Map("Agnew" -> Set(Person), "Rudolph" -> Set(Person))
    )
    tokens.zip(expectedOutputs).foreach { case (tokenInput, expected) =>
      nerTagger.tag(tokenInput.value, Language.English, Seq(NameEntityType.Person)).tokenTags shouldEqual expected
    }
  }

  it should "load all the existing name entity recognition models" in {
    val languageNameEntityPairs = Seq(
      (Language.English, NameEntityType.Date),
      (Language.English, NameEntityType.Location),
      (Language.English, NameEntityType.Money),
      (Language.English, NameEntityType.Organization),
      (Language.English, NameEntityType.Percentage),
      (Language.English, NameEntityType.Person),
      (Language.English, NameEntityType.Time),
      (Language.Spanish, NameEntityType.Location),
      (Language.Spanish, NameEntityType.Organization),
      (Language.Spanish, NameEntityType.Person),
      (Language.Spanish, NameEntityType.Misc),
      (Language.Dutch, NameEntityType.Location),
      (Language.Dutch, NameEntityType.Organization),
      (Language.Dutch, NameEntityType.Person),
      (Language.Dutch, NameEntityType.Misc)
    )
    languageNameEntityPairs.foreach { case (l, n) =>
      OpenNLPModels.getTokenNameFinderModel(l, n).isDefined shouldBe true
    }
  }

  it should "not get any model correctly if no such model exists" in {
    val languageNameEntityPairs = Seq(
      (Language.Unknown, NameEntityType.Other),
      (Language.Urdu, NameEntityType.Location)
    )
    languageNameEntityPairs.foreach { case (l, n) =>
      OpenNLPModels.getTokenNameFinderModel(l, n) shouldBe None
    }
  }

  // test the convertSpansToMap function
  it should "retrieve correct information from the output of name entity recognition model" in {
    val inputs = Seq(Array("ab", "xx", "yy", "zz", "ss", "dd", "cc") ->
      Seq(new Span(2, 4, "person"), new Span(3, 5, "location")), // interweaving entities
      Array("a", "b", "c", "d") -> Seq(new Span(3, 4, "location")), // end of sentence entity
      Array("a", "b", "c", "d") -> Seq(new Span(0, 2, "location")), // beginning of sentence entity
      Array("a", "b", "c", "d") -> Seq.empty
    )
    val expectedOutputs = Seq(
      Map("yy" -> Set(Person), "zz" -> Set(Person, Location), "ss" -> Set(Location)),
      Map("d" -> Set(Location)),
      Map("a" -> Set(Location), "b" -> Set(Location)),
      Map.empty[String, Set[String]]
    )

    inputs.zip(expectedOutputs).map { case (tokensInput, expected) =>
      val actual = nerTagger.convertSpansToMap(tokensInput._2, tokensInput._1)
      actual shouldEqual expected
    }
  }

} 
Example 156
Source File: OpenNLPSentenceSplitterTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.text

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.feature.TextTokenizer
import com.salesforce.op.stages.impl.feature.TextTokenizer.TextTokenizerResult
import com.salesforce.op.test.TestCommon
import com.salesforce.op.utils.text.Language._
import opennlp.tools.sentdetect.SentenceModel
import opennlp.tools.tokenize.TokenizerModel
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpenNLPSentenceSplitterTest extends FlatSpec with TestCommon {

  val splitter = new OpenNLPSentenceSplitter()

  Spec[OpenNLPSentenceSplitter] should "split an English paragraph into sentences" in {
    val input =
      "Pierre Vinken, 61 years old, will join the board as a nonexecutive director Nov 29. " +
        "Mr Vinken is chairman of Elsevier N.V., the Dutch publishing group. Rudolph Agnew, 55 years old and " +
        "former chairman of Consolidated Gold Fields PLC, was named a director of this British industrial conglomerate."

    splitter.getSentences(input, language = English) shouldEqual Seq(
      "Pierre Vinken, 61 years old, will join the board as a nonexecutive director Nov 29.",
      "Mr Vinken is chairman of Elsevier N.V., the Dutch publishing group.",
      "Rudolph Agnew, 55 years old and former chairman of Consolidated Gold Fields PLC, " +
        "was named a director of this British industrial conglomerate."
    )

    TextTokenizer.tokenize(input.toText, sentenceSplitter = Option(splitter), defaultLanguage = English) shouldEqual
      TextTokenizerResult(English, Seq(
        Seq("pierr", "vinken", "61", "year", "old", "will", "join", "board",
          "nonexecut", "director", "nov", "29").toTextList,
        Seq("mr", "vinken", "chairman", "elsevi", "n.v", "dutch", "publish", "group").toTextList,
        Seq("rudolph", "agnew", "55", "year", "old", "former", "chairman", "consolid", "gold", "field", "plc",
          "name", "director", "british", "industri", "conglomer").toTextList))

    TextTokenizer.tokenize(input.toText, analyzer = new OpenNLPAnalyzer(), sentenceSplitter = Option(splitter),
      defaultLanguage = English) shouldEqual TextTokenizerResult(
      English, Seq(
        Seq("pierre", "vinken", ",", "61", "years", "old", ",", "will", "join", "the", "board", "as", "a",
          "nonexecutive", "director", "nov", "29", ".").toTextList,
        Seq("mr", "vinken", "is", "chairman", "of", "elsevier", "n", ".v.", ",", "the", "dutch", "publishing",
          "group", ".").toTextList,
        Seq("rudolph", "agnew", ",", "55", "years", "old", "and", "former", "chairman", "of", "consolidated",
          "gold", "fields", "plc", ",", "was", "named", "a", "director", "of", "this", "british", "industrial",
          "conglomerate", ".").toTextList))
  }

  it should "split a Portuguese text into sentences" in {
    // scalastyle:off
    val input = "Depois de Guimarães, o North Music Festival estaciona este ano no Porto. A partir de sexta-feira, " +
      "a Alfândega do Porto recebe a segunda edição deste festival de dois dias. No cartaz há nomes como os " +
      "portugueses Linda Martini e Mão Morta, mas também Guano Apes ou os DJ’s portugueses Rich e Mendes."

    splitter.getSentences(input, language = Portuguese) shouldEqual Seq(
      "Depois de Guimarães, o North Music Festival estaciona este ano no Porto.",
      "A partir de sexta-feira, a Alfândega do Porto recebe a segunda edição deste festival de dois dias.",
      "No cartaz há nomes como os portugueses Linda Martini e Mão Morta, mas também Guano Apes ou os DJ’s " +
        "portugueses Rich e Mendes."
    )
    // scalastyle:on
  }

  it should "load a sentence detection and tokenizer model for a language if they exist" in {
    val languages = Seq(Danish, Portuguese, English, Dutch, German, Sami)
    languages.map { language =>
      OpenNLPModels.getSentenceModel(language).exists(_.isInstanceOf[SentenceModel]) shouldBe true
      OpenNLPModels.getTokenizerModel(language).exists(_.isInstanceOf[TokenizerModel]) shouldBe true
    }
  }

  it should "load not a sentence detection and tokenizer model for a language if they do not exist" in {
    val languages = Seq(Japanese, Czech)
    languages.map { language =>
      OpenNLPModels.getSentenceModel(language) shouldEqual None
      OpenNLPModels.getTokenizerModel(language) shouldEqual None
    }
  }

  it should "return non-preprocessed input if no such a sentence detection model exist" in {
    // scalastyle:off
    val input = "ピエール・ヴィンケン(61歳)は、11月29日に臨時理事に就任します。" +
      "ヴィンケン氏は、オランダの出版グループであるエルゼビアN.V.の会長です。 " +
      "55歳のルドルフ・アグニュー(Rudolph Agnew、元コネチカットゴールドフィールドPLC)会長は、" +
      "この英国の産業大企業の取締役に任命されました。"
    // scalastyle:on
    splitter.getSentences(input, language = Language.Japanese) shouldEqual Seq(input)
  }
} 
Example 157
Source File: JobGroupUtilTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.spark

import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
import com.salesforce.op.test.{TestCommon, TestSparkContext}

@RunWith(classOf[JUnitRunner])
class JobGroupUtilTest extends FlatSpec with TestCommon with TestSparkContext {

  Spec(JobGroupUtil.getClass) should "be able to set a job group ID around a code block" in {
    JobGroupUtil.withJobGroup(OpStep.DataReadingAndFiltering) {
      spark.sparkContext.parallelize(Seq(1, 2, 3, 4, 5)).collect()
    }
    spark.sparkContext.statusTracker.getJobIdsForGroup("DataReadingAndFiltering") should not be empty
  }

  it should "reset the job group ID after a code block" in {
    JobGroupUtil.withJobGroup(OpStep.DataReadingAndFiltering) {
      spark.sparkContext.parallelize(Seq(1, 2, 3, 4, 5)).collect()
    }
    spark.sparkContext.parallelize(Seq(1, 2, 3, 4, 5)).collect()
    // Ensure that the last `.collect()` was not tagged with "DataReadingAndFiltering"
    spark.sparkContext.statusTracker.getJobIdsForGroup(null) should not be empty
  }
} 
Example 158
Source File: FeatureJsonHelperTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features

import com.salesforce.op._
import com.salesforce.op.test.{PassengerFeaturesTest, TestCommon}
import org.json4s.MappingException
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class FeatureJsonHelperTest extends FlatSpec with PassengerFeaturesTest with TestCommon {

  trait DifferentParents {
    val feature = height + weight
    val stages = Map(feature.originStage.uid -> feature.originStage)
    val features = Map(height.uid -> height, weight.uid -> weight)
  }

  trait SameParents {
    val feature = height + height
    val stages = Map(feature.originStage.uid -> feature.originStage)
    val features = Map(height.uid -> height, height.uid -> height)
  }

  Spec(FeatureJsonHelper.getClass) should "serialize/deserialize a feature properly" in new DifferentParents {
    val json = feature.toJson()
    val parsedFeature = FeatureJsonHelper.fromJsonString(json, stages, features)
    if (parsedFeature.isFailure) fail(s"Failed to deserialize from json: $json", parsedFeature.failed.get)

    val res = parsedFeature.get
    res shouldBe a[Feature[_]]
    res.equals(feature) shouldBe true
    res.uid shouldBe feature.uid
    res.wtt.tpe =:= feature.wtt.tpe shouldBe true
  }

  it should "deserialize a set of parent features from one reference" in new SameParents {
    val json = feature.toJson()
    val parsedFeature = FeatureJsonHelper.fromJsonString(feature.toJson(), stages, features)
    if (parsedFeature.isFailure) fail(s"Failed to deserialize from json: $json", parsedFeature.failed.get)

    val res = parsedFeature.get
    res.equals(feature) shouldBe true
    res.wtt.tpe =:= feature.wtt.tpe shouldBe true
  }

  it should "fail to deserialize invalid json" in new DifferentParents {
    val res = FeatureJsonHelper.fromJsonString("{}", stages, features)
    res.isFailure shouldBe true
    res.failed.get shouldBe a[MappingException]
  }

  it should "fail when origin stage is not found" in new DifferentParents {
    val res = FeatureJsonHelper.fromJsonString(feature.toJson(), stages = Map.empty, features)
    res.isFailure shouldBe true
    res.failed.get shouldBe a[RuntimeException]
  }

  it should "fail when not all parents are found" in new DifferentParents {
    val res = FeatureJsonHelper.fromJsonString(feature.toJson(), stages, features = Map.empty)
    res.isFailure shouldBe true
    res.failed.get shouldBe a[RuntimeException]
  }


} 
Example 159
Source File: RandomIntegralTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import java.text.SimpleDateFormat

import com.salesforce.op.features.types._
import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}

import scala.language.postfixOps

@RunWith(classOf[JUnitRunner])
class RandomIntegralTest extends FlatSpec with TestCommon with Assertions {
  private val numTries = 10000
  private val rngSeed = 314159214142135L

  private def check[T <: Integral](
    g: RandomIntegral[T],
    predicate: Long => Boolean = _ => true
  ) = {
    g reset rngSeed

    def segment = g limit numTries

    val numberOfEmpties = segment count (_.isEmpty)

    val expectedNumberOfEmpties = g.probabilityOfEmpty * numTries

    withClue(s"numEmpties = $numberOfEmpties, expected $expectedNumberOfEmpties") {
      math.abs(numberOfEmpties - expectedNumberOfEmpties) < 2 * math.sqrt(numTries) shouldBe true
    }

    val maybeValues = segment filterNot (_.isEmpty) map (_.value)
    val values = maybeValues collect { case Some(s) => s }

    values foreach (x => predicate(x) shouldBe true)

    withClue(s"number of distinct values = ${values.size}, expected:") {
      math.abs(maybeValues.size - values.toSet.size) < maybeValues.size / 20
    }

  }

  private val df = new SimpleDateFormat("dd/MM/yy")

  Spec[RandomIntegral[Integral]] should "generate empties and distinct numbers" in {
    val sut0 = RandomIntegral.integrals
    val sut = sut0.withProbabilityOfEmpty(0.3)
    check(sut)
    sut.probabilityOfEmpty shouldBe 0.3
  }

  Spec[RandomIntegral[Integral]] should "generate empties and distinct numbers in some range" in {
    val sut0 = RandomIntegral.integrals(100, 200)
    val sut = sut0.withProbabilityOfEmpty(0.3)
    check(sut, i => i >= 100 && i < 200)
    sut.probabilityOfEmpty shouldBe 0.3
  }

  Spec[RandomIntegral[Date]] should "generate dates" in {
    val sut = RandomIntegral.dates(df.parse("01/01/2017"), 1000, 1000000)
    var d0 = 0L
    check(sut withProbabilityOfEmpty 0.01, d => {
      val d1 = d0
      d0 = d
      d0 > d1
    })
  }

  Spec[RandomIntegral[DateTime]] should "generate dates with times" in {
    val sut = RandomIntegral.datetimes(df.parse("08/24/2017"), 1000, 1000000)
    var d0 = 0L
    check(sut withProbabilityOfEmpty 0.001, d => {
      val d1 = d0
      d0 = d
      d0 > d1
    })
  }
} 
Example 160
Source File: OpTransformerSequenceReaderWriterTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages

import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.sequence.SequenceLambdaTransformer
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpTransformerSequenceReaderWriterTest extends OpPipelineStageReaderWriterTest {
  override val expectedFeaturesLength = 1
  override val hasOutputName = false

  val stage =
    new SequenceLambdaTransformer[DateList, Real](
      operationName = "test",
      transformFn = new Lambdas.FncSequence,
      uid = "uid_1234"
    ).setInput(boarded).setMetadata(meta)

  val expected = Array(2942.toReal, 1471.toReal, 0.toReal, 1471.toReal, 1471.toReal, 1471.toReal)
} 
Example 161
Source File: OpTransformerBinarySequenceReaderWriterTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages

import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.sequence.BinarySequenceLambdaTransformer
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpTransformerBinarySequenceReaderWriterTest extends OpPipelineStageReaderWriterTest {
  override val expectedFeaturesLength = 2
  override val hasOutputName = false

  val stage =
    new BinarySequenceLambdaTransformer[Real, DateList, Real](
      operationName = "test",
      transformFn = new Lambdas.FncBinarySequence
    ).setInput(weight, boarded).setMetadata(meta)

  val expected = Array(3114.toReal, 1538.toReal, 0.toReal, 1549.toReal, 1567.toReal, 1538.toReal)
} 
Example 162
Source File: OpRegressionModelTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.regression

import com.salesforce.op.features.types.{Prediction, RealNN}
import com.salesforce.op.stages.sparkwrappers.specific.SparkModelConverter.toOP
import com.salesforce.op.test._
import com.salesforce.op.testkit._
import ml.dmlc.xgboost4j.scala.spark.{OpXGBoost, OpXGBoostQuietLogging, XGBoostRegressor}
import org.apache.spark.ml.regression._
import org.apache.spark.sql.DataFrame
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpRegressionModelTest extends FlatSpec with TestSparkContext with OpXGBoostQuietLogging {

  private val label = RandomIntegral.integrals(0, 2).limit(1000)
    .map{ v => RealNN(v.value.map(_.toDouble).getOrElse(0.0)) }
  private val fv = RandomVector.binary(10, 0.3).limit(1000)

  private val data = label.zip(fv)

  private val (rawDF, labelF, featureV) = TestFeatureBuilder("label", "features", data)

  Spec[OpDecisionTreeRegressionModel] should "produce the same values as the spark version" in {
    val spk = new DecisionTreeRegressor()
      .setFeaturesCol(featureV.name)
      .setLabelCol(labelF.name)
      .fit(rawDF)

    val op = toOP(spk, spk.uid).setInput(labelF, featureV)
    compareOutputs(spk.transform(rawDF), op.transform(rawDF))
  }

  Spec[OpLinearRegressionModel] should "produce the same values as the spark version" in {
    val spk = new LinearRegression()
      .setFeaturesCol(featureV.name)
      .setLabelCol(labelF.name)
      .fit(rawDF)

    val op = toOP(spk, spk.uid).setInput(labelF, featureV)
    compareOutputs(spk.transform(rawDF), op.transform(rawDF))
  }

  Spec[OpGBTRegressionModel] should "produce the same values as the spark version" in {
    val spk = new GBTRegressor()
      .setFeaturesCol(featureV.name)
      .setLabelCol(labelF.name)
      .fit(rawDF)

    val op = toOP(spk, spk.uid).setInput(labelF, featureV)
    compareOutputs(spk.transform(rawDF), op.transform(rawDF))
  }

  Spec[OpRandomForestRegressionModel] should "produce the same values as the spark version" in {
    val spk = new RandomForestRegressor()
      .setFeaturesCol(featureV.name)
      .setLabelCol(labelF.name)
      .fit(rawDF)

    val op = toOP(spk, spk.uid).setInput(labelF, featureV)
    compareOutputs(spk.transform(rawDF), op.transform(rawDF))
  }

  Spec[OpGeneralizedLinearRegressionModel] should "produce the same values as the spark version" in {
    val spk = new GeneralizedLinearRegression()
      .setFeaturesCol(featureV.name)
      .setLabelCol(labelF.name)
      .fit(rawDF)

    val op = toOP(spk, spk.uid).setInput(labelF, featureV)
    compareOutputs(spk.transform(rawDF), op.transform(rawDF))
  }

  Spec[OpXGBoostRegressionModel] should "produce the same values as the spark version" in {
    val reg = new XGBoostRegressor()
    reg.set(reg.trackerConf, OpXGBoost.DefaultTrackerConf)
      .setFeaturesCol(featureV.name)
      .setLabelCol(labelF.name)
    val spk = reg.fit(rawDF)

    val op = toOP(spk, spk.uid).setInput(labelF, featureV)
    compareOutputs(spk.transform(rawDF), op.transform(rawDF))
  }

  def compareOutputs(df1: DataFrame, df2: DataFrame): Unit = {
    val sorted1 = df1.collect().sortBy(_.getAs[Double](2))
    val sorted2 = df2.collect().sortBy(_.getAs[Map[String, Double]](2)(Prediction.Keys.PredictionName))
    sorted1.zip(sorted2).foreach{ case (r1, r2) =>
      val map = r2.getAs[Map[String, Double]](2)
      r1.getAs[Double](2) shouldEqual map(Prediction.Keys.PredictionName)
    }
  }
} 
Example 163
Source File: OpXGBoostRegressorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.regression

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test._
import ml.dmlc.xgboost4j.scala.spark.{OpXGBoostQuietLogging, XGBoostRegressionModel, XGBoostRegressor}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpXGBoostRegressorTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[XGBoostRegressionModel],
  OpPredictorWrapper[XGBoostRegressor, XGBoostRegressionModel]]
  with PredictionEquality with OpXGBoostQuietLogging {

  override def specName: String = Spec[OpXGBoostRegressor]

  val rawData = Seq(
    (10.0, Vectors.dense(1.0, 4.3, 1.3)),
    (20.0, Vectors.dense(2.0, 0.3, 0.1)),
    (30.0, Vectors.dense(3.0, 3.9, 4.3)),
    (40.0, Vectors.dense(4.0, 1.3, 0.9)),
    (50.0, Vectors.dense(5.0, 4.7, 1.3))
  ).map { case (l, v) => l.toRealNN -> v.toOPVector }

  val (inputData, label, features) = TestFeatureBuilder("label", "features", rawData)

  val estimator = new OpXGBoostRegressor().setInput(label.copy(isResponse = true), features)
  estimator.setSilent(1)

  val expectedResult = Seq(
    Prediction(1.9250000715255737),
    Prediction(8.780000686645508),
    Prediction(8.780000686645508),
    Prediction(8.780000686645508),
    Prediction(8.780000686645508)
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator.setMaxDepth(18).setBaseScore(0.12345).setSkipDrop(0.6234)
    estimator.fit(inputData)
    estimator.predictor.getMaxDepth shouldBe 18
    estimator.predictor.getBaseScore shouldBe 0.12345
    estimator.predictor.getSkipDrop shouldBe 0.6234

  }
} 
Example 164
Source File: OpGBTRegressorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.regression

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpGBTRegressorTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[GBTRegressionModel],
  OpPredictorWrapper[GBTRegressor, GBTRegressionModel]] with PredictionEquality {

  override def specName: String = Spec[OpGBTRegressor]

  val (inputData, rawLabel, features) = TestFeatureBuilder(
    Seq[(RealNN, OPVector)](
      (10.0.toRealNN, Vectors.dense(1.0, 4.3, 1.3).toOPVector),
      (20.0.toRealNN, Vectors.dense(2.0, 0.3, 0.1).toOPVector),
      (30.0.toRealNN, Vectors.dense(3.0, 3.9, 4.3).toOPVector),
      (40.0.toRealNN, Vectors.dense(4.0, 1.3, 0.9).toOPVector),
      (50.0.toRealNN, Vectors.dense(5.0, 4.7, 1.3).toOPVector)
    )
  )
  val label = rawLabel.copy(isResponse = true)
  val estimator = new OpGBTRegressor().setInput(label, features)

  val expectedResult = Seq(
    Prediction(10.0),
    Prediction(20.0),
    Prediction(30.0),
    Prediction(40.0),
    Prediction(50.0)
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator
      .setMaxIter(10)
      .setMaxDepth(6)
      .setMaxBins(2)
      .setMinInstancesPerNode(2)
      .setMinInfoGain(0.1)
    estimator.fit(inputData)

    estimator.predictor.getMaxIter shouldBe 10
    estimator.predictor.getMaxDepth shouldBe 6
    estimator.predictor.getMaxBins shouldBe 2
    estimator.predictor.getMinInstancesPerNode shouldBe 2
    estimator.predictor.getMinInfoGain shouldBe 0.1
  }
} 
Example 165
Source File: OpRandomForestRegressorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.regression

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpRandomForestRegressorTest extends OpEstimatorSpec[Prediction,
  OpPredictorWrapperModel[RandomForestRegressionModel],
  OpPredictorWrapper[RandomForestRegressor, RandomForestRegressionModel]] with PredictionEquality {

  override def specName: String = Spec[OpRandomForestRegressor]

  val (inputData, rawLabel, features) = TestFeatureBuilder(
    Seq[(RealNN, OPVector)](
      (10.0.toRealNN, Vectors.dense(1.0, 4.3, 1.3).toOPVector),
      (20.0.toRealNN, Vectors.dense(2.0, 0.3, 0.1).toOPVector),
      (30.0.toRealNN, Vectors.dense(3.0, 3.9, 4.3).toOPVector),
      (40.0.toRealNN, Vectors.dense(4.0, 1.3, 0.9).toOPVector),
      (50.0.toRealNN, Vectors.dense(5.0, 4.7, 1.3).toOPVector)
    )
  )
  val label = rawLabel.copy(isResponse = true)
  val estimator = new OpRandomForestRegressor().setInput(label, features)

  val expectedResult = Seq(
    Prediction(20.0),
    Prediction(23.5),
    Prediction(31.5),
    Prediction(35.5),
    Prediction(37.0)
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator
      .setMaxDepth(7)
      .setMaxBins(3)
      .setMinInstancesPerNode(2)
      .setMinInfoGain(0.1)
      .setSeed(42L)
    estimator.fit(inputData)

    estimator.predictor.getMaxDepth shouldBe 7
    estimator.predictor.getMaxBins shouldBe 3
    estimator.predictor.getMinInstancesPerNode shouldBe 2
    estimator.predictor.getMinInfoGain shouldBe 0.1
    estimator.predictor.getSeed shouldBe 42L

  }
} 
Example 166
Source File: OpDecisionTreeRegressorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.regression

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpDecisionTreeRegressorTest extends OpEstimatorSpec[Prediction,
  OpPredictorWrapperModel[DecisionTreeRegressionModel],
  OpPredictorWrapper[DecisionTreeRegressor, DecisionTreeRegressionModel]] with PredictionEquality {

  override def specName: String = Spec[OpDecisionTreeRegressor]

  val (inputData, rawLabel, features) = TestFeatureBuilder(
    Seq[(RealNN, OPVector)](
      (10.0.toRealNN, Vectors.dense(1.0, 4.3, 1.3).toOPVector),
      (20.0.toRealNN, Vectors.dense(2.0, 0.3, 0.1).toOPVector),
      (30.0.toRealNN, Vectors.dense(3.0, 3.9, 4.3).toOPVector),
      (40.0.toRealNN, Vectors.dense(4.0, 1.3, 0.9).toOPVector),
      (50.0.toRealNN, Vectors.dense(5.0, 4.7, 1.3).toOPVector)
    )
  )
  val label = rawLabel.copy(isResponse = true)
  val estimator = new OpDecisionTreeRegressor().setInput(label, features)

  val expectedResult = Seq(
    Prediction(10.0),
    Prediction(20.0),
    Prediction(30.0),
    Prediction(40.0),
    Prediction(50.0)
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator
      .setMaxDepth(6)
      .setMaxBins(2)
      .setMinInstancesPerNode(2)
      .setMinInfoGain(0.1)
    estimator.fit(inputData)

    estimator.predictor.getMaxDepth shouldBe 6
    estimator.predictor.getMaxBins shouldBe 2
    estimator.predictor.getMinInstancesPerNode shouldBe 2
    estimator.predictor.getMinInfoGain shouldBe 0.1

  }
} 
Example 167
Source File: OpLinearRegressionTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.regression

import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.binary.{BinaryEstimator, BinaryModel}
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpLinearRegressionTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[LinearRegressionModel],
  OpPredictorWrapper[LinearRegression, LinearRegressionModel]] with PredictionEquality {

  override def specName: String = Spec[OpLinearRegression]

  val (inputData, rawLabel, features) = TestFeatureBuilder(
    Seq[(RealNN, OPVector)](
      (10.0.toRealNN, Vectors.dense(1.0, 4.3, 1.3).toOPVector),
      (20.0.toRealNN, Vectors.dense(2.0, 0.3, 0.1).toOPVector),
      (30.0.toRealNN, Vectors.dense(3.0, 3.9, 4.3).toOPVector),
      (40.0.toRealNN, Vectors.dense(4.0, 1.3, 0.9).toOPVector),
      (50.0.toRealNN, Vectors.dense(5.0, 4.7, 1.3).toOPVector)
    )
  )
  val label = rawLabel.copy(isResponse = true)
  val estimator = new OpLinearRegression().setInput(label, features)

  val expectedResult = Seq(
    Prediction(10.0),
    Prediction(20.0),
    Prediction(30.0),
    Prediction(40.0),
    Prediction(50.0)
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator
      .setMaxIter(10)
      .setRegParam(0.1)
      .setFitIntercept(true)
      .setElasticNetParam(0.1)
      .setSolver("normal")
    estimator.fit(inputData)

    estimator.predictor.getMaxIter shouldBe 10
    estimator.predictor.getRegParam shouldBe 0.1
    estimator.predictor.getFitIntercept shouldBe true
    estimator.predictor.getElasticNetParam shouldBe 0.1
    estimator.predictor.getSolver shouldBe "normal"

  }
} 
Example 168
Source File: OpGeneralizedLinearRegressionTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.regression

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.{GeneralizedLinearRegression, GeneralizedLinearRegressionModel}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpGeneralizedLinearRegressionTest extends OpEstimatorSpec[Prediction,
  OpPredictorWrapperModel[GeneralizedLinearRegressionModel],
  OpPredictorWrapper[GeneralizedLinearRegression, GeneralizedLinearRegressionModel]] with PredictionEquality {

  override def specName: String = Spec[OpGeneralizedLinearRegression]

  val (inputData, rawLabel, features) = TestFeatureBuilder(
    Seq[(RealNN, OPVector)](
      (10.0.toRealNN, Vectors.dense(1.0, 4.3, 1.3).toOPVector),
      (20.0.toRealNN, Vectors.dense(2.0, 0.3, 0.1).toOPVector),
      (30.0.toRealNN, Vectors.dense(3.0, 3.9, 4.3).toOPVector),
      (40.0.toRealNN, Vectors.dense(4.0, 1.3, 0.9).toOPVector),
      (50.0.toRealNN, Vectors.dense(5.0, 4.7, 1.3).toOPVector)
    )
  )
  val label = rawLabel.copy(isResponse = true)
  val estimator = new OpGeneralizedLinearRegression().setInput(label, features)

  val expectedResult = Seq(
    Prediction(10.0, 9.99),
    Prediction(20.0, 19.99),
    Prediction(30.0, 29.99),
    Prediction(40.0, 40.0),
    Prediction(50.0, 50.0)
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator
      .setMaxIter(10)
      .setRegParam(0.1)
      .setFitIntercept(true)
      .setTol(1E-4)
      .setSolver("irls")
    estimator.fit(inputData)

    estimator.predictor.getMaxIter shouldBe 10
    estimator.predictor.getRegParam shouldBe 0.1
    estimator.predictor.getFitIntercept shouldBe true
    estimator.predictor.getTol shouldBe 1E-4
    estimator.predictor.getSolver shouldBe "irls"

  }
} 
Example 169
Source File: EchoCommandTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.commands

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class EchoCommandTest extends CommonWordSpec {

  "EchoCommand" should {
    "work with 0 params" in {
      new EchoCommand().executeLine(List.empty) should be(true)
    }

    "work with multiple params" in {
      new EchoCommand().executeLine(List("a", "b", "c", "d", "e")) should be(true)
    }
  }

} 
Example 170
Source File: KillableSingleThreadTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.interrupts

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

import scala.concurrent.Await
import scala.concurrent.duration.Duration

@RunWith(classOf[JUnitRunner])
class KillableSingleThreadTest extends CommonWordSpec {
  "KillableSingleThread" should {
    "provide the result of a thread" in {
      val sut = new KillableSingleThread(
        "hello"
      )

      sut.start()

      Await.result(sut.future, Duration.apply(10, "seconds")) should be("hello")
    }

    "let you kill a stuck thread" in {
      val sut = new KillableSingleThread(
        while (true) {
          Thread.sleep(500)
        }
      )

      sut.start()

      sut.kill(Duration.apply(500, "ms"))
    }
  }

} 
Example 171
Source File: ShellColorsTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ShellColorsTest extends CommonWordSpec {
  "ShellColors" should {
    "not error for each color and contain original text" in {
      ShellColors.black("test") should include("test")
      ShellColors.blue("test") should include("test")
      ShellColors.cyan("test") should include("test")
      ShellColors.green("test") should include("test")
      ShellColors.magenta("test") should include("test")
      ShellColors.red("test") should include("test")
      ShellColors.white("test") should include("test")
      ShellColors.yellow("test") should include("test")
    }
  }
} 
Example 172
Source File: ShellBannerTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ShellBannerTest extends CommonWordSpec {
  "ShellBanner" should {
    "load banners from resources" in {
      println("Eyeball it:")
      println(ShellBanner.Warning)

      ShellBanner.Warning.contains("_///////") should be(true)
    }
  }
} 
Example 173
Source File: ShellCommandAliasTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase

import com.sumologic.shellbase.cmdline.RichCommandLine._
import com.sumologic.shellbase.cmdline.{CommandLineOption, RichCommandLine}
import org.apache.commons.cli.{CommandLine, Options}
import org.junit.runner.RunWith
import org.mockito.Mockito._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ShellCommandAliasTest extends CommonWordSpec {
  "ShellCommandAlias" should {
    "execute original command" in {
      val cmd = new DummyCommand("almond")
      val subtree = new ShellCommandAlias(cmd, "better_almond", List())
      subtree.execute(null)

      cmd.executed should be(true)
    }

    "have same options" in {
      val option = new CommandLineOption("e", "example", true, "An example in test")
      val cmd = new ShellCommand("almond", "Just for testing") {

        def execute(cmdLine: CommandLine) = {
          cmdLine.get(option).isDefined
        }

        override def addOptions(opts: Options) {
          opts.addOption(option)
        }
      }
      val subtree = new ShellCommandAlias(cmd, "alias", List())

      subtree.parseOptions(List("--example", "a")).hasOption("example") should be(true)
      cmd.parseOptions(List("--example", "a")).hasOption("example") should be(true)
      val cmdLine = mock(classOf[CommandLine])
      when(cmdLine.hasOption("example")).thenReturn(true)
      when(cmdLine.getOptionValue("example")).thenReturn("return")
      subtree.execute(cmdLine)
    }
  }
} 
Example 174
Source File: ShellStringSupportTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase

import com.sumologic.shellbase.ShellStringSupport._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ShellStringSupportTest extends CommonWordSpec {
  "ShellStringSupport" should {
    "calculate the visible length of a string without escapes" in {
      val s = "a string"
      s.visibleLength should be(8)
    }

    "calculate the visible length of a string with escapes" in {
      val s = ShellColors.red("a red string")
      s.visibleLength should be(12)
    }

    "trim a string with escapes" in {
      val s = s"a ${ShellColors.red("red")} string"
      s.escapedTrim(6) should be(s"a ${ShellColors.red("red")} ")
    }
  }
} 
Example 175
Source File: ScriptRendererSpec.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase

import java.io.File

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ScriptRendererSpec extends CommonWordSpec {
  "ScriptRenderer" should {
    "convert the arguments to a map" in {
      val command = new ScriptRenderer(null, Array("key=value", "key1=value1"))
      val props = command.argsToMap
      props("key") should be("value")
      props("key1") should be("value1")
    }
    "get the lines of non velocity script" in {
      val parser = new ScriptRenderer(new File("src/test/resources/scripts/novelocity"), Array[String]())
      val lines: Seq[String] = parser.getLines
      lines should contain("do something")
      lines should contain("exit")
    }
    "get the lines of velocity script with keys replaced with values" in {
      val parser = new ScriptRenderer(new File("src/test/resources/scripts/velocity"), Array("key1=value1", "key2=value2"))
      val lines: Seq[String] = parser.getLines
      lines should contain("do something value1")
      lines should contain("do something value2")
      lines should contain("exit")
    }
  }
} 
Example 176
Source File: TimeFormatsTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.timeutil

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TimeFormatsTest extends CommonWordSpec {
  import TimeFormats._

  "TimeFormats.parseTersePeriod" should {
    "return a time" when {
      "a time in millis is passed" in {
        parseTersePeriod("1234") should be (Some(1234))
      }

      "a terse period is passed" in {
        val seconds30 = 30 * 1000
        val minutes2 = seconds30 * 4
        parseTersePeriod("2m30s") should be (Some(minutes2 + seconds30))
      }
    }

    "return None" when {
      "null is passed" in {
        parseTersePeriod(null) should be (None)
      }

      "empty string is passed" in {
        parseTersePeriod("") should be (None)
      }

      "unparsable is passed" in {
        parseTersePeriod("humbug") should be (None)
      }
    }
  }

  "TimeFormats.formatAsTersePeriod" should {
    "return 0 when span is 0" in {
      formatAsTersePeriod(0) should be ("0")
    }

    "format as milliseconds when span is less than 1000" in {
      formatAsTersePeriod(100) should be ("100ms")
      formatAsTersePeriod(-100) should be ("-100ms")
      formatAsTersePeriod(1) should be ("1ms")
    }

    "format as seconds when span is 1000" in {
      formatAsTersePeriod(1000) should be ("1s")
    }

    "format properly for different values" in {
      formatAsTersePeriod(60*1000) should be ("1m")
      formatAsTersePeriod(60*60*1000) should be ("1h")
      formatAsTersePeriod(24*60*60*1000) should be ("1d")
      formatAsTersePeriod(7*24*60*60*1000) should be ("7d")

      val millis_3d6h5m4s10ms = 3*24*60*60*1000 + 6*60*60*1000 + 5*60*1000 + 4*1000 + 10
      formatAsTersePeriod(millis_3d6h5m4s10ms) should be ("3d6h5m4s")
    }

    "not fail when time span is large" in {
      formatAsTersePeriod(Long.MaxValue - 1) should be (errorString)
    }
  }
} 
Example 177
Source File: TimedBlockTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.timeutil

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TimedBlockTest extends CommonWordSpec {

  "TimedBlock" should {
    "return the function result" in {
      TimedBlock("test") {
        "hello"
      } should be ("hello")
    }

    "bubble the exception" in {
      class VerySpecificException extends Exception
      intercept[VerySpecificException] {
        TimedBlock("test") {
          throw new VerySpecificException
        }
      }
    }

    "support other writers" in {
      var storageString = ""
      def recordString(str: String) = storageString += str
      TimedBlock("test", recordString) {
        "test"
      }

      storageString should not be ('empty)
    }
  }

} 
Example 178
Source File: TeeCommandTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.commands

import java.nio.charset.Charset
import java.nio.file.{Files, Path}

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

import scala.collection.JavaConverters._
import scala.util.Random

@RunWith(classOf[JUnitRunner])
class TeeCommandTest extends CommonWordSpec {
  "TeeCommand" should {
    "execute a subcommand and propagate exit code" in {
      var calls = 0
      def callCheck(ret: Boolean)(input: String): Boolean = {
        input should be("hi")
        calls += 1
        ret
      }

      new TeeCommand(callCheck(true)).executeLine(List("`hi`", "-o", getTempFilePath().toString)) should be(true)
      calls should be(1)

      new TeeCommand(callCheck(false)).executeLine(List("`hi`", "-o", getTempFilePath().toString)) should be(false)
      calls should be(2)
    }

    "degrade nicely with malformatted input" in {
      new TeeCommand(_ => true).executeLine(List.empty) should be(false)
      new TeeCommand(_ => true).executeLine(List("test")) should be(false)
    }

    "write output to file, and support append mode" in {
      def printMessage(str: String): Boolean = {
        println(str)
        true
      }

      val tempFile = getTempFilePath()
      new TeeCommand(printMessage).executeLine(List("`hi mom`", "-o", tempFile.toString))
      // The first line is the debug line, so everything after is logged
      readTempFile(tempFile) should be(List("hi mom"))

      // We should override since not in append mode
      new TeeCommand(printMessage).executeLine(List("`hi mom 2`", "-o", tempFile.toString))
      // The first line is the debug line, so everything after is logged
      readTempFile(tempFile) should be(List("hi mom 2"))

      // We have both 2 and 3 since in append move
      new TeeCommand(printMessage).executeLine(List("`hi mom 3`", "-o", tempFile.toString, "-a"))
      // The first line is the debug line, so everything after is logged
      readTempFile(tempFile) should be(List("hi mom 2", "hi mom 3"))
    }


  }

  private def getTempFilePath(): Path = {
    Files.createTempFile("teecommand", ".tmp")
  }

  private def readTempFile(path: Path): List[String] = {
    Files.readAllLines(path, Charset.defaultCharset()).asScala.filterNot(_.startsWith("Running")).toList
  }

} 
Example 179
Source File: ExitCommandTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.commands

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ExitCommandTest extends CommonWordSpec {
  "ExitCommand" should {
    "exit using the method given to it" in {
      var called = false
      def callCapture(given: Int): Unit = {
        called = true
        given should be(0)
      }

      new ExitCommand(callCapture).executeLine(List.empty) should be(true)

      called should be(true)
    }
  }

} 
Example 180
Source File: MinVarianceFilterMetadataTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.preparators

import com.salesforce.op.stages.impl.preparators.MinVarianceSummary.statisticsFromMetadata
import com.salesforce.op.test.TestSparkContext
import com.salesforce.op.utils.spark.RichMetadata._
import org.apache.spark.sql.types.Metadata
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class MinVarianceFilterMetadataTest extends FlatSpec with TestSparkContext {

  val summary = MinVarianceSummary(
    dropped = Seq("f1"),
    featuresStatistics = SummaryStatistics(3, 0.01, Seq(0.1, 0.2, 0.3), Seq(0.1, 0.2, 0.3),
      Seq(0.1, 0.2, 0.3), Seq(0.1, 0.2, 0.3)),
    names = Seq("f1", "f2", "f3")
  )

  Spec[MinVarianceSummary] should "convert to and from metadata correctly" in {
    val meta = summary.toMetadata()
    meta.isInstanceOf[Metadata] shouldBe true

    val retrieved = MinVarianceSummary.fromMetadata(meta)
    retrieved.isInstanceOf[MinVarianceSummary]

    retrieved.dropped should contain theSameElementsAs summary.dropped
    retrieved.featuresStatistics.count shouldBe summary.featuresStatistics.count
    retrieved.featuresStatistics.max should contain theSameElementsAs summary.featuresStatistics.max
    retrieved.featuresStatistics.min should contain theSameElementsAs summary.featuresStatistics.min
    retrieved.featuresStatistics.mean should contain theSameElementsAs summary.featuresStatistics.mean
    retrieved.featuresStatistics.variance should contain theSameElementsAs summary.featuresStatistics.variance
    retrieved.names should contain theSameElementsAs summary.names
  }

  it should "convert to and from JSON and give the same values" in {
    val meta = summary.toMetadata()
    val json = meta.wrapped.prettyJson
    val recovered = Metadata.fromJson(json).wrapped
    val dropped = recovered.getArray[String](MinVarianceNames.Dropped).toSeq
    val stats = statisticsFromMetadata(recovered.get[Metadata](MinVarianceNames.FeaturesStatistics))
    val names = recovered.getArray[String](MinVarianceNames.Names).toSeq

    dropped should contain theSameElementsAs summary.dropped
    stats.count shouldBe summary.featuresStatistics.count
    stats.max should contain theSameElementsAs summary.featuresStatistics.max
    stats.min should contain theSameElementsAs summary.featuresStatistics.min
    stats.mean should contain theSameElementsAs summary.featuresStatistics.mean
    stats.variance should contain theSameElementsAs summary.featuresStatistics.variance
    names should contain theSameElementsAs summary.names
  }

} 
Example 181
Source File: SleepCommandTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.commands

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class SleepCommandTest extends CommonWordSpec {

  "SleepCommand" should {
    "let you sleep a bit" in {
      new SleepCommand().executeLine(List("30")) should be(true)
    }

    "inform you about missing arguments for sleep" in {
      new SleepCommand().executeLine(List.empty) should be(false)
    }

    "work in verbose mode" in {
      new SleepCommand().executeLine(List("-v", "30")) should be(true)
    }
  }

} 
Example 182
Source File: RunScriptCommandTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.commands

import java.io.File

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class RunScriptCommandTest extends CommonWordSpec {
  "RunScriptCommand" should {
    "handle non-existent script files" in {
      val sut = new RunScriptCommand(List(scriptsDir), null, runCommand = _ => throw new Exception("Should not get here"))
      sut.executeLine(List.empty) should be (false)

      sut.executeLine(List("does_not_exist")) should be (false)
    }

    "run scripts with short names" in {
      import java.nio.file.Files
      import java.io.File

      def from(n: Int): Stream[Int] = n #:: from(n + 1)

      val nats: Seq[Int] = from(0)

      val tmpDir = new File(System.getProperty("java.io.tmpdir"))
      val tmpFileOpt = nats.take(100).map(n => new File(tmpDir, n.toString)).find(f => !f.exists())

      tmpFileOpt match {
        case Some(tmpFile) =>
          tmpFile.deleteOnExit()
          Files.write(tmpFile.toPath, "echo hello".getBytes())

          val sut = new RunScriptCommand(List(scriptsDir), "", runCommand = _ => true)
          sut.executeLine(List(tmpFile.getAbsolutePath)) should be(true)

        case None => fail("Can't create unique tmp file")
      }
    }

    // FIXME(chris, 2016-05-25): These tests does not pass as the path resolution ignores scriptDir.  I need to fix that
    // first, but it's a larger task than I want to do right this second.  Additionally, until we unify that code, we should
    // skip writing tests for auto-complete.

    "accept either scripts or the parent of scripts dir" ignore {
      val sut1 = new RunScriptCommand(List(scriptsDir), null, runCommand = inputShouldBeSimple(true))
      sut1.executeLine(List("simple")) should be (true)

      val sut2 = new RunScriptCommand(List(scriptsDir.getParentFile), null, runCommand = inputShouldBeSimple(true))
      sut2.executeLine(List("simple")) should be (true)
    }

    "return the status of runCommand" ignore {
      val sut1 = new RunScriptCommand(List(scriptsDir), null, runCommand = inputShouldBeSimple(false))
      sut1.executeLine(List("simple")) should be (false)

      val sut2 = new RunScriptCommand(List(scriptsDir), null, runCommand = inputShouldBeSimple(true))
      sut2.executeLine(List("simple")) should be (true)
    }
  }

  private def inputShouldBeSimple(ret: Boolean = true)(cmd: String): Boolean = {
    cmd should be ("hi")
    ret
  }

  private val scriptsDir = new File("src/test/resources/scripts")

} 
Example 183
Source File: TimeCommandTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.commands

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TimeCommandTest extends CommonWordSpec {
  "TimeCommand" should {
    "execute a subcommand and propagate exit code" in {
      var calls = 0
      def callCheck(ret: Boolean)(input: String): Boolean = {
        input should be("hi")
        calls += 1
        ret
      }

      new TimeCommand(callCheck(true)).executeLine(List("`hi`")) should be(true)
      calls should be(1)

      new TimeCommand(callCheck(false)).executeLine(List("`hi`")) should be(false)
      calls should be(2)
    }

    "degrade nicely with malformatted input" in {
      new TimeCommand(_ => true).executeLine(List.empty) should be(false)
      new TimeCommand(_ => true).executeLine(List("test")) should be(false)
    }
  }

} 
Example 184
Source File: InMemoryShellNotificationManagerTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.notifications

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.mockito.Mockito._
import org.scalatest.BeforeAndAfterEach
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar

@RunWith(classOf[JUnitRunner])
class InMemoryShellNotificationManagerTest extends CommonWordSpec with BeforeAndAfterEach with MockitoSugar {

  "InMemoryShellNotificationManager" should {
    "provide notification names" in {
      val sut = new InMemoryShellNotificationManager("", Seq(notification1, notification2))
      sut.notifierNames should be(Seq(firstName, secondName))
    }

    "know if a notification is enabled by default" in {
      val sut = new InMemoryShellNotificationManager("", Seq(notification1, notification2), enabledByDefault = false)
      sut.notificationEnabled(firstName) should be(false)
      sut.notificationEnabled(secondName) should be(false)
      sut.notificationEnabled("madeUp") should be(false)

      val sut2 = new InMemoryShellNotificationManager("", Seq(notification1, notification2), enabledByDefault = true)
      sut2.notificationEnabled(firstName) should be(true)
      sut2.notificationEnabled(secondName) should be(true)
      sut2.notificationEnabled("madeUp") should be(true)
    }

    "support enabling and disabling notifications" in {
      val sut = new InMemoryShellNotificationManager("", Seq(notification1, notification2))
      sut.notificationEnabled(firstName) should be(false)
      sut.notificationEnabled(secondName) should be(false)

      sut.enable(firstName)
      sut.notificationEnabled(firstName) should be(true)
      sut.notificationEnabled(secondName) should be(false)

      sut.enable(secondName)
      sut.notificationEnabled(firstName) should be(true)
      sut.notificationEnabled(secondName) should be(true)

      sut.disable(firstName)
      sut.notificationEnabled(firstName) should be(false)
      sut.notificationEnabled(secondName) should be(true)

      sut.disable(secondName)
      sut.notificationEnabled(firstName) should be(false)
      sut.notificationEnabled(secondName) should be(false)
    }

    "only notify enabled notifications" in {
      val notificationString = "test"
      val sut = new InMemoryShellNotificationManager("", Seq(notification1, notification2))

      sut.notify(notificationString)
      verify(notification1, times(0)).notify("", notificationString)
      verify(notification2, times(0)).notify("", notificationString)

      sut.enable(firstName)
      sut.notify(notificationString)
      verify(notification1, times(1)).notify("", notificationString)
      verify(notification2, times(0)).notify("", notificationString)

      sut.enable(secondName)
      sut.notify(notificationString)
      verify(notification1, times(2)).notify("", notificationString)
      verify(notification2, times(1)).notify("", notificationString)

      sut.disable(firstName)
      sut.notify(notificationString)
      verify(notification1, times(2)).notify("", notificationString)
      verify(notification2, times(2)).notify("", notificationString)

    }
  }

  private val firstName = "first"
  private val secondName = "second"

  private var notification1: ShellNotification = _
  private var notification2: ShellNotification = _

  override protected def beforeEach(): Unit = {
    notification1 = mock[ShellNotification]
    notification2 = mock[ShellNotification]

    when(notification1.name).thenReturn(firstName)
    when(notification2.name).thenReturn(secondName)
  }
} 
Example 185
Source File: NotificationCommandSetTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.notifications

import com.sumologic.shellbase.CommonWordSpec
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class NotificationCommandSetTest extends CommonWordSpec {
  "Notification Command Set" should {
    "list notifications" in {
      val manager = new InMemoryShellNotificationManager("", Seq(createNotification("test")))
      val sut = new NotificationCommandSet(manager)
      sut.executeLine(List("list"))
    }

    "list notifications (even if empty)" in {
      val manager = new InMemoryShellNotificationManager("", Seq.empty)
      val sut = new NotificationCommandSet(manager)
      sut.executeLine(List("list"))
    }

    "let you toggle on/off all notifications at once" in {
      val manager = new InMemoryShellNotificationManager("", Seq(createNotification("1"), createNotification("2"), createNotification("3")))
      val sut = new NotificationCommandSet(manager)

      sut.executeLine(List("enable"))
      manager.notificationEnabled("1") should be(true)
      manager.notificationEnabled("2") should be(true)
      manager.notificationEnabled("3") should be(true)

      sut.executeLine(List("disable"))
      manager.notificationEnabled("1") should be(false)
      manager.notificationEnabled("2") should be(false)
      manager.notificationEnabled("3") should be(false)

      sut.executeLine(List("enable", "all"))
      manager.notificationEnabled("1") should be(true)
      manager.notificationEnabled("2") should be(true)
      manager.notificationEnabled("3") should be(true)

      sut.executeLine(List("disable", "all"))
      manager.notificationEnabled("1") should be(false)
      manager.notificationEnabled("2") should be(false)
      manager.notificationEnabled("3") should be(false)
    }

    "let you toggle on/off notifications individually/in a group" in {
      val manager = new InMemoryShellNotificationManager("", Seq(createNotification("1"), createNotification("2"), createNotification("3")))
      val sut = new NotificationCommandSet(manager)

      sut.executeLine(List("enable", "1"))
      manager.notificationEnabled("1") should be(true)
      manager.notificationEnabled("2") should be(false)
      manager.notificationEnabled("3") should be(false)

      sut.executeLine(List("disable", "1"))
      manager.notificationEnabled("1") should be(false)
      manager.notificationEnabled("2") should be(false)
      manager.notificationEnabled("3") should be(false)

      sut.executeLine(List("enable", "2,3"))
      manager.notificationEnabled("1") should be(false)
      manager.notificationEnabled("2") should be(true)
      manager.notificationEnabled("3") should be(true)

      sut.executeLine(List("disable", "1,3"))
      manager.notificationEnabled("1") should be(false)
      manager.notificationEnabled("2") should be(true)
      manager.notificationEnabled("3") should be(false)

    }
  }

  private def createNotification(n: String) = new ShellNotification {
    override def notify(title: String, message: String): Unit = ???

    override def name: String = n
  }
} 
Example 186
Source File: PostToSlackHelperTest.scala    From shellbase   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.shellbase.slack

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class PostToSlackHelperTest extends CommonWordSpec {
  // NOTE: Some of the test coverage for PostToSlackHelper is done by PostCommandToSlackTest

  "PostToSlackHelper" should {
    "skip posting if username is blacklisted" in {
      val sut = new PostToSlackHelper {
        override protected val slackState: SlackState = null
        override protected val username = "my_test"
        override protected val blacklistedUsernames = Set("my_test", "my_test_2")
      }

      sut.sendSlackMessageIfConfigured("")
    }

    "allow posting if username is not blacklist" in {
      val sut = new PostToSlackHelper {
        override protected val slackState: SlackState = null
        override protected val username = "abc"
        override protected val blacklistedUsernames = Set("my_test", "my_test_2")
      }

      intercept[NullPointerException] {
        sut.sendSlackMessageIfConfigured("")
      }
    }
  }

} 
Example 187
Source File: StackOverflowSuite.scala    From big-data-scala-spark   with MIT License 5 votes vote down vote up
package stackoverflow

import org.scalatest.{FunSuite, BeforeAndAfterAll}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import java.io.File

@RunWith(classOf[JUnitRunner])
class StackOverflowSuite extends FunSuite with BeforeAndAfterAll {


  lazy val testObject = new StackOverflow {
    override val langs =
      List(
        "JavaScript", "Java", "PHP", "Python", "C#", "C++", "Ruby", "CSS",
        "Objective-C", "Perl", "Scala", "Haskell", "MATLAB", "Clojure", "Groovy")
    override def langSpread = 50000
    override def kmeansKernels = 45
    override def kmeansEta: Double = 20.0D
    override def kmeansMaxIterations = 120
  }

  test("testObject can be instantiated") {
    val instantiatable = try {
      testObject
      true
    } catch {
      case _: Throwable => false
    }
    assert(instantiatable, "Can't instantiate a StackOverflow object")
  }


} 
Example 188
Source File: RandomSetTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import com.salesforce.op.features.types._
import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}

import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class RandomSetTest extends FlatSpec with TestCommon with Assertions {
  private val numTries = 10000
  private val rngSeed = 314159214142136L

  private def check[D, T <: OPSet[D]](
    g: RandomSet[D, T],
    minLen: Int, maxLen: Int,
    predicate: (D => Boolean) = (_: D) => true
  ) = {
    g reset rngSeed

    def segment = g limit numTries

    segment count (_.value.size < minLen) shouldBe 0
    segment count (_.value.size > maxLen) shouldBe 0
    segment foreach (Set => Set.value foreach { x =>
      predicate(x) shouldBe true
    })
  }

  Spec[MultiPickList] should "generate multipicklists" in {
    val sut = RandomMultiPickList.of(RandomText.countries, maxLen = 5)

    check[String, MultiPickList](sut, 0, 5, _.nonEmpty)

    val expected = List(
      Set(),
      Set("Aldorria", "Palau", "Glubbdubdrib"),
      Set(),
      Set(),
      Set("Sweden", "Wuhu Islands", "Tuvalu")
    )

    {sut reset 42; sut limit 5 map (_.value)} shouldBe expected

    {sut reset 42; sut limit 5 map (_.value)} shouldBe expected
  }

} 
Example 189
Source File: RandomBinaryTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import com.salesforce.op.features.types.Binary
import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class RandomBinaryTest extends FlatSpec with TestCommon {
  val numTries = 1000000
  val rngSeed = 12345

  private def truthWithProbability(probabilityOfTrue: Double) = {
    RandomBinary(probabilityOfTrue)
  }

  Spec[RandomBinary] should "generate empties, truths and falses" in {
    check(truthWithProbability(0.5) withProbabilityOfEmpty 0.5)
    check(truthWithProbability(0.3) withProbabilityOfEmpty 0.65)
    check(truthWithProbability(0.0) withProbabilityOfEmpty 0.1)
    check(truthWithProbability(1.0) withProbabilityOfEmpty 0.0)
  }

  private def check(g: RandomBinary) = {
    g reset rngSeed
    val numberOfEmpties = g limit numTries count (_.isEmpty)
    val expectedNumberOfEmpties = g.probabilityOfEmpty * numTries
    withClue(s"numEmpties = $numberOfEmpties, expected $expectedNumberOfEmpties") {
      math.abs(numberOfEmpties - expectedNumberOfEmpties) < numTries / 100 shouldBe true
    }

    val expectedNumberOfTruths = g.probabilityOfSuccess * (1 - g.probabilityOfEmpty) * numTries
    val numberOfTruths = g limit numTries count (Binary(true) ==)
    withClue(s"numTruths = $numberOfTruths, expected $expectedNumberOfTruths") {
      math.abs(numberOfTruths - expectedNumberOfTruths) < numTries / 100 shouldBe true
    }
  }
} 
Example 190
Source File: PointDStreamExtensionsSpec.scala    From reactiveinflux-spark   with Apache License 2.0 5 votes vote down vote up
package com.pygmalios.reactiveinflux.extensions

import com.holdenkarau.spark.testing.StreamingActionBase
import com.pygmalios.reactiveinflux.spark._
import com.pygmalios.reactiveinflux._
import org.apache.spark.streaming.dstream.DStream
import org.junit.runner.RunWith
import org.scalatest.BeforeAndAfterAll
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class PointDStreamExtensionsSpec extends StreamingActionBase
  with BeforeAndAfterAll {
  import PointRDDExtensionsSpec._

  override def beforeAll: Unit = {
    super.beforeAll
    withInflux(_.create())
  }

  override def afterAll: Unit = {
    withInflux(_.drop())
    super.afterAll
  }

  test("write single point to Influx") {
    val points = List(point1)

    // Execute
    runAction(Seq(points), (dstream: DStream[Point]) => dstream.saveToInflux())

    // Assert
    val result = withInflux(
      _.query(Query(s"SELECT * FROM $measurement1")).result.singleSeries)

    assert(result.rows.size == 1)

    val row = result.rows.head
    assert(row.time == point1.time)
    assert(row.values.size == 5)
  }
} 
Example 191
Source File: PointRDDExtensionsSpec.scala    From reactiveinflux-spark   with Apache License 2.0 5 votes vote down vote up
package com.pygmalios.reactiveinflux.extensions

import com.holdenkarau.spark.testing.SharedSparkContext
import com.pygmalios.reactiveinflux.Point.Measurement
import com.pygmalios.reactiveinflux._
import com.pygmalios.reactiveinflux.extensions.PointRDDExtensionsSpec._
import com.pygmalios.reactiveinflux.spark._
import com.pygmalios.reactiveinflux.spark.extensions.PointRDDExtensions
import org.joda.time.{DateTime, DateTimeZone}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfter, FlatSpec}

import scala.concurrent.duration._

@RunWith(classOf[JUnitRunner])
class PointRDDExtensionsSpec extends FlatSpec with SharedSparkContext
  with BeforeAndAfter {

  before {
    withInflux(_.create())
  }

  after {
    withInflux(_.drop())
  }

  behavior of "saveToInflux"

  it should "write single point to Influx" in {
    val points = List(point1)
    val rdd = sc.parallelize(points)

    // Execute
    rdd.saveToInflux()

    // Assert
    assert(PointRDDExtensions.totalBatchCount == 1)
    assert(PointRDDExtensions.totalPointCount == 1)
    val result = withInflux(
      _.query(Query(s"SELECT * FROM $measurement1"))
      .result
      .singleSeries)

    assert(result.rows.size == 1)

    val row = result.rows.head
    assert(row.time == point1.time)
    assert(row.values.size == 5)
  }

  it should "write 1000 points to Influx" in {
    val points = (1 to 1000).map { i =>
      Point(
        time = point1.time.plusMinutes(i),
        measurement = point1.measurement,
        tags = point1.tags,
        fields = point1.fields
      )
    }
    val rdd = sc.parallelize(points)

    // Execute
    rdd.saveToInflux()

    // Assert
    assert(PointRDDExtensions.totalBatchCount == 8)
    assert(PointRDDExtensions.totalPointCount == 1000)
    val result = withInflux(
      _.query(Query(s"SELECT * FROM $measurement1"))
        .result
        .singleSeries)

    assert(result.rows.size == 1000)
  }
}

object PointRDDExtensionsSpec {
  implicit val params: ReactiveInfluxDbName = ReactiveInfluxDbName("test")
  implicit val awaitAtMost: Duration = 1.second

  val measurement1: Measurement = "measurement1"
  val point1 = Point(
    time        = new DateTime(1983, 1, 10, 7, 43, 10, 3, DateTimeZone.UTC),
    measurement = measurement1,
    tags        = Map("tagKey1" -> "tagValue1", "tagKey2" -> "tagValue2"),
    fields      = Map("fieldKey1" -> StringFieldValue("fieldValue1"), "fieldKey2" -> BigDecimalFieldValue(10.7)))
} 
Example 192
Source File: SwingApiTest.scala    From Principles-of-Reactive-Programming   with GNU General Public License v3.0 5 votes vote down vote up
package suggestions



import scala.collection._
import scala.concurrent._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.{Try, Success, Failure}
import scala.swing.event.Event
import scala.swing.Reactions.Reaction
import rx.lang.scala._
import org.scalatest._
import gui._

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class SwingApiTest extends FunSuite {

  object swingApi extends SwingApi {
    class ValueChanged(val textField: TextField) extends Event

    object ValueChanged {
      def unapply(x: Event) = x match {
        case vc: ValueChanged => Some(vc.textField)
        case _ => None
      }
    }

    class ButtonClicked(val source: Button) extends Event

    object ButtonClicked {
      def unapply(x: Event) = x match {
        case bc: ButtonClicked => Some(bc.source)
        case _ => None
      }
    }

    class Component {
      private val subscriptions = mutable.Set[Reaction]()
      def subscribe(r: Reaction) {
        subscriptions add r
      }
      def unsubscribe(r: Reaction) {
        subscriptions remove r
      }
      def publish(e: Event) {
        for (r <- subscriptions) r(e)
      }
    }

    class TextField extends Component {
      private var _text = ""
      def text = _text
      def text_=(t: String) {
        _text = t
        publish(new ValueChanged(this))
      }
    }

    class Button extends Component {
      def click() {
        publish(new ButtonClicked(this))
      }
    }
  }

  import swingApi._
  
  test("SwingApi should emit text field values to the observable") {
    val textField = new swingApi.TextField
    val values = textField.textValues

    val observed = mutable.Buffer[String]()
    val sub = values subscribe {
      observed += _
    }

    // write some text now
    textField.text = "T"
    textField.text = "Tu"
    textField.text = "Tur"
    textField.text = "Turi"
    textField.text = "Turin"
    textField.text = "Turing"

    assert(observed == Seq("T", "Tu", "Tur", "Turi", "Turin", "Turing"), observed)
  }

} 
Example 193
Source File: WikipediaApiTest.scala    From Principles-of-Reactive-Programming   with GNU General Public License v3.0 5 votes vote down vote up
package suggestions



import language.postfixOps
import scala.concurrent._
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.{Try, Success, Failure}
import rx.lang.scala._
import org.scalatest._
import gui._

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class WikipediaApiTest extends FunSuite {

  object mockApi extends WikipediaApi {
    def wikipediaSuggestion(term: String) = Future {
      if (term.head.isLetter) {
        for (suffix <- List(" (Computer Scientist)", " (Footballer)")) yield term + suffix
      } else {
        List(term)
      }
    }
    def wikipediaPage(term: String) = Future {
      "Title: " + term
    }
  }

  import mockApi._

  test("WikipediaApi should make the stream valid using sanitized") {
    val notvalid = Observable.just("erik", "erik meijer", "martin")
    val valid = notvalid.sanitized

    var count = 0
    var completed = false

    val sub = valid.subscribe(
      term => {
        assert(term.forall(_ != ' '))
        count += 1
      },
      t => assert(false, s"stream error $t"),
      () => completed = true
    )
    assert(completed && count == 3, "completed: " + completed + ", event count: " + count)
  }
  test("WikipediaApi should correctly use concatRecovered") {
    val requests = Observable.just(1, 2, 3)
    val remoteComputation = (n: Int) => Observable.just(0 to n : _*)
    val responses = requests concatRecovered remoteComputation
    val sum = responses.foldLeft(0) { (acc, tn) =>
      tn match {
        case Success(n) => acc + n
        case Failure(t) => throw t
      }
    }
    var total = -1
    val sub = sum.subscribe {
      s => total = s
    }
    assert(total == (1 + 1 + 2 + 1 + 2 + 3), s"Sum: $total")
  }

} 
Example 194
Source File: QuickCheckSuite.scala    From Principles-of-Reactive-Programming   with GNU General Public License v3.0 5 votes vote down vote up
package quickcheck

import org.scalatest.FunSuite

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
import org.scalacheck.Prop
import org.scalacheck.Prop._

import org.scalatest.exceptions.TestFailedException

object QuickCheckBinomialHeap extends QuickCheckHeap with BinomialHeap

@RunWith(classOf[JUnitRunner])
class QuickCheckSuite extends FunSuite with Checkers {
  def checkBogus(p: Prop) {
    var ok = false
    try {
      check(p)
    } catch {
      case e: TestFailedException =>
        ok = true
    }
    assert(ok, "A bogus heap should NOT satisfy all properties. Try to find the bug!")
  }

  test("Binomial heap satisfies properties.") {
    check(new QuickCheckHeap with BinomialHeap)
  }

  test("Bogus (1) binomial heap does not satisfy properties.") {
    checkBogus(new QuickCheckHeap with Bogus1BinomialHeap)
  }

  test("Bogus (2) binomial heap does not satisfy properties.") {
    checkBogus(new QuickCheckHeap with Bogus2BinomialHeap)
  }

  test("Bogus (3) binomial heap does not satisfy properties.") {
    checkBogus(new QuickCheckHeap with Bogus3BinomialHeap)
  }

  test("Bogus (4) binomial heap does not satisfy properties.") {
    checkBogus(new QuickCheckHeap with Bogus4BinomialHeap)
  }

  test("Bogus (5) binomial heap does not satisfy properties.") {
    checkBogus(new QuickCheckHeap with Bogus5BinomialHeap)
  }
} 
Example 195
Source File: SidechainNodeViewHolderTest.scala    From Sidechains-SDK   with MIT License 5 votes vote down vote up
package com.horizen.actors

import java.util.concurrent.TimeUnit

import akka.actor.{ActorRef, ActorSystem}
import akka.pattern.ask
import akka.testkit.TestKit
import akka.util.Timeout
import com.horizen.SidechainNodeViewHolder.ReceivableMessages.GetDataFromCurrentSidechainNodeView
import com.horizen.fixtures.SidechainNodeViewHolderFixture
import com.horizen.node.SidechainNodeView
import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}

import scala.concurrent._
import scala.concurrent.duration._
import org.scalatest._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class SidechainNodeViewHolderTest extends Suites(
  new SidechainNodeViewHolderTest1,
  new SidechainNodeViewHolderTest2
)

@RunWith(classOf[JUnitRunner])
class SidechainNodeViewHolderTest1
  extends TestKit(ActorSystem("testsystem"))
  with FunSuiteLike
  with BeforeAndAfterAll
  with SidechainNodeViewHolderFixture
{

  implicit val timeout = Timeout(5, TimeUnit.SECONDS)

  override def afterAll: Unit = {
    //info("Actor system is shutting down...")
    TestKit.shutdownActorSystem(system)
  }

  test ("Test1") {
    def f(v: SidechainNodeView) = v
    val sidechainNodeViewHolderRef: ActorRef = getSidechainNodeViewHolderRef
    val nodeView = (sidechainNodeViewHolderRef ? GetDataFromCurrentSidechainNodeView(f))
      .mapTo[SidechainNodeView]

    assert(Await.result(nodeView, 5 seconds) != null)
  }

  test("Test2") {
  }

}

@RunWith(classOf[JUnitRunner])
class SidechainNodeViewHolderTest2
  extends TestKit(ActorSystem("testSystem"))
  with FeatureSpecLike
  with BeforeAndAfterAll
  with Matchers
  with SidechainNodeViewHolderFixture
{

  implicit val timeout = Timeout(5, TimeUnit.SECONDS)

  override def afterAll: Unit = {
    //info("Actor system is shutting down...")
    TestKit.shutdownActorSystem(system)
  }

  feature("Actor1") {
    scenario("Scenario 1"){
      system should not be(null)

      def f(v: SidechainNodeView) = v
      val sidechainNodeViewHolderRef: ActorRef = getSidechainNodeViewHolderRef
      val nodeView = (sidechainNodeViewHolderRef ? GetDataFromCurrentSidechainNodeView(f))
        .mapTo[SidechainNodeView]

      Await.result(nodeView, 5 seconds) should not be(null)

    }
  }
} 
Example 196
Source File: DockerHelperTest.scala    From MaRe   with Apache License 2.0 5 votes vote down vote up
package se.uu.it.mare

import scala.util.Properties

import org.junit.runner.RunWith
import org.scalatest.FunSuite

import com.github.dockerjava.core.DefaultDockerClientConfig
import com.github.dockerjava.core.DockerClientBuilder
import com.github.dockerjava.core.command.PullImageResultCallback
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class DockerHelperTest extends FunSuite {

  // Init Docker client
  private val configBuilder = DefaultDockerClientConfig.createDefaultConfigBuilder()
  if (Properties.envOrNone("DOCKER_HOST") != None) {
    configBuilder.withDockerHost(System.getenv("DOCKER_HOST"))
  }
  if (Properties.envOrNone("DOCKER_TLS_VERIFY") != None) {
    val tlsVerify = System.getenv("DOCKER_TLS_VERIFY") == "1"
    configBuilder.withDockerTlsVerify(tlsVerify)
  }
  if (Properties.envOrNone("DOCKER_CERT_PATH") != None) {
    configBuilder.withDockerCertPath(System.getenv("DOCKER_CERT_PATH"))
  }
  private val config = configBuilder.build
  private val dockerClient = DockerClientBuilder.getInstance(config).build

  test("Map-like Docker run, image not present") {

    // Remove image if present
    val localImgList = dockerClient.listImagesCmd
      .withImageNameFilter("busybox:1")
      .exec
    if (localImgList.size > 0) {
      dockerClient.removeImageCmd("busybox:1")
        .withForce(true)
        .exec
    }

    // Run docker
    val statusCode = DockerHelper.run(
      imageName = "busybox:1",
      command = "true",
      bindFiles = Seq(),
      volumeFiles = Seq(),
      forcePull = false)

    assert(statusCode == 0)

  }

  test("Map-like Docker run, image present") {

    // Pull image
    dockerClient.pullImageCmd("busybox:1")
      .exec(new PullImageResultCallback)
      .awaitSuccess()

    // Run docker
    val statusCode = DockerHelper.run(
      imageName = "busybox:1",
      command = "true",
      bindFiles = Seq(),
      volumeFiles = Seq(),
      forcePull = false)

    assert(statusCode == 0)

  }

  test("Map-like Docker run, force pull") {

    // Pull image
    dockerClient.pullImageCmd("busybox:1")
      .exec(new PullImageResultCallback)
      .awaitSuccess()

    // Run docker
    val statusCode = DockerHelper.run(
      imageName = "busybox:1",
      command = "true",
      bindFiles = Seq(),
      volumeFiles = Seq(),
      forcePull = true)

    assert(statusCode == 0)

  }

} 
Example 197
Source File: VirtualScreeningTest.scala    From MaRe   with Apache License 2.0 5 votes vote down vote up
package se.uu.it.mare

import java.io.File
import java.util.UUID

import scala.io.Source
import scala.util.Properties

import org.apache.spark.SharedSparkContext
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

private object SDFUtils {
  def parseIDsAndScores(sdf: String): Array[(String, String)] = {
    sdf.split("\\n\\$\\$\\$\\$\\n").map { mol =>
      val lines = mol.split("\\n")
      (lines(0), lines.last)
    }
  }
}

@RunWith(classOf[JUnitRunner])
class VirtualScreeningTest extends FunSuite with SharedSparkContext {

  private val tmpDir = new File(Properties.envOrElse("TMPDIR", "/tmp"))

  test("Virtual Screening") {

    sc.hadoopConfiguration.set("textinputformat.record.delimiter", "\n$$$$\n")
    val mols = sc.textFile(getClass.getResource("sdf/molecules.sdf").getPath)

    // Parallel execution with MaRe
    val hitsParallel = new MaRe(mols)
      .map(
        inputMountPoint = TextFile("/input.sdf", "\n$$$$\n"),
        outputMountPoint = TextFile("/output.sdf", "\n$$$$\n"),
        imageName = "mcapuccini/oe:latest",
        command = "fred -receptor /var/openeye/hiv1_protease.oeb " +
          "-hitlist_size 0 " +
          "-conftest none " +
          "-dock_resolution Low " +
          "-dbase /input.sdf " +
          "-docked_molecule_file /output.sdf")
      .reduce(
        inputMountPoint = TextFile("/input.sdf", "\n$$$$\n"),
        outputMountPoint = TextFile("/output.sdf", "\n$$$$\n"),
        imageName = "mcapuccini/sdsorter:latest",
        command = "sdsorter -reversesort='FRED Chemgauss4 score' " +
          "-keep-tag='FRED Chemgauss4 score' " +
          "-nbest=30 " +
          "/input.sdf " +
          "/output.sdf")
      .rdd.collect.mkString("\n$$$$\n")

    // Serial execution
    val inputFile = new File(getClass.getResource("sdf/molecules.sdf").getPath)
    val dockedFile = new File(tmpDir, "mare_test_" + UUID.randomUUID.toString)
    dockedFile.createNewFile
    dockedFile.deleteOnExit
    val outputFile = new File(tmpDir, "mare_test_" + UUID.randomUUID.toString)
    outputFile.createNewFile
    outputFile.deleteOnExit
    DockerHelper.run(
      imageName = "mcapuccini/oe:latest",
      command = "fred -receptor /var/openeye/hiv1_protease.oeb " +
        "-hitlist_size 0 " +
        "-conftest none " +
        "-dock_resolution Low " +
        "-dbase /input.sdf " +
        "-docked_molecule_file /docked.sdf",
      bindFiles = Seq(inputFile, dockedFile),
      volumeFiles = Seq(new File("/input.sdf"), new File("/docked.sdf")),
      forcePull = false)
    DockerHelper.run(
      imageName = "mcapuccini/sdsorter:latest",
      command = "sdsorter -reversesort='FRED Chemgauss4 score' " +
        "-keep-tag='FRED Chemgauss4 score' " +
        "-nbest=30 " +
        "/docked.sdf " +
        "/output.sdf",
      bindFiles = Seq(dockedFile, outputFile),
      volumeFiles = Seq(new File("/docked.sdf"), new File("/output.sdf")),
      forcePull = false)
    val hitsSerial = Source.fromFile(outputFile).mkString

    // Test
    val parallel = SDFUtils.parseIDsAndScores(hitsParallel)
    val serial = SDFUtils.parseIDsAndScores(hitsSerial)
    assert(parallel.deep == serial.deep)

  }

} 
Example 198
Source File: ValuesStoreTest.scala    From random-projections-at-berlinbuzzwords   with Apache License 2.0 5 votes vote down vote up
package com.stefansavev

import java.util.Random

import com.stefansavev.randomprojections.datarepr.dense.store._
import com.typesafe.scalalogging.StrictLogging
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class TestSingleByteEncodingSpec extends FlatSpec with Matchers {
  "Error after encoding double to float" should "be small" in {
    val minV = -1.0f
    val maxV = 2.0f
    val rnd = new Random(481861)
    for (i <- 0 until 100) {
      //we encode a float (which is 4 bytes) with a single byte
      //therefore the loss of precision
      val value = rnd.nextFloat() * 3.0f - 1.0f
      val enc = FloatToSingleByteEncoder.encodeValue(minV, maxV, value)
      val dec = FloatToSingleByteEncoder.decodeValue(minV, maxV, enc)
      val error = Math.abs(value - dec)
      error should be < (0.01)
    }
  }
}

@RunWith(classOf[JUnitRunner])
class TestValueStores extends FlatSpec with Matchers {

  case class BuilderTypeWithErrorPredicate(builderType: StoreBuilderType, pred: Double => Boolean)

  "ValueStore" should "return store the data with small error" in {

    val tests = List(
      BuilderTypeWithErrorPredicate(StoreBuilderAsDoubleType, error => (error <= 0.0)),
      BuilderTypeWithErrorPredicate(StoreBuilderAsBytesType, error => (error <= 0.01)),
      BuilderTypeWithErrorPredicate(StoreBuilderAsSingleByteType, error => (error <= 0.01))
    )

    for (test <- tests) {
      testBuilder(test)
    }

    def testBuilder(builderWithPred: BuilderTypeWithErrorPredicate): Unit = {
      val dataGenSettings = RandomBitStrings.RandomBitSettings(
        numGroups = 1000,
        numRowsPerGroup = 2,
        numCols = 256,
        per1sInPrototype = 0.5,
        perNoise = 0.2)

      val debug = false
      val randomBitStringsDataset = RandomBitStrings.genRandomData(58585, dataGenSettings, debug, true)
      val builder = builderWithPred.builderType.getBuilder(randomBitStringsDataset.numCols)

      def addValues(): Unit = {
        var i = 0
        while (i < randomBitStringsDataset.numRows) {
          val values = randomBitStringsDataset.getPointAsDenseVector(i)
          builder.addValues(values)
          i += 1
        }
      }

      addValues()

      val valueStore = builder.build()

      def verifyStoredValues(expected: Array[Double], stored: Array[Double]): Unit = {
        for (i <- 0 until expected.length) {
          val error = Math.abs(expected(i) - stored(i))
          val passed = builderWithPred.pred(error)
          passed should be (true)
        }
      }

      def testValues(): Unit = {
        var i = 0
        while (i < randomBitStringsDataset.numRows) {
          val values = randomBitStringsDataset.getPointAsDenseVector(i)
          val output = Array.ofDim[Double](randomBitStringsDataset.numCols)
          valueStore.fillRow(i, output, true)
          verifyStoredValues(values, output)
          i += 1
        }
      }
      testValues()
    }
  }
}


object Test extends StrictLogging {
  def main(args: Array[String]) {
    logger.info("hello")
  }
} 
Example 199
Source File: GloveUnitTest.scala    From random-projections-at-berlinbuzzwords   with Apache License 2.0 5 votes vote down vote up
package com.stefansavev.fuzzysearchtest

import java.io.StringReader

import com.stefansavev.TemporaryFolderFixture
import com.stefansavev.core.serialization.TupleSerializers._
import org.junit.runner.RunWith
import org.scalatest.{FunSuite, Matchers}
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class GloveUnitTest extends FunSuite with TemporaryFolderFixture with Matchers {

  def readResource(name: String): String ={
    val stream = getClass.getResourceAsStream(name)
    val lines = scala.io.Source.fromInputStream( stream ).getLines
    lines.mkString("\n")
  }

  def parameterizedTest(inputTextFile: String, indexFile: String, numTrees: Int, expectedResultsName: String): Unit ={
    val expectedResults = readResource(expectedResultsName).trim
    val queryResults = GloveTest.run(inputTextFile, indexFile, numTrees).trim
    assertResult(expectedResults)(queryResults)
  }

  //manually download http://nlp.stanford.edu/data/glove.6B.zip and unzip into test/resources/glove
  //then enable the test
  ignore("test glove num trees 1") {
    val numTrees: Int = 1
    val inputTextFile: String = "src/test/resources/glove/glove.6B.100d.txt"
    val index = temporaryFolder.newFolder("index").getAbsolutePath
    val expectedResultsResouceName = "/glove/expected_results_num_trees_1.txt"
    parameterizedTest(inputTextFile, index, numTrees, expectedResultsResouceName)
  }

  ignore("test glove num trees 150") {
    val numTrees: Int = 150
    val inputTextFile: String = "src/test/resources/glove/glove.6B.100d.txt"
    val index = temporaryFolder.newFolder("index").getAbsolutePath
    val expectedResultsResouceName = "/glove/expected_results_num_trees_150.txt"
    parameterizedTest(inputTextFile, index, numTrees, expectedResultsResouceName)
  }
} 
Example 200
Source File: String2IdTests.scala    From random-projections-at-berlinbuzzwords   with Apache License 2.0 5 votes vote down vote up
package com.stefansavev.core.string2id

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class TestDynamicString2IdHasher extends FlatSpec with Matchers {
  val table = new String2UniqueIdTable()

  def makeName(i: Int): String = i + "#" + i

  for (i <- 0 until 100000) {
    table.addString(makeName(i))
  }

  "strings " should "be from 0 to 100000" in {
    for (i <- 0 until 100000) {
      val str = table.getStringById(i)
      val expectedStr = makeName(i).toString
      str should be(expectedStr)
    }
  }
}

@RunWith(classOf[JUnitRunner])
class String2IdHasherTester extends FlatSpec with Matchers {
  "some strings" should "not be added because of lack of space" in {
    val settings = StringIdHasherSettings(2, 50, 4)
    val h = new String2IdHasher(settings)
    val input = Array("a", "a", "b", "cd", "p", "q")
    for (word <- input) {
      val arr = word.toCharArray
      val index = h.getOrAddId(arr, 0, arr.length, true)
      if (word == "cd" || word == "p" || word == "q") {
        //cannot add
        index should be(-2)
      }
    }
  }
}