org.junit.runner.RunWith Scala Examples
The following examples show how to use org.junit.runner.RunWith.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: RegressITCase.scala From flink-tensorflow with Apache License 2.0 | 6 votes |
package org.apache.flink.contrib.tensorflow.ml import com.twitter.bijection.Conversion._ import org.apache.flink.api.common.functions.RichFlatMapFunction import org.apache.flink.api.scala._ import org.apache.flink.configuration.Configuration import org.apache.flink.contrib.tensorflow.ml.signatures.RegressionMethod._ import org.apache.flink.contrib.tensorflow.types.TensorInjections.{message2Tensor, messages2Tensor} import org.apache.flink.contrib.tensorflow.util.TestData._ import org.apache.flink.contrib.tensorflow.util.{FlinkTestBase, RegistrationUtils} import org.apache.flink.core.fs.Path import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.util.Collector import org.apache.flink.util.Preconditions.checkState import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpecLike} import org.tensorflow.Tensor import org.tensorflow.contrib.scala.Arrays._ import org.tensorflow.contrib.scala.Rank._ import org.tensorflow.contrib.scala._ import org.tensorflow.example.Example import resource._ @RunWith(classOf[JUnitRunner]) class RegressITCase extends WordSpecLike with Matchers with FlinkTestBase { override val parallelism = 1 type LabeledExample = (Example, Float) def examples(): Seq[LabeledExample] = { for (v <- Seq(0.0f -> 2.0f, 1.0f -> 2.5f, 2.0f -> 3.0f, 3.0f -> 3.5f)) yield (example("x" -> feature(v._1)), v._2) } "A RegressFunction" should { "process elements" in { val env = StreamExecutionEnvironment.getExecutionEnvironment RegistrationUtils.registerTypes(env.getConfig) val model = new HalfPlusTwo(new Path("../models/half_plus_two")) val outputs = env .fromCollection(examples()) .flatMap(new RichFlatMapFunction[LabeledExample, Float] { override def open(parameters: Configuration): Unit = model.open() override def close(): Unit = model.close() override def flatMap(value: (Example, Float), out: Collector[Float]): Unit = { for { x <- managed(Seq(value._1).toList.as[Tensor].taggedAs[ExampleTensor]) y <- model.regress_x_to_y(x) } { // cast as a 1D tensor to use the available conversion val o = y.taggedAs[TypedTensor[`1D`,Float]].as[Array[Float]] val actual = o(0) checkState(actual == value._2) out.collect(actual) } } }) .print() env.execute() } } }
Example 2
Source File: LoggedUserTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.serving.core.models import com.stratio.sparta.serving.core.models.dto.{LoggedUser, LoggedUserConstant} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class LoggedUserTest extends WordSpec with Matchers { val dummyGroupID = "66" "An input String" when { "containing a well-formed JSON" should { "be correctly transformed into a LoggedUser" in { val objectUser = LoggedUser("1234-qwerty", "user1", LoggedUserConstant.dummyMail, dummyGroupID, Seq.empty[String], Seq("admin")) val stringJson = """ {"id":"1234-qwerty", "attributes":[ {"cn":"user1"}, {"mail":"[email protected]"}, {"gidNumber":"66"}, {"groups":[]}, {"roles":["admin"]} ]}""" val parsedUser = LoggedUser.jsonToDto(stringJson) parsedUser shouldBe defined parsedUser.get should equal(objectUser) } } } "An input String" when { "has missing fields" should { "be correctly parsed " in { val stringSparta = """{"id":"sparta","attributes":[ |{"cn":"sparta"}, |{"mail":"[email protected]"}, |{"groups":["Developers"]}, |{"roles":[]}]}""".stripMargin val parsedUser = LoggedUser.jsonToDto(stringSparta) val objectUser = LoggedUser("sparta", "sparta", "[email protected]", "", Seq("Developers"), Seq.empty[String]) parsedUser shouldBe defined parsedUser.get should equal (objectUser) } } } "An input String" when { "is empty" should { "be transformed into None" in { val stringJson = "" val parsedUser = LoggedUser.jsonToDto(stringJson) parsedUser shouldBe None } } } "A user" when { "Oauth2 security is enabled" should { "be authorized only if one of its roles is contained inside allowedRoles" in { val objectUser = LoggedUser("1234-qwerty", "user1", LoggedUserConstant.dummyMail, dummyGroupID, Seq.empty[String], Seq("admin")) objectUser.isAuthorized(securityEnabled = true, allowedRoles = Seq("admin")) === true && objectUser.isAuthorized(securityEnabled = true, allowedRoles = Seq("OtherAdministratorRole", "dummyUser")) === false } } } "A user" when { "Oauth2 security is disabled" should { "always be authorized" in { val objectUser = LoggedUser("1234-qwerty", "user1", LoggedUserConstant.dummyMail, dummyGroupID, Seq.empty[String], Seq("admin")) objectUser.isAuthorized(securityEnabled = false, allowedRoles = LoggedUserConstant.allowedRoles) === true } } } }
Example 3
Source File: ErrorsModelTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.serving.core.models import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class ErrorsModelTest extends WordSpec with Matchers { val error = new ErrorModel("100", "Error 100", None, None) "ErrorModel" should { "toString method should return the number of the error and the error" in { val res = ErrorModel.toString(error) res should be ("""{"i18nCode":"100","message":"Error 100"}""") } "toError method should return the number of the error and the error" in { val res = ErrorModel.toErrorModel( """ |{ | "i18nCode": "100", | "message": "Error 100" |} """.stripMargin) res should be (error) } } }
Example 4
Source File: ServingExceptionTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.serving.core.exception import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class ServingExceptionTest extends WordSpec with Matchers { "A ServingException" should { "create an exception with message" in { ServingCoreException.create("message").getMessage should be("message") } "create an exception with message and a cause" in { val cause = new IllegalArgumentException("any exception") val exception = ServingCoreException.create("message", cause) exception.getMessage should be("message") exception.getCause should be theSameInstanceAs(cause) } } }
Example 5
Source File: SparkContextFactoryTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.driver.test.factory import com.stratio.sparta.driver.factory.SparkContextFactory import com.stratio.sparta.serving.core.config.SpartaConfig import com.stratio.sparta.serving.core.helpers.PolicyHelper import com.typesafe.config.ConfigFactory import org.apache.spark.streaming.Duration import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{BeforeAndAfterAll, FlatSpec, _} @RunWith(classOf[JUnitRunner]) class SparkContextFactoryTest extends FlatSpec with ShouldMatchers with BeforeAndAfterAll { self: FlatSpec => override def afterAll { SparkContextFactory.destroySparkContext() } trait WithConfig { val config = SpartaConfig.initConfig("sparta.local") val wrongConfig = ConfigFactory.empty val seconds = 6 val batchDuraction = Duration(seconds) val specificConfig = Map("spark.driver.allowMultipleContexts" -> "true") ++ PolicyHelper.getSparkConfFromProps(config.get) } "SparkContextFactorySpec" should "fails when properties is missing" in new WithConfig { an[Exception] should be thrownBy SparkContextFactory.sparkStandAloneContextInstance( Map.empty[String, String], Seq()) } it should "create and reuse same context" in new WithConfig { val sc = SparkContextFactory.sparkStandAloneContextInstance(specificConfig, Seq()) val otherSc = SparkContextFactory.sparkStandAloneContextInstance(specificConfig, Seq()) sc should be equals (otherSc) SparkContextFactory.destroySparkContext() } it should "create and reuse same SparkSession" in new WithConfig { val sc = SparkContextFactory.sparkStandAloneContextInstance(specificConfig, Seq()) val sqc = SparkContextFactory.sparkSessionInstance sqc shouldNot be equals (null) val otherSqc = SparkContextFactory.sparkSessionInstance sqc should be equals (otherSqc) SparkContextFactory.destroySparkContext() } it should "create and reuse same SparkStreamingContext" in new WithConfig { val checkpointDir = "checkpoint/SparkContextFactorySpec" val sc = SparkContextFactory.sparkStandAloneContextInstance(specificConfig, Seq()) val ssc = SparkContextFactory.sparkStreamingInstance(batchDuraction, checkpointDir, None) ssc shouldNot be equals (None) val otherSsc = SparkContextFactory.sparkStreamingInstance(batchDuraction, checkpointDir, None) ssc should be equals (otherSsc) } }
Example 6
Source File: CubeMakerTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.driver.test.cube import java.sql.Timestamp import com.github.nscala_time.time.Imports._ import com.stratio.sparta.driver.step.{Cube, CubeOperations, Trigger} import com.stratio.sparta.driver.writer.WriterOptions import com.stratio.sparta.plugin.default.DefaultField import com.stratio.sparta.plugin.cube.field.datetime.DateTimeField import com.stratio.sparta.plugin.cube.operator.count.CountOperator import com.stratio.sparta.sdk.pipeline.aggregation.cube.{Dimension, DimensionValue, DimensionValuesTime, InputFields} import com.stratio.sparta.sdk.pipeline.schema.TypeOp import com.stratio.sparta.sdk.utils.AggregationTime import org.apache.spark.sql.Row import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.streaming.TestSuiteBase import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class CubeMakerTest extends TestSuiteBase { val PreserverOrder = false def getEventOutput(timestamp: Timestamp, millis: Long): Seq[Seq[(DimensionValuesTime, InputFields)]] = { val dimensionString = Dimension("dim1", "eventKey", "identity", new DefaultField) val dimensionTime = Dimension("minute", "minute", "minute", new DateTimeField) val dimensionValueString1 = DimensionValue(dimensionString, "value1") val dimensionValueString2 = dimensionValueString1.copy(value = "value2") val dimensionValueString3 = dimensionValueString1.copy(value = "value3") val dimensionValueTs = DimensionValue(dimensionTime, timestamp) val tsMap = Row(timestamp) val valuesMap1 = InputFields(Row("value1", timestamp), 1) val valuesMap2 = InputFields(Row("value2", timestamp), 1) val valuesMap3 = InputFields(Row("value3", timestamp), 1) Seq(Seq( (DimensionValuesTime("cubeName", Seq(dimensionValueString1, dimensionValueTs)), valuesMap1), (DimensionValuesTime("cubeName", Seq(dimensionValueString2, dimensionValueTs)), valuesMap2), (DimensionValuesTime("cubeName", Seq(dimensionValueString3, dimensionValueTs)), valuesMap3) )) } }
Example 7
Source File: RawStageTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.driver.test.stage import akka.actor.ActorSystem import akka.testkit.TestKit import com.stratio.sparta.driver.stage.{LogError, RawDataStage} import com.stratio.sparta.sdk.pipeline.autoCalculations.AutoCalculatedField import com.stratio.sparta.sdk.properties.JsoneyString import com.stratio.sparta.serving.core.models.policy.writer.{AutoCalculatedFieldModel, WriterModel} import com.stratio.sparta.serving.core.models.policy.{PolicyModel, RawDataModel} import org.junit.runner.RunWith import org.mockito.Mockito.when import org.scalatest.junit.JUnitRunner import org.scalatest.mock.MockitoSugar import org.scalatest.{FlatSpecLike, ShouldMatchers} @RunWith(classOf[JUnitRunner]) class RawStageTest extends TestKit(ActorSystem("RawStageTest")) with FlatSpecLike with ShouldMatchers with MockitoSugar { case class TestRawData(policy: PolicyModel) extends RawDataStage with LogError def mockPolicy: PolicyModel = { val policy = mock[PolicyModel] when(policy.id).thenReturn(Some("id")) policy } "rawDataStage" should "Generate a raw data" in { val field = "field" val timeField = "time" val tableName = Some("table") val outputs = Seq("output") val partitionBy = Some("field") val autocalculateFields = Seq(AutoCalculatedFieldModel()) val configuration = Map.empty[String, JsoneyString] val policy = mockPolicy val rawData = mock[RawDataModel] val writerModel = mock[WriterModel] when(policy.rawData).thenReturn(Some(rawData)) when(rawData.dataField).thenReturn(field) when(rawData.timeField).thenReturn(timeField) when(rawData.writer).thenReturn(writerModel) when(writerModel.tableName).thenReturn(tableName) when(writerModel.outputs).thenReturn(outputs) when(writerModel.partitionBy).thenReturn(partitionBy) when(writerModel.autoCalculatedFields).thenReturn(autocalculateFields) when(rawData.configuration).thenReturn(configuration) val result = TestRawData(policy).rawDataStage() result.timeField should be(timeField) result.dataField should be(field) result.writerOptions.tableName should be(tableName) result.writerOptions.partitionBy should be(partitionBy) result.configuration should be(configuration) result.writerOptions.outputs should be(outputs) } "rawDataStage" should "Fail with bad table name" in { val field = "field" val timeField = "time" val tableName = None val outputs = Seq("output") val partitionBy = Some("field") val configuration = Map.empty[String, JsoneyString] val policy = mockPolicy val rawData = mock[RawDataModel] val writerModel = mock[WriterModel] when(policy.rawData).thenReturn(Some(rawData)) when(rawData.dataField).thenReturn(field) when(rawData.timeField).thenReturn(timeField) when(rawData.writer).thenReturn(writerModel) when(writerModel.tableName).thenReturn(tableName) when(writerModel.outputs).thenReturn(outputs) when(writerModel.partitionBy).thenReturn(partitionBy) when(rawData.configuration).thenReturn(configuration) the[IllegalArgumentException] thrownBy { TestRawData(policy).rawDataStage() } should have message "Something gone wrong saving the raw data. Please re-check the policy." } }
Example 8
Source File: DimensionTypeTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.sdk.pipeline.aggregation.cube import java.io.{Serializable => JSerializable} import com.stratio.sparta.sdk.pipeline.schema.TypeOp import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class DimensionTypeTest extends WordSpec with Matchers { val prop = Map("hello" -> "bye") "DimensionType" should { "the return operations properties" in { val dimensionTypeTest = new DimensionTypeMock(prop) val result = dimensionTypeTest.operationProps result should be(prop) } "the return properties" in { val dimensionTypeTest = new DimensionTypeMock(prop) val result = dimensionTypeTest.properties result should be(prop) } "the return precisionValue" in { val dimensionTypeTest = new DimensionTypeMock(prop) val expected = (DimensionType.getIdentity(None, dimensionTypeTest.defaultTypeOperation), "hello") val result = dimensionTypeTest.precisionValue("", "hello") result should be(expected) } "the return precision" in { val dimensionTypeTest = new DimensionTypeMock(prop) val expected = (DimensionType.getIdentity(None, dimensionTypeTest.defaultTypeOperation)) val result = dimensionTypeTest.precision("") result should be(expected) } } "DimensionType object" should { "getIdentity must be " in { val identity = DimensionType.getIdentity(None, TypeOp.Int) identity.typeOp should be(TypeOp.Int) identity.id should be(DimensionType.IdentityName) val identity2 = DimensionType.getIdentity(Some(TypeOp.String), TypeOp.Int) identity2.typeOp should be(TypeOp.String) } "getIdentityField must be " in { val identity = DimensionType.getIdentityField(None, TypeOp.Int) identity.typeOp should be(TypeOp.Int) identity.id should be(DimensionType.IdentityFieldName) val identity2 = DimensionType.getIdentityField(Some(TypeOp.String), TypeOp.Int) identity2.typeOp should be(TypeOp.String) } "getTimestamp must be " in { val identity = DimensionType.getTimestamp(None, TypeOp.Int) identity.typeOp should be(TypeOp.Int) identity.id should be(DimensionType.TimestampName) val identity2 = DimensionType.getTimestamp(Some(TypeOp.String), TypeOp.Int) identity2.typeOp should be(TypeOp.String) } } }
Example 9
Source File: DimensionTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.sdk.pipeline.aggregation.cube import org.junit.runner.RunWith import org.scalatest._ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class DimensionTest extends WordSpec with Matchers { "Dimension" should { val defaultDimensionType = new DimensionTypeMock(Map()) val dimension = Dimension("dim1", "eventKey", "identity", defaultDimensionType) val dimensionIdentity = Dimension("dim1", "identity", "identity", defaultDimensionType) val dimensionNotIdentity = Dimension("dim1", "key", "key", defaultDimensionType) "Return the associated identity precision name" in { val expected = "identity" val result = dimensionIdentity.getNamePrecision result should be(expected) } "Return the associated name precision name" in { val expected = "key" val result = dimensionNotIdentity.getNamePrecision result should be(expected) } "Return the associated precision name" in { val expected = "eventKey" val result = dimension.getNamePrecision result should be(expected) } "Compare function with other dimension must be less" in { val dimension2 = Dimension("dim2", "eventKey", "identity", defaultDimensionType) val expected = -1 val result = dimension.compare(dimension2) result should be(expected) } "Compare function with other dimension must be equal" in { val dimension2 = Dimension("dim1", "eventKey", "identity", defaultDimensionType) val expected = 0 val result = dimension.compare(dimension2) result should be(expected) } "Compare function with other dimension must be higher" in { val dimension2 = Dimension("dim0", "eventKey", "identity", defaultDimensionType) val expected = 1 val result = dimension.compare(dimension2) result should be(expected) } "classSuffix must be " in { val expected = "Field" val result = Dimension.FieldClassSuffix result should be(expected) } } }
Example 10
Source File: OutputTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.sdk.pipeline.output import com.stratio.sparta.sdk.pipeline.aggregation.cube.{Dimension, DimensionTypeMock, DimensionValue, DimensionValuesTime} import com.stratio.sparta.sdk.pipeline.transformation.OutputMock import org.apache.spark.sql.types._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class OutputTest extends WordSpec with Matchers { trait CommonValues { val timeDimension = "minute" val tableName = "table" val timestamp = 1L val defaultDimension = new DimensionTypeMock(Map()) val dimensionValuesT = DimensionValuesTime("testCube", Seq( DimensionValue(Dimension("dim1", "eventKey", "identity", defaultDimension), "value1"), DimensionValue(Dimension("dim2", "eventKey", "identity", defaultDimension), "value2"), DimensionValue(Dimension("minute", "eventKey", "identity", defaultDimension), 1L))) val dimensionValuesTFixed = DimensionValuesTime("testCube", Seq( DimensionValue(Dimension("dim1", "eventKey", "identity", defaultDimension), "value1"), DimensionValue(Dimension("minute", "eventKey", "identity", defaultDimension), 1L))) val outputName = "outputName" val output = new OutputMock(outputName, Map()) val outputOperation = new OutputMock(outputName, Map()) val outputProps = new OutputMock(outputName, Map()) } "Output" should { "Name must be " in new CommonValues { val expected = outputName val result = output.name result should be(expected) } "the spark geo field returned must be " in new CommonValues { val expected = StructField("field", ArrayType(DoubleType), false) val result = Output.defaultGeoField("field", false) result should be(expected) } "classSuffix must be " in { val expected = "Output" val result = Output.ClassSuffix result should be(expected) } } }
Example 11
Source File: InputTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.sdk.pipeline.input import org.apache.spark.storage.StorageLevel import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class InputTest extends WordSpec with Matchers { "Input" should { val input = new InputMock(Map()) val expected = StorageLevel.DISK_ONLY val result = input.storageLevel("DISK_ONLY") "Return the associated storageLevel" in { result should be(expected) } } "classSuffix must be " in { val expected = "Input" val result = Input.ClassSuffix result should be(expected) } }
Example 12
Source File: ParserTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.sdk.pipeline.transformation import java.io.{Serializable => JSerializable} import org.apache.spark.sql.Row import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class ParserTest extends WordSpec with Matchers { "Parser" should { val parserTest = new ParserMock( 1, Some("input"), Seq("output"), StructType(Seq(StructField("some", StringType))), Map() ) "Order must be " in { val expected = 1 val result = parserTest.getOrder result should be(expected) } "Parse must be " in { val event = Row("value") val expected = Seq(event) val result = parserTest.parse(event) result should be(expected) } "checked fields not be contained in outputs must be " in { val keyMap = Map("field" -> "value") val expected = Map() val result = parserTest.checkFields(keyMap) result should be(expected) } "checked fields are contained in outputs must be " in { val keyMap = Map("output" -> "value") val expected = keyMap val result = parserTest.checkFields(keyMap) result should be(expected) } "classSuffix must be " in { val expected = "Parser" val result = Parser.ClassSuffix result should be(expected) } } }
Example 13
Source File: TypeConversionsTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.sdk.pipeline.schema import com.stratio.sparta.sdk.pipeline.aggregation.cube.Precision import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class TypeConversionsTest extends WordSpec with Matchers { "TypeConversions" should { val typeConvesions = new TypeConversionsMock "typeOperation must be " in { val expected = TypeOp.Int val result = typeConvesions.defaultTypeOperation result should be(expected) } "operationProps must be " in { val expected = Map("typeOp" -> "string") val result = typeConvesions.operationProps result should be(expected) } "the operation type must be " in { val expected = Some(TypeOp.String) val result = typeConvesions.getTypeOperation result should be(expected) } "the detailed operation type must be " in { val expected = Some(TypeOp.String) val result = typeConvesions.getTypeOperation("string") result should be(expected) } "the precision type must be " in { val expected = Precision("precision", TypeOp.String, Map()) val result = typeConvesions.getPrecision("precision", Some(TypeOp.String)) result should be(expected) } } }
Example 14
Source File: JsoneyStringTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.sdk.properties import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.Serialization.write import org.json4s.{DefaultFormats, _} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpecLike} @RunWith(classOf[JUnitRunner]) class JsoneyStringTest extends WordSpecLike with Matchers { "A JsoneyString" should { "have toString equivalent to its internal string" in { assertResult("foo")(new JsoneyString("foo").toString) } "be deserialized if its JSON" in { implicit val json4sJacksonFormats = DefaultFormats + new JsoneyStringSerializer() val result = parse( """{ "foo": "bar" }""").extract[JsoneyString] assertResult(new JsoneyString( """{"foo":"bar"}"""))(result) } "be deserialized if it's a String" in { implicit val json4sJacksonFormats = DefaultFormats + new JsoneyStringSerializer() val result = parse("\"foo\"").extract[JsoneyString] assertResult(new JsoneyString("foo"))(result) } "be deserialized if it's an Int" in { implicit val json4sJacksonFormats = DefaultFormats + new JsoneyStringSerializer() val result = parse("1").extract[JsoneyString] assertResult(new JsoneyString("1"))(result) } "be serialized as JSON" in { implicit val json4sJacksonFormats = DefaultFormats + new JsoneyStringSerializer() var result = write(new JsoneyString("foo")) assertResult("\"foo\"")(result) result = write(new JsoneyString("{\"foo\":\"bar\"}")) assertResult("\"{\\\"foo\\\":\\\"bar\\\"}\"")(result) } "be deserialized if it's an JBool" in { implicit val json4sJacksonFormats = DefaultFormats + new JsoneyStringSerializer() val result = parse("true").extract[JsoneyString] assertResult(new JsoneyString("true"))(result) } "have toSeq equivalent to its internal string" in { assertResult(Seq("o"))(new JsoneyString("foo").toSeq) } } }
Example 15
Source File: SpartaClusterLauncherActorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.serving.api.helpers import com.stratio.sparta.serving.core.config.{SpartaConfigFactory, SpartaConfig} import com.typesafe.config.ConfigFactory import org.junit.runner.RunWith import org.scalamock.scalatest._ import org.scalatest._ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class SpartaClusterLauncherActorTest extends FlatSpec with MockFactory with ShouldMatchers with Matchers { it should "init SpartaConfig from a file with a configuration" in { val config = ConfigFactory.parseString( """ |sparta { | testKey : "testValue" |} """.stripMargin) val spartaConfig = SpartaConfig.initConfig(node = "sparta", configFactory = SpartaConfigFactory(config)) spartaConfig.get.getString("testKey") should be("testValue") } it should "init a config from a given config" in { val config = ConfigFactory.parseString( """ |sparta { | testNode { | testKey : "testValue" | } |} """.stripMargin) val spartaConfig = SpartaConfig.initConfig(node = "sparta", configFactory = SpartaConfigFactory(config)) val testNodeConfig = SpartaConfig.initConfig("testNode", spartaConfig, SpartaConfigFactory(config)) testNodeConfig.get.getString("testKey") should be("testValue") } }
Example 16
Source File: ControllerActorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.serving.api.actor import akka.actor.{ActorSystem, Props} import akka.testkit.{ImplicitSender, TestKit} import com.stratio.sparta.driver.service.StreamingContextService import com.stratio.sparta.serving.core.actor.{RequestActor, FragmentActor, StatusActor} import com.stratio.sparta.serving.core.config.SpartaConfig import com.stratio.sparta.serving.core.constants.AkkaConstant import org.apache.curator.framework.CuratorFramework import org.junit.runner.RunWith import org.scalamock.scalatest.MockFactory import org.scalatest._ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class ControllerActorTest(_system: ActorSystem) extends TestKit(_system) with ImplicitSender with WordSpecLike with Matchers with BeforeAndAfterAll with MockFactory { SpartaConfig.initMainConfig() SpartaConfig.initApiConfig() val curatorFramework = mock[CuratorFramework] val statusActor = _system.actorOf(Props(new StatusActor(curatorFramework))) val executionActor = _system.actorOf(Props(new RequestActor(curatorFramework))) val streamingContextService = new StreamingContextService(curatorFramework) val fragmentActor = _system.actorOf(Props(new FragmentActor(curatorFramework))) val policyActor = _system.actorOf(Props(new PolicyActor(curatorFramework, statusActor))) val sparkStreamingContextActor = _system.actorOf( Props(new LauncherActor(streamingContextService, curatorFramework))) val pluginActor = _system.actorOf(Props(new PluginActor())) val configActor = _system.actorOf(Props(new ConfigActor())) def this() = this(ActorSystem("ControllerActorSpec", SpartaConfig.daemonicAkkaConfig)) implicit val actors = Map( AkkaConstant.StatusActorName -> statusActor, AkkaConstant.FragmentActorName -> fragmentActor, AkkaConstant.PolicyActorName -> policyActor, AkkaConstant.LauncherActorName -> sparkStreamingContextActor, AkkaConstant.PluginActorName -> pluginActor, AkkaConstant.ExecutionActorName -> executionActor, AkkaConstant.ConfigActorName -> configActor ) override def afterAll { TestKit.shutdownActorSystem(system) } "ControllerActor" should { "set up the controller actor that contains all sparta's routes without any error" in { _system.actorOf(Props(new ControllerActor(actors, curatorFramework))) } } }
Example 17
Source File: DriverActorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.serving.api.actor import java.nio.file.{Files, Path} import akka.actor.{ActorSystem, Props} import akka.testkit.{DefaultTimeout, ImplicitSender, TestKit} import akka.util.Timeout import com.stratio.sparta.serving.api.actor.DriverActor.UploadDrivers import com.stratio.sparta.serving.core.config.{SpartaConfig, SpartaConfigFactory} import com.stratio.sparta.serving.core.models.SpartaSerializer import com.stratio.sparta.serving.core.models.files.{SpartaFile, SpartaFilesResponse} import com.typesafe.config.{Config, ConfigFactory} import org.junit.runner.RunWith import org.scalatest._ import org.scalatest.junit.JUnitRunner import org.scalatest.mock.MockitoSugar import spray.http.BodyPart import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.{Failure, Success} @RunWith(classOf[JUnitRunner]) class DriverActorTest extends TestKit(ActorSystem("PluginActorSpec")) with DefaultTimeout with ImplicitSender with WordSpecLike with Matchers with BeforeAndAfterAll with BeforeAndAfterEach with MockitoSugar with SpartaSerializer { val tempDir: Path = Files.createTempDirectory("test") tempDir.toFile.deleteOnExit() val localConfig: Config = ConfigFactory.parseString( s""" |sparta{ | api { | host = local | port= 7777 | } |} | |sparta.config.driverPackageLocation = "$tempDir" """.stripMargin) val fileList = Seq(BodyPart("reference.conf", "file")) override def beforeEach(): Unit = { SpartaConfig.initMainConfig(Option(localConfig), SpartaConfigFactory(localConfig)) SpartaConfig.initApiConfig() } override def afterAll: Unit = { shutdown() } override implicit val timeout: Timeout = Timeout(15 seconds) "DriverActor " must { "Not save files with wrong extension" in { val driverActor = system.actorOf(Props(new DriverActor())) driverActor ! UploadDrivers(fileList) expectMsgPF() { case SpartaFilesResponse(Success(f: Seq[SpartaFile])) => f.isEmpty shouldBe true } } "Not upload empty files" in { val driverActor = system.actorOf(Props(new DriverActor())) driverActor ! UploadDrivers(Seq.empty) expectMsgPF() { case SpartaFilesResponse(Failure(f)) => f.getMessage shouldBe "At least one file is expected" } } "Save a file" in { val driverActor = system.actorOf(Props(new DriverActor())) driverActor ! UploadDrivers(Seq(BodyPart("reference.conf", "file.jar"))) expectMsgPF() { case SpartaFilesResponse(Success(f: Seq[SpartaFile])) => f.head.fileName.endsWith("file.jar") shouldBe true } } } }
Example 18
Source File: PluginActorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.serving.api.actor import java.nio.file.{Files, Path} import akka.actor.{ActorSystem, Props} import akka.testkit.{DefaultTimeout, ImplicitSender, TestKit} import akka.util.Timeout import com.stratio.sparta.serving.api.actor.PluginActor.{PluginResponse, UploadPlugins} import com.stratio.sparta.serving.api.constants.HttpConstant import com.stratio.sparta.serving.core.config.{SpartaConfig, SpartaConfigFactory} import com.stratio.sparta.serving.core.models.SpartaSerializer import com.stratio.sparta.serving.core.models.files.{SpartaFile, SpartaFilesResponse} import com.typesafe.config.{Config, ConfigFactory} import org.junit.runner.RunWith import org.scalatest._ import org.scalatest.junit.JUnitRunner import org.scalatest.mock.MockitoSugar import spray.http.BodyPart import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.{Failure, Success} @RunWith(classOf[JUnitRunner]) class PluginActorTest extends TestKit(ActorSystem("PluginActorSpec")) with DefaultTimeout with ImplicitSender with WordSpecLike with Matchers with BeforeAndAfterAll with BeforeAndAfterEach with MockitoSugar with SpartaSerializer { val tempDir: Path = Files.createTempDirectory("test") tempDir.toFile.deleteOnExit() val localConfig: Config = ConfigFactory.parseString( s""" |sparta{ | api { | host = local | port= 7777 | } |} | |sparta.config.pluginPackageLocation = "$tempDir" """.stripMargin) val fileList = Seq(BodyPart("reference.conf", "file")) override def beforeEach(): Unit = { SpartaConfig.initMainConfig(Option(localConfig), SpartaConfigFactory(localConfig)) SpartaConfig.initApiConfig() } override def afterAll: Unit = { shutdown() } override implicit val timeout: Timeout = Timeout(15 seconds) "PluginActor " must { "Not save files with wrong extension" in { val pluginActor = system.actorOf(Props(new PluginActor())) pluginActor ! UploadPlugins(fileList) expectMsgPF() { case SpartaFilesResponse(Success(f: Seq[SpartaFile])) => f.isEmpty shouldBe true } } "Not upload empty files" in { val pluginActor = system.actorOf(Props(new PluginActor())) pluginActor ! UploadPlugins(Seq.empty) expectMsgPF() { case SpartaFilesResponse(Failure(f)) => f.getMessage shouldBe "At least one file is expected" } } "Save a file" in { val pluginActor = system.actorOf(Props(new PluginActor())) pluginActor ! UploadPlugins(Seq(BodyPart("reference.conf", "file.jar"))) expectMsgPF() { case SpartaFilesResponse(Success(f: Seq[SpartaFile])) => f.head.fileName.endsWith("file.jar") shouldBe true } } } }
Example 19
Source File: CustomExceptionHandlerTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.serving.api.service.handler import akka.actor.ActorSystem import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} import spray.http.StatusCodes import spray.httpx.Json4sJacksonSupport import spray.routing.{Directives, HttpService, StandardRoute} import spray.testkit.ScalatestRouteTest import com.stratio.sparta.sdk.exception.MockException import com.stratio.sparta.serving.api.service.handler.CustomExceptionHandler._ import com.stratio.sparta.serving.core.exception.ServingCoreException import com.stratio.sparta.serving.core.models.{ErrorModel, SpartaSerializer} @RunWith(classOf[JUnitRunner]) class CustomExceptionHandlerTest extends WordSpec with Directives with ScalatestRouteTest with Matchers with Json4sJacksonSupport with HttpService with SpartaSerializer { def actorRefFactory: ActorSystem = system trait MyTestRoute { val exception: Throwable val route: StandardRoute = complete(throw exception) } def route(throwable: Throwable): StandardRoute = complete(throw throwable) "CustomExceptionHandler" should { "encapsulate a unknow error in an error model and response with a 500 code" in new MyTestRoute { val exception = new MockException Get() ~> sealRoute(route) ~> check { status should be(StatusCodes.InternalServerError) response.entity.asString should be(ErrorModel.toString(new ErrorModel("666", "unknown"))) } } "encapsulate a serving api error in an error model and response with a 400 code" in new MyTestRoute { val exception = ServingCoreException.create(ErrorModel.toString(new ErrorModel("333", "testing exception"))) Get() ~> sealRoute(route) ~> check { status should be(StatusCodes.NotFound) response.entity.asString should be(ErrorModel.toString(new ErrorModel("333", "testing exception"))) } } } }
Example 20
Source File: ConfigHttpServiceTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.serving.api.service.http import akka.actor.ActorRef import akka.testkit.TestProbe import com.stratio.sparta.serving.api.actor.ConfigActor import com.stratio.sparta.serving.api.actor.ConfigActor._ import com.stratio.sparta.serving.api.constants.HttpConstant import com.stratio.sparta.serving.core.config.{SpartaConfig, SpartaConfigFactory} import com.stratio.sparta.serving.core.constants.{AkkaConstant, AppConstant} import com.stratio.sparta.serving.core.models.dto.LoggedUserConstant import com.stratio.sparta.serving.core.models.frontend.FrontendConfiguration import org.junit.runner.RunWith import org.scalatest.WordSpec import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class ConfigHttpServiceTest extends WordSpec with ConfigHttpService with HttpServiceBaseTest{ val configActorTestProbe = TestProbe() val dummyUser = Some(LoggedUserConstant.AnonymousUser) override implicit val actors: Map[String, ActorRef] = Map( AkkaConstant.ConfigActorName -> configActorTestProbe.ref ) override val supervisor: ActorRef = testProbe.ref override def beforeEach(): Unit = { SpartaConfig.initMainConfig(Option(localConfig), SpartaConfigFactory(localConfig)) } protected def retrieveStringConfig(): FrontendConfiguration = FrontendConfiguration(AppConstant.DefaultFrontEndTimeout, Option(AppConstant.DefaultOauth2CookieName)) "ConfigHttpService.FindAll" should { "retrieve a FrontendConfiguration item" in { startAutopilot(ConfigResponse(retrieveStringConfig())) Get(s"/${HttpConstant.ConfigPath}") ~> routes(dummyUser) ~> check { testProbe.expectMsgType[ConfigActor.FindAll.type] responseAs[FrontendConfiguration] should equal(retrieveStringConfig()) } } } }
Example 21
Source File: AppStatusHttpServiceTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.serving.api.service.http import akka.actor.ActorRef import com.stratio.sparta.serving.api.constants.HttpConstant import org.apache.curator.framework.CuratorFramework import org.junit.runner.RunWith import org.scalamock.scalatest.MockFactory import org.scalatest.WordSpec import org.scalatest.junit.JUnitRunner import spray.http.StatusCodes @RunWith(classOf[JUnitRunner]) class AppStatusHttpServiceTest extends WordSpec with AppStatusHttpService with HttpServiceBaseTest with MockFactory { override implicit val actors: Map[String, ActorRef] = Map() override val supervisor: ActorRef = testProbe.ref override val curatorInstance = mock[CuratorFramework] "AppStatusHttpService" should { "check the status of the server" in { Get(s"/${HttpConstant.AppStatus}") ~> routes() ~> check { status should be (StatusCodes.InternalServerError) } } } }
Example 22
Source File: PluginsHttpServiceTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.serving.api.service.http import akka.actor.ActorRef import akka.testkit.TestProbe import com.stratio.sparta.serving.api.actor.PluginActor.{PluginResponse, UploadPlugins} import com.stratio.sparta.serving.api.constants.HttpConstant import com.stratio.sparta.serving.core.config.{SpartaConfig, SpartaConfigFactory} import com.stratio.sparta.serving.core.models.dto.LoggedUserConstant import com.stratio.sparta.serving.core.models.files.{SpartaFile, SpartaFilesResponse} import org.junit.runner.RunWith import org.scalatest.WordSpec import org.scalatest.junit.JUnitRunner import spray.http._ import scala.util.{Failure, Success} @RunWith(classOf[JUnitRunner]) class PluginsHttpServiceTest extends WordSpec with PluginsHttpService with HttpServiceBaseTest { override val supervisor: ActorRef = testProbe.ref val pluginTestProbe = TestProbe() val dummyUser = Some(LoggedUserConstant.AnonymousUser) override implicit val actors: Map[String, ActorRef] = Map.empty override def beforeEach(): Unit = { SpartaConfig.initMainConfig(Option(localConfig), SpartaConfigFactory(localConfig)) } "PluginsHttpService.upload" should { "Upload a file" in { val response = SpartaFilesResponse(Success(Seq(SpartaFile("", "", "", "")))) startAutopilot(response) Put(s"/${HttpConstant.PluginsPath}") ~> routes(dummyUser) ~> check { testProbe.expectMsgType[UploadPlugins] status should be(StatusCodes.OK) } } "Fail when service is not available" in { val response = SpartaFilesResponse(Failure(new IllegalArgumentException("Error"))) startAutopilot(response) Put(s"/${HttpConstant.PluginsPath}") ~> routes(dummyUser) ~> check { testProbe.expectMsgType[UploadPlugins] status should be(StatusCodes.InternalServerError) } } } }
Example 23
Source File: FileSystemOutputIT.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.output.filesystem import java.io.File import com.stratio.sparta.plugin.TemporalSparkContext import com.stratio.sparta.plugin.output.fileSystem.FileSystemOutput import com.stratio.sparta.sdk.pipeline.output.{Output, OutputFormatEnum, SaveModeEnum} import org.apache.commons.io.FileUtils import org.apache.spark.sql._ import org.apache.spark.sql.types._ import org.junit.runner.RunWith import org.scalatest.Matchers import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class FileSystemOutputIT extends TemporalSparkContext with Matchers { val directory = getClass().getResource("/origin.txt") val parentFile = new File(directory.getPath).getParent val properties = Map(("path", parentFile + "/testRow"), ("outputFormat", "row")) val fields = StructType(StructField("name", StringType, false) :: StructField("age", IntegerType, false) :: StructField("year", IntegerType, true) :: Nil) val fsm = new FileSystemOutput("key", properties) "An object of type FileSystemOutput " should "have the same values as the properties Map" in { fsm.outputFormat should be(OutputFormatEnum.ROW) } private def dfGen(): DataFrame = { val sqlCtx = SparkSession.builder().config(sc.getConf).getOrCreate() val dataRDD = sc.parallelize(List(("user1", 23, 1993), ("user2", 26, 1990), ("user3", 21, 1995))) .map { case (name, age, year) => Row(name, age, year) } sqlCtx.createDataFrame(dataRDD, fields) } def fileExists(path: String): Boolean = new File(path).exists() "Given a DataFrame, a directory" should "be created with the data written inside" in { fsm.save(dfGen(), SaveModeEnum.Append, Map(Output.TableNameKey -> "test")) fileExists(fsm.path.get) should equal(true) } it should "exist with the given path and be deleted" in { if (fileExists(fsm.path.get)) FileUtils.deleteDirectory(new File(fsm.path.get)) fileExists(fsm.path.get) should equal(false) } val fsm2 = new FileSystemOutput("key", properties.updated("outputFormat", "json") .updated("path", parentFile + "/testJson")) "Given another DataFrame, a directory" should "be created with the data inside in JSON format" in { fsm2.outputFormat should be(OutputFormatEnum.JSON) fsm2.save(dfGen(), SaveModeEnum.Append, Map(Output.TableNameKey -> "test")) fileExists(fsm2.path.get) should equal(true) } it should "exist with the given path and be deleted" in { if (fileExists(s"${fsm2.path.get}/test")) FileUtils.deleteDirectory(new File(s"${fsm2.path.get}/test")) fileExists(s"${fsm2.path.get}/test") should equal(false) } }
Example 24
Source File: AvroOutputIT.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.output.avro import java.sql.Timestamp import java.time.Instant import com.databricks.spark.avro._ import com.stratio.sparta.plugin.TemporalSparkContext import com.stratio.sparta.sdk.pipeline.output.{Output, SaveModeEnum} import org.apache.spark.sql.types._ import org.apache.spark.sql.{Row, SparkSession} import org.junit.runner.RunWith import org.scalatest._ import org.scalatest.junit.JUnitRunner import scala.reflect.io.File import scala.util.Random @RunWith(classOf[JUnitRunner]) class AvroOutputIT extends TemporalSparkContext with Matchers { trait CommonValues { val tmpPath: String = File.makeTemp().name val sparkSession = SparkSession.builder().config(sc.getConf).getOrCreate() val schema = StructType(Seq( StructField("name", StringType), StructField("age", IntegerType), StructField("minute", LongType) )) val data = sparkSession.createDataFrame(sc.parallelize(Seq( Row("Kevin", Random.nextInt, Timestamp.from(Instant.now).getTime), Row("Kira", Random.nextInt, Timestamp.from(Instant.now).getTime), Row("Ariadne", Random.nextInt, Timestamp.from(Instant.now).getTime) )), schema) } trait WithEventData extends CommonValues { val properties = Map("path" -> tmpPath) val output = new AvroOutput("avro-test", properties) } "AvroOutput" should "throw an exception when path is not present" in { an[Exception] should be thrownBy new AvroOutput("avro-test", Map.empty) } it should "throw an exception when empty path " in { an[Exception] should be thrownBy new AvroOutput("avro-test", Map("path" -> " ")) } it should "save a dataframe " in new WithEventData { output.save(data, SaveModeEnum.Append, Map(Output.TableNameKey -> "person")) val read = sparkSession.read.avro(s"$tmpPath/person") read.count should be(3) read should be eq data File(tmpPath).deleteRecursively File("spark-warehouse").deleteRecursively } }
Example 25
Source File: HttpOutputTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.output.http import com.stratio.sparta.plugin.TemporalSparkContext import com.stratio.sparta.sdk.pipeline.output.OutputFormatEnum import org.apache.spark.sql._ import org.apache.spark.sql.types._ import org.junit.runner.RunWith import org.scalatest.Matchers import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class HttpOutputTest extends TemporalSparkContext with Matchers { val properties = Map( "url" -> "https://httpbin.org/post", "delimiter" -> ",", "parameterName" -> "thisIsAKeyName", "readTimeOut" -> "5000", "outputFormat" -> "ROW", "postType" -> "body", "connTimeout" -> "6000" ) val fields = StructType(StructField("name", StringType, false) :: StructField("age", IntegerType, false) :: StructField("year", IntegerType, true) :: Nil) val OkHTTPResponse = 200 "An object of type RestOutput " should "have the same values as the properties Map" in { val rest = new HttpOutput("key", properties) rest.outputFormat should be(OutputFormatEnum.ROW) rest.readTimeout should be(5000) } it should "throw a NoSuchElementException" in { val properties2 = properties.updated("postType", "vooooooody") a[NoSuchElementException] should be thrownBy { new HttpOutput("keyName", properties2) } } private def dfGen(): DataFrame = { val sqlCtx = SparkSession.builder().config(sc.getConf).getOrCreate() val dataRDD = sc.parallelize(List(("user1", 23, 1993), ("user2", 26, 1990))).map { case (name, age, year) => Row(name, age, year) } sqlCtx.createDataFrame(dataRDD, fields) } val restMock1 = new HttpOutput("key", properties) "Given a DataFrame it" should "be parsed and send through a Raw data POST request" in { dfGen().collect().foreach(row => { assertResult(OkHTTPResponse)(restMock1.sendData(row.mkString(restMock1.delimiter)).code) }) } it should "return the same amount of responses as rows in the DataFrame" in { val size = dfGen().collect().map(row => restMock1.sendData(row.mkString(restMock1.delimiter)).code).size assertResult(dfGen().count())(size) } val restMock2 = new HttpOutput("key", properties.updated("postType", "parameter")) it should "be parsed and send as a POST request along with a parameter stated by properties.parameterKey " in { dfGen().collect().foreach(row => { assertResult(OkHTTPResponse)(restMock2.sendData(row.mkString(restMock2.delimiter)).code) }) } val restMock3 = new HttpOutput("key", properties.updated("outputFormat", "JSON")) "Given a DataFrame it" should "be sent as JSON through a Raw data POST request" in { dfGen().toJSON.collect().foreach(row => { assertResult(OkHTTPResponse)(restMock3.sendData(row).code) }) } val restMock4 = new HttpOutput("key", properties.updated("postType", "parameter").updated("format", "JSON")) it should "sent as a POST request along with a parameter stated by properties.parameterKey " in { dfGen().toJSON.collect().foreach(row => { assertResult(OkHTTPResponse)(restMock4.sendData(row).code) }) } }
Example 26
Source File: CassandraOutputTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.output.cassandra import java.io.{Serializable => JSerializable} import com.datastax.spark.connector.cql.CassandraConnector import com.stratio.sparta.sdk._ import com.stratio.sparta.sdk.properties.JsoneyString import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.mock.MockitoSugar import org.scalatest.{FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class CassandraOutputTest extends FlatSpec with Matchers with MockitoSugar with AnswerSugar { val s = "sum" val properties = Map(("connectionHost", "127.0.0.1"), ("connectionPort", "9042")) "getSparkConfiguration" should "return a Seq with the configuration" in { val configuration = Map(("connectionHost", "127.0.0.1"), ("connectionPort", "9042")) val cass = CassandraOutput.getSparkConfiguration(configuration) cass should be(List(("spark.cassandra.connection.host", "127.0.0.1"), ("spark.cassandra.connection.port", "9042"))) } "getSparkConfiguration" should "return all cassandra-spark config" in { val config: Map[String, JSerializable] = Map( ("sparkProperties" -> JsoneyString( "[{\"sparkPropertyKey\":\"spark.cassandra.input.fetch.size_in_rows\",\"sparkPropertyValue\":\"2000\"}," + "{\"sparkPropertyKey\":\"spark.cassandra.input.split.size_in_mb\",\"sparkPropertyValue\":\"64\"}]")), ("anotherProperty" -> "true") ) val sparkConfig = CassandraOutput.getSparkConfiguration(config) sparkConfig.exists(_ == ("spark.cassandra.input.fetch.size_in_rows" -> "2000")) should be(true) sparkConfig.exists(_ == ("spark.cassandra.input.split.size_in_mb" -> "64")) should be(true) sparkConfig.exists(_ == ("anotherProperty" -> "true")) should be(false) } "getSparkConfiguration" should "not return cassandra-spark config" in { val config: Map[String, JSerializable] = Map( ("hadoopProperties" -> JsoneyString( "[{\"sparkPropertyKey\":\"spark.cassandra.input.fetch.size_in_rows\",\"sparkPropertyValue\":\"2000\"}," + "{\"sparkPropertyKey\":\"spark.cassandra.input.split.size_in_mb\",\"sparkPropertyValue\":\"64\"}]")), ("anotherProperty" -> "true") ) val sparkConfig = CassandraOutput.getSparkConfiguration(config) sparkConfig.exists(_ == ("spark.cassandra.input.fetch.size_in_rows" -> "2000")) should be(false) sparkConfig.exists(_ == ("spark.cassandra.input.split.size_in_mb" -> "64")) should be(false) sparkConfig.exists(_ == ("anotherProperty" -> "true")) should be(false) } }
Example 27
Source File: ElasticSearchOutputTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.output.elasticsearch import com.stratio.sparta.sdk.properties.JsoneyString import org.apache.spark.sql.types._ import org.junit.runner.RunWith import org.scalatest._ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class ElasticSearchOutputTest extends FlatSpec with ShouldMatchers { trait BaseValues { final val localPort = 9200 final val remotePort = 9300 val output = getInstance() val outputMultipleNodes = new ElasticSearchOutput("ES-out", Map("nodes" -> new JsoneyString( s"""[{"node":"host-a","tcpPort":"$remotePort","httpPort":"$localPort"},{"node":"host-b", |"tcpPort":"9301","httpPort":"9201"}]""".stripMargin), "dateType" -> "long")) def getInstance(host: String = "localhost", httpPort: Int = localPort, tcpPort: Int = remotePort) : ElasticSearchOutput = new ElasticSearchOutput("ES-out", Map("nodes" -> new JsoneyString( s"""[{"node":"$host","httpPort":"$httpPort","tcpPort":"$tcpPort"}]"""), "clusterName" -> "elasticsearch")) } trait NodeValues extends BaseValues { val ipOutput = getInstance("127.0.0.1", localPort, remotePort) val ipv6Output = getInstance("0:0:0:0:0:0:0:1", localPort, remotePort) val remoteOutput = getInstance("dummy", localPort, remotePort) } trait TestingValues extends BaseValues { val indexNameType = "spartatable/sparta" val tableName = "spartaTable" val baseFields = Seq(StructField("string", StringType), StructField("int", IntegerType)) val schema = StructType(baseFields) val extraFields = Seq(StructField("id", StringType, false), StructField("timestamp", LongType, false)) val properties = Map("nodes" -> new JsoneyString( """[{"node":"localhost","httpPort":"9200","tcpPort":"9300"}]""".stripMargin), "dateType" -> "long", "clusterName" -> "elasticsearch") override val output = new ElasticSearchOutput("ES-out", properties) val dateField = StructField("timestamp", TimestampType, false) val expectedDateField = StructField("timestamp", LongType, false) val stringField = StructField("string", StringType) val expectedStringField = StructField("string", StringType) } trait SchemaValues extends BaseValues { val fields = Seq( StructField("long", LongType), StructField("double", DoubleType), StructField("decimal", DecimalType(10, 0)), StructField("int", IntegerType), StructField("boolean", BooleanType), StructField("date", DateType), StructField("timestamp", TimestampType), StructField("array", ArrayType(StringType)), StructField("map", MapType(StringType, IntegerType)), StructField("string", StringType), StructField("binary", BinaryType)) val completeSchema = StructType(fields) } "ElasticSearchOutput" should "format properties" in new NodeValues with SchemaValues { output.httpNodes should be(Seq(("localhost", 9200))) outputMultipleNodes.httpNodes should be(Seq(("host-a", 9200), ("host-b", 9201))) output.clusterName should be("elasticsearch") } it should "parse correct index name type" in new TestingValues { output.indexNameType(tableName) should be(indexNameType) } it should "return a Seq of tuples (host,port) format" in new NodeValues { output.getHostPortConfs("nodes", "localhost", "9200", "node", "httpPort") should be(List(("localhost", 9200))) output.getHostPortConfs("nodes", "localhost", "9300", "node", "tcpPort") should be(List(("localhost", 9300))) outputMultipleNodes.getHostPortConfs("nodes", "localhost", "9200", "node", "httpPort") should be(List( ("host-a", 9200), ("host-b", 9201))) outputMultipleNodes.getHostPortConfs("nodes", "localhost", "9300", "node", "tcpPort") should be(List( ("host-a", 9300), ("host-b", 9301))) } }
Example 28
Source File: CsvOutputIT.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.output.csv import java.sql.Timestamp import java.time.Instant import com.databricks.spark.avro._ import com.stratio.sparta.plugin.TemporalSparkContext import com.stratio.sparta.sdk.pipeline.output.{Output, SaveModeEnum} import org.apache.spark.sql.types._ import org.apache.spark.sql.{Row, SparkSession} import org.junit.runner.RunWith import org.scalatest._ import org.scalatest.junit.JUnitRunner import scala.reflect.io.File import scala.util.Random @RunWith(classOf[JUnitRunner]) class CsvOutputIT extends TemporalSparkContext with Matchers { trait CommonValues { val tmpPath: String = File.makeTemp().name val sparkSession = SparkSession.builder().config(sc.getConf).getOrCreate() val schema = StructType(Seq( StructField("name", StringType), StructField("age", IntegerType), StructField("minute", LongType) )) val data = sparkSession.createDataFrame(sc.parallelize(Seq( Row("Kevin", Random.nextInt, Timestamp.from(Instant.now).getTime), Row("Kira", Random.nextInt, Timestamp.from(Instant.now).getTime), Row("Ariadne", Random.nextInt, Timestamp.from(Instant.now).getTime) )), schema) } trait WithEventData extends CommonValues { val properties = Map("path" -> tmpPath) val output = new CsvOutput("csv-test", properties) } "CsvOutput" should "throw an exception when path is not present" in { an[Exception] should be thrownBy new CsvOutput("csv-test", Map.empty) } it should "throw an exception when empty path " in { an[Exception] should be thrownBy new CsvOutput("csv-test", Map("path" -> " ")) } it should "save a dataframe " in new WithEventData { output.save(data, SaveModeEnum.Append, Map(Output.TableNameKey -> "person")) val read = sparkSession.read.csv(s"$tmpPath/person.csv") read.count should be(3) read should be eq data File(tmpPath).deleteRecursively File("spark-warehouse").deleteRecursively } }
Example 29
Source File: LastValueOperatorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.lastValue import java.util.Date import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class LastValueOperatorTest extends WordSpec with Matchers { "LastValue operator" should { val initSchema = StructType(Seq( StructField("field1", IntegerType, false), StructField("field2", IntegerType, false), StructField("field3", IntegerType, false) )) val initSchemaFail = StructType(Seq( StructField("field2", IntegerType, false) )) "processMap must be " in { val inputField = new LastValueOperator("lastValue", initSchema, Map()) inputField.processMap(Row(1, 2)) should be(None) val inputFields2 = new LastValueOperator("lastValue", initSchemaFail, Map("inputField" -> "field1")) inputFields2.processMap(Row(1, 2)) should be(None) val inputFields3 = new LastValueOperator("lastValue", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row(1, 2)) should be(Some(1)) val inputFields4 = new LastValueOperator("lastValue", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]")) inputFields4.processMap(Row(1, 2)) should be(Some(1L)) val inputFields5 = new LastValueOperator("lastValue", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]")) inputFields5.processMap(Row(1, 2)) should be(None) val inputFields6 = new LastValueOperator("lastValue", initSchema, Map("inputField" -> "field1", "filters" -> { "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," + "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]" })) inputFields6.processMap(Row(1, 2)) should be(None) } "processReduce must be " in { val inputFields = new LastValueOperator("lastValue", initSchema, Map()) inputFields.processReduce(Seq()) should be(None) val inputFields2 = new LastValueOperator("lastValue", initSchema, Map()) inputFields2.processReduce(Seq(Some(1), Some(2))) should be(Some(2)) val inputFields3 = new LastValueOperator("lastValue", initSchema, Map()) inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some("b")) } "associative process must be " in { val inputFields = new LastValueOperator("lastValue", initSchema, Map()) val resultInput = Seq((Operator.OldValuesKey, Some(1L)), (Operator.NewValuesKey, Some(1L)), (Operator.NewValuesKey, None)) inputFields.associativity(resultInput) should be(Some(1L)) val inputFields2 = new LastValueOperator("lastValue", initSchema, Map("typeOp" -> "int")) val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)), (Operator.NewValuesKey, Some(1L))) inputFields2.associativity(resultInput2) should be(Some(1)) val inputFields3 = new LastValueOperator("lastValue", initSchema, Map("typeOp" -> null)) val resultInput3 = Seq((Operator.OldValuesKey, Some(1)), (Operator.NewValuesKey, Some(2))) inputFields3.associativity(resultInput3) should be(Some(2)) val inputFields4 = new LastValueOperator("lastValue", initSchema, Map()) val resultInput4 = Seq() inputFields4.associativity(resultInput4) should be(None) val inputFields5 = new LastValueOperator("lastValue", initSchema, Map()) val date = new Date() val resultInput5 = Seq((Operator.NewValuesKey, Some(date))) inputFields5.associativity(resultInput5) should be(Some(date)) } } }
Example 30
Source File: StddevOperatorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.stddev import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class StddevOperatorTest extends WordSpec with Matchers { "Std dev operator" should { val initSchema = StructType(Seq( StructField("field1", IntegerType, false), StructField("field2", IntegerType, false), StructField("field3", IntegerType, false) )) val initSchemaFail = StructType(Seq( StructField("field2", IntegerType, false) )) "processMap must be " in { val inputField = new StddevOperator("stdev", initSchema, Map()) inputField.processMap(Row(1, 2)) should be(None) val inputFields2 = new StddevOperator("stdev", initSchemaFail, Map("inputField" -> "field1")) inputFields2.processMap(Row(1, 2)) should be(None) val inputFields3 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row(1, 2)) should be(Some(1)) val inputFields4 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row("1", 2)) should be(Some(1)) val inputFields6 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1")) inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5)) val inputFields7 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1")) inputFields7.processMap(Row(5L, 2)) should be(Some(5L)) val inputFields8 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]")) inputFields8.processMap(Row(1, 2)) should be(Some(1L)) val inputFields9 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]")) inputFields9.processMap(Row(1, 2)) should be(None) val inputFields10 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1", "filters" -> { "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," + "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]" })) inputFields10.processMap(Row(1, 2)) should be(None) } "processReduce must be " in { val inputFields = new StddevOperator("stdev", initSchema, Map()) inputFields.processReduce(Seq()) should be(Some(0d)) val inputFields2 = new StddevOperator("stdev", initSchema, Map()) inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be (Some(2.8284271247461903)) val inputFields3 = new StddevOperator("stdev", initSchema, Map()) inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be (Some(2.850438562747845)) val inputFields4 = new StddevOperator("stdev", initSchema, Map()) inputFields4.processReduce(Seq(None)) should be(Some(0d)) val inputFields5 = new StddevOperator("stdev", initSchema, Map("typeOp" -> "string")) inputFields5.processReduce( Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some("2.850438562747845")) } "processReduce distinct must be " in { val inputFields = new StddevOperator("stdev", initSchema, Map("distinct" -> "true")) inputFields.processReduce(Seq()) should be(Some(0d)) val inputFields2 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true")) inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be (Some(2.8284271247461903)) val inputFields3 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true")) inputFields3.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be (Some(2.850438562747845)) val inputFields4 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true")) inputFields4.processReduce(Seq(None)) should be(Some(0d)) val inputFields5 = new StddevOperator("stdev", initSchema, Map("typeOp" -> "string", "distinct" -> "true")) inputFields5.processReduce( Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some("2.850438562747845")) } } }
Example 31
Source File: MedianOperatorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.median import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class MedianOperatorTest extends WordSpec with Matchers { "Median operator" should { val initSchema = StructType(Seq( StructField("field1", IntegerType, false), StructField("field2", IntegerType, false), StructField("field3", IntegerType, false) )) val initSchemaFail = StructType(Seq( StructField("field2", IntegerType, false) )) "processMap must be " in { val inputField = new MedianOperator("median", initSchema, Map()) inputField.processMap(Row(1, 2)) should be(None) val inputFields2 = new MedianOperator("median", initSchemaFail, Map("inputField" -> "field1")) inputFields2.processMap(Row(1, 2)) should be(None) val inputFields3 = new MedianOperator("median", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row(1, 2)) should be(Some(1)) val inputFields4 = new MedianOperator("median", initSchema, Map("inputField" -> "field1")) inputFields4.processMap(Row("1", 2)) should be(Some(1)) val inputFields6 = new MedianOperator("median", initSchema, Map("inputField" -> "field1")) inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5)) val inputFields7 = new MedianOperator("median", initSchema, Map("inputField" -> "field1")) inputFields7.processMap(Row(5L, 2)) should be(Some(5L)) val inputFields8 = new MedianOperator("median", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]")) inputFields8.processMap(Row(1, 2)) should be(Some(1L)) val inputFields9 = new MedianOperator("median", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]")) inputFields9.processMap(Row(1, 2)) should be(None) val inputFields10 = new MedianOperator("median", initSchema, Map("inputField" -> "field1", "filters" -> { "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," + "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]" })) inputFields10.processMap(Row(1, 2)) should be(None) } "processReduce must be " in { val inputFields = new MedianOperator("median", initSchema, Map()) inputFields.processReduce(Seq()) should be(Some(0d)) val inputFields2 = new MedianOperator("median", initSchema, Map()) inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(3d)) val inputFields3 = new MedianOperator("median", initSchema, Map()) inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(3)) val inputFields4 = new MedianOperator("median", initSchema, Map()) inputFields4.processReduce(Seq(None)) should be(Some(0d)) val inputFields5 = new MedianOperator("median", initSchema, Map("typeOp" -> "string")) inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("3.0")) } "processReduce distinct must be " in { val inputFields = new MedianOperator("median", initSchema, Map("distinct" -> "true")) inputFields.processReduce(Seq()) should be(Some(0d)) val inputFields2 = new MedianOperator("median", initSchema, Map("distinct" -> "true")) inputFields2.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(2.5)) val inputFields3 = new MedianOperator("median", initSchema, Map("distinct" -> "true")) inputFields3.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(3)) val inputFields4 = new MedianOperator("median", initSchema, Map("distinct" -> "true")) inputFields4.processReduce(Seq(None)) should be(Some(0d)) val inputFields5 = new MedianOperator("median", initSchema, Map("typeOp" -> "string", "distinct" -> "true")) inputFields5.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("2.5")) } } }
Example 32
Source File: ModeOperatorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.mode import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class ModeOperatorTest extends WordSpec with Matchers { "Mode operator" should { val initSchema = StructType(Seq( StructField("field1", IntegerType, false), StructField("field2", IntegerType, false), StructField("field3", IntegerType, false) )) val initSchemaFail = StructType(Seq( StructField("field2", IntegerType, false) )) "processMap must be " in { val inputField = new ModeOperator("mode", initSchema, Map()) inputField.processMap(Row(1, 2)) should be(None) val inputFields2 = new ModeOperator("mode", initSchemaFail, Map("inputField" -> "field1")) inputFields2.processMap(Row(1, 2)) should be(None) val inputFields3 = new ModeOperator("mode", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row(1, 2)) should be(Some(1)) val inputFields4 = new ModeOperator("mode", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]")) inputFields4.processMap(Row(1, 2)) should be(Some(1L)) val inputFields5 = new ModeOperator("mode", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]")) inputFields5.processMap(Row(1, 2)) should be(None) val inputFields6 = new ModeOperator("mode", initSchema, Map("inputField" -> "field1", "filters" -> { "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," + "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]" })) inputFields6.processMap(Row(1, 2)) should be(None) } "processReduce must be " in { val inputFields = new ModeOperator("mode", initSchema, Map()) inputFields.processReduce(Seq()) should be(Some(List())) val inputFields2 = new ModeOperator("mode", initSchema, Map()) inputFields2.processReduce(Seq(Some("hey"), Some("hey"), Some("hi"))) should be(Some(List("hey"))) val inputFields3 = new ModeOperator("mode", initSchema, Map()) inputFields3.processReduce(Seq(Some("1"), Some("1"), Some("4"))) should be(Some(List("1"))) val inputFields4 = new ModeOperator("mode", initSchema, Map()) inputFields4.processReduce(Seq( Some("1"), Some("1"), Some("4"), Some("4"), Some("4"), Some("4"))) should be(Some(List("4"))) val inputFields5 = new ModeOperator("mode", initSchema, Map()) inputFields5.processReduce(Seq( Some("1"), Some("1"), Some("2"), Some("2"), Some("4"), Some("4"))) should be(Some(List("1", "2", "4"))) val inputFields6 = new ModeOperator("mode", initSchema, Map()) inputFields6.processReduce(Seq( Some("1"), Some("1"), Some("2"), Some("2"), Some("4"), Some("4"), Some("5")) ) should be(Some(List("1", "2", "4"))) } } }
Example 33
Source File: RangeOperatorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.range import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class RangeOperatorTest extends WordSpec with Matchers { "Range operator" should { val initSchema = StructType(Seq( StructField("field1", IntegerType, false), StructField("field2", IntegerType, false), StructField("field3", IntegerType, false) )) val initSchemaFail = StructType(Seq( StructField("field2", IntegerType, false) )) "processMap must be " in { val inputField = new RangeOperator("range", initSchema, Map()) inputField.processMap(Row(1, 2)) should be(None) val inputFields2 = new RangeOperator("range", initSchemaFail, Map("inputField" -> "field1")) inputFields2.processMap(Row(1, 2)) should be(None) val inputFields3 = new RangeOperator("range", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row(1, 2)) should be(Some(1)) val inputFields4 = new RangeOperator("range", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row("1", 2)) should be(Some(1)) val inputFields6 = new RangeOperator("range", initSchema, Map("inputField" -> "field1")) inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5)) val inputFields7 = new RangeOperator("range", initSchema, Map("inputField" -> "field1")) inputFields7.processMap(Row(5L, 2)) should be(Some(5L)) val inputFields8 = new RangeOperator("range", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]")) inputFields8.processMap(Row(1, 2)) should be(Some(1L)) val inputFields9 = new RangeOperator("range", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]")) inputFields9.processMap(Row(1, 2)) should be(None) val inputFields10 = new RangeOperator("range", initSchema, Map("inputField" -> "field1", "filters" -> { "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," + "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]" })) inputFields10.processMap(Row(1, 2)) should be(None) } "processReduce must be " in { val inputFields = new RangeOperator("range", initSchema, Map()) inputFields.processReduce(Seq()) should be(Some(0d)) val inputFields2 = new RangeOperator("range", initSchema, Map()) inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(0)) val inputFields3 = new RangeOperator("range", initSchema, Map()) inputFields3.processReduce(Seq(Some(1), Some(2), Some(4))) should be(Some(3)) val inputFields4 = new RangeOperator("range", initSchema, Map()) inputFields4.processReduce(Seq(None)) should be(Some(0d)) val inputFields5 = new RangeOperator("range", initSchema, Map("typeOp" -> "string")) inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("6.0")) } "processReduce distinct must be " in { val inputFields = new RangeOperator("range", initSchema, Map("distinct" -> "true")) inputFields.processReduce(Seq()) should be(Some(0d)) val inputFields2 = new RangeOperator("range", initSchema, Map("distinct" -> "true")) inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(0)) val inputFields3 = new RangeOperator("range", initSchema, Map("distinct" -> "true")) inputFields3.processReduce(Seq(Some(1), Some(2), Some(4))) should be(Some(3)) val inputFields4 = new RangeOperator("range", initSchema, Map("distinct" -> "true")) inputFields4.processReduce(Seq(None)) should be(Some(0d)) val inputFields5 = new RangeOperator("range", initSchema, Map("typeOp" -> "string", "distinct" -> "true")) inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("6.0")) } } }
Example 34
Source File: AccumulatorOperatorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.accumulator import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class AccumulatorOperatorTest extends WordSpec with Matchers { "Accumulator operator" should { val initSchema = StructType(Seq( StructField("field1", IntegerType, false), StructField("field2", IntegerType, false) )) val initSchemaFail = StructType(Seq( StructField("field2", IntegerType, false) )) "processMap must be " in { val inputField = new AccumulatorOperator("accumulator", initSchema, Map()) inputField.processMap(Row(1, 2)) should be(None) val inputFields2 = new AccumulatorOperator("accumulator", initSchemaFail, Map("inputField" -> "field1")) inputFields2.processMap(Row(1, 2)) should be(None) val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row(1, 2)) should be(Some(1)) val inputFields4 = new AccumulatorOperator("accumulator", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]")) inputFields4.processMap(Row(1, 2)) should be(Some(1L)) val inputFields5 = new AccumulatorOperator("accumulator", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":2}]")) inputFields5.processMap(Row(1, 2)) should be(None) val inputFields6 = new AccumulatorOperator("accumulator", initSchema, Map("inputField" -> "field1", "filters" -> { "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}," + "{\"field\":\"field2\", \"type\": \"<\", \"value\":2}]" })) inputFields6.processMap(Row(1, 2)) should be(None) } "processReduce must be " in { val inputFields = new AccumulatorOperator("accumulator", initSchema, Map()) inputFields.processReduce(Seq()) should be(Some(Seq())) val inputFields2 = new AccumulatorOperator("accumulator", initSchema, Map()) inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(Seq("1", "1"))) val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map()) inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some(Seq("a", "b"))) } "associative process must be " in { val inputFields = new AccumulatorOperator("accumulator", initSchema, Map()) val resultInput = Seq((Operator.OldValuesKey, Some(Seq(1L))), (Operator.NewValuesKey, Some(Seq(2L))), (Operator.NewValuesKey, None)) inputFields.associativity(resultInput) should be(Some(Seq("1", "2"))) val inputFields2 = new AccumulatorOperator("accumulator", initSchema, Map("typeOp" -> "arraydouble")) val resultInput2 = Seq((Operator.OldValuesKey, Some(Seq(1))), (Operator.NewValuesKey, Some(Seq(3)))) inputFields2.associativity(resultInput2) should be(Some(Seq(1d, 3d))) val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map("typeOp" -> null)) val resultInput3 = Seq((Operator.OldValuesKey, Some(Seq(1))), (Operator.NewValuesKey, Some(Seq(1)))) inputFields3.associativity(resultInput3) should be(Some(Seq("1", "1"))) } } }
Example 35
Source File: FirstValueOperatorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.firstValue import java.util.Date import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class FirstValueOperatorTest extends WordSpec with Matchers { "FirstValue operator" should { val initSchema = StructType(Seq( StructField("field1", IntegerType, false), StructField("field2", IntegerType, false), StructField("field3", IntegerType, false) )) val initSchemaFail = StructType(Seq( StructField("field2", IntegerType, false) )) "processMap must be " in { val inputField = new FirstValueOperator("firstValue", initSchema, Map()) inputField.processMap(Row(1, 2)) should be(None) val inputFields2 = new FirstValueOperator("firstValue", initSchemaFail, Map("inputField" -> "field1")) inputFields2.processMap(Row(1, 2)) should be(None) val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row(1, 2)) should be(Some(1)) val inputFields4 = new FirstValueOperator("firstValue", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]")) inputFields4.processMap(Row(1, 2)) should be(Some(1L)) val inputFields5 = new FirstValueOperator("firstValue", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]")) inputFields5.processMap(Row(1, 2)) should be(None) val inputFields6 = new FirstValueOperator("firstValue", initSchema, Map("inputField" -> "field1", "filters" -> { "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," + "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]" })) inputFields6.processMap(Row(1, 2)) should be(None) } "processReduce must be " in { val inputFields = new FirstValueOperator("firstValue", initSchema, Map()) inputFields.processReduce(Seq()) should be(None) val inputFields2 = new FirstValueOperator("firstValue", initSchema, Map()) inputFields2.processReduce(Seq(Some(1), Some(2))) should be(Some(1)) val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map()) inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some("a")) } "associative process must be " in { val inputFields = new FirstValueOperator("firstValue", initSchema, Map()) val resultInput = Seq((Operator.OldValuesKey, Some(1L)), (Operator.NewValuesKey, Some(1L)), (Operator.NewValuesKey, None)) inputFields.associativity(resultInput) should be(Some(1L)) val inputFields2 = new FirstValueOperator("firstValue", initSchema, Map("typeOp" -> "int")) val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)), (Operator.NewValuesKey, Some(1L))) inputFields2.associativity(resultInput2) should be(Some(1)) val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map("typeOp" -> null)) val resultInput3 = Seq((Operator.OldValuesKey, Some(1)), (Operator.NewValuesKey, Some(1)), (Operator.NewValuesKey, None)) inputFields3.associativity(resultInput3) should be(Some(1)) val inputFields4 = new FirstValueOperator("firstValue", initSchema, Map()) val resultInput4 = Seq() inputFields4.associativity(resultInput4) should be(None) val inputFields5 = new FirstValueOperator("firstValue", initSchema, Map()) val date = new Date() val resultInput5 = Seq((Operator.NewValuesKey, Some(date))) inputFields5.associativity(resultInput5) should be(Some(date)) } } }
Example 36
Source File: MeanAssociativeOperatorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.mean import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class MeanAssociativeOperatorTest extends WordSpec with Matchers { "Mean operator" should { val initSchema = StructType(Seq( StructField("field1", IntegerType, false), StructField("field2", IntegerType, false), StructField("field3", IntegerType, false) )) val initSchemaFail = StructType(Seq( StructField("field2", IntegerType, false) )) "processMap must be " in { val inputField = new MeanAssociativeOperator("avg", initSchema, Map()) inputField.processMap(Row(1, 2)) should be(None) val inputFields2 = new MeanAssociativeOperator("avg", initSchemaFail, Map("inputField" -> "field1")) inputFields2.processMap(Row(1, 2)) should be(None) val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row(1, 2)) should be(Some(1)) val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1")) inputFields4.processMap(Row("1", 2)) should be(Some(1)) val inputFields6 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1")) inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5)) val inputFields7 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1")) inputFields7.processMap(Row(5L, 2)) should be(Some(5L)) val inputFields8 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]")) inputFields8.processMap(Row(1, 2)) should be(Some(1L)) val inputFields9 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]")) inputFields9.processMap(Row(1, 2)) should be(None) val inputFields10 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1", "filters" -> { "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," + "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]" })) inputFields10.processMap(Row(1, 2)) should be(None) } "processReduce must be " in { val inputFields = new MeanAssociativeOperator("avg", initSchema, Map()) inputFields.processReduce(Seq()) should be(Some(List())) val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map()) inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be (Some(List(1.0, 1.0))) val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map()) inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), None)) should be(Some(List(1.0, 2.0, 3.0))) val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map()) inputFields4.processReduce(Seq(None)) should be(Some(List())) } "processReduce distinct must be " in { val inputFields = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true")) inputFields.processReduce(Seq()) should be(Some(List())) val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true")) inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(List(1.0))) val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true")) inputFields3.processReduce(Seq(Some(1), Some(3), Some(1), None)) should be(Some(List(1.0, 3.0))) val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true")) inputFields4.processReduce(Seq(None)) should be(Some(List())) } "associative process must be " in { val inputFields = new MeanAssociativeOperator("avg", initSchema, Map()) val resultInput = Seq((Operator.OldValuesKey, Some(Map("count" -> 1d, "sum" -> 2d, "mean" -> 2d))), (Operator.NewValuesKey, None)) inputFields.associativity(resultInput) should be(Some(Map("count" -> 1.0, "sum" -> 2.0, "mean" -> 2.0))) val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map()) val resultInput2 = Seq((Operator.OldValuesKey, Some(Map("count" -> 1d, "sum" -> 2d, "mean" -> 2d))), (Operator.NewValuesKey, Some(Seq(1d)))) inputFields2.associativity(resultInput2) should be(Some(Map("sum" -> 3.0, "count" -> 2.0, "mean" -> 1.5))) } } }
Example 37
Source File: MeanOperatorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.mean import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class MeanOperatorTest extends WordSpec with Matchers { "Mean operator" should { val initSchema = StructType(Seq( StructField("field1", IntegerType, false), StructField("field2", IntegerType, false), StructField("field3", IntegerType, false) )) val initSchemaFail = StructType(Seq( StructField("field2", IntegerType, false) )) "processMap must be " in { val inputField = new MeanOperator("avg", initSchema, Map()) inputField.processMap(Row(1, 2)) should be(None) val inputFields2 = new MeanOperator("avg", initSchemaFail, Map("inputField" -> "field1")) inputFields2.processMap(Row(1, 2)) should be(None) val inputFields3 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row(1, 2)) should be(Some(1)) val inputFields4 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1")) inputFields4.processMap(Row("1", 2)) should be(Some(1)) val inputFields6 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1")) inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5)) val inputFields7 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1")) inputFields7.processMap(Row(5L, 2)) should be(Some(5L)) val inputFields8 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]")) inputFields8.processMap(Row(1, 2)) should be(Some(1L)) val inputFields9 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]")) inputFields9.processMap(Row(1, 2)) should be(None) val inputFields10 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1", "filters" -> { "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," + "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]" })) inputFields10.processMap(Row(1, 2)) should be(None) } "processReduce must be " in { val inputFields = new MeanOperator("avg", initSchema, Map()) inputFields.processReduce(Seq()) should be(Some(0d)) val inputFields2 = new MeanOperator("avg", initSchema, Map()) inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(1)) val inputFields3 = new MeanOperator("avg", initSchema, Map()) inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), None)) should be(Some(2)) val inputFields4 = new MeanOperator("avg", initSchema, Map()) inputFields4.processReduce(Seq(None)) should be(Some(0d)) val inputFields5 = new MeanOperator("avg", initSchema, Map("typeOp" -> "string")) inputFields5.processReduce(Seq(Some(1), Some(1))) should be(Some("1.0")) } "processReduce distinct must be " in { val inputFields = new MeanOperator("avg", initSchema, Map("distinct" -> "true")) inputFields.processReduce(Seq()) should be(Some(0d)) val inputFields2 = new MeanOperator("avg", initSchema, Map("distinct" -> "true")) inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(1)) val inputFields3 = new MeanOperator("avg", initSchema, Map("distinct" -> "true")) inputFields3.processReduce(Seq(Some(1), Some(3), Some(1), None)) should be(Some(2)) val inputFields4 = new MeanOperator("avg", initSchema, Map("distinct" -> "true")) inputFields4.processReduce(Seq(None)) should be(Some(0d)) val inputFields5 = new MeanOperator("avg", initSchema, Map("typeOp" -> "string", "distinct" -> "true")) inputFields5.processReduce(Seq(Some(1), Some(1))) should be(Some("1.0")) } } }
Example 38
Source File: OperatorEntityCountTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.entityCount import java.io.{Serializable => JSerializable} import org.apache.spark.sql.Row import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class OperatorEntityCountTest extends WordSpec with Matchers { "EntityCount" should { val props = Map( "inputField" -> "inputField".asInstanceOf[JSerializable], "split" -> ",".asInstanceOf[JSerializable]) val schema = StructType(Seq(StructField("inputField", StringType))) val entityCount = new OperatorEntityCountMock("op1", schema, props) val inputFields = Row("hello,bye") "Return the associated precision name" in { val expected = Option(Seq("hello", "bye")) val result = entityCount.processMap(inputFields) result should be(expected) } "Return empty list" in { val expected = None val result = entityCount.processMap(Row()) result should be(expected) } } }
Example 39
Source File: EntityCountOperatorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.entityCount import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class EntityCountOperatorTest extends WordSpec with Matchers { "Entity Count Operator" should { val initSchema = StructType(Seq( StructField("field1", IntegerType, false), StructField("field2", IntegerType, false), StructField("field3", IntegerType, false) )) val initSchemaFail = StructType(Seq( StructField("field2", IntegerType, false) )) "processMap must be " in { val inputField = new EntityCountOperator("entityCount", initSchema, Map()) inputField.processMap(Row(1, 2)) should be(None) val inputFields2 = new EntityCountOperator("entityCount", initSchemaFail, Map("inputField" -> "field1")) inputFields2.processMap(Row(1, 2)) should be(None) val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo"))) val inputFields4 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> ",")) inputFields4.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo"))) val inputFields5 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> "-")) inputFields5.processMap(Row("hola-holo", 2)) should be(Some(Seq("hola", "holo"))) val inputFields6 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> ",")) inputFields6.processMap(Row("hola,holo adios", 2)) should be(Some(Seq("hola", "holo " + "adios"))) val inputFields7 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]")) inputFields7.processMap(Row("hola", 2)) should be(None) val inputFields8 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]", "split" -> " ")) inputFields8.processMap(Row("hola holo", 2)) should be(Some(Seq("hola", "holo"))) } "processReduce must be " in { val inputFields = new EntityCountOperator("entityCount", initSchema, Map()) inputFields.processReduce(Seq()) should be(Some(Seq())) val inputFields2 = new EntityCountOperator("entityCount", initSchema, Map()) inputFields2.processReduce(Seq(Some(Seq("hola", "holo")))) should be(Some(Seq("hola", "holo"))) val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map()) inputFields3.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(Seq("hola", "holo", "hola"))) } "associative process must be " in { val inputFields = new EntityCountOperator("entityCount", initSchema, Map()) val resultInput = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))), (Operator.NewValuesKey, None)) inputFields.associativity(resultInput) should be(Some(Map("hola" -> 1L, "holo" -> 1L))) val inputFields2 = new EntityCountOperator("entityCount", initSchema, Map("typeOp" -> "int")) val resultInput2 = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))), (Operator.NewValuesKey, Some(Seq("hola")))) inputFields2.associativity(resultInput2) should be(Some(Map())) val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map("typeOp" -> null)) val resultInput3 = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L)))) inputFields3.associativity(resultInput3) should be(Some(Map("hola" -> 1L, "holo" -> 1L))) } } }
Example 40
Source File: SumOperatorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.sum import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class SumOperatorTest extends WordSpec with Matchers { "Sum operator" should { val initSchema = StructType(Seq( StructField("field1", IntegerType, false), StructField("field2", IntegerType, false), StructField("field3", IntegerType, false) )) val initSchemaFail = StructType(Seq( StructField("field2", IntegerType, false) )) "processMap must be " in { val inputField = new SumOperator("sum", initSchema, Map()) inputField.processMap(Row(1, 2)) should be(None) val inputFields2 = new SumOperator("sum", initSchemaFail, Map("inputField" -> "field1")) inputFields2.processMap(Row(1, 2)) should be(None) val inputFields3 = new SumOperator("sum", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row(1, 2)) should be(Some(1)) val inputFields4 = new SumOperator("sum", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row("1", 2)) should be(Some(1)) val inputFields6 = new SumOperator("sum", initSchema, Map("inputField" -> "field1")) inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5)) val inputFields7 = new SumOperator("sum", initSchema, Map("inputField" -> "field1")) inputFields7.processMap(Row(5L, 2)) should be(Some(5L)) val inputFields8 = new SumOperator("sum", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]")) inputFields8.processMap(Row(1, 2)) should be(Some(1L)) val inputFields9 = new SumOperator("sum", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]")) inputFields9.processMap(Row(1, 2)) should be(None) val inputFields10 = new SumOperator("sum", initSchema, Map("inputField" -> "field1", "filters" -> { "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," + "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]" })) inputFields10.processMap(Row(1, 2)) should be(None) } "processReduce must be " in { val inputFields = new SumOperator("sum", initSchema, Map()) inputFields.processReduce(Seq()) should be(Some(0d)) val inputFields2 = new SumOperator("sum", initSchema, Map()) inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(20d)) val inputFields3 = new SumOperator("sum", initSchema, Map()) inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(20d)) val inputFields4 = new SumOperator("sum", initSchema, Map()) inputFields4.processReduce(Seq(None)) should be(Some(0d)) } "processReduce distinct must be " in { val inputFields = new SumOperator("sum", initSchema, Map("distinct" -> "true")) inputFields.processReduce(Seq()) should be(Some(0d)) val inputFields2 = new SumOperator("sum", initSchema, Map("distinct" -> "true")) inputFields2.processReduce(Seq(Some(1), Some(2), Some(1))) should be(Some(3d)) } "associative process must be " in { val inputFields = new SumOperator("count", initSchema, Map()) val resultInput = Seq((Operator.OldValuesKey, Some(1L)), (Operator.NewValuesKey, Some(1L)), (Operator.NewValuesKey, None)) inputFields.associativity(resultInput) should be(Some(2d)) val inputFields2 = new SumOperator("count", initSchema, Map("typeOp" -> "string")) val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)), (Operator.NewValuesKey, Some(1L)), (Operator.NewValuesKey, None)) inputFields2.associativity(resultInput2) should be(Some("2.0")) } } }
Example 41
Source File: FullTextOperatorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.fullText import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class FullTextOperatorTest extends WordSpec with Matchers { "FullText operator" should { val initSchema = StructType(Seq( StructField("field1", IntegerType, false), StructField("field2", IntegerType, false), StructField("field3", IntegerType, false) )) val initSchemaFail = StructType(Seq( StructField("field2", IntegerType, false) )) "processMap must be " in { val inputField = new FullTextOperator("fullText", initSchema, Map()) inputField.processMap(Row(1, 2)) should be(None) val inputFields2 = new FullTextOperator("fullText", initSchemaFail, Map("inputField" -> "field1")) inputFields2.processMap(Row(1, 2)) should be(None) val inputFields3 = new FullTextOperator("fullText", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row(1, 2)) should be(Some(1)) val inputFields4 = new FullTextOperator("fullText", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]")) inputFields4.processMap(Row(1, 2)) should be(Some(1L)) val inputFields5 = new FullTextOperator("fullText", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]")) inputFields5.processMap(Row(1, 2)) should be(None) val inputFields6 = new FullTextOperator("fullText", initSchema, Map("inputField" -> "field1", "filters" -> { "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," + "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]" })) inputFields6.processMap(Row(1, 2)) should be(None) } "processReduce must be " in { val inputFields = new FullTextOperator("fullText", initSchema, Map()) inputFields.processReduce(Seq()) should be(Some("")) val inputFields2 = new FullTextOperator("fullText", initSchema, Map()) inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(s"1${Operator.SpaceSeparator}1")) val inputFields3 = new FullTextOperator("fullText", initSchema, Map()) inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some(s"a${Operator.SpaceSeparator}b")) } "associative process must be " in { val inputFields = new FullTextOperator("fullText", initSchema, Map()) val resultInput = Seq((Operator.OldValuesKey, Some(2)), (Operator.NewValuesKey, None)) inputFields.associativity(resultInput) should be(Some("2")) val inputFields2 = new FullTextOperator("fullText", initSchema, Map("typeOp" -> "arraystring")) val resultInput2 = Seq((Operator.OldValuesKey, Some(2)), (Operator.NewValuesKey, Some(1))) inputFields2.associativity(resultInput2) should be(Some(Seq(s"2${Operator.SpaceSeparator}1"))) val inputFields3 = new FullTextOperator("fullText", initSchema, Map("typeOp" -> null)) val resultInput3 = Seq((Operator.OldValuesKey, Some(2)), (Operator.OldValuesKey, Some(3))) inputFields3.associativity(resultInput3) should be(Some(s"2${Operator.SpaceSeparator}3")) } } }
Example 42
Source File: TotalEntityCountOperatorTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.operator.totalEntityCount import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class TotalEntityCountOperatorTest extends WordSpec with Matchers { "Entity Count Operator" should { val initSchema = StructType(Seq( StructField("field1", IntegerType, false), StructField("field2", IntegerType, false), StructField("field3", IntegerType, false) )) val initSchemaFail = StructType(Seq( StructField("field2", IntegerType, false) )) "processMap must be " in { val inputField = new TotalEntityCountOperator("totalEntityCount", initSchema, Map()) inputField.processMap(Row(1, 2)) should be(None) val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchemaFail, Map("inputField" -> "field1")) inputFields2.processMap(Row(1, 2)) should be(None) val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1")) inputFields3.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo"))) val inputFields4 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> ",")) inputFields4.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo"))) val inputFields5 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> "-")) inputFields5.processMap(Row("hola-holo", 2)) should be(Some(Seq("hola", "holo"))) val inputFields6 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> ",")) inputFields6.processMap(Row("hola,holo adios", 2)) should be(Some(Seq("hola", "holo " + "adios"))) val inputFields7 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]")) inputFields7.processMap(Row("hola", 2)) should be(None) val inputFields8 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]", "split" -> " ")) inputFields8.processMap(Row("hola holo", 2)) should be (Some(Seq("hola", "holo"))) } "processReduce must be " in { val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map()) inputFields.processReduce(Seq()) should be(Some(0L)) val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map()) inputFields2.processReduce(Seq(Some(Seq("hola", "holo")))) should be(Some(2L)) val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map()) inputFields3.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(3L)) val inputFields4 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map()) inputFields4.processReduce(Seq(None)) should be(Some(0L)) } "processReduce distinct must be " in { val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("distinct" -> "true")) inputFields.processReduce(Seq()) should be(Some(0L)) val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("distinct" -> "true")) inputFields2.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(2L)) } "associative process must be " in { val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map()) val resultInput = Seq((Operator.OldValuesKey, Some(2)), (Operator.NewValuesKey, None)) inputFields.associativity(resultInput) should be(Some(2)) val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("typeOp" -> "int")) val resultInput2 = Seq((Operator.OldValuesKey, Some(2)), (Operator.NewValuesKey, Some(1))) inputFields2.associativity(resultInput2) should be(Some(3)) val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("typeOp" -> null)) val resultInput3 = Seq((Operator.OldValuesKey, Some(2))) inputFields3.associativity(resultInput3) should be(Some(2)) } } }
Example 43
Source File: HierarchyFieldTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.field.hierarchy import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.prop.TableDrivenPropertyChecks import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Matchers, WordSpecLike} @RunWith(classOf[JUnitRunner]) class HierarchyFieldTest extends WordSpecLike with Matchers with BeforeAndAfter with BeforeAndAfterAll with TableDrivenPropertyChecks { var hbs: Option[HierarchyField] = _ before { hbs = Some(new HierarchyField()) } after { hbs = None } "A HierarchyDimension" should { "In default implementation, get 4 precisions for all precision sizes" in { val precisionLeftToRight = hbs.get.precisionValue(HierarchyField.LeftToRightName, "") val precisionRightToLeft = hbs.get.precisionValue(HierarchyField.RightToLeftName, "") val precisionLeftToRightWithWildCard = hbs.get.precisionValue(HierarchyField.LeftToRightWithWildCardName, "") val precisionRightToLeftWithWildCard = hbs.get.precisionValue(HierarchyField.RightToLeftWithWildCardName, "") precisionLeftToRight._1.id should be(HierarchyField.LeftToRightName) precisionRightToLeft._1.id should be(HierarchyField.RightToLeftName) precisionLeftToRightWithWildCard._1.id should be(HierarchyField.LeftToRightWithWildCardName) precisionRightToLeftWithWildCard._1.id should be(HierarchyField.RightToLeftWithWildCardName) } "In default implementation, every proposed combination should be ok" in { val data = Table( ("i", "o"), ("google.com", Seq("google.com", "*.com", "*")) ) forAll(data) { (i: String, o: Seq[String]) => val result = hbs.get.precisionValue(HierarchyField.LeftToRightWithWildCardName, i) assertResult(o)(result._2) } } "In reverse implementation, every proposed combination should be ok" in { hbs = Some(new HierarchyField()) val data = Table( ("i", "o"), ("com.stratio.sparta", Seq("com.stratio.sparta", "com.stratio.*", "com.*", "*")) ) forAll(data) { (i: String, o: Seq[String]) => val result = hbs.get.precisionValue(HierarchyField.RightToLeftWithWildCardName, i.asInstanceOf[Any]) assertResult(o)(result._2) } } "In reverse implementation without wildcards, every proposed combination should be ok" in { hbs = Some(new HierarchyField()) val data = Table( ("i", "o"), ("com.stratio.sparta", Seq("com.stratio.sparta", "com.stratio", "com", "*")) ) forAll(data) { (i: String, o: Seq[String]) => val result = hbs.get.precisionValue(HierarchyField.RightToLeftName, i.asInstanceOf[Any]) assertResult(o)(result._2) } } "In non-reverse implementation without wildcards, every proposed combination should be ok" in { hbs = Some(new HierarchyField()) val data = Table( ("i", "o"), ("google.com", Seq("google.com", "com", "*")) ) forAll(data) { (i: String, o: Seq[String]) => val result = hbs.get.precisionValue(HierarchyField.LeftToRightName, i.asInstanceOf[Any]) assertResult(o)(result._2) } } } }
Example 44
Source File: DateTimeFieldTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.field.datetime import java.io.{Serializable => JSerializable} import java.util.Date import com.stratio.sparta.sdk.pipeline.schema.TypeOp import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpecLike} @RunWith(classOf[JUnitRunner]) class DateTimeFieldTest extends WordSpecLike with Matchers { val dateTimeDimension = new DateTimeField(Map("second" -> "long", "minute" -> "date", "typeOp" -> "datetime")) "A DateTimeDimension" should { "In default implementation, get 6 dimensions for a specific time" in { val newDate = new Date() val precision5s = dateTimeDimension.precisionValue("5s", newDate.asInstanceOf[JSerializable]) val precision10s = dateTimeDimension.precisionValue("10s", newDate.asInstanceOf[JSerializable]) val precision15s = dateTimeDimension.precisionValue("15s", newDate.asInstanceOf[JSerializable]) val precisionSecond = dateTimeDimension.precisionValue("second", newDate.asInstanceOf[JSerializable]) val precisionMinute = dateTimeDimension.precisionValue("minute", newDate.asInstanceOf[JSerializable]) val precisionHour = dateTimeDimension.precisionValue("hour", newDate.asInstanceOf[JSerializable]) val precisionDay = dateTimeDimension.precisionValue("day", newDate.asInstanceOf[JSerializable]) val precisionMonth = dateTimeDimension.precisionValue("month", newDate.asInstanceOf[JSerializable]) val precisionYear = dateTimeDimension.precisionValue("year", newDate.asInstanceOf[JSerializable]) precision5s._1.id should be("5s") precision10s._1.id should be("10s") precision15s._1.id should be("15s") precisionSecond._1.id should be("second") precisionMinute._1.id should be("minute") precisionHour._1.id should be("hour") precisionDay._1.id should be("day") precisionMonth._1.id should be("month") precisionYear._1.id should be("year") } "Each precision dimension have their output type, second must be long, minute must be date, others datetime" in { dateTimeDimension.precision("5s").typeOp should be(TypeOp.DateTime) dateTimeDimension.precision("10s").typeOp should be(TypeOp.DateTime) dateTimeDimension.precision("15s").typeOp should be(TypeOp.DateTime) dateTimeDimension.precision("second").typeOp should be(TypeOp.Long) dateTimeDimension.precision("minute").typeOp should be(TypeOp.Date) dateTimeDimension.precision("day").typeOp should be(TypeOp.DateTime) dateTimeDimension.precision("month").typeOp should be(TypeOp.DateTime) dateTimeDimension.precision("year").typeOp should be(TypeOp.DateTime) dateTimeDimension.precision(DateTimeField.TimestampPrecision.id).typeOp should be(TypeOp.Timestamp) } } }
Example 45
Source File: DefaultFieldTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.cube.field.defaultField import com.stratio.sparta.plugin.default.DefaultField import com.stratio.sparta.sdk.pipeline.aggregation.cube.{DimensionType, Precision} import com.stratio.sparta.sdk.pipeline.schema.TypeOp import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpecLike} @RunWith(classOf[JUnitRunner]) class DefaultFieldTest extends WordSpecLike with Matchers { val defaultDimension: DefaultField = new DefaultField(Map("typeOp" -> "int")) "A DefaultDimension" should { "In default implementation, get one precisions for a specific time" in { val precision: (Precision, Any) = defaultDimension.precisionValue("", "1".asInstanceOf[Any]) precision._2 should be(1) precision._1.id should be(DimensionType.IdentityName) } "The precision must be int" in { defaultDimension.precision(DimensionType.IdentityName).typeOp should be(TypeOp.Int) } } }
Example 46
Source File: SocketInputTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.input.socket import java.io.{Serializable => JSerializable} import org.junit.runner.RunWith import org.scalatest._ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class SocketInputTest extends WordSpec { "A SocketInput" should { "instantiate successfully with parameters" in { new SocketInput(Map("hostname" -> "localhost", "port" -> 9999).mapValues(_.asInstanceOf[JSerializable])) } "fail without parameters" in { intercept[IllegalStateException] { new SocketInput(Map()) } } "fail with bad port argument" in { intercept[IllegalStateException] { new SocketInput(Map("hostname" -> "localhost", "port" -> "BADPORT").mapValues(_.asInstanceOf[JSerializable])) } } } }
Example 47
Source File: TwitterJsonInputTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.input.twitter import java.io.{Serializable => JSerializable} import org.junit.runner.RunWith import org.scalatest._ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TwitterJsonInputTest extends WordSpec { "A TwitterInput" should { "fail without parameters" in { intercept[IllegalStateException] { new TwitterJsonInput(Map()) } } "fail with bad arguments argument" in { intercept[IllegalStateException] { new TwitterJsonInput(Map("hostname" -> "localhost", "port" -> "BADPORT") .mapValues(_.asInstanceOf[JSerializable])) } } } }
Example 48
Source File: RabbitMQInputIT.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.input.rabbitmq import java.util.UUID import akka.pattern.{ask, gracefulStop} import com.github.sstone.amqp.Amqp._ import com.github.sstone.amqp.{Amqp, ChannelOwner, ConnectionOwner, Consumer} import com.rabbitmq.client.ConnectionFactory import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import scala.concurrent.Await @RunWith(classOf[JUnitRunner]) class RabbitMQInputIT extends RabbitIntegrationSpec { val queueName = s"$configQueueName-${this.getClass.getName}-${UUID.randomUUID().toString}" def initRabbitMQ(): Unit = { val connFactory = new ConnectionFactory() connFactory.setUri(RabbitConnectionURI) val conn = system.actorOf(ConnectionOwner.props(connFactory, RabbitTimeOut)) val producer = ConnectionOwner.createChildActor( conn, ChannelOwner.props(), timeout = RabbitTimeOut, name = Some("RabbitMQ.producer") ) val queue = QueueParameters( name = queueName, passive = false, exclusive = false, durable = true, autodelete = false ) Amqp.waitForConnection(system, conn, producer).await() val deleteQueueResult = producer ? DeleteQueue(queueName) Await.result(deleteQueueResult, RabbitTimeOut) val createQueueResult = producer ? DeclareQueue(queue) Await.result(createQueueResult, RabbitTimeOut) //Send some messages to the queue val results = for (register <- 1 to totalRegisters) yield producer ? Publish( exchange = "", key = queueName, body = register.toString.getBytes ) results.map(result => Await.result(result, RabbitTimeOut)) conn ! Close() Await.result(gracefulStop(conn, RabbitTimeOut), RabbitTimeOut * 2) Await.result(gracefulStop(consumer, RabbitTimeOut), RabbitTimeOut * 2) } "RabbitMQInput " should { "Read all the records" in { val props = Map( "hosts" -> hosts, "queueName" -> queueName) val input = new RabbitMQInput(props) val distributedStream = input.initStream(ssc.get, DefaultStorageLevel) val totalEvents = ssc.get.sparkContext.accumulator(0L, "Number of events received") // Fires each time the configured window has passed. distributedStream.foreachRDD(rdd => { if (!rdd.isEmpty()) { val count = rdd.count() // Do something with this message log.info(s"EVENTS COUNT : $count") totalEvents.add(count) } else log.info("RDD is empty") log.info(s"TOTAL EVENTS : $totalEvents") }) ssc.get.start() // Start the computation ssc.get.awaitTerminationOrTimeout(SparkTimeOut) // Wait for the computation to terminate totalEvents.value should ===(totalRegisters.toLong) } } }
Example 49
Source File: MessageHandlerTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.input.rabbitmq.handler import com.rabbitmq.client.QueueingConsumer.Delivery import com.stratio.sparta.plugin.input.rabbitmq.handler.MessageHandler.{ByteArrayMessageHandler, StringMessageHandler} import org.junit.runner.RunWith import org.mockito.Mockito._ import org.scalatest.junit.JUnitRunner import org.scalatest.mock.MockitoSugar import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class MessageHandlerTest extends WordSpec with Matchers with MockitoSugar { val message = "This is the message for testing" "RabbitMQ MessageHandler Factory " should { "Get correct handler for string with a map " in { MessageHandler(Map(MessageHandler.KeyDeserializer -> "")) should matchPattern { case StringMessageHandler => } MessageHandler(Map.empty[String, String]) should matchPattern { case StringMessageHandler => } MessageHandler(Map(MessageHandler.KeyDeserializer -> "badInput")) should matchPattern { case StringMessageHandler => } MessageHandler(Map(MessageHandler.KeyDeserializer -> "arraybyte")) should matchPattern { case ByteArrayMessageHandler => } } "Get correct handler for string " in { val result = MessageHandler("string") result should matchPattern { case StringMessageHandler => } } "Get correct handler for arraybyte " in { val result = MessageHandler("arraybyte") result should matchPattern { case ByteArrayMessageHandler => } } "Get correct handler for empty input " in { val result = MessageHandler("") result should matchPattern { case StringMessageHandler => } } "Get correct handler for bad input " in { val result = MessageHandler("badInput") result should matchPattern { case StringMessageHandler => } } } "StringMessageHandler " should { "Handle strings" in { val delivery = mock[Delivery] when(delivery.getBody).thenReturn(message.getBytes) val result = StringMessageHandler.handler(delivery) verify(delivery, times(1)).getBody result.getString(0) shouldBe message } } "ByteArrayMessageHandler " should { "Handle bytes" in { val delivery = mock[Delivery] when(delivery.getBody).thenReturn(message.getBytes) val result = ByteArrayMessageHandler.handler(delivery) verify(delivery, times(1)).getBody result.getAs[Array[Byte]](0) shouldBe message.getBytes } } }
Example 50
Source File: HostPortZkTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.input.kafka import java.io.Serializable import com.stratio.sparta.sdk.properties.JsoneyString import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpec} @RunWith(classOf[JUnitRunner]) class HostPortZkTest extends WordSpec with Matchers { class KafkaTestInput(val properties: Map[String, Serializable]) extends KafkaBase "getHostPortZk" should { "return a chain (zookeper:conection , host:port)" in { val conn = """[{"host": "localhost", "port": "2181"}]""" val props = Map("zookeeper.connect" -> JsoneyString(conn), "zookeeper.path" -> "") val input = new KafkaTestInput(props) input.getHostPortZk("zookeeper.connect", "localhost", "2181") should be(Map("zookeeper.connect" -> "localhost:2181")) } "return a chain (zookeper:conection , host:port, zookeeper.path:path)" in { val conn = """[{"host": "localhost", "port": "2181"}]""" val props = Map("zookeeper.connect" -> JsoneyString(conn), "zookeeper.path" -> "/test") val input = new KafkaTestInput(props) input.getHostPortZk("zookeeper.connect", "localhost", "2181") should be(Map("zookeeper.connect" -> "localhost:2181/test")) } "return a chain (zookeper:conection , host:port,host:port,host:port)" in { val conn = """[{"host": "localhost", "port": "2181"},{"host": "localhost", "port": "2181"}, |{"host": "localhost", "port": "2181"}]""".stripMargin val props = Map("zookeeper.connect" -> JsoneyString(conn)) val input = new KafkaTestInput(props) input.getHostPortZk("zookeeper.connect", "localhost", "2181") should be(Map("zookeeper.connect" -> "localhost:2181,localhost:2181,localhost:2181")) } "return a chain (zookeper:conection , host:port,host:port,host:port, zookeeper.path:path)" in { val conn = """[{"host": "localhost", "port": "2181"},{"host": "localhost", "port": "2181"}, |{"host": "localhost", "port": "2181"}]""".stripMargin val props = Map("zookeeper.connect" -> JsoneyString(conn), "zookeeper.path" -> "/test") val input = new KafkaTestInput(props) input.getHostPortZk("zookeeper.connect", "localhost", "2181") should be(Map("zookeeper.connect" -> "localhost:2181,localhost:2181,localhost:2181/test")) } "return a chain with default port (zookeper:conection , host: defaultport)" in { val props = Map("foo" -> "var") val input = new KafkaTestInput(props) input.getHostPortZk("zookeeper.connect", "localhost", "2181") should be(Map("zookeeper.connect" -> "localhost:2181")) } "return a chain with default port (zookeper:conection , host: defaultport, zookeeper.path:path)" in { val props = Map("zookeeper.path" -> "/test") val input = new KafkaTestInput(props) input.getHostPortZk("zookeeper.connect", "localhost", "2181") should be(Map("zookeeper.connect" -> "localhost:2181/test")) } "return a chain with default host and default porty (zookeeper.connect: ," + "defaultHost: defaultport," + "zookeeper.path:path)" in { val props = Map("foo" -> "var") val input = new KafkaTestInput(props) input.getHostPortZk("zookeeper.connect", "localhost", "2181") should be(Map("zookeeper.connect" -> "localhost:2181")) } } }
Example 51
Source File: MorphlinesParserTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.transformation.morphline import java.io.Serializable import com.stratio.sparta.sdk.pipeline.input.Input import org.apache.spark.sql.Row import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Matchers, WordSpecLike} @RunWith(classOf[JUnitRunner]) class MorphlinesParserTest extends WordSpecLike with Matchers with BeforeAndAfter with BeforeAndAfterAll { val morphlineConfig = """ id : test1 importCommands : ["org.kitesdk.**"] commands: [ { readJson {}, } { extractJsonPaths { paths : { col1 : /col1 col2 : /col2 } } } { java { code : "return child.process(record);" } } { removeFields { blacklist:["literal:_attachment_body"] } } ] """ val inputField = Some(Input.RawDataKey) val outputsFields = Seq("col1", "col2") val props: Map[String, Serializable] = Map("morphline" -> morphlineConfig) val schema = StructType(Seq(StructField("col1", StringType), StructField("col2", StringType))) val parser = new MorphlinesParser(1, inputField, outputsFields, schema, props) "A MorphlinesParser" should { "parse a simple json" in { val simpleJson = """{ "col1":"hello", "col2":"word" } """ val input = Row(simpleJson) val result = parser.parse(input) val expected = Seq(Row(simpleJson, "hello", "world")) result should be eq(expected) } "parse a simple json removing raw" in { val simpleJson = """{ "col1":"hello", "col2":"word" } """ val input = Row(simpleJson) val result = parser.parse(input) val expected = Seq(Row("hello", "world")) result should be eq(expected) } "exclude not configured fields" in { val simpleJson = """{ "col1":"hello", "col2":"word", "col3":"!" } """ val input = Row(simpleJson) val result = parser.parse(input) val expected = Seq(Row(simpleJson, "hello", "world")) result should be eq(expected) } } }
Example 52
Source File: DateTimeParserTest.scala From sparta with Apache License 2.0 | 5 votes |
package com.stratio.sparta.plugin.transformation.datetime import com.stratio.sparta.sdk.properties.JsoneyString import org.apache.spark.sql.Row import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpecLike} @RunWith(classOf[JUnitRunner]) class DateTimeParserTest extends WordSpecLike with Matchers { val inputField = Some("ts") val outputsFields = Seq("ts") //scalastyle:off "A DateTimeParser" should { "parse unixMillis to string" in { val input = Row(1416330788000L) val schema = StructType(Seq(StructField("ts", StringType))) val result = new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unixMillis")) .parse(input) val expected = Seq(Row(1416330788000L, "1416330788000")) assertResult(result)(expected) } "parse unix to string" in { val input = Row(1416330788) val schema = StructType(Seq(StructField("ts", StringType))) val result = new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unix")) .parse(input) val expected = Seq(Row(1416330788, "1416330788000")) assertResult(result)(expected) } "parse unix to string removing raw" in { val input = Row(1416330788) val schema = StructType(Seq(StructField("ts", StringType))) val result = new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unix", "removeInputField" -> JsoneyString.apply("true"))) .parse(input) val expected = Seq(Row("1416330788000")) assertResult(result)(expected) } "not parse anything if the field does not match" in { val input = Row("1212") val schema = StructType(Seq(StructField("otherField", StringType))) an[IllegalStateException] should be thrownBy new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unixMillis")).parse(input) } "not parse anything and generate a new Date" in { val input = Row("anything") val schema = StructType(Seq(StructField("ts", StringType))) val result = new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "autoGenerated")) .parse(input) assertResult(result.head.size)(2) } "Auto generated if inputFormat does not exist" in { val input = Row("1416330788") val schema = StructType(Seq(StructField("ts", StringType))) val result = new DateTimeParser(1, inputField, outputsFields, schema, Map()).parse(input) assertResult(result.head.size)(2) } "parse dateTime in hive format" in { val input = Row("2015-11-08 15:58:58") val schema = StructType(Seq(StructField("ts", StringType))) val result = new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "hive")) .parse(input) val expected = Seq(Row("2015-11-08 15:58:58", "1446998338000")) assertResult(result)(expected) } } }
Example 53
Source File: LongInputTests.scala From boson with Apache License 2.0 | 5 votes |
package io.zink.boson import bsonLib.BsonObject import io.netty.util.ResourceLeakDetector import io.vertx.core.json.JsonObject import io.zink.boson.bson.bsonImpl.BosonImpl import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner import org.junit.Assert._ import scala.collection.mutable.ArrayBuffer import scala.concurrent.Await import scala.concurrent.duration.Duration import scala.io.Source @RunWith(classOf[JUnitRunner]) class LongInputTests extends FunSuite { ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.ADVANCED) val bufferedSource: Source = Source.fromURL(getClass.getResource("/jsonOutput.txt")) val finale: String = bufferedSource.getLines.toSeq.head bufferedSource.close val json: JsonObject = new JsonObject(finale) val bson: BsonObject = new BsonObject(json) test("extract top field") { val expression: String = ".Epoch" val boson: Boson = Boson.extractor(expression, (out: Int) => { assertTrue(3 == out) }) val res = boson.go(bson.encode.getBytes) Await.result(res, Duration.Inf) } test("extract bottom field") { val expression: String = "SSLNLastName" val expected: String = "de Huanuco" val boson: Boson = Boson.extractor(expression, (out: String) => { assertTrue(expected.zip(out).forall(e => e._1.equals(e._2))) }) val res = boson.go(bson.encode.getBytes) Await.result(res, Duration.Inf) } test("extract positions of an Array") { val expression: String = "Markets[3 to 5]" val mutableBuffer: ArrayBuffer[Array[Byte]] = ArrayBuffer() val boson: Boson = Boson.extractor(expression, (out: Array[Byte]) => { mutableBuffer += out }) val res = boson.go(bson.encode.getBytes) Await.result(res, Duration.Inf) assertEquals(3, mutableBuffer.size) } test("extract further positions of an Array") { val expression: String = "Markets[50 to 55]" val mutableBuffer: ArrayBuffer[Array[Byte]] = ArrayBuffer() val boson: Boson = Boson.extractor(expression, (out: Array[Byte]) => { mutableBuffer += out }) val res = boson.go(bson.encode.getBytes) Await.result(res, Duration.Inf) assertEquals(6, mutableBuffer.size) } test("size of all occurrences of Key") { val expression: String = "Price" val mutableBuffer: ArrayBuffer[Float] = ArrayBuffer() val boson: Boson = Boson.extractor(expression, (out: Float) => { mutableBuffer += out }) val res = boson.go(bson.encode.getBytes) Await.result(res, Duration.Inf) assertEquals(195, mutableBuffer.size) } }
Example 54
Source File: StorageTest.scala From mqttd with MIT License | 5 votes |
package plantae.citrus.mqtt.actors.session import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class StorageTest extends FunSuite { test("persist test") { val storage = Storage("persist-test") Range(1, 2000).foreach(count => { storage.persist((count + " persist").getBytes, (count % 3).toShort, true, "topic" + count) }) assert( !Range(1, 2000).exists(count => { storage.nextMessage match { case Some(message) => storage.complete(message.packetId match { case Some(x) => Some(x) case None => None }) println(new String(message.payload.toArray)) count + " persist" != new String(message.payload.toArray) case None => true } }) ) } }
Example 55
Source File: LogisticRegressionTest.scala From spark-cp with Apache License 2.0 | 5 votes |
package se.uu.farmbio.cp.alg import org.apache.spark.SharedSparkContext import org.junit.runner.RunWith import org.scalatest.FunSuite import se.uu.farmbio.cp.ICP import se.uu.farmbio.cp.TestUtils import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class LogisticRegressionTest extends FunSuite with SharedSparkContext { test("test performance") { val trainData = TestUtils.generateBinaryData(100, 11) val testData = TestUtils.generateBinaryData(30, 22) val (calibration, properTrain) = ICP.calibrationSplit(sc.parallelize(trainData), 16) val lr = new LogisticRegression(properTrain, 30) val model = ICP.trainClassifier(lr, numClasses=2, calibration) assert(TestUtils.testPerformance(model, sc.parallelize(testData))) } }
Example 56
Source File: SVMTest.scala From spark-cp with Apache License 2.0 | 5 votes |
package se.uu.farmbio.cp.alg import org.apache.spark.SharedSparkContext import org.junit.runner.RunWith import org.scalatest.FunSuite import se.uu.farmbio.cp.ICP import se.uu.farmbio.cp.TestUtils import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class SVMTest extends FunSuite with SharedSparkContext { test("test performance") { val trainData = TestUtils.generateBinaryData(100, 11) val testData = TestUtils.generateBinaryData(30, 22) val (calibration, properTrain) = ICP.calibrationSplit(sc.parallelize(trainData), 16) val svm = new SVM(properTrain, 30) val model = ICP.trainClassifier(svm, numClasses=2, calibration) assert(TestUtils.testPerformance(model, sc.parallelize(testData))) } }
Example 57
Source File: GBTTest.scala From spark-cp with Apache License 2.0 | 5 votes |
package se.uu.farmbio.cp.alg import scala.util.Random import org.apache.spark.SharedSparkContext import org.junit.runner.RunWith import org.scalatest.FunSuite import se.uu.farmbio.cp.ICP import se.uu.farmbio.cp.TestUtils import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class GBTTest extends FunSuite with SharedSparkContext { Random.setSeed(11) test("test performance") { val trainData = TestUtils.generateBinaryData(100, 11) val testData = TestUtils.generateBinaryData(30, 22) val (calibration, properTrain) = ICP.calibrationSplit(sc.parallelize(trainData), 16) val gbt = new GBT(properTrain, 30) val model = ICP.trainClassifier(gbt, numClasses=2, calibration) assert(TestUtils.testPerformance(model, sc.parallelize(testData))) } }
Example 58
Source File: KafkaTestUtilsTest.scala From spark-testing-base with Apache License 2.0 | 5 votes |
package com.holdenkarau.spark.testing.kafka import java.util.Properties import scala.collection.JavaConversions._ import kafka.consumer.ConsumerConfig import org.apache.spark.streaming.kafka.KafkaTestUtils import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{BeforeAndAfterAll, FunSuite} @RunWith(classOf[JUnitRunner]) class KafkaTestUtilsTest extends FunSuite with BeforeAndAfterAll { private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll(): Unit = { kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } override def afterAll(): Unit = if (kafkaTestUtils != null) { kafkaTestUtils.teardown() kafkaTestUtils = null } test("Kafka send and receive message") { val topic = "test-topic" val message = "HelloWorld!" kafkaTestUtils.createTopic(topic) kafkaTestUtils.sendMessages(topic, message.getBytes) val consumerProps = new Properties() consumerProps.put("zookeeper.connect", kafkaTestUtils.zkAddress) consumerProps.put("group.id", "test-group") consumerProps.put("flow-topic", topic) consumerProps.put("auto.offset.reset", "smallest") consumerProps.put("zookeeper.session.timeout.ms", "2000") consumerProps.put("zookeeper.connection.timeout.ms", "6000") consumerProps.put("zookeeper.sync.time.ms", "2000") consumerProps.put("auto.commit.interval.ms", "2000") val consumer = kafka.consumer.Consumer.createJavaConsumerConnector(new ConsumerConfig(consumerProps)) try { val topicCountMap = Map(topic -> new Integer(1)) val consumerMap = consumer.createMessageStreams(topicCountMap) val stream = consumerMap.get(topic).get(0) val it = stream.iterator() val mess = it.next assert(new String(mess.message().map(_.toChar)) === message) } finally { consumer.shutdown() } } }
Example 59
Source File: TestBase.scala From open-korean-text with Apache License 2.0 | 5 votes |
package org.openkoreantext.processor import java.util.logging.{Level, Logger} import org.junit.runner.RunWith import org.openkoreantext.processor.util.KoreanDictionaryProvider._ import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner object TestBase { case class ParseTime(time: Long, chunk: String) def time[R](block: => R): Long = { val t0 = System.currentTimeMillis() block val t1 = System.currentTimeMillis() t1 - t0 } def assertExamples(exampleFiles: String, log: Logger, f: (String => String)) { assert({ val input = readFileByLineFromResources(exampleFiles) val (parseTimes, hasErrors) = input.foldLeft((List[ParseTime](), true)) { case ((l: List[ParseTime], output: Boolean), line: String) => val s = line.split("\t") val (chunk, parse) = (s(0), if (s.length == 2) s(1) else "") val oldTokens = parse val t0 = System.currentTimeMillis() val newTokens = f(chunk) val t1 = System.currentTimeMillis() val oldParseMatches = oldTokens == newTokens if (!oldParseMatches) { System.err.println("Example set match error: %s \n - EXPECTED: %s\n - ACTUAL : %s".format( chunk, oldTokens, newTokens)) } (ParseTime(t1 - t0, chunk) :: l, output && oldParseMatches) } val averageTime = parseTimes.map(_.time).sum.toDouble / parseTimes.size val maxItem = parseTimes.maxBy(_.time) log.log(Level.INFO, ("Parsed %d chunks. \n" + " Total time: %d ms \n" + " Average time: %.2f ms \n" + " Max time: %d ms, %s").format( parseTimes.size, parseTimes.map(_.time).sum, averageTime, maxItem.time, maxItem.chunk )) hasErrors }, "Some parses did not match the example set.") } } @RunWith(classOf[JUnitRunner]) abstract class TestBase extends FunSuite
Example 60
Source File: StreamingFormulaDemo1.scala From sscheck with Apache License 2.0 | 5 votes |
package es.ucm.fdi.sscheck.spark.demo import org.junit.runner.RunWith import org.specs2.runner.JUnitRunner import org.specs2.ScalaCheck import org.specs2.Specification import org.specs2.matcher.ResultMatchers import org.scalacheck.Arbitrary.arbitrary import org.apache.spark.rdd.RDD import org.apache.spark.streaming.Duration import org.apache.spark.streaming.dstream.DStream import es.ucm.fdi.sscheck.spark.streaming.SharedStreamingContextBeforeAfterEach import es.ucm.fdi.sscheck.prop.tl.{Formula,DStreamTLProperty} import es.ucm.fdi.sscheck.prop.tl.Formula._ import es.ucm.fdi.sscheck.gen.{PDStreamGen,BatchGen} @RunWith(classOf[JUnitRunner]) class StreamingFormulaDemo1 extends Specification with DStreamTLProperty with ResultMatchers with ScalaCheck { // Spark configuration override def sparkMaster : String = "local[*]" override def batchDuration = Duration(150) override def defaultParallelism = 4 def is = sequential ^ s2""" Simple demo Specs2 example for ScalaCheck properties with temporal formulas on Spark Streaming programs - where a simple property for DStream.count is a success ${countForallAlwaysProp(_.count)} - where a faulty implementation of the DStream.count is detected ${countForallAlwaysProp(faultyCount) must beFailing} """ def faultyCount(ds : DStream[Double]) : DStream[Long] = ds.count.transform(_.map(_ - 1)) def countForallAlwaysProp(testSubject : DStream[Double] => DStream[Long]) = { type U = (RDD[Double], RDD[Long]) val (inBatch, transBatch) = ((_ : U)._1, (_ : U)._2) val numBatches = 10 val formula : Formula[U] = always { (u : U) => transBatch(u).count === 1 and inBatch(u).count === transBatch(u).first } during numBatches val gen = BatchGen.always(BatchGen.ofNtoM(10, 50, arbitrary[Double]), numBatches) forAllDStream( gen)( testSubject)( formula) }.set(minTestsOk = 10).verbose }
Example 61
Source File: StreamingFormulaDemo2.scala From sscheck with Apache License 2.0 | 5 votes |
package es.ucm.fdi.sscheck.spark.demo import org.junit.runner.RunWith import org.specs2.runner.JUnitRunner import org.specs2.ScalaCheck import org.specs2.Specification import org.specs2.matcher.ResultMatchers import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Gen import org.apache.spark.rdd.RDD import org.apache.spark.streaming.Duration import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.dstream.DStream._ import scalaz.syntax.std.boolean._ import es.ucm.fdi.sscheck.spark.streaming.SharedStreamingContextBeforeAfterEach import es.ucm.fdi.sscheck.prop.tl.{Formula,DStreamTLProperty} import es.ucm.fdi.sscheck.prop.tl.Formula._ import es.ucm.fdi.sscheck.gen.{PDStreamGen,BatchGen} import es.ucm.fdi.sscheck.gen.BatchGenConversions._ import es.ucm.fdi.sscheck.gen.PDStreamGenConversions._ import es.ucm.fdi.sscheck.matcher.specs2.RDDMatchers._ @RunWith(classOf[JUnitRunner]) class StreamingFormulaDemo2 extends Specification with DStreamTLProperty with ResultMatchers with ScalaCheck { // Spark configuration override def sparkMaster : String = "local[*]" override def batchDuration = Duration(300) override def defaultParallelism = 3 override def enableCheckpointing = true def is = sequential ^ s2""" Check process to persistently detect and ban bad users - where a stateful implementation extracts the banned users correctly ${checkExtractBannedUsersList(listBannedUsers)} - where a trivial implementation ${checkExtractBannedUsersList(statelessListBannedUsers) must beFailing} """ type UserId = Long def listBannedUsers(ds : DStream[(UserId, Boolean)]) : DStream[UserId] = ds.updateStateByKey((flags : Seq[Boolean], maybeFlagged : Option[Unit]) => maybeFlagged match { case Some(_) => maybeFlagged case None => flags.contains(false) option {()} } ).transform(_.keys) def statelessListBannedUsers(ds : DStream[(UserId, Boolean)]) : DStream[UserId] = ds.map(_._1) def checkExtractBannedUsersList(testSubject : DStream[(UserId, Boolean)] => DStream[UserId]) = { val batchSize = 20 val (headTimeout, tailTimeout, nestedTimeout) = (10, 10, 5) val (badId, ids) = (15L, Gen.choose(1L, 50L)) val goodBatch = BatchGen.ofN(batchSize, ids.map((_, true))) val badBatch = goodBatch + BatchGen.ofN(1, (badId, false)) val gen = BatchGen.until(goodBatch, badBatch, headTimeout) ++ BatchGen.always(Gen.oneOf(goodBatch, badBatch), tailTimeout) type U = (RDD[(UserId, Boolean)], RDD[UserId]) val (inBatch, outBatch) = ((_ : U)._1, (_ : U)._2) val formula = { val badInput = at(inBatch)(_ should existsRecord(_ == (badId, false))) val allGoodInputs = at(inBatch)(_ should foreachRecord(_._2 == true)) val noIdBanned = at(outBatch)(_.isEmpty) val badIdBanned = at(outBatch)(_ should existsRecord(_ == badId)) ( ( allGoodInputs and noIdBanned ) until badIdBanned on headTimeout ) and ( always { badInput ==> (always(badIdBanned) during nestedTimeout) } during tailTimeout ) } forAllDStream( gen)( testSubject)( formula) }.set(minTestsOk = 10).verbose }
Example 62
Source File: SimpleStreamingFormulas.scala From sscheck with Apache License 2.0 | 5 votes |
package es.ucm.fdi.sscheck.spark.simple import org.junit.runner.RunWith import org.specs2.runner.JUnitRunner import org.specs2.matcher.ResultMatchers import org.scalacheck.Arbitrary.arbitrary import org.apache.spark.rdd.RDD import org.apache.spark.streaming.Duration import org.apache.spark.streaming.dstream.DStream import es.ucm.fdi.sscheck.spark.streaming.SharedStreamingContextBeforeAfterEach import es.ucm.fdi.sscheck.prop.tl.{Formula,DStreamTLProperty} import es.ucm.fdi.sscheck.prop.tl.Formula._ import es.ucm.fdi.sscheck.matcher.specs2.RDDMatchers._ import es.ucm.fdi.sscheck.gen.{PDStreamGen,BatchGen} import org.scalacheck.Gen import es.ucm.fdi.sscheck.gen.PDStream import es.ucm.fdi.sscheck.gen.Batch @RunWith(classOf[JUnitRunner]) class SimpleStreamingFormulas extends org.specs2.Specification with DStreamTLProperty with org.specs2.ScalaCheck { // Spark configuration override def sparkMaster : String = "local[*]" override def batchDuration = Duration(50) override def defaultParallelism = 4 def is = sequential ^ s2""" Simple demo Specs2 example for ScalaCheck properties with temporal formulas on Spark Streaming programs - Given a stream of integers When we filter out negative numbers Then we get only numbers greater or equal to zero $filterOutNegativeGetGeqZero - where time increments for each batch $timeIncreasesMonotonically """ def filterOutNegativeGetGeqZero = { type U = (RDD[Int], RDD[Int]) val numBatches = 10 val gen = BatchGen.always(BatchGen.ofNtoM(10, 50, arbitrary[Int]), numBatches) val formula = always(nowTime[U]{ (letter, time) => val (_input, output) = letter output should foreachRecord {_ >= 0} }) during numBatches forAllDStream( gen)( _.filter{ x => !(x < 0)})( formula) }.set(minTestsOk = 50).verbose def timeIncreasesMonotonically = { type U = (RDD[Int], RDD[Int]) val numBatches = 10 val gen = BatchGen.always(BatchGen.ofNtoM(10, 50, arbitrary[Int])) val formula = always(nextTime[U]{ (letter, time) => nowTime[U]{ (nextLetter, nextTime) => time.millis <= nextTime.millis } }) during numBatches-1 forAllDStream( gen)( identity[DStream[Int]])( formula) }.set(minTestsOk = 10).verbose }
Example 63
Source File: SharedStreamingContextBeforeAfterEachTest.scala From sscheck with Apache License 2.0 | 5 votes |
package es.ucm.fdi.sscheck.spark.streaming import org.junit.runner.RunWith import org.specs2.runner.JUnitRunner import org.specs2.execute.Result import org.apache.spark.streaming.Duration import org.apache.spark.rdd.RDD import scala.collection.mutable.Queue import scala.concurrent.duration._ import org.slf4j.LoggerFactory import es.ucm.fdi.sscheck.matcher.specs2.RDDMatchers._ // sbt "test-only es.ucm.fdi.sscheck.spark.streaming.SharedStreamingContextBeforeAfterEachTest" @RunWith(classOf[JUnitRunner]) class SharedStreamingContextBeforeAfterEachTest extends org.specs2.Specification with org.specs2.matcher.MustThrownExpectations with org.specs2.matcher.ResultMatchers with SharedStreamingContextBeforeAfterEach { // cannot use private[this] due to https://issues.scala-lang.org/browse/SI-8087 @transient private val logger = LoggerFactory.getLogger("SharedStreamingContextBeforeAfterEachTest") // Spark configuration override def sparkMaster : String = "local[5]" override def batchDuration = Duration(250) override def defaultParallelism = 3 override def enableCheckpointing = false // as queueStream doesn't support checkpointing def is = sequential ^ s2""" Simple test for SharedStreamingContextBeforeAfterEach where a simple queueStream test must be successful $successfulSimpleQueueStreamTest where a simple queueStream test can also fail $failingSimpleQueueStreamTest """ def successfulSimpleQueueStreamTest = simpleQueueStreamTest(expectedCount = 0) def failingSimpleQueueStreamTest = simpleQueueStreamTest(expectedCount = 1) must beFailing def simpleQueueStreamTest(expectedCount : Int) : Result = { val record = "hola" val batches = Seq.fill(5)(Seq.fill(10)(record)) val queue = new Queue[RDD[String]] queue ++= batches.map(batch => sc.parallelize(batch, numSlices = defaultParallelism)) val inputDStream = ssc.queueStream(queue, oneAtATime = true) val sizesDStream = inputDStream.map(_.length) var batchCount = 0 // NOTE wrapping assertions with a Result object is needed // to avoid the Spark Streaming runtime capturing the exceptions // from failing assertions var result : Result = ok inputDStream.foreachRDD { rdd => batchCount += 1 println(s"completed batch number $batchCount: ${rdd.collect.mkString(",")}") result = result and { rdd.filter(_!= record).count() === expectedCount rdd should existsRecord(_ == "hola") } } sizesDStream.foreachRDD { rdd => result = result and { rdd should foreachRecord(record.length)(len => _ == len) } } // should only start the dstream after all the transformations and actions have been defined ssc.start() // wait for completion of batches.length batches StreamingContextUtils.awaitForNBatchesCompleted(batches.length, atMost = 10 seconds)(ssc) result } }
Example 64
Source File: ScalaCheckStreamingTest.scala From sscheck with Apache License 2.0 | 5 votes |
package es.ucm.fdi.sscheck.spark.streaming import org.junit.runner.RunWith import org.specs2.runner.JUnitRunner import org.specs2.ScalaCheck import org.specs2.execute.{AsResult, Result} import org.scalacheck.{Prop, Gen} import org.scalacheck.Arbitrary.arbitrary import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration} import org.apache.spark.streaming.dstream.DStream import es.ucm.fdi.sscheck.prop.tl.Formula._ import es.ucm.fdi.sscheck.prop.tl.DStreamTLProperty import es.ucm.fdi.sscheck.matcher.specs2.RDDMatchers._ @RunWith(classOf[JUnitRunner]) class ScalaCheckStreamingTest extends org.specs2.Specification with DStreamTLProperty with org.specs2.matcher.ResultMatchers with ScalaCheck { override def sparkMaster : String = "local[5]" override def batchDuration = Duration(350) override def defaultParallelism = 4 def is = sequential ^ s2""" Simple properties for Spark Streaming - where the first property is a success $prop1 - where a simple property for DStream.count is a success ${countProp(_.count)} - where a faulty implementation of the DStream.count is detected ${countProp(faultyCount) must beFailing} """ def prop1 = { val batchSize = 30 val numBatches = 10 val dsgenSeqSeq1 = { val zeroSeqSeq = Gen.listOfN(numBatches, Gen.listOfN(batchSize, 0)) val oneSeqSeq = Gen.listOfN(numBatches, Gen.listOfN(batchSize, 1)) Gen.oneOf(zeroSeqSeq, oneSeqSeq) } type U = (RDD[Int], RDD[Int]) forAllDStream[Int, Int]( "inputDStream" |: dsgenSeqSeq1)( (inputDs : DStream[Int]) => { val transformedDs = inputDs.map(_+1) transformedDs })(always ((u : U) => { val (inputBatch, transBatch) = u inputBatch.count === batchSize and inputBatch.count === transBatch.count and (inputBatch.intersection(transBatch).isEmpty should beTrue) and ( inputBatch should foreachRecord(_ == 0) or (inputBatch should foreachRecord(_ == 1)) ) }) during numBatches )}.set(minTestsOk = 10).verbose def faultyCount(ds : DStream[Double]) : DStream[Long] = ds.count.transform(_.map(_ - 1)) def countProp(testSubject : DStream[Double] => DStream[Long]) = { type U = (RDD[Double], RDD[Long]) val numBatches = 10 forAllDStream[Double, Long]( Gen.listOfN(numBatches, Gen.listOfN(30, arbitrary[Double])))( testSubject )(always ((u : U) => { val (inputBatch, transBatch) = u transBatch.count === 1 and inputBatch.count === transBatch.first }) during numBatches )}.set(minTestsOk = 10).verbose }
Example 65
Source File: ITSelectorSuite.scala From spark-infotheoretic-feature-selection with Apache License 2.0 | 5 votes |
package org.apache.spark.ml.feature import org.apache.spark.sql.{DataFrame, SQLContext} import org.junit.runner.RunWith import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.junit.JUnitRunner import TestHelper._ test("Run ITFS on nci data (nPart = 10, nfeat = 10)") { val df = readCSVData(sqlContext, "test_nci9_s3.csv") val cols = df.columns val pad = 2 val allVectorsDense = true val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head, 10, 10, allVectorsDense, pad) assertResult("443, 755, 1369, 1699, 3483, 5641, 6290, 7674, 9399, 9576") { model.selectedFeatures.mkString(", ") } } }
Example 66
Source File: instagram_api_yaml.scala From play-swagger with MIT License | 5 votes |
package instagram.api.yaml import de.zalando.play.controllers._ import org.scalacheck._ import org.scalacheck.Arbitrary._ import org.scalacheck.Prop._ import org.scalacheck.Test._ import org.specs2.mutable._ import play.api.test.Helpers._ import play.api.test._ import play.api.mvc.MultipartFormData.FilePart import play.api.mvc._ import org.junit.runner.RunWith import org.specs2.runner.JUnitRunner import java.net.URLEncoder import com.fasterxml.jackson.databind.ObjectMapper import play.api.http.Writeable import play.api.libs.Files.TemporaryFile import play.api.test.Helpers.{status => requestStatusCode_} import play.api.test.Helpers.{contentAsString => requestContentAsString_} import play.api.test.Helpers.{contentType => requestContentType_} import scala.math.BigInt import scala.math.BigDecimal import Generators._ @RunWith(classOf[JUnitRunner]) class Instagram_api_yamlSpec extends Specification { def toPath[T](value: T)(implicit binder: PathBindable[T]): String = Option(binder.unbind("", value)).getOrElse("") def toQuery[T](key: String, value: T)(implicit binder: QueryStringBindable[T]): String = Option(binder.unbind(key, value)).getOrElse("") def toHeader[T](value: T)(implicit binder: PathBindable[T]): String = Option(binder.unbind("", value)).getOrElse("") def checkResult(props: Prop) = Test.check(Test.Parameters.default, props).status match { case Failed(args, labels) => val failureMsg = labels.mkString("\n") + " given args: " + args.map(_.arg).mkString("'", "', '","'") failure(failureMsg) case Proved(_) | Exhausted | Passed => success case PropException(_, e, labels) => val error = if (labels.isEmpty) e.getLocalizedMessage() else labels.mkString("\n") failure(error) } private def parserConstructor(mimeType: String) = PlayBodyParsing.jacksonMapper(mimeType) def parseResponseContent[T](mapper: ObjectMapper, content: String, mimeType: Option[String], expectedType: Class[T]) = mapper.readValue(content, expectedType) }
Example 67
Source File: security_api_yaml.scala From play-swagger with MIT License | 5 votes |
package security.api.yaml import de.zalando.play.controllers._ import org.scalacheck._ import org.scalacheck.Arbitrary._ import org.scalacheck.Prop._ import org.scalacheck.Test._ import org.specs2.mutable._ import play.api.test.Helpers._ import play.api.test._ import play.api.mvc.MultipartFormData.FilePart import play.api.mvc._ import org.junit.runner.RunWith import org.specs2.runner.JUnitRunner import java.net.URLEncoder import com.fasterxml.jackson.databind.ObjectMapper import play.api.http.Writeable import play.api.libs.Files.TemporaryFile import play.api.test.Helpers.{status => requestStatusCode_} import play.api.test.Helpers.{contentAsString => requestContentAsString_} import play.api.test.Helpers.{contentType => requestContentType_} import de.zalando.play.controllers.ArrayWrapper import Generators._ @RunWith(classOf[JUnitRunner]) class Security_api_yamlSpec extends Specification { def toPath[T](value: T)(implicit binder: PathBindable[T]): String = Option(binder.unbind("", value)).getOrElse("") def toQuery[T](key: String, value: T)(implicit binder: QueryStringBindable[T]): String = Option(binder.unbind(key, value)).getOrElse("") def toHeader[T](value: T)(implicit binder: PathBindable[T]): String = Option(binder.unbind("", value)).getOrElse("") def checkResult(props: Prop) = Test.check(Test.Parameters.default, props).status match { case Failed(args, labels) => val failureMsg = labels.mkString("\n") + " given args: " + args.map(_.arg).mkString("'", "', '","'") failure(failureMsg) case Proved(_) | Exhausted | Passed => success case PropException(_, e, labels) => val error = if (labels.isEmpty) e.getLocalizedMessage() else labels.mkString("\n") failure(error) } private def parserConstructor(mimeType: String) = PlayBodyParsing.jacksonMapper(mimeType) def parseResponseContent[T](mapper: ObjectMapper, content: String, mimeType: Option[String], expectedType: Class[T]) = mapper.readValue(content, expectedType) }
Example 68
Source File: Downloader$Test.scala From mystem-scala with MIT License | 5 votes |
package ru.stachek66.tools import java.io.File import java.net.URL import org.junit.runner.RunWith import org.scalatest.{Ignore, FunSuite} import org.scalatest.junit.JUnitRunner @Ignore class Downloader$Test extends FunSuite { test("downloading-something") { val hello = new File("hello-test.html") val mystem = new File("atmta.binary") Downloader.downloadBinaryFile(new URL("http://www.stachek66.ru/"), hello) Downloader.downloadBinaryFile( new URL("http://download.cdn.yandex.net/mystem/mystem-3.0-linux3.1-64bit.tar.gz"), mystem ) Downloader.downloadBinaryFile( new URL("http://download.cdn.yandex.net/mystem/mystem-3.1-win-64bit.zip"), mystem ) hello.delete mystem.delete } test("download-and-unpack") { val bin = new File("atmta.binary.tar.gz") val bin2 = new File("executable") Decompressor.select.unpack( Downloader.downloadBinaryFile( new URL("http://download.cdn.yandex.net/mystem/mystem-3.0-linux3.1-64bit.tar.gz"), bin), bin2 ) bin.delete bin2.delete } }
Example 69
Source File: Zip$Test.scala From mystem-scala with MIT License | 5 votes |
package ru.stachek66.tools import java.io.{File, FileInputStream} import org.apache.commons.io.IOUtils import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner import org.junit.runner.RunWith class Zip$Test extends FunSuite { test("zip-test") { val src = new File("src/test/resources/test.txt") Zip.unpack( new File("src/test/resources/test.zip"), new File("src/test/resources/res.txt")) match { case f => val content0 = IOUtils.toString(new FileInputStream(f)) val content1 = IOUtils.toString(new FileInputStream(src)) print(content0.trim + " vs " + content1.trim) assert(content0 === content1) } } }
Example 70
Source File: TarGz$Test.scala From mystem-scala with MIT License | 5 votes |
package ru.stachek66.tools import java.io.{File, FileInputStream} import org.apache.commons.io.IOUtils import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner class TarGz$Test extends FunSuite { test("tgz-test") { val src = new File("src/test/resources/test.txt") TarGz.unpack( new File("src/test/resources/test.tar.gz"), new File("src/test/resources/res.txt")) match { case f => val content0 = IOUtils.toString(new FileInputStream(f)) val content1 = IOUtils.toString(new FileInputStream(src)) print(content0.trim + " vs " + content1.trim) assert(content0 === content1) } } }
Example 71
Source File: DaoServiceTest.scala From Scala-Design-Patterns-Second-Edition with MIT License | 5 votes |
package com.ivan.nikolov.scheduler.dao import com.ivan.nikolov.scheduler.TestEnvironment import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class DaoServiceTest extends FlatSpec with Matchers with BeforeAndAfter with TestEnvironment { override val databaseService = new H2DatabaseService override val migrationService = new MigrationService override val daoService = new DaoServiceImpl before { // we run this here. Generally migrations will only // be dealing with data layout and we will be able to have // test classes that insert test data. migrationService.runMigrations() } after { migrationService.cleanupDatabase() } "readResultSet" should "properly iterate over a result set and apply a function to it." in { val connection = daoService.getConnection() try { val result = daoService.executeSelect( connection.prepareStatement( "SELECT name FROM people" ) ) { case rs => daoService.readResultSet(rs) { case row => row.getString("name") } } result should have size(3) result should contain("Ivan") result should contain("Maria") result should contain("John") } finally { connection.close() } } }
Example 72
Source File: JobConfigReaderServiceTest.scala From Scala-Design-Patterns-Second-Edition with MIT License | 5 votes |
package com.ivan.nikolov.scheduler.services import com.ivan.nikolov.scheduler.TestEnvironment import com.ivan.nikolov.scheduler.config.job.{Console, Daily, JobConfig, TimeOptions} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class JobConfigReaderServiceTest extends FlatSpec with Matchers with TestEnvironment { override val ioService: IOService = new IOService override val jobConfigReaderService: JobConfigReaderService = new JobConfigReaderService "readJobConfigs" should "read and parse configurations successfully." in { val result = jobConfigReaderService.readJobConfigs() result should have size(1) result should contain( JobConfig( "Test Command", "ping google.com -c 10", Console, Daily, TimeOptions(12, 10) ) ) } }
Example 73
Source File: TimeOptionsTest.scala From Scala-Design-Patterns-Second-Edition with MIT License | 5 votes |
package com.ivan.nikolov.scheduler.config.job import java.time.LocalDateTime import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class TimeOptionsTest extends FlatSpec with Matchers { "getInitialDelay" should "get the right initial delay for hourly less than an hour after now." in { val now = LocalDateTime.of(2018, 3, 20, 12, 43, 10) val later = now.plusMinutes(20) val timeOptions = TimeOptions(later.getHour, later.getMinute) val result = timeOptions.getInitialDelay(now, Hourly) result.toMinutes should equal(20) } it should "get the right initial delay for hourly more than an hour after now." in { val now = LocalDateTime.of(2018, 3, 20, 18, 51, 17) val later = now.plusHours(3) val timeOptions = TimeOptions(later.getHour, later.getMinute) val result = timeOptions.getInitialDelay(now, Hourly) result.toHours should equal(3) } it should "get the right initial delay for hourly less than an hour before now." in { val now = LocalDateTime.of(2018, 3, 20, 11, 18, 55) val earlier = now.minusMinutes(25) // because of the logic and it will fail otherwise. if (earlier.getDayOfWeek == now.getDayOfWeek) { val timeOptions = TimeOptions(earlier.getHour, earlier.getMinute) val result = timeOptions.getInitialDelay(now, Hourly) result.toMinutes should equal(35) } } it should "get the right initial delay for hourly more than an hour before now." in { val now = LocalDateTime.of(2018, 3, 20, 12, 43, 59) val earlier = now.minusHours(1).minusMinutes(25) // because of the logic and it will fail otherwise. if (earlier.getDayOfWeek == now.getDayOfWeek) { val timeOptions = TimeOptions(earlier.getHour, earlier.getMinute) val result = timeOptions.getInitialDelay(now, Hourly) result.toMinutes should equal(35) } } it should "get the right initial delay for daily before now." in { val now = LocalDateTime.of(2018, 3, 20, 14, 43, 10) val earlier = now.minusMinutes(25) // because of the logic and it will fail otherwise. if (earlier.getDayOfWeek == now.getDayOfWeek) { val timeOptions = TimeOptions(earlier.getHour, earlier.getMinute) val result = timeOptions.getInitialDelay(now, Daily) result.toMinutes should equal(24 * 60 - 25) } } it should "get the right initial delay for daily after now." in { val now = LocalDateTime.of(2018, 3, 20, 16, 21, 6) val later = now.plusMinutes(20) val timeOptions = TimeOptions(later.getHour, later.getMinute) val result = timeOptions.getInitialDelay(now, Daily) result.toMinutes should equal(20) } }
Example 74
Source File: TraitATest.scala From Scala-Design-Patterns-Second-Edition with MIT License | 5 votes |
package com.ivan.nikolov.composition import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class TraitATest extends FlatSpec with Matchers with A { "hello" should "greet properly." in { hello() should equal("Hello, I am trait A!") } "pass" should "return the right string with the number." in { pass(10) should equal("Trait A said: 'You passed 10.'") } it should "be correct also for negative values." in { pass(-10) should equal("Trait A said: 'You passed -10.'") } }
Example 75
Source File: TraitACaseScopeTest.scala From Scala-Design-Patterns-Second-Edition with MIT License | 5 votes |
package com.ivan.nikolov.composition import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class TraitACaseScopeTest extends FlatSpec with Matchers { "hello" should "greet properly." in new A { hello() should equal("Hello, I am trait A!") } "pass" should "return the right string with the number." in new A { pass(10) should equal("Trait A said: 'You passed 10.'") } it should "be correct also for negative values." in new A { pass(-10) should equal("Trait A said: 'You passed -10.'") } }
Example 76
Source File: UserComponentTest.scala From Scala-Design-Patterns-Second-Edition with MIT License | 5 votes |
package com.ivan.nikolov.cake import com.ivan.nikolov.cake.model.Person import org.junit.runner.RunWith import org.mockito.Mockito._ import org.scalatest.junit.JUnitRunner import org.scalatest.mockito.MockitoSugar import org.scalatest.{FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class UserComponentTest extends FlatSpec with Matchers with MockitoSugar with TestEnvironment { val className = "A" val emptyClassName = "B" val people = List( Person(1, "a", 10), Person(2, "b", 15), Person(3, "c", 20) ) override val userService = new UserService when(dao.getPeopleInClass(className)).thenReturn(people) when(dao.getPeopleInClass(emptyClassName)).thenReturn(List()) "getAverageAgeOfUsersInClass" should "properly calculate the average of all ages." in { userService.getAverageAgeOfUsersInClass(className) should equal(15.0) } it should "properly handle an empty result." in { userService.getAverageAgeOfUsersInClass(emptyClassName) should equal(0.0) } }
Example 77
Source File: MetricsStatsReceiverTest.scala From finagle-metrics with MIT License | 5 votes |
package com.twitter.finagle.metrics import com.twitter.finagle.metrics.MetricsStatsReceiver._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.FunSuite @RunWith(classOf[JUnitRunner]) class MetricsStatsReceiverTest extends FunSuite { private[this] val receiver = new MetricsStatsReceiver() private[this] def readGauge(name: String): Option[Number] = Option(metrics.getGauges.get(name)) match { case Some(gauge) => Some(gauge.getValue.asInstanceOf[Float]) case _ => None } private[this] def readCounter(name: String): Option[Number] = Option(metrics.getMeters.get(name)) match { case Some(counter) => Some(counter.getCount) case _ => None } private[this] def readStat(name: String): Option[Number] = Option(metrics.getHistograms.get(name)) match { case Some(stat) => Some(stat.getSnapshot.getValues.toSeq.sum) case _ => None } test("MetricsStatsReceiver should store and read gauge into the Codahale Metrics library") { val x = 1.5f receiver.addGauge("my_gauge")(x) assert(readGauge("my_gauge") === Some(x)) } test("MetricsStatsReceiver should always assume the latest value of an already created gauge") { val gaugeName = "my_gauge2" val expectedValue = 8.8f receiver.addGauge(gaugeName)(2.2f) receiver.addGauge(gaugeName)(9.9f) receiver.addGauge(gaugeName)(expectedValue) assert(readGauge(gaugeName) === Some(expectedValue)) } test("MetricsStatsReceiver should store and remove gauge into the Codahale Metrics Library") { val gaugeName = "temp-gauge" val expectedValue = 2.8f val tempGauge = receiver.addGauge(gaugeName)(expectedValue) assert(readGauge(gaugeName) === Some(expectedValue)) tempGauge.remove() assert(readGauge(gaugeName) === None) } test("MetricsStatsReceiver should store and read stat into the Codahale Metrics library") { val x = 1 val y = 3 val z = 5 val s = receiver.stat("my_stat") s.add(x) s.add(y) s.add(z) assert(readStat("my_stat") === Some(x + y + z)) } test("MetricsStatsReceiver should store and read counter into the Codahale Metrics library") { val x = 2 val y = 5 val z = 8 val c = receiver.counter("my_counter") c.incr(x) c.incr(y) c.incr(z) assert(readCounter("my_counter") === Some(x + y + z)) } }
Example 78
Source File: ReceiverWithoutOffsetIT.scala From datasource-receiver with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.datasource import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext import org.apache.spark.streaming.datasource.models.{InputSentences, StopConditions} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class ReceiverWithoutOffsetIT extends TemporalDataSuite { test("DataSource Receiver should read all the records on each batch without offset conditions") { sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) val rdd = sc.parallelize(registers) sqlContext.createDataFrame(rdd, schema).registerTempTable(tableName) ssc = new StreamingContext(sc, Seconds(1)) val totalEvents = ssc.sparkContext.accumulator(0L, "Number of events received") val inputSentences = InputSentences( s"select * from $tableName", StopConditions(stopWhenEmpty = true, finishContextWhenEmpty = true), initialStatements = Seq.empty[String] ) val distributedStream = DatasourceUtils.createStream(ssc, inputSentences, datasourceParams) distributedStream.start() distributedStream.foreachRDD(rdd => { val streamingEvents = rdd.count() log.info(s" EVENTS COUNT : \t $streamingEvents") totalEvents += streamingEvents log.info(s" TOTAL EVENTS : \t $totalEvents") if (!rdd.isEmpty()) assert(streamingEvents === totalRegisters.toLong) }) ssc.start() ssc.awaitTerminationOrTimeout(10000L) assert(totalEvents.value === totalRegisters.toLong * 10) } }
Example 79
Source File: ReceiverNotStopContextIT.scala From datasource-receiver with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.datasource import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext import org.apache.spark.streaming.datasource.models.{InputSentences, OffsetConditions, OffsetField} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class ReceiverNotStopContextIT extends TemporalDataSuite { test("DataSource Receiver should read all the records in one streaming batch") { sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) val rdd = sc.parallelize(registers) sqlContext.createDataFrame(rdd, schema).registerTempTable(tableName) ssc = new StreamingContext(sc, Seconds(1)) val totalEvents = ssc.sparkContext.accumulator(0L, "Number of events received") val inputSentences = InputSentences( s"select * from $tableName", OffsetConditions(OffsetField("idInt")), initialStatements = Seq.empty[String] ) val distributedStream = DatasourceUtils.createStream(ssc, inputSentences, datasourceParams) distributedStream.start() distributedStream.foreachRDD(rdd => { totalEvents += rdd.count() }) ssc.start() ssc.awaitTerminationOrTimeout(15000L) assert(totalEvents.value === totalRegisters.toLong) } }
Example 80
Source File: ReceiverLimitedIT.scala From datasource-receiver with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.datasource import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext import org.apache.spark.streaming.datasource.models.{InputSentences, OffsetConditions, OffsetField, StopConditions} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class ReceiverLimitedIT extends TemporalDataSuite { test("DataSource Receiver should read the records limited on each batch") { sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) val rdd = sc.parallelize(registers) sqlContext.createDataFrame(rdd, schema).registerTempTable(tableName) ssc = new StreamingContext(sc, Seconds(1)) val totalEvents = ssc.sparkContext.accumulator(0L, "Number of events received") val inputSentences = InputSentences( s"select * from $tableName", OffsetConditions(OffsetField("idInt"), limitRecords = 1000), StopConditions(stopWhenEmpty = true, finishContextWhenEmpty = true), initialStatements = Seq.empty[String] ) val distributedStream = DatasourceUtils.createStream(ssc, inputSentences, datasourceParams) // Start up the receiver. distributedStream.start() // Fires each time the configured window has passed. distributedStream.foreachRDD(rdd => { totalEvents += rdd.count() }) ssc.start() // Start the computation ssc.awaitTerminationOrTimeout(15000L) // Wait for the computation to terminate assert(totalEvents.value === totalRegisters.toLong) } }
Example 81
Source File: ReceiverBasicIT.scala From datasource-receiver with Apache License 2.0 | 5 votes |
package org.apache.spark.streaming.datasource import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext import org.apache.spark.streaming.datasource.models.{InputSentences, OffsetConditions, OffsetField, StopConditions} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class ReceiverBasicIT extends TemporalDataSuite { test ("DataSource Receiver should read all the records in one streaming batch") { sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) val rdd = sc.parallelize(registers) sqlContext.createDataFrame(rdd, schema).registerTempTable(tableName) ssc = new StreamingContext(sc, Seconds(1)) val totalEvents = ssc.sparkContext.accumulator(0L, "Number of events received") val inputSentences = InputSentences( s"select * from $tableName", OffsetConditions(OffsetField("idInt")), StopConditions(stopWhenEmpty = true, finishContextWhenEmpty = true), initialStatements = Seq.empty[String] ) val distributedStream = DatasourceUtils.createStream(ssc, inputSentences, datasourceParams) distributedStream.start() distributedStream.foreachRDD(rdd => { val streamingEvents = rdd.count() log.info(s" EVENTS COUNT : \t $streamingEvents") totalEvents += streamingEvents log.info(s" TOTAL EVENTS : \t $totalEvents") val streamingRegisters = rdd.collect() if (!rdd.isEmpty()) assert(streamingRegisters === registers.reverse) }) ssc.start() ssc.awaitTerminationOrTimeout(15000L) assert(totalEvents.value === totalRegisters.toLong) } }
Example 82
Source File: GeohashTest.scala From sfseize with Apache License 2.0 | 5 votes |
package org.eichelberger.sfc.examples import com.typesafe.scalalogging.slf4j.LazyLogging import org.eichelberger.sfc.SpaceFillingCurve._ import org.eichelberger.sfc.study.composition.CompositionSampleData._ import org.eichelberger.sfc.utils.Timing import org.eichelberger.sfc.{DefaultDimensions, Dimension} import org.junit.runner.RunWith import org.specs2.mutable.Specification import org.specs2.runner.JUnitRunner @RunWith(classOf[JUnitRunner]) class GeohashTest extends Specification with LazyLogging { val xCville = -78.488407 val yCville = 38.038668 "Geohash example" should { val geohash = new Geohash(35) "encode/decode round-trip for an interior point" >> { // encode val hash = geohash.pointToHash(Seq(xCville, yCville)) hash must equalTo("dqb0muw") // decode val cell = geohash.hashToCell(hash) println(s"[Geohash example, Charlottesville] POINT($xCville $yCville) -> $hash -> $cell") cell(0).containsAny(xCville) must beTrue cell(1).containsAny(yCville) must beTrue } "encode/decode properly at the four corners and the center" >> { for (x <- Seq(-180.0, 0.0, 180.0); y <- Seq(-90.0, 0.0, 90.0)) { // encode val hash = geohash.pointToHash(Seq(x, y)) // decode val cell = geohash.hashToCell(hash) println(s"[Geohash example, extrema] POINT($x $y) -> $hash -> $cell") cell(0).containsAny(x) must beTrue cell(1).containsAny(y) must beTrue } // degenerate test outcome 1 must equalTo(1) } def getCvilleRanges(curve: Geohash): (OrdinalPair, OrdinalPair, Iterator[OrdinalPair]) = { val lonIdxRange = OrdinalPair( curve.children(0).asInstanceOf[Dimension[Double]].index(bboxCville._1), curve.children(1).asInstanceOf[Dimension[Double]].index(bboxCville._3) ) val latIdxRange = OrdinalPair( curve.children(0).asInstanceOf[Dimension[Double]].index(bboxCville._2), curve.children(1).asInstanceOf[Dimension[Double]].index(bboxCville._4) ) val query = Query(Seq(OrdinalRanges(lonIdxRange), OrdinalRanges(latIdxRange))) val cellQuery = Cell(Seq( DefaultDimensions.createDimension("x", bboxCville._1, bboxCville._3, 0), DefaultDimensions.createDimension("y", bboxCville._2, bboxCville._4, 0) )) (lonIdxRange, latIdxRange, curve.getRangesCoveringCell(cellQuery)) } "generate valid selection indexes" >> { val (_, _, ranges) = getCvilleRanges(geohash) ranges.size must equalTo(90) } "report range efficiency" >> { def atPrecision(xBits: OrdinalNumber, yBits: OrdinalNumber): (Long, Long) = { val curve = new Geohash(xBits + yBits) val (lonRange, latRange, ranges) = getCvilleRanges(curve) (lonRange.size * latRange.size, ranges.size.toLong) } for (dimPrec <- 10 to 25) { val ((numCells, numRanges), ms) = Timing.time{ () => atPrecision(dimPrec, dimPrec - 1) } println(s"[ranges across scales, Charlottesville] precision ($dimPrec, ${dimPrec - 1}) -> $numCells / $numRanges = ${numCells / numRanges} in $ms milliseconds") } 1 must equalTo(1) } } }
Example 83
Source File: LexicographicTest.scala From sfseize with Apache License 2.0 | 5 votes |
package org.eichelberger.sfc.utils import com.typesafe.scalalogging.slf4j.LazyLogging import org.eichelberger.sfc.SpaceFillingCurve.{OrdinalVector, ords2ordvec} import org.eichelberger.sfc.{DefaultDimensions, ZCurve} import org.junit.runner.RunWith import org.specs2.mutable.Specification import org.specs2.runner.JUnitRunner @RunWith(classOf[JUnitRunner]) class LexicographicTest extends Specification with LazyLogging { sequential "Lexicographical encoding" should { val precisions = new ords2ordvec(Seq(18L, 17L)).toOrdinalVector val sfc = ZCurve(precisions) val Longitude = DefaultDimensions.createLongitude(18L) val Latitude = DefaultDimensions.createLatitude(17L) "work for a known point" >> { val x = -78.488407 val y = 38.038668 val point = OrdinalVector(Longitude.index(x), Latitude.index(y)) val idx = sfc.index(point) val gh = sfc.lexEncodeIndex(idx) gh must equalTo("dqb0muw") } "be consistent round-trip" >> { val xs = (-180.0 to 180.0 by 33.3333).toSeq ++ Seq(180.0) val ys = (-90.0 to 90.0 by 33.3333).toSeq ++ Seq(90.0) for (x <- xs; y <- ys) { val ix = Longitude.index(x) val iy = Latitude.index(y) val point = OrdinalVector(ix, iy) val idx = sfc.index(point) val gh = sfc.lexEncodeIndex(idx) val idx2 = sfc.lexDecodeIndex(gh) idx2 must equalTo(idx) val point2 = sfc.inverseIndex(idx2) point2(0) must equalTo(ix) point2(1) must equalTo(iy) val rx = Longitude.inverseIndex(ix) val ry = Latitude.inverseIndex(iy) val sx = x.formatted("%8.3f") val sy = y.formatted("%8.3f") val sidx = idx.formatted("%20d") println(s"[LEXI ROUND-TRIP] POINT($sx $sy) -> $sidx = $gh -> ($rx, $ry)") } // degenerate 1 must equalTo(1) } } "multiple lexicographical encoders" should { "return different results for different base resolutions" >> { val x = -78.488407 val y = 38.038668 for (xBits <- 1 to 30; yBits <- xBits - 1 to xBits if yBits > 0) { val precisions = new ords2ordvec(Seq(xBits, yBits)).toOrdinalVector val sfc = ZCurve(precisions) val Longitude = DefaultDimensions.createLongitude(xBits) val Latitude = DefaultDimensions.createLatitude(yBits) val idx = sfc.index(OrdinalVector(Longitude.index(x), Latitude.index(y))) val gh = sfc.lexEncodeIndex(idx) val idx2 = sfc.lexDecodeIndex(gh) idx2 must equalTo(idx) println(s"[LEXI ACROSS RESOLUTIONS] mx $xBits + my $yBits = base ${sfc.alphabet.size}, idx $idx -> gh $gh -> $idx2") } // degenerate 1 must equalTo(1) } } }
Example 84
Source File: BitManipulationsTest.scala From sfseize with Apache License 2.0 | 5 votes |
package org.eichelberger.sfc.utils import com.typesafe.scalalogging.slf4j.LazyLogging import org.junit.runner.RunWith import org.specs2.mutable.Specification import org.specs2.runner.JUnitRunner import BitManipulations._ @RunWith(classOf[JUnitRunner]) class BitManipulationsTest extends Specification with LazyLogging { "static methods" should { "usedMask" >> { // single bits for (pos <- 0 to 62) { val v = 1L << pos.toLong val actual = usedMask(v) val expected = (1L << (pos + 1L)) - 1L println(s"[usedMask single bit] pos $pos, value $v, actual $actual, expected $expected") actual must equalTo(expected) } // full bit masks for (pos <- 0 to 62) { val expected = (1L << (pos.toLong + 1L)) - 1L val actual = usedMask(expected) println(s"[usedMask full bit masks] pos $pos, value $expected, actual $actual, expected $expected") actual must equalTo(expected) } usedMask(0) must equalTo(0) } "sharedBitPrefix" >> { sharedBitPrefix(2, 3) must equalTo(2) sharedBitPrefix(178, 161) must equalTo(160) } "common block extrema" >> { commonBlockMin(178, 161) must equalTo(160) commonBlockMax(178, 161) must equalTo(191) } } }
Example 85
Source File: CompositionParserTest.scala From sfseize with Apache License 2.0 | 5 votes |
package org.eichelberger.sfc.utils import com.typesafe.scalalogging.slf4j.LazyLogging import org.eichelberger.sfc.SpaceFillingCurve.SpaceFillingCurve import org.eichelberger.sfc.SpaceFillingCurve.SpaceFillingCurve import org.eichelberger.sfc._ import org.junit.runner.RunWith import org.specs2.mutable.Specification import org.specs2.runner.JUnitRunner @RunWith(classOf[JUnitRunner]) class CompositionParserTest extends Specification { sequential def parsableCurve(curve: SpaceFillingCurve): String = curve match { case c: ComposedCurve => c.delegate.name.charAt(0).toString + c.children.map { case d: Dimension[_] => d.precision case s: SubDimension[_] => s.precision case c: SpaceFillingCurve => parsableCurve(c) }.mkString("(", ", ", ")") case s => s.name.charAt(0).toString + s.precisions.toSeq.map(_.toString).mkString("(", ", ", ")") } def eval(curve: ComposedCurve): Boolean = { val toParse: String = parsableCurve(curve) val parsed: ComposedCurve = CompositionParser.buildWholeNumberCurve(toParse) val fromParse: String = parsableCurve(parsed) println(s"[CURVE PARSER]\n Input: $toParse\n Output: $fromParse") toParse == fromParse } "simple expressions" should { val R23 = new ComposedCurve( RowMajorCurve(2, 3), Seq( DefaultDimensions.createIdentityDimension(2), DefaultDimensions.createIdentityDimension(3) ) ) val H_2_R23 = new ComposedCurve( CompactHilbertCurve(2), Seq( DefaultDimensions.createIdentityDimension(2), R23 ) ) val Z_R23_2 = new ComposedCurve( ZCurve(2), Seq( R23, DefaultDimensions.createIdentityDimension(2) ) ) "parse correctly" >> { eval(R23) must beTrue eval(H_2_R23) must beTrue eval(Z_R23_2) must beTrue } } }
Example 86
Source File: LocalityEstimatorTest.scala From sfseize with Apache License 2.0 | 5 votes |
package org.eichelberger.sfc.utils import com.typesafe.scalalogging.slf4j.LazyLogging import org.eichelberger.sfc.{CompactHilbertCurve, RowMajorCurve, ZCurve} import org.junit.runner.RunWith import org.specs2.mutable.Specification import org.specs2.runner.JUnitRunner @RunWith(classOf[JUnitRunner]) class LocalityEstimatorTest extends Specification with LazyLogging { sequential "locality" should { "evaluate on square 2D curves" >> { (1 to 6).foreach { p => val locR = LocalityEstimator(RowMajorCurve(p, p)).locality println(s"[LOCALITY R($p, $p)] $locR") val locZ = LocalityEstimator(ZCurve(p, p)).locality println(s"[LOCALITY Z($p, $p)] $locZ") val locH = LocalityEstimator(CompactHilbertCurve(p, p)).locality println(s"[LOCALITY H($p, $p)] $locH") } 1 must beEqualTo(1) } "evaluate on non-square 2D curves" >> { (1 to 6).foreach { p => val locR = LocalityEstimator(RowMajorCurve(p << 1L, p)).locality println(s"[LOCALITY R(${p*2}, $p)] $locR") val locZ = LocalityEstimator(ZCurve(p << 1L, p)).locality println(s"[LOCALITY Z(${p*2}, $p)] $locZ") val locH = LocalityEstimator(CompactHilbertCurve(p << 1L, p)).locality println(s"[LOCALITY H(${p*2}, $p)] $locH") } 1 must beEqualTo(1) } } }
Example 87
Source File: RowMajorCurveTest.scala From sfseize with Apache License 2.0 | 5 votes |
package org.eichelberger.sfc import com.typesafe.scalalogging.slf4j.LazyLogging import org.eichelberger.sfc.CompactHilbertCurve.Mask import org.eichelberger.sfc.SpaceFillingCurve.{OrdinalVector, SpaceFillingCurve, _} import org.junit.runner.RunWith import org.specs2.mutable.Specification import org.specs2.runner.JUnitRunner @RunWith(classOf[JUnitRunner]) class RowMajorCurveTest extends Specification with GenericCurveValidation with LazyLogging { sequential def curveName = "RowmajorCurve" def createCurve(precisions: OrdinalNumber*): SpaceFillingCurve = RowMajorCurve(precisions.toOrdinalVector) "rowmajor space-filling curves" should { "satisfy the ordering constraints" >> { timeTestOrderings() must beTrue } "identify sub-ranges correctly" >> { val sfc = createCurve(3, 3) val query = Query(Seq(OrdinalRanges(OrdinalPair(1, 2)), OrdinalRanges(OrdinalPair(1, 3)))) val ranges = sfc.getRangesCoveringQuery(query).toList for (i <- 0 until ranges.size) { println(s"[rowmajor ranges: query $query] range $i = ${ranges(i)}") } ranges(0) must equalTo(OrdinalPair(9, 11)) ranges(1) must equalTo(OrdinalPair(17, 19)) } } }
Example 88
Source File: DefaultSaverITCase.scala From flink-tensorflow with Apache License 2.0 | 5 votes |
package org.apache.flink.contrib.tensorflow.io import org.apache.flink.contrib.tensorflow.models.savedmodel.DefaultSavedModelLoader import org.apache.flink.contrib.tensorflow.util.{FlinkTestBase, RegistrationUtils} import org.apache.flink.core.fs.Path import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpecLike} import org.tensorflow.{Session, Tensor} import scala.collection.JavaConverters._ @RunWith(classOf[JUnitRunner]) class DefaultSaverITCase extends WordSpecLike with Matchers with FlinkTestBase { override val parallelism = 1 "A DefaultSaver" should { "run the save op" in { val env = StreamExecutionEnvironment.getExecutionEnvironment RegistrationUtils.registerTypes(env.getConfig) val loader = new DefaultSavedModelLoader(new Path("../models/half_plus_two"), "serve") val bundle = loader.load() val saverDef = loader.metagraph.getSaverDef val saver = new DefaultSaver(saverDef) def getA = getVariable(bundle.session(), "a").floatValue() def setA(value: Float) = setVariable(bundle.session(), "a", Tensor.create(value)) val initialA = getA println("Initial value: " + initialA) setA(1.0f) val savePath = tempFolder.newFolder("model-0").getAbsolutePath val path = saver.save(bundle.session(), savePath) val savedA = getA savedA shouldBe (1.0f) println("Saved value: " + getA) setA(2.0f) val updatedA = getA updatedA shouldBe (2.0f) println("Updated value: " + updatedA) saver.restore(bundle.session(), path) val restoredA = getA restoredA shouldBe (savedA) println("Restored value: " + restoredA) } def getVariable(sess: Session, name: String): Tensor = { val result = sess.runner().fetch(name).run().asScala result.head } def setVariable(sess: Session, name: String, value: Tensor): Unit = { sess.runner() .addTarget(s"$name/Assign") .feed(s"$name/initial_value", value) .run() } } }
Example 89
Source File: ArraysTest.scala From flink-tensorflow with Apache License 2.0 | 5 votes |
package org.tensorflow.contrib.scala import com.twitter.bijection.Conversion._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Matchers, WordSpecLike} import org.tensorflow.contrib.scala.Arrays._ import org.tensorflow.contrib.scala.Rank._ import resource._ @RunWith(classOf[JUnitRunner]) class ArraysTest extends WordSpecLike with Matchers { "Arrays" when { "Array[Float]" should { "convert to Tensor[`1D`,Float]" in { val expected = Array(1f,2f,3f) managed(expected.as[TypedTensor[`1D`,Float]]).foreach { t => t.shape shouldEqual Array(expected.length) val actual = t.as[Array[Float]] actual shouldEqual expected } } } } }
Example 90
Source File: TestRenaming.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.lir import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners import at.forsyte.apalache.tla.lir.transformations.standard.Renaming import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{BeforeAndAfterEach, FunSuite} @RunWith(classOf[JUnitRunner]) class TestRenaming extends FunSuite with BeforeAndAfterEach with TestingPredefs { import at.forsyte.apalache.tla.lir.Builder._ private var renaming = new Renaming(TrackerWithListeners()) override protected def beforeEach(): Unit = { renaming = new Renaming(TrackerWithListeners()) } test("test renaming exists/forall") { val original = and( exists(n_x, n_S, gt(n_x, int(1))), forall(n_x, n_T, lt(n_x, int(42)))) /// val expected = and( exists(name("x_1"), n_S, gt(name("x_1"), int(1))), forall(name("x_2"), n_T, lt(name("x_2"), int(42)))) val renamed = renaming.renameBindingsUnique(original) assert(expected == renamed) } test("test renaming filter") { val original = cup( filter(name("x"), name("S"), eql(name("x"), int(1))), filter(name("x"), name("S"), eql(name("x"), int(2))) ) val expected = cup( filter(name("x_1"), name("S"), eql(name("x_1"), int(1))), filter(name("x_2"), name("S"), eql(name("x_2"), int(2)))) val renamed = renaming.renameBindingsUnique(original) assert(expected == renamed) } test( "Test renaming LET-IN" ) { // LET p(t) == \A x \in S . R(t,x) IN \E x \in S . p(x) val original = letIn( exists( n_x, n_S, appOp( name( "p" ), n_x ) ), declOp( "p", forall( n_x, n_S, appOp( name( "R" ), name( "t" ), n_x ) ), "t" ) ) val expected = letIn( exists( name( "x_2" ), n_S, appOp( name( "p_1" ), name( "x_2" ) ) ), declOp( "p_1", forall( name( "x_1" ), n_S, appOp( name( "R" ), name( "t_1" ), name( "x_1" ) ) ), "t_1" ) ) val actual = renaming( original ) assert(expected == actual) } test( "Test renaming multiple LET-IN" ) { // LET X == TRUE IN X /\ LET X == FALSE IN X val original = and( letIn( appOp( name( "X" ) ), declOp( "X", trueEx ) ), letIn( appOp( name( "X" ) ), declOp( "X", falseEx ) ) ) val expected = and( letIn( appOp( name( "X_1" ) ), declOp( "X_1", trueEx ) ), letIn( appOp( name( "X_2" ) ), declOp( "X_2", falseEx ) ) ) val actual = renaming( original ) assert(expected == actual) } }
Example 91
Source File: TestLirValues.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.lir import at.forsyte.apalache.tla.lir.values._ import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestLirValues extends FunSuite { test("create booleans") { val b = TlaBool(false) assert(!b.value) } test("create int") { val i = TlaInt(1) assert(i.value == BigInt(1)) assert(i == TlaInt(1)) assert(i.isNatural) assert(TlaInt(0).isNatural) assert(!TlaInt(-1).isNatural) } test("create a string") { val s = TlaStr("hello") assert(s.value == "hello") } test("create a constant") { val c = new TlaConstDecl("x") assert("x" == c.name) } test("create a variable") { val c = new TlaVarDecl("x") assert("x" == c.name) } }
Example 92
Source File: TestAux.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.lir import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner @RunWith( classOf[JUnitRunner] ) class TestAux extends FunSuite with TestingPredefs { test( "Test aux::collectSegments" ){ val ar0Decl1 = TlaOperDecl( "X", List.empty, n_x ) val ar0Decl2 = TlaOperDecl( "Y", List.empty, n_y ) val ar0Decl3 = TlaOperDecl( "Z", List.empty, n_z ) val arGe0Decl1 = TlaOperDecl( "A", List( SimpleFormalParam( "t" ) ), n_a ) val arGe0Decl2 = TlaOperDecl( "B", List( SimpleFormalParam( "t" ) ), n_b ) val arGe0Decl3 = TlaOperDecl( "C", List( SimpleFormalParam( "t" ) ), n_c ) val pa1 = List( ar0Decl1 ) -> List( List( ar0Decl1 ) ) val pa2 = List( ar0Decl1, ar0Decl2 ) -> List( List( ar0Decl1, ar0Decl2 ) ) val pa3 = List( arGe0Decl1, ar0Decl1 ) -> List( List( arGe0Decl1 ), List( ar0Decl1 ) ) val pa4 = List( arGe0Decl1, arGe0Decl2 ) -> List( List( arGe0Decl1, arGe0Decl2 ) ) val pa5 = List( arGe0Decl1, arGe0Decl2, ar0Decl1, ar0Decl2, arGe0Decl3 ) -> List( List( arGe0Decl1, arGe0Decl2 ), List( ar0Decl1, ar0Decl2 ), List( arGe0Decl3 ) ) val expected = Seq( pa1, pa2, pa3, pa4, pa5 ) val cmp = expected map { case (k, v) => (v, aux.collectSegments( k )) } cmp foreach { case (ex, act) => assert( ex == act ) } } }
Example 93
Source File: TestTypeReduction.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.types import at.forsyte.apalache.tla.lir.TestingPredefs import org.junit.runner.RunWith import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.junit.JUnitRunner @RunWith( classOf[JUnitRunner] ) class TestTypeReduction extends FunSuite with TestingPredefs with BeforeAndAfter { var gen = new SmtVarGenerator var tr = new TypeReduction( gen ) before { gen = new SmtVarGenerator tr = new TypeReduction( gen ) } test( "Test nesting" ) { val tau = FunT( IntT, SetT( IntT ) ) val m = Map.empty[TypeVar, SmtTypeVariable] val rr = tr( tau, m ) assert( rr.t == fun( int, set( int ) ) ) } test("Test tuples"){ val tau = SetT( FunT( TupT( IntT, StrT ), SetT( IntT ) ) ) val m = Map.empty[TypeVar, SmtTypeVariable] val rr = tr(tau, m) val idx = SmtIntVariable( 0 ) assert( rr.t == set( fun( tup( idx ), set( int ) ) ) ) assert( rr.phi.contains( hasIndex( idx, 0, int ) ) ) assert( rr.phi.contains( hasIndex( idx, 1, str ) ) ) } }
Example 94
Source File: TestSymbStateRewriterStr.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.values.TlaStr import at.forsyte.apalache.tla.lir.{NameEx, ValEx} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterStr extends RewriterBase { test("SE-STR-CTOR: \"red\" -> $C$k") { val state = new SymbState(ValEx(TlaStr("red")), CellTheory(), arena, new Binding) val rewriter = create() val nextStateRed = rewriter.rewriteUntilDone(state) nextStateRed.ex match { case predEx@NameEx(name) => assert(CellTheory().hasConst(name)) assert(CellTheory() == state.theory) assert(solverContext.sat()) val redEqBlue = tla.eql(tla.str("blue"), tla.str("red")) val nextStateEq = rewriter.rewriteUntilDone(nextStateRed.setRex(redEqBlue)) rewriter.push() solverContext.assertGroundExpr(nextStateEq.ex) assert(!solverContext.sat()) rewriter.pop() solverContext.assertGroundExpr(tla.not(nextStateEq.ex)) assert(solverContext.sat()) case _ => fail("Unexpected rewriting result") } } }
Example 95
Source File: TestVCGenerator.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.imp.SanyImporter import at.forsyte.apalache.tla.imp.src.SourceStore import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker import at.forsyte.apalache.tla.lir.{TlaModule, TlaOperDecl} import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner import scala.io.Source @RunWith(classOf[JUnitRunner]) class TestVCGenerator extends FunSuite { private def mkVCGen(): VCGenerator = { new VCGenerator(new IdleTracker) } test("simple invariant") { val text = """---- MODULE inv ---- |EXTENDS Integers |VARIABLE x |Inv == x > 0 |==================== """.stripMargin val mod = loadFromText("inv", text) val newMod = mkVCGen().gen(mod, "Inv") assertDecl(newMod, "VCInv$0", "x > 0") assertDecl(newMod, "VCNotInv$0", "¬(x > 0)") } test("conjunctive invariant") { val text = """---- MODULE inv ---- |EXTENDS Integers |VARIABLE x |Inv == x > 0 /\ x < 10 |==================== """.stripMargin val mod = loadFromText("inv", text) val newMod = mkVCGen().gen(mod, "Inv") assertDecl(newMod, "VCInv$0", "x > 0") assertDecl(newMod, "VCInv$1", "x < 10") assertDecl(newMod, "VCNotInv$0", "¬(x > 0)") assertDecl(newMod, "VCNotInv$1", "¬(x < 10)") } test("conjunction under universals") { val text = """---- MODULE inv ---- |EXTENDS Integers |VARIABLE x, S |Inv == \A z \in S: \A y \in S: y > 0 /\ y < 10 |==================== """.stripMargin val mod = loadFromText("inv", text) val newMod = mkVCGen().gen(mod, "Inv") assertDecl(newMod, "VCInv$0", """∀z ∈ S: (∀y ∈ S: (y > 0))""") assertDecl(newMod, "VCInv$1", """∀z ∈ S: (∀y ∈ S: (y < 10))""") assertDecl(newMod, "VCNotInv$0", """¬(∀z ∈ S: (∀y ∈ S: (y > 0)))""") assertDecl(newMod, "VCNotInv$1", """¬(∀z ∈ S: (∀y ∈ S: (y < 10)))""") } private def assertDecl(mod: TlaModule, name: String, expectedBodyText: String): Unit = { val vc = mod.declarations.find(_.name == name) assert(vc.nonEmpty, s"(VC $name not found)") assert(vc.get.isInstanceOf[TlaOperDecl]) assert(vc.get.asInstanceOf[TlaOperDecl].body.toString == expectedBodyText) } private def loadFromText(moduleName: String, text: String): TlaModule = { val locationStore = new SourceStore val (rootName, modules) = new SanyImporter(locationStore) .loadFromSource(moduleName, Source.fromString(text)) modules(moduleName) } }
Example 96
Source File: TestTypeInference.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.types.{Signatures, TypeInference} import at.forsyte.apalache.tla.lir.TestingPredefs import at.forsyte.apalache.tla.lir.convenience._ import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner // TODO: remove? @RunWith( classOf[JUnitRunner] ) class TestTypeInference extends FunSuite with TestingPredefs { ignore( "Signatures" ) { val exs = List( tla.and( n_x, n_y ), tla.choose( n_x, n_S, n_p ), tla.enumSet( seq( 10 ) : _* ), tla.in( n_x, n_S ), tla.map( n_e, n_x, n_S ) ) val sigs = exs map Signatures.get exs zip sigs foreach { case (x, y) => println( s"${x} ... ${y}" ) } val funDef = tla.funDef( tla.plus( n_x, n_y ), n_x, n_S, n_y, n_T ) val sig = Signatures.get( funDef ) printsep() println( sig ) printsep() } ignore( "TypeInference" ) { val ex = tla.and( tla.primeEq( n_a, tla.choose( n_x, n_S, n_p ) ), tla.in( 2, n_S ) ) val r = TypeInference.theta( ex ) println( r ) } ignore( "Application" ) { val ex = tla.eql( tla.plus( tla.appFun( n_f, n_x ) , 2), 4 ) val ex2 = tla.and( tla.in( n_x, n_S ), tla.le( tla.plus( tla.mult( 2, n_x ), 5 ), 10 ), tla.primeEq( n_x, tla.appFun( n_f, n_x ) ) ) val r = TypeInference( ex ) } }
Example 97
Source File: TestSymbStateRewriterChoose.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.types.{AnnotationParser, FinSetT, IntT} import at.forsyte.apalache.tla.lir.TestingPredefs import at.forsyte.apalache.tla.lir.convenience.tla import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterChoose extends RewriterBase with TestingPredefs { test("""CHOOSE x \in {1, 2, 3}: x > 1""") { val ex = tla.choose(tla.name("x"), tla.enumSet(tla.int(1), tla.int(2), tla.int(3)), tla.gt(tla.name("x"), tla.int(1))) val state = new SymbState(ex, CellTheory(), arena, new Binding) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) def assertEq(i: Int): SymbState = { val ns = rewriter.rewriteUntilDone(nextState.setRex(tla.eql(nextState.ex, tla.int(i)))) solverContext.assertGroundExpr(ns.ex) ns } rewriter.push() assertEq(3) assert(solverContext.sat()) rewriter.pop() rewriter.push() assertEq(2) assert(solverContext.sat()) rewriter.pop() rewriter.push() val ns = assertEq(1) assertUnsatOrExplain(rewriter, ns) } test("""CHOOSE x \in {1}: x > 1""") { val ex = tla.choose(tla.name("x"), tla.enumSet(tla.int(1)), tla.gt(tla.name("x"), tla.int(1))) val state = new SymbState(ex, CellTheory(), arena, new Binding) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) // the buggy implementation of choose fails on a dynamically empty set assert(solverContext.sat()) // The semantics of choose does not restrict the outcome on the empty sets, // so we do not test for anything here. Our previous implementation of CHOOSE produced default values in this case, // but this happened to be error-prone and sometimes conflicting with other rules. So, no default values. } test("""CHOOSE x \in {}: x > 1""") { val ex = tla.choose(tla.name("x"), tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT()))), tla.gt(tla.name("x"), tla.int(1))) val state = new SymbState(ex, CellTheory(), arena, new Binding) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) // the buggy implementation of choose fails on a dynamically empty set assert(solverContext.sat()) def assertEq(i: Int): SymbState = { val ns = rewriter.rewriteUntilDone(nextState.setRex(tla.eql(nextState.ex, tla.int(i)))) solverContext.assertGroundExpr(ns.ex) ns } // Actually, semantics of choose does not restrict the outcome on the empty sets. // But we know that our implementation would always return 0 in this case. val ns = assertEq(1) assertUnsatOrExplain(rewriter, ns) } }
Example 98
Source File: TestSymbStateRewriterFiniteSets.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.types._ import at.forsyte.apalache.tla.lir.{NameEx, TlaEx} import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.oper.TlaFunOper import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterFiniteSets extends RewriterBase { test("""Cardinality({1, 2, 3}) = 3""") { val set = tla.enumSet(1.to(3).map(tla.int) :_*) val card = tla.card(set) val state = new SymbState(card, CellTheory(), arena, new Binding) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(3), nextState.ex))) } test("""Cardinality({1, 2, 2, 2, 3, 3}) = 3""") { val set = tla.enumSet(Seq(1, 2, 2, 2, 3, 3).map(tla.int) :_*) val card = tla.card(set) val state = new SymbState(card, CellTheory(), arena, new Binding) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(3), nextState.ex))) } test("""Cardinality({1, 2, 3} \ {2}) = 2""") { def setminus(set: TlaEx, intVal: Int): TlaEx = { tla.filter(tla.name("t"), set, tla.not(tla.eql(tla.name("t"), tla.int(intVal)))) } val set = setminus(tla.enumSet(1.to(3).map(tla.int) :_*), 2) val card = tla.card(set) val state = new SymbState(card, CellTheory(), arena, new Binding) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(2), nextState.ex))) } test("""IsFiniteSet({1, 2, 3}) = TRUE""") { val set = tla.enumSet(1.to(3).map(tla.int) :_*) val card = tla.isFin(set) val state = new SymbState(card, CellTheory(), arena, new Binding) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.bool(true), nextState.ex))) } }
Example 99
Source File: TestUninterpretedConstOracle.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.bmcmt.rules.aux import at.forsyte.apalache.tla.bmcmt.types.BoolT import at.forsyte.apalache.tla.bmcmt.{Binding, CellTheory, RewriterBase, SymbState} import at.forsyte.apalache.tla.lir.TestingPredefs import at.forsyte.apalache.tla.lir.convenience.tla import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestUninterpretedConstOracle extends RewriterBase with TestingPredefs { test("""Oracle.create""") { val rewriter = create() var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding) // introduce an oracle val (nextState, oracle) = UninterpretedConstOracle.create(rewriter, state, 6) assert(solverContext.sat()) } test("""Oracle.whenEqualTo""") { val rewriter = create() var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding) // introduce an oracle val (nextState, oracle) = UninterpretedConstOracle.create(rewriter, state, 6) assert(solverContext.sat()) rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 3)) assert(solverContext.sat()) rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 4)) assert(!solverContext.sat()) } test("""Oracle.evalPosition""") { val rewriter = create() var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding) // introduce an oracle val (nextState, oracle) = UninterpretedConstOracle.create(rewriter, state, 6) assert(solverContext.sat()) rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 3)) assert(solverContext.sat()) val position = oracle.evalPosition(rewriter.solverContext, nextState) assert(3 == position) } test("""Oracle.caseAssertions""") { val rewriter = create() var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding) state = state.updateArena(_.appendCell(BoolT())) val flag = state.arena.topCell // introduce an oracle val (nextState, oracle) = UninterpretedConstOracle.create(rewriter, state, 2) // assert flag == true iff oracle = 0 rewriter.solverContext.assertGroundExpr(oracle.caseAssertions(nextState, Seq(flag.toNameEx, tla.not(flag.toNameEx)))) // assert oracle = 1 rewriter.push() rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 1)) assert(solverContext.sat()) assert(solverContext.evalGroundExpr(flag.toNameEx) == tla.bool(false)) rewriter.pop() // assert oracle = 0 rewriter.push() rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 0)) assert(solverContext.sat()) assert(solverContext.evalGroundExpr(flag.toNameEx) == tla.bool(true)) rewriter.pop() } }
Example 100
Source File: TestPropositionalOracle.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.bmcmt.rules.aux import at.forsyte.apalache.tla.bmcmt.types.BoolT import at.forsyte.apalache.tla.bmcmt.{Binding, CellTheory, RewriterBase, SymbState} import at.forsyte.apalache.tla.lir.TestingPredefs import at.forsyte.apalache.tla.lir.convenience.tla import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestPropositionalOracle extends RewriterBase with TestingPredefs { test("""Oracle.create""") { val rewriter = create() var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding) // introduce an oracle val (nextState, oracle) = PropositionalOracle.create(rewriter, state, 6) assert(solverContext.sat()) } test("""Oracle.whenEqualTo""") { val rewriter = create() var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding) // introduce an oracle val (nextState, oracle) = PropositionalOracle.create(rewriter, state, 6) assert(solverContext.sat()) rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 3)) assert(solverContext.sat()) rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 4)) assert(!solverContext.sat()) } test("""Oracle.evalPosition""") { val rewriter = create() var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding) // introduce an oracle val (nextState, oracle) = PropositionalOracle.create(rewriter, state, 6) assert(solverContext.sat()) rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 3)) assert(solverContext.sat()) val position = oracle.evalPosition(rewriter.solverContext, nextState) assert(3 == position) } test("""Oracle.caseAssertions""") { val rewriter = create() var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding) state = state.updateArena(_.appendCell(BoolT())) val flag = state.arena.topCell // introduce an oracle val (nextState, oracle) = PropositionalOracle.create(rewriter, state, 2) // assert flag == true iff oracle = 0 rewriter.solverContext.assertGroundExpr(oracle.caseAssertions(nextState, Seq(flag.toNameEx, tla.not(flag.toNameEx)))) // assert oracle = 1 rewriter.push() rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 1)) assert(solverContext.sat()) assert(solverContext.evalGroundExpr(flag.toNameEx) == tla.bool(false)) rewriter.pop() // assert oracle = 0 rewriter.push() rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 0)) assert(solverContext.sat()) assert(solverContext.evalGroundExpr(flag.toNameEx) == tla.bool(true)) rewriter.pop() } }
Example 101
Source File: TestSymbStateRewriterAction.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.SymbStateRewriter.Continue import at.forsyte.apalache.tla.bmcmt.types.IntT import at.forsyte.apalache.tla.lir.NameEx import at.forsyte.apalache.tla.lir.convenience._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterAction extends RewriterBase { test("""SE-PRIME: x' ~~> NameEx(x')""") { val rewriter = create() arena.appendCell(IntT()) // the type finder is strict about unassigned types, so let's create a cell for x' val state = new SymbState(tla.prime(NameEx("x")), CellTheory(), arena, Binding("x'" -> arena.topCell)) rewriter.rewriteOnce(state) match { case Continue(next) => assert(next.ex == NameEx("x'")) case _ => fail("Expected x to be renamed to x'") } } }
Example 102
Source File: TestArena.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.types.{BoolT, FinSetT, UnknownT} import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestArena extends FunSuite { test("create cells") { val solverContext = new Z3SolverContext() val emptyArena = Arena.create(solverContext) val arena = emptyArena.appendCell(UnknownT()) assert(emptyArena.cellCount + 1 == arena.cellCount) assert(UnknownT() == arena.topCell.cellType) val arena2 = arena.appendCell(BoolT()) assert(emptyArena.cellCount + 2 == arena2.cellCount) assert(BoolT() == arena2.topCell.cellType) } test("add 'has' edges") { val solverContext = new Z3SolverContext() val arena = Arena.create(solverContext).appendCell(FinSetT(UnknownT())) val set = arena.topCell val arena2 = arena.appendCell(BoolT()) val elem = arena2.topCell val arena3 = arena2.appendHas(set, elem) assert(List(elem) == arena3.getHas(set)) } test("BOOLEAN has FALSE and TRUE") { val solverContext = new Z3SolverContext() val arena = Arena.create(solverContext) val boolean = arena.cellBooleanSet() assert(List(arena.cellFalse(), arena.cellTrue()) == arena.getHas(arena.cellBooleanSet())) } }
Example 103
Source File: TestSymbStateRewriterExpand.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.types._ import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx} import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.oper.BmcOper import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterExpand extends RewriterBase { test("""Expand(SUBSET {1, 2})""") { val baseset = tla.enumSet(tla.int(1), tla.int(2)) val expandPowset = OperEx(BmcOper.expand, tla.powSet(baseset)) val state = new SymbState(expandPowset, CellTheory(), arena, new Binding) val rewriter = create() var nextState = rewriter.rewriteUntilDone(state) val powCell = nextState.asCell // check equality val eq = tla.eql(nextState.ex, tla.enumSet(tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT()))), tla.enumSet(tla.int(1)), tla.enumSet(tla.int(2)), tla.enumSet(tla.int(1), tla.int(2)))) assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } test("""Expand([{1, 2, 3} -> {FALSE, TRUE}]) should fail""") { val domain = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) val codomain = tla.enumSet(tla.bool(false), tla.bool(true)) val funSet = OperEx(BmcOper.expand, tla.funSet(domain, codomain)) val state = new SymbState(funSet, CellTheory(), arena, new Binding) val rewriter = create() assertThrows[RewriterException](rewriter.rewriteUntilDone(state)) } // Constructing an explicit set of functions is, of course, expensive. But it should work for small values. // Left for the future... ignore("""Expand([{1, 2} -> {FALSE, TRUE}]) should work""") { val domain = tla.enumSet(tla.int(1), tla.int(2)) val codomain = tla.enumSet(tla.bool(false), tla.bool(true)) val funSet = OperEx(BmcOper.expand, tla.funSet(domain, codomain)) val state = new SymbState(funSet, CellTheory(), arena, new Binding) val rewriter = create() var nextState = rewriter.rewriteUntilDone(state) val funSetCell = nextState.asCell def mkFun(v1: Boolean, v2: Boolean): TlaEx = { val mapEx = tla.ite(tla.eql(NameEx("x"), tla.int(1)), tla.bool(v1), tla.bool(v2)) tla.funDef(mapEx, tla.name("x"), domain) } val expected = tla.enumSet(mkFun(false, false), mkFun(false, true), mkFun(true, false), mkFun(true, true)) assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(expected, funSetCell.toNameEx))) } }
Example 104
Source File: TestSourceStore.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.imp.src import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.src.SourceRegion import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSourceStore extends FunSuite { test("basic add and find") { val store = new SourceStore() val ex = tla.int(1) val loc = SourceLocation("root", SourceRegion(1, 2, 3, 4)) store.addRec(ex, loc) val foundLoc = store.find(ex.ID) assert(loc == foundLoc.get) } test("recursive add and find") { val store = new SourceStore() val int1 = tla.int(1) val set = tla.enumSet(int1) val loc = SourceLocation("root", SourceRegion(1, 2, 3, 4)) store.addRec(set, loc) val foundLoc = store.find(set.ID) assert(loc == foundLoc.get) val foundLoc2 = store.find(int1.ID) assert(loc == foundLoc2.get) } test("locations are not overwritten") { val store = new SourceStore() val int1 = tla.int(1) val set = tla.enumSet(int1) val set2 = tla.enumSet(set) val loc1 = SourceLocation("tada", SourceRegion(100, 200, 300, 400)) store.addRec(int1, loc1) val loc2 = SourceLocation("root", SourceRegion(1, 2, 3, 4)) store.addRec(set2, loc2) assert(loc2 == store.find(set2.ID).get) assert(loc2 == store.find(set.ID).get) assert(loc1 == store.find(int1.ID).get) } }
Example 105
Source File: TestRegionTree.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.imp.src import at.forsyte.apalache.tla.lir.src.{RegionTree, SourcePosition, SourceRegion} import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestRegionTree extends FunSuite { test("add") { val tree = new RegionTree() val region = SourceRegion(SourcePosition(1, 20), SourcePosition(3, 10)) tree.add(region) } test("add a subregion, then size") { val tree = new RegionTree() val reg1 = SourceRegion(SourcePosition(1, 20), SourcePosition(3, 10)) tree.add(reg1) assert(tree.size == 1) val reg2 = SourceRegion(SourcePosition(1, 20), SourcePosition(2, 5)) tree.add(reg2) assert(tree.size == 2) val reg3 = SourceRegion(SourcePosition(2, 10), SourcePosition(3, 10)) tree.add(reg3) assert(tree.size == 3) } test("add an overlapping subregion") { val tree = new RegionTree() val reg1 = SourceRegion(SourcePosition(1, 10), SourcePosition(3, 10)) tree.add(reg1) val reg2 = SourceRegion(SourcePosition(1, 20), SourcePosition(5, 20)) assertThrows[IllegalArgumentException] { tree.add(reg2) } } test("add a small region, then a larger region") { val tree = new RegionTree() val reg1 = SourceRegion(SourcePosition(2, 10), SourcePosition(3, 10)) tree.add(reg1) val reg2 = SourceRegion(SourcePosition(1, 1), SourcePosition(4, 1)) tree.add(reg2) } test("add a region twice") { val tree = new RegionTree() val reg1 = SourceRegion(SourcePosition(2, 10), SourcePosition(3, 10)) tree.add(reg1) val reg2 = SourceRegion(SourcePosition(2, 10), SourcePosition(3, 10)) tree.add(reg2) } test("add and find") { val tree = new RegionTree() val region = SourceRegion(SourcePosition(1, 20), SourcePosition(3, 10)) val idx = tree.add(region) val found = tree(idx) assert(found == region) } test("find non-existing index") { val tree = new RegionTree() val region = SourceRegion(SourcePosition(1, 20), SourcePosition(3, 10)) val idx = tree.add(region) assertThrows[IndexOutOfBoundsException] { tree(999) } } }
Example 106
Source File: TestConstAndDefRewriter.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.pp import at.forsyte.apalache.tla.imp.SanyImporter import at.forsyte.apalache.tla.imp.src.SourceStore import at.forsyte.apalache.tla.lir.{SimpleFormalParam, TlaOperDecl} import at.forsyte.apalache.tla.lir.convenience._ import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{BeforeAndAfterEach, FunSuite} import scala.io.Source @RunWith(classOf[JUnitRunner]) class TestConstAndDefRewriter extends FunSuite with BeforeAndAfterEach { test("override a constant") { val text = """---- MODULE const ---- |CONSTANT n |OVERRIDE_n == 10 |A == {n} |================================ """.stripMargin val (rootName, modules) = new SanyImporter(new SourceStore) .loadFromSource("const", Source.fromString(text)) val root = modules(rootName) val rewritten = new ConstAndDefRewriter(new IdleTracker())(root) assert(rewritten.constDeclarations.isEmpty) // no constants anymore assert(rewritten.operDeclarations.size == 2) val expected_n = TlaOperDecl("n", List(), tla.int(10)) assert(expected_n == rewritten.operDeclarations.head) val expected_A = TlaOperDecl("A", List(), tla.enumSet(tla.appOp(tla.name("n")))) assert(expected_A == rewritten.operDeclarations(1)) } // In TLA+, constants may be operators with multiple arguments. // We do not support that yet. test("override a constant with a unary operator") { val text = """---- MODULE const ---- |CONSTANT n |OVERRIDE_n(x) == x |A == {n} |================================ """.stripMargin val (rootName, modules) = new SanyImporter(new SourceStore) .loadFromSource("const", Source.fromString(text)) val root = modules(rootName) assertThrows[OverridingError](new ConstAndDefRewriter(new IdleTracker())(root)) } test("overriding a variable with an operator => error") { val text = """---- MODULE const ---- |VARIABLE n, m |OVERRIDE_n == m |A == {n} |================================ """.stripMargin val (rootName, modules) = new SanyImporter(new SourceStore) .loadFromSource("const", Source.fromString(text)) val root = modules(rootName) assertThrows[OverridingError](new ConstAndDefRewriter(new IdleTracker())(root)) } test("override an operator") { val text = """---- MODULE op ---- |BoolMin(S) == CHOOSE x \in S: \A y \in S: x => y |OVERRIDE_BoolMin(S) == CHOOSE x \in S: TRUE |================================ """.stripMargin val (rootName, modules) = new SanyImporter(new SourceStore) .loadFromSource("op", Source.fromString(text)) val root = modules(rootName) val rewritten = new ConstAndDefRewriter(new IdleTracker())(root) assert(rewritten.constDeclarations.isEmpty) assert(rewritten.operDeclarations.size == 1) val expected = TlaOperDecl("BoolMin", List(SimpleFormalParam("S")), tla.choose(tla.name("x"), tla.name("S"), tla.bool(true))) assert(expected == rewritten.operDeclarations.head) } test("override a unary operator with a binary operator") { val text = """---- MODULE op ---- |BoolMin(S) == CHOOSE x \in S: \A y \in S: x => y |OVERRIDE_BoolMin(S, T) == CHOOSE x \in S: x \in T |================================ """.stripMargin val (rootName, modules) = new SanyImporter(new SourceStore) .loadFromSource("op", Source.fromString(text)) val root = modules(rootName) assertThrows[OverridingError](new ConstAndDefRewriter(new IdleTracker())(root)) } }
Example 107
Source File: TestUniqueNameGenerator.scala From apalache with Apache License 2.0 | 5 votes |
package at.forsyte.apalache.tla.pp import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{BeforeAndAfterEach, FunSuite} @RunWith(classOf[JUnitRunner]) class TestUniqueNameGenerator extends FunSuite with BeforeAndAfterEach { test("first three") { val gen = new UniqueNameGenerator assert("t_1" == gen.newName()) assert("t_2" == gen.newName()) assert("t_3" == gen.newName()) } test("after 10000") { val gen = new UniqueNameGenerator for (i <- 1.to(10000)) { gen.newName() } assert("t_7pt" == gen.newName()) } }
Example 108
Source File: DumThroAwayTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.semantics.compiled.plugin.csv import com.eharmony.aloha.audit.impl.OptionAuditor import com.eharmony.aloha.factory.ModelFactory import com.eharmony.aloha.semantics.compiled.CompiledSemantics import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler import org.junit.Assert.assertEquals import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import scala.concurrent.ExecutionContext.Implicits.global @RunWith(classOf[BlockJUnit4ClassRunner]) class DumThroAwayTest { @Test def test1() { val compiler = TwitterEvalCompiler() val plugin = CompiledSemanticsCsvPlugin(Map("profile.user_id" -> CsvTypes.withNameExtended("oi"))) val imports = Seq("com.eharmony.aloha.feature.BasicFunctions._", "scala.math._") val semantics = CompiledSemantics(compiler, plugin, imports) val factory = ModelFactory.defaultFactory(semantics, OptionAuditor[Double]()) val model = factory.fromResource("fizzbuzz.json").get val lineProducer = CsvLines(Map("profile.user_id" -> 0)) val examples = "" :: (-16 to 16 map { _.toString }).toList val lines = lineProducer(examples) val expected = Seq( (None, -1.0), (Some(-16), 16.0), (Some(-15), -6.0), (Some(-14), 14.0), (Some(-13), 13.0), (Some(-12), -2.0), (Some(-11), 11.0), (Some(-10), -4.0), (Some(-9), -2.0), (Some(-8), 8.0), (Some(-7), 7.0), (Some(-6), -2.0), (Some(-5), -4.0), (Some(-4), 4.0), (Some(-3), -2.0), (Some(-2), 2.0), (Some(-1), 1.0), (Some(0), -6.0), (Some(1), 1.0), (Some(2), 2.0), (Some(3), -2.0), (Some(4), 4.0), (Some(5), -4.0), (Some(6), -2.0), (Some(7), 7.0), (Some(8), 8.0), (Some(9), -2.0), (Some(10), -4.0), (Some(11), 11.0), (Some(12), -2.0), (Some(13), 13.0), (Some(14), 14.0), (Some(15), -6.0), (Some(16), 16.0) ) val results = lines.map { l => (l.oi("profile.user_id"), model(l)) }. map { case (optId, s) => (optId, s.get) } assertEquals(expected, results) } }
Example 109
Source File: CsvColumnTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.dataset.csv.json import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import spray.json._ import spray.json.DefaultJsonProtocol._ @RunWith(classOf[BlockJUnit4ClassRunner]) class CsvColumnTest { @Test def test1() { val examples = Seq( """{ "name": "long", "type": "long", "spec": "${long}" }""", """{ "name": "opt_double", "type": "double", "spec": "${opt_double}" }""", """{ "name": "syn_enum", "type": "enum", "spec": "${opt_string}", "values": [ "e1v1" ] }""", """{ "name": "enum", "type": "enum", "spec": "${string}", "enumClass": "com.eharmony.matching.notaloha.AnEnum" }""" ) val expected = Seq( CsvColumnWithDefault[Long]("long", "${long}"), CsvColumnWithDefault[Double]("opt_double", "${opt_double}"), SyntheticEnumCsvColumn("syn_enum", "${opt_string}", Seq("e1v1")), EnumCsvColumn("enum", "${string}", "com.eharmony.matching.notaloha.AnEnum") ) val act = examples.map { ex => CsvColumn.csvColumnSpecFormat.read(ex.parseJson) } assertEquals(expected, act) } @Test def testReqEnum() { val jsonTxt = """{ "name": "some_enum", | "type": "enum", | "spec": "${string}", | "enumClass": "com.eharmony.matching.notaloha.AnEnum" |}""".stripMargin val json = jsonTxt.parseJson val col = json.convertTo[CsvColumn] assertTrue(col.isInstanceOf[EnumCsvColumn]) } @Test def testOptEnum() { val jsonTxt = """{ "name": "some_enum", | "type": "enum", | "spec": "${string}", | "enumClass": "com.eharmony.matching.notaloha.AnEnum", | "defVal": "VALUE_2", | "optional": true |}""".stripMargin val json = jsonTxt.parseJson val col = json.convertTo[CsvColumn] assertTrue(col.isInstanceOf[OptionEnumCsvColumn[_]]) } @Test def testSizedByte(): Unit = testSizedCreation[Byte] @Test def testSizedChar(): Unit = testSizedCreation[Char] @Test def testSizedShort(): Unit = testSizedCreation[Short] @Test def testSizedInt(): Unit = testSizedCreation[Int] @Test def testSizedLong(): Unit = testSizedCreation[Long] @Test def testSizedFloat(): Unit = testSizedCreation[Float] @Test def testSizedDouble(): Unit = testSizedCreation[Double] @Test def testSizedString(): Unit = testSizedCreation[String] private def testSizedCreation[A: RefInfo: JsonFormat]: Unit = { val tpe = RefInfoOps.toString(RefInfo[A]).split("\\.").last.toLowerCase val name = tpe.replaceAll("[aeiou]", "") val spec = "${string}" val jsonTxt = s"""{ "name": "$name", "type": "$tpe", "size": 2, "spec": "$spec"}""" val col = jsonTxt.parseJson.convertTo[CsvColumn] val exp = SeqCsvColumnWithNoDefault[A](name, spec, 2) assertEquals(exp, col) } }
Example 110
Source File: OptionCsvColumnWithDefaultTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.dataset.csv.json import com.eharmony.aloha.semantics.compiled.CompiledSemantics import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler import com.eharmony.aloha.semantics.compiled.plugin.csv.{CompiledSemanticsCsvPlugin, CsvLines, CsvTypes} import com.eharmony.aloha.semantics.func.GenAggFunc import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import spray.json.DefaultJsonProtocol.DoubleJsonFormat import scala.concurrent.ExecutionContext.Implicits.global private[this] def compileOptFn[A, C](s: CompiledSemantics[A], c: TypedColumnCol[C]): GenAggFunc[A, Option[C]] = { s.createFunction[Option[C]](c.wrappedSpec, Some(c.defVal))(c.refInfo).fold( errs => throw new RuntimeException(s"Problem compiling function:\n${errs.mkString("\n")}"), fn => c match { case col: OptionCsvColumnWithDefault[C] => fn.andThenGenAggFunc(_ orElse c.defVal) case _ => fn } ) } } private[json] object OptionCsvColumnWithDefaultTest { type TypedColumnCol[A] = CsvColumn { type ColType = A } private[this] val features = Seq( "height_mm" -> CsvTypes.DoubleOptionType, "height_cm" -> CsvTypes.IntType ) private[this] val missing = "" // Test height actual data: height_mm [TAB] height_cm val lines = CsvLines(indices = features.unzip._1.zipWithIndex.toMap)( "1800\t180", s"$missing\t165" ) lazy val plugin = CompiledSemanticsCsvPlugin(features: _*) lazy val semantics = CompiledSemantics(TwitterEvalCompiler(), plugin, Nil) }
Example 111
Source File: ComparisonsTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.feature import org.junit.runners.BlockJUnit4ClassRunner import org.junit.runner.RunWith import org.junit.Test import org.junit.Assert._ @RunWith(classOf[BlockJUnit4ClassRunner]) class ComparisonsTest { import ComparisonsTest._ @Test def test_gtLt_1_1(): Unit = assertFalse(gtLt2(1, 1)) @Test def test_gtLt_1_2(): Unit = assertFalse(gtLt2(1, 2)) @Test def test_gtLt_1_3(): Unit = assertTrue(gtLt2(1, 3)) @Test def test_gtLt_2_1(): Unit = assertFalse(gtLt2(2, 1)) @Test def test_gtLt_2_2(): Unit = assertFalse(gtLt2(2, 2)) @Test def test_gtLt_2_3(): Unit = assertFalse(gtLt2(2, 3)) @Test def test_gtLt_3_1(): Unit = assertFalse(gtLt2(3, 1)) @Test def test_gtLt_3_2(): Unit = assertFalse(gtLt2(3, 2)) @Test def test_gtLt_3_3(): Unit = assertFalse(gtLt2(3, 3)) @Test def test_gtLte_1_1(): Unit = assertFalse(gtLte2(1, 1)) @Test def test_gtLte_1_2(): Unit = assertTrue(gtLte2(1, 2)) @Test def test_gtLte_1_3(): Unit = assertTrue(gtLte2(1, 3)) @Test def test_gtLte_2_1(): Unit = assertFalse(gtLte2(2, 1)) @Test def test_gtLte_2_2(): Unit = assertFalse(gtLte2(2, 2)) @Test def test_gtLte_2_3(): Unit = assertFalse(gtLte2(2, 3)) @Test def test_gtLte_3_1(): Unit = assertFalse(gtLte2(3, 1)) @Test def test_gtLte_3_2(): Unit = assertFalse(gtLte2(3, 2)) @Test def test_gtLte_3_3(): Unit = assertFalse(gtLte2(3, 3)) @Test def test_gteLt_1_1(): Unit = assertFalse(gteLt2(1, 1)) @Test def test_gteLt_1_2(): Unit = assertFalse(gteLt2(1, 2)) @Test def test_gteLt_1_3(): Unit = assertTrue(gteLt2(1, 3)) @Test def test_gteLt_2_1(): Unit = assertFalse(gteLt2(2, 1)) @Test def test_gteLt_2_2(): Unit = assertFalse(gteLt2(2, 2)) @Test def test_gteLt_2_3(): Unit = assertTrue(gteLt2(2, 3)) @Test def test_gteLt_3_1(): Unit = assertFalse(gteLt2(3, 1)) @Test def test_gteLt_3_2(): Unit = assertFalse(gteLt2(3, 2)) @Test def test_gteLt_3_3(): Unit = assertFalse(gteLt2(3, 3)) @Test def test_gteLte_1_1(): Unit = assertFalse(gteLte2(1, 1)) @Test def test_gteLte_1_2(): Unit = assertTrue(gteLte2(1, 2)) @Test def test_gteLte_1_3(): Unit = assertTrue(gteLte2(1, 3)) @Test def test_gteLte_2_1(): Unit = assertFalse(gteLte2(2, 1)) @Test def test_gteLte_2_2(): Unit = assertTrue(gteLte2(2, 2)) @Test def test_gteLte_2_3(): Unit = assertTrue(gteLte2(2, 3)) @Test def test_gteLte_3_1(): Unit = assertFalse(gteLte2(3, 1)) @Test def test_gteLte_3_2(): Unit = assertFalse(gteLte2(3, 2)) @Test def test_gteLte_3_3(): Unit = assertFalse(gteLte2(3, 3)) } object ComparisonsTest { import Comparisons._ val gtLt2: (Int, Int) => Boolean = gtLt(2, _, _) val gtLte2: (Int, Int) => Boolean = gtLe(2, _, _) val gteLt2: (Int, Int) => Boolean = geLt(2, _, _) val gteLte2: (Int, Int) => Boolean = geLe(2, _, _) }
Example 112
Source File: BasicFunctionsTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.feature import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import BasicFunctions._ object BasicFunctionsTest { type KVPair = Iterable[(String, Double)] val one: KVPair = Iterable(("", 1.0)) } @RunWith(classOf[BlockJUnit4ClassRunner]) class BasicFunctionsTest { import BasicFunctionsTest._ @Test def testByteToSeq(): Unit = testAnyValToSeq(1.toByte, one) @Test def testShortToSeq(): Unit = testAnyValToSeq(1.toShort, one) @Test def testIntToSeq(): Unit = testAnyValToSeq(1, one) @Test def testLongToSeq(): Unit = testAnyValToSeq(1L, one) @Test def testFloatToSeq(): Unit = testAnyValToSeq(1f, one) @Test def testDoubleToSeq(): Unit = testAnyValToSeq(1d, one) @Test def testOptByteToSeq(): Unit = assertEquals(one, Option(1.toByte).toKv) @Test def testOptShortToSeq(): Unit = assertEquals(one, Option(1.toShort).toKv) @Test def testOptIntToSeq(): Unit = assertEquals(one, Option(1).toKv) @Test def testOptLongToSeq(): Unit = assertEquals(one, Option(1L).toKv) @Test def testOptFloatToSeq(): Unit = assertEquals(one, Option(1f).toKv) @Test def testOptDoubleToSeq(): Unit = assertEquals(one, Option(1d).toKv) @Test def testNoneByteToSeq(): Unit = testNoneToSeq[Byte] @Test def testNoneShortToSeq(): Unit = testNoneToSeq[Short] @Test def testNoneIntToSeq(): Unit = testNoneToSeq[Int] @Test def testNoneLongToSeq(): Unit = testNoneToSeq[Long] @Test def testNoneFloatToSeq(): Unit = testNoneToSeq[Float] @Test def testNoneDoubleToSeq(): Unit = testNoneToSeq[Double] def testAnyValToSeq[A](a: A, kv: KVPair)(implicit f: A => KVPair): Unit = assertEquals(kv, f(a)) def testNoneToSeq[A](implicit f: A => Double): Unit = assertEquals(Nil, Option.empty[A].toKv) }
Example 113
Source File: SparsityTransformsTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.feature import org.junit.runners.BlockJUnit4ClassRunner import org.junit.runner.RunWith import org.junit.Test import org.junit.Assert._ import scala.util.Random @RunWith(classOf[BlockJUnit4ClassRunner]) class SparsityTransformsTest { import SparsityTransforms._ @Test def testSparifiedDensify() { implicit val r = new Random(0) (1 to 100) foreach { i => { val n = r.nextInt(100) val d = Seq.fill(r.nextInt(10000))(r.nextInt(1000)) val k = Seq.fill(n)(r.nextInt(1000)) val v = Seq.fill(n)(r.nextDouble()) val m = k.zip(v).toMap val f = m.get _ assertTrue(s" test $i iterable: ", parIterableInverseLaw(d, k, v)) assertTrue(s" test $i map: ", mapInverseLaw(d, m)) assertTrue(s" test $i function: ", fnInverseLaw(d, f)) }} } @Test def testDensifyPI() { val res = densifyPI(_3to6, Array(4, 6), Array(1, 2), 0) assertEquals(Vector(0, 1, 0, 2), res) // Show off the cool CBF / functor stuff. Vector because 3 to 6 is an IndexedSeq. assertEquals("scala.collection.immutable.Vector", res.getClass.getCanonicalName) } @Test def testDensifyPIwithEmptyKeys() { val res = densifyPI(_3to6, Seq.empty, Array(1, 2), 0) assertEquals(Vector.fill(4)(0), res) } @Test def testDensifyPIwithEmptyValues() { val res = densifyPI(_3to6, Array(4, 6), Seq.empty, 0) assertEquals(Vector.fill(4)(0), res) } @Test def testDensifyF() { val f = _map.get _ val res = densifyFn(_3to6, f, 0) assertEquals(Vector(0, 1, 0, 2), res) // Show off the cool CBF / functor stuff. Vector because 3 to 6 is an IndexedSeq. assertEquals("scala.collection.immutable.Vector", res.getClass.getCanonicalName) } @Test def testDensifyMap() { val res = densifyMap(_3to6, _map, 0) assertEquals(Vector(0, 1, 0, 2), res) // Show off the cool CBF / functor stuff. Vector because 3 to 6 is an IndexedSeq. assertEquals("scala.collection.immutable.Vector", res.getClass.getCanonicalName) } }
Example 114
Source File: FactoryImportedModelTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.factory import com.eharmony.aloha.audit.impl.OptionAuditor import com.eharmony.aloha.factory.ex.{AlohaFactoryException, RecursiveModelDefinitionException} import com.eharmony.aloha.semantics.NoSemantics import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class FactoryImportedModelTest { private[this] val factory = ModelFactory.defaultFactory(NoSemantics[Any](), OptionAuditor[Int]()) @Test(expected = classOf[RecursiveModelDefinitionException]) def test1CycleDetected() { factory.fromResource("com/eharmony/aloha/factory/cycle1_A.json").get } @Test(expected = classOf[RecursiveModelDefinitionException]) def test2CycleDetected() { factory.fromResource("com/eharmony/aloha/factory/cycle2_A.json").get } @Test(expected = classOf[RecursiveModelDefinitionException]) def test3CycleDetected() { factory.fromResource("com/eharmony/aloha/factory/cycle3_A.json").get } @Test def test1LevelSuccessDefault() { val m = factory.fromResource("com/eharmony/aloha/factory/success_1_level_default.json").get assertEquals(Option(1), m(null)) } @Test def test1LevelSuccessVfs1() { val m = factory.fromResource("com/eharmony/aloha/factory/success_1_level_vfs1.json").get assertEquals(Option(3), m(null)) } @Test def test1LevelSuccessVfs2() { val m = factory.fromResource("com/eharmony/aloha/factory/success_1_level_vfs2.json").get assertEquals(Option(4), m(null)) } @Test def test1LevelSuccessFile() { val m = factory.fromResource("com/eharmony/aloha/factory/success_1_level_file.json").get assertEquals(Option(2), m(null)) } @Test def test1LevelAppropriateFailureWithDefaultProtocol() { try { factory.fromResource("com/eharmony/aloha/factory/bad_reference_default.json").get fail() } catch { case e: AlohaFactoryException => assertTrue(e.getMessage.startsWith("Couldn't resolve VFS2 file")) case e: Exception => fail() } } @Test def test1LevelAppropriateFailureWithVfs1() { try { factory.fromResource("com/eharmony/aloha/factory/bad_reference_vfs1.json").get fail() } catch { case e: AlohaFactoryException => assertTrue(e.getMessage, e.getMessage.startsWith("Couldn't resolve VFS1 file")) case e: Exception => fail() } } @Test def test1LevelAppropriateFailureWithVfs2() { try { factory.fromResource("com/eharmony/aloha/factory/bad_reference_vfs2.json").get fail() } catch { case e: AlohaFactoryException => assertTrue(e.getMessage.startsWith("Couldn't resolve VFS2 file")) case e: Exception => fail() } } @Test def test1LevelApproprivateFailureWithFile() { try { factory.fromResource("com/eharmony/aloha/factory/bad_reference_file.json").get fail() } catch { case e: AlohaFactoryException => assertTrue(e.getMessage.startsWith("Couldn't get JSON for file")) case e: Exception => fail() } } }
Example 115
Source File: FormatsTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.factory import com.eharmony.matching.notaloha.AnEnum import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import spray.json.DefaultJsonProtocol.jsonFormat1 import spray.json._ object FormatsTest { case class GenEnumPossessor[E <: Enum[E]](value: E) case class EnumPossessor(value: AnEnum) implicit val AnEnumFormat = JavaJsonFormats.enumFormat(classOf[AnEnum]) implicit val EnumPossessorFormat: RootJsonFormat[EnumPossessor] = jsonFormat1(EnumPossessor) } @RunWith(classOf[BlockJUnit4ClassRunner]) class FormatsTest { import FormatsTest._ @Test(expected = classOf[DeserializationException]) def testEnumFormatValue1(): Unit = """{ "value": "VALUE_1" }""".parseJson.convertTo[EnumPossessor] @Test def testEnumFormatValue2(): Unit = assertEquals(EnumPossessor(AnEnum.VALUE_2), """{ "value": "VALUE_2" }""".parseJson.convertTo[EnumPossessor]) @Test def testEnumFormatValue3(): Unit = assertEquals(EnumPossessor(AnEnum.VALUE_3), """{ "value": "VALUE_3" }""".parseJson.convertTo[EnumPossessor]) @Test def testGenEnumFormatValue3(): Unit = { val clas = Class.forName(classOf[AnEnum].getName) implicit val ge = geFormat(clas) val v = """{ "value": "VALUE_3" }""".parseJson.convertTo(ge) assertEquals(GenEnumPossessor(AnEnum.VALUE_3), v) } def geFormat[E <: Enum[E]](clas: Class[_]): RootJsonFormat[GenEnumPossessor[E]] = { implicit val ef = Formats.enumFormat(clas.asInstanceOf[Class[E]]) jsonFormat1(GenEnumPossessor[E]) } }
Example 116
Source File: ErrorModelTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.models import com.eharmony.aloha.ModelSerializationTestHelper import com.eharmony.aloha.audit.impl.OptionAuditor import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor import com.eharmony.aloha.factory.ModelFactory import com.eharmony.aloha.id.ModelId import com.eharmony.aloha.semantics.NoSemantics import org.junit.Assert.{assertEquals, assertNotNull, assertTrue} import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class ErrorModelTest extends ModelSerializationTestHelper { private val factory = ModelFactory.defaultFactory(NoSemantics[Unit](), OptionAuditor[Byte]()) @Test def test1() { val em = ErrorModel(ModelId(), Seq("There should be a valid user ID. Couldn't find one...", "blah blah"), RootedTreeAuditor.noUpperBound[Byte]()) val s = em(null) assertNotNull(s) assertTrue(s.value.isEmpty) } @Test def testEmptyErrors() { val json = """ |{ | "modelType": "Error", | "modelId": { "id": 0, "name": "" } |} """.stripMargin val m1 = factory.fromString(json) assertTrue(m1.isSuccess) val json2 = """ |{ | "modelType": "Error", | "modelId": { "id": 0, "name": "" }, | "errors": [] |} """.stripMargin val m2 = factory.fromString(json2) assertTrue(m2.isSuccess) } @Test def testSerialization(): Unit = { val m = ErrorModel(ModelId(2, "abc"), Seq("def", "ghi"), OptionAuditor[Byte]()) val m1 = serializeDeserializeRoundTrip(m) assertEquals(m, m1) } }
Example 117
Source File: BigModelParseTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.models.reg import com.eharmony.aloha.models.reg.json.RegressionModelJson import spray.json.pimpString import com.eharmony.aloha.io.StringReadable import org.junit.runners.BlockJUnit4ClassRunner import org.junit.runner.RunWith import org.junit.Test import org.junit.Assert._ import com.eharmony.aloha.util.{Logging, Timing} @RunWith(classOf[BlockJUnit4ClassRunner]) class BigModelParseTest extends RegressionModelJson with Timing with Logging { @Test def testBigJsonParsedToAstForRegModel() { val ((s, data), t) = time(getBigZippedData("/com/eharmony/aloha/models/reg/semi_cleaned_big_model.json.gz")) assertTrue(s"Should take less than 10 seconds to parse, took $t", t < 10) assertEquals("file lines", 184846, scala.io.Source.fromString(s).getLines().size) assertEquals("Features", 94, data.features.size) assertEquals("First order weights", 874, data.weights.size) assertEquals("Higher order weights", 30598, data.higherOrderFeatures.map(_.size).getOrElse(0)) assertEquals("spline size", 341, data.spline.map(_.knots.size).getOrElse(0)) debug("file lines: 184846, features: 94, first order weights: 874, higher order weights: 30598, spline size: 341") } private[this] def getBigZippedData(resourcePath: String) = { val s = StringReadable.gz.fromResource(resourcePath) (s, s.parseJson.convertTo[RegData]) } }
Example 118
Source File: PolynomialEvaluationAlgoTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.models.reg import com.eharmony.aloha.audit.impl.OptionAuditor import com.eharmony.aloha.factory.ModelFactory import com.eharmony.aloha.reflect.RefInfo import com.eharmony.aloha.semantics.Semantics import com.eharmony.aloha.semantics.func.{GenAggFunc, GenFunc, GeneratedAccessor} import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class PolynomialEvaluationAlgoTest { private[this] val expected = (1 << 7) - 1 // 0111 1111 private[this] val tolerance = 1.0e-6 private[this] val accessor = (key: String) => (_:Map[String, String]).get(key).map(k => Seq((k, 1.0))).getOrElse(Nil) private[this] val semantics = new Semantics[Map[String, String]] { def close(): Unit = {} def refInfoA: RefInfo[Map[String, String]] = RefInfo[Map[String, String]] def accessorFunctionNames = Nil def createFunction[B: RefInfo](codeSpec: String, default: Option[B]): Either[Seq[String], GenAggFunc[Map[String, String], B]] = { val acc = GeneratedAccessor(codeSpec, accessor(codeSpec)) val f = GenFunc.f1(acc)(codeSpec, identity) Right(f.asInstanceOf[GenAggFunc[Map[String, String], B]]) } } private[this] val factory = ModelFactory.defaultFactory(semantics, OptionAuditor[Double]()) @Test def testManualPolyEval() { val x = IndexedSeq( Seq(("intercept", 1.0)), Seq(("female_country=1", 1.0)), Seq(("male_country=2", 1.0)), Seq(("user_gender=MALE", 1.0)), Seq(("cand_gender=FEMALE", 1.0)) ) val weightPaths = Map[Map[String, Int], Double]( Map("intercept" -> 0 ) -> (1 << 0), Map("female_country=1" -> 1 ) -> (1 << 1), Map("male_country=2" -> 2 ) -> (1 << 2), Map("user_gender=MALE" -> 3 ) -> (1 << 3), Map("female_country=1" -> 1, "user_gender=MALE" -> 3) -> (1 << 4), Map("female_country=1" -> 1, "cand_gender=FEMALE" -> 4) -> (1 << 5), Map("male_country=2" -> 2, "user_gender=MALE" -> 3) -> (1 << 6) ) val w = (PolynomialEvaluator.builder ++= weightPaths).result() val y = w at x assertEquals(expected, y, tolerance) assertEquals(weightPaths.values.sum, y, tolerance) } @Test def testJsonParsedPolyEval() { val jStr = """ |{ | "modelType": "Regression", | "modelId": { "id": 0, "name": "" }, | "features": { | "intercept": "intercept", | "female_country": "female_country", | "male_country": "male_country", | "user_gender": "user_gender", | "cand_gender": "cand_gender" | }, | "weights": { | "intercept": 1, | "female_country=1": 2, | "male_country=2": 4, | "user_gender=MALE": 8 | }, | "higherOrderFeatures": [ | { "features": { "female_country": ["female_country=1"], "user_gender": ["user_gender=MALE"] }, "wt": 16 }, | { "features": { "female_country": ["female_country=1"], "cand_gender": ["cand_gender=FEMALE"] }, "wt": 32 }, | { "features": { "male_country": ["male_country=2"], "user_gender": ["user_gender=MALE"] }, "wt": 64 } | ] |} """.stripMargin.trim val m = factory.fromString(jStr).get val x = Map( "intercept" -> "", "female_country" -> "=1", "male_country" -> "=2", "user_gender" -> "=MALE", "cand_gender" -> "=FEMALE" ) val score = m(x) assertEquals(expected, score.get, tolerance) } }
Example 119
Source File: ConstantModelTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.models import com.eharmony.aloha.ModelSerializationTestHelper import com.eharmony.aloha.audit.impl.OptionAuditor import com.eharmony.aloha.id.ModelId import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class ConstantModelTest extends ModelSerializationTestHelper { @Test def testSerialization(): Unit = { val m = ConstantModel(Option(1), ModelId(2, "abc"), OptionAuditor[Int]()) val m1 = serializeDeserializeRoundTrip(m) assertEquals(m, m1) val m2 = ConstantModel(None: Option[String], ModelId(3, "abc"), OptionAuditor[String]()) val m3 = serializeDeserializeRoundTrip(m2) assertEquals(m2, m3) } }
Example 120
Source File: ConstantModelParserTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.models import com.eharmony.aloha.audit.impl.OptionAuditor import com.eharmony.aloha.factory.ModelFactory import com.eharmony.aloha.factory.ex.AlohaFactoryException import com.eharmony.aloha.semantics.NoSemantics import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class ConstantModelParserTest { private val factory = ModelFactory.defaultFactory(NoSemantics[String](), OptionAuditor[Int]()) @Test def testValueOnly() { val js = """ |{ | "modelType": "Constant", | "modelId": {"id": 0, "name": ""}, | "value": 1 |} """.stripMargin val m = factory.fromString(js).get val s = m(null) assertEquals(Option(1), s) } @Test(expected = classOf[Exception]) def testNoOutputSpecified() { val js = """ |{ | "modelType": "Constant", | "modelId": {"id": 0, "name": ""} |} """.stripMargin val m = factory.fromString(js) m.get } @Test(expected = classOf[Exception]) def testNoModelIdSpecified() { val js = """ |{ | "modelType": "Constant", | "value": 1 |} """.stripMargin val m = factory.fromString(js) m.get } @Test(expected = classOf[AlohaFactoryException]) def testNothingSpecified() { val js = """ |{ |} """.stripMargin val m = factory.fromString(js) m.get } }
Example 121
Source File: ErrorModelParserTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.models import com.eharmony.aloha.audit.impl.OptionAuditor import com.eharmony.aloha.factory.ModelFactory import com.eharmony.aloha.semantics.NoSemantics import org.junit.Assert.assertTrue import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class ErrorModelParserTest { private val factory = ModelFactory.defaultFactory(NoSemantics[String](), OptionAuditor[Int]()) @Test def testErrorsFieldMissing() { val js = """ |{ | "modelType": "Error", | "modelId": {"id":0, "name": ""} |} """.stripMargin val m = factory.fromString(js) assertTrue(m.isSuccess) } @Test def test0Errors() { val js = """ |{ | "modelType": "Error", | "modelId": {"id":0, "name": ""}, | "errors": [] |} """.stripMargin val m = factory.fromString(js) assertTrue(m.isSuccess) } @Test def test1Error() { val js = """ |{ | "modelType": "Error", | "modelId": {"id":0, "name": ""}, | "errors": [ | "error 1" | ] |} """.stripMargin val m = factory.fromString(js) assertTrue(m.isSuccess) } @Test def test2Errors() { val js = """ |{ | "modelType": "Error", | "modelId": {"id":0, "name": ""}, | "errors": [ | "error 1", | "error 2" | ] |} """.stripMargin val m = factory.fromString(js) assertTrue(m.isSuccess) } }
Example 122
Source File: PkgTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class PkgTest { @Test def testPkgLocationOk(): Unit = { assertEquals("com.eharmony.aloha", pkgName) } @Test def testVersionFormatOk(): Unit = { val ok = """(\d+)\.(\d+)\.(\d+)(-(SNAPSHOT))?""".r version match { case ok(major, minor, fix, _, snapshot) => () case notOk => fail(s"Bad version format: $notOk") } } }
Example 123
Source File: HeaderTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.dataset.csv import com.eharmony.aloha.dataset.RowCreatorBuilder import com.eharmony.aloha.semantics.compiled.CompiledSemantics import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler import com.eharmony.aloha.semantics.compiled.plugin.csv._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import org.junit.Assert.assertEquals import scala.concurrent.ExecutionContext.Implicits.global private def dsJson(encoding: String) = s""" |{ | "imports": [], | "separator": ",", | "nullValue": "null", | "encoding": "$encoding", | "features": [ | { "spec": "1 to 3", "type": "int", "size": 3, "name": "vec" }, | { "spec": "\\"some_string_value\\"", "type": "string", "name": "str" }, | { "spec": "4", "type": "double", "name": "doub" }, | { "spec": "true", "type": "boolean", "name": "bool" }, | { "spec": "com.eharmony.matching.notaloha.AnEnum.VALUE_2", | "type": "enum", | "enumClass": "com.eharmony.matching.notaloha.AnEnum", | "name": "enum" | } | ] |} """.stripMargin private def csvRowCreator(encoding: String) = { val json = dsJson(encoding) val plugin = CompiledSemanticsCsvPlugin() val semantics = CompiledSemantics(TwitterEvalCompiler(classCacheDir = None), plugin, Nil) val sb = RowCreatorBuilder(semantics, List(CsvRowCreator.Producer[CsvLine]())) sb.fromString(json).get } // Since dsJson doesn't rely on any input data, this can be anything, including null. private val EmptyLine: CsvLineImpl = null }
Example 124
Source File: CsvTypesTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.semantics.compiled.plugin.csv import org.junit.Test import org.junit.runners.BlockJUnit4ClassRunner import org.junit.runner.RunWith import org.junit.Assert._ @RunWith(classOf[BlockJUnit4ClassRunner]) class CsvTypesTest { @Test def testNumTypesCorrect() { assertEquals("Wrong number of types found in CsvTypes", 28, CsvTypes.values.size) } @Test def testTypeMethodCorrespondence() { val typeNames = CsvTypes.values.map(_.toString).toSet val methodNames = classOf[CsvLine].getDeclaredMethods.map(_.getName).toSet val typesWithoutMethods = typeNames -- methodNames val methodsWithoutTypes = methodNames -- typeNames assertEquals(s"The following types in CsvTypes seem to be missing methods in CsvLine: ${typesWithoutMethods.mkString("{", ", ", "}" )}", 0, typesWithoutMethods.size) assertEquals(s"The following methods in CsvLine don't seem to have associated types in CsvTypes: ${methodsWithoutTypes.mkString("{", ", ", "}" )}", 0, methodsWithoutTypes.size) } }
Example 125
Source File: CompiledSemanticsTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.semantics.compiled import java.{lang => jl} import com.eharmony.aloha.FileLocations import com.eharmony.aloha.reflect.RefInfo import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import scala.concurrent.ExecutionContext.Implicits.global import scala.language.implicitConversions @RunWith(classOf[BlockJUnit4ClassRunner]) class CompiledSemanticsTest { private[this] val compiler = TwitterEvalCompiler(classCacheDir = Option(FileLocations.testGeneratedClasses)) @Test def test0() { val s = CompiledSemantics(compiler, MapStringLongPlugin, Seq()) val f = s.createFunction[Int]("List(${five:-5L}).sum.toInt").right.get val x1 = Map("five" -> 1L) val x2 = Map.empty[String, Long] assertEquals(1, f(x1)) assertEquals(5, f(x2)) } @Test def test1() { val s = CompiledSemantics(compiler, MapStringLongPlugin, Seq()) val f = s.createFunction[Int]("List(${one}, ${two}, ${three}).sum.toInt", Option(Int.MinValue)).right.get val x1 = Map[String, Long]("one" -> 2, "two" -> 4, "three" -> 6) val y1 = f(x1) assertEquals(12, y1) } @Test def test2() { val s = CompiledSemantics(compiler, MapStringLongPlugin, Seq()) val f = s.createFunction[Double]("${user.inboundComm} / ${user.pageViews}.toDouble", Some(Double.NaN)).right.get val x1 = Map[String, Long]("user.inboundComm" -> 5, "user.pageViews" -> 10) val x2 = Map[String, Long]("user.inboundComm" -> 5) val y1 = f(x1) val y2 = f(x2) assertEquals(0.5, y1, 1.0e-6) assertEquals(Double.NaN, y2, 0) } @Test def test3() { val s = CompiledSemantics(compiler, MapStringLongPlugin, Seq()) val f = s.createFunction[Long]("new util.Random(0).nextLong").right.get val y1 = f(null) assertEquals(-4962768465676381896L, y1) } @Test def testNullDefaultOnExistingValue() { val s = CompiledSemantics(compiler, MapStringLongPlugin, Seq("com.eharmony.aloha.semantics.compiled.StaticFuncs._")) val f = s.createFunction[Long]("f(${one})").left.map(_.foreach(println)).right.get val y1 = f(Map("one" -> 1)) assertEquals(18, y1) } @Test def testNullDefaultOnNonMissingPrimitiveValue() { val s = CompiledSemantics(compiler, MapStringLongPlugin, Seq("com.eharmony.aloha.semantics.compiled.StaticFuncs._")) var errors: Seq[String] = Nil val f = s.createFunction[Long]("f(${missing:-null}.asInstanceOf[java.lang.Long])"). left.map(e => errors = e). right.get val y1 = f(Map("missing" -> 13)) assertEquals("Should process correctly when defaulting to null", 18, y1) assertEquals("No errors should appear", 0, errors.size) } private[this] object MapStringLongPlugin extends CompiledSemanticsPlugin[Map[String, Long]] { def refInfoA = RefInfo[Map[String, Long]] def accessorFunctionCode(spec: String) = { val required = Seq("user.inboundComm", "one", "two", "three") spec match { case s if required contains s => Right(RequiredAccessorCode(Seq("(_:Map[String, Long]).apply(\"" + spec + "\")"))) case _ => Right(OptionalAccessorCode(Seq("(_:Map[String, Long]).get(\"" + spec + "\")"))) } } } } object StaticFuncs { def f(a: jl.Long): Long = if (null == a) 13 else 18 implicit def doubletoJlDouble(d: Double): java.lang.Double = java.lang.Double.valueOf(d) }
Example 126
Source File: PostgresJsonMarshallerTest.scala From sundial with MIT License | 5 votes |
package dao.postgres.marshalling import com.fasterxml.jackson.databind.PropertyNamingStrategy.SNAKE_CASE import com.hbc.svc.sundial.v2.models.NotificationOptions import model.{EmailNotification, PagerdutyNotification, Team} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatestplus.play.PlaySpec import util.Json @RunWith(classOf[JUnitRunner]) class PostgresJsonMarshallerTest extends PlaySpec { private val postgresJsonMarshaller = new PostgresJsonMarshaller() private val objectMapper = Json.mapper() objectMapper.setPropertyNamingStrategy(SNAKE_CASE) // objectMapper.setVisibility(PropertyAccessor.FIELD,Visibility.ANY) "PostgresJsonMarshaller" should { "correctly deserialize a json string into Seq[Team]" in { val json = """ | [{ | "name" : "teamName", | "email" : "teamEmail", | "notify_action": "on_state_change_and_failures" | }] """.stripMargin val expectedTeams: Seq[Team] = Vector(Team("teamName", "teamEmail", "on_state_change_and_failures")) val actualTeams = postgresJsonMarshaller.toTeams(json) actualTeams must be(expectedTeams) } "correctly serialise a Seq[Team] in a json string" in { val expectedJson = """ | [{ | "name" : "teamName", | "email" : "teamEmail", | "notify_action": "on_state_change_and_failures" | }] """.stripMargin val expectedTeams: Seq[Team] = Vector(Team("teamName", "teamEmail", "on_state_change_and_failures")) val actualJson = postgresJsonMarshaller.toJson(expectedTeams) objectMapper.readTree(actualJson) must be( objectMapper.readTree(expectedJson)) } "correctly deserialize a json string into Seq[Notification]" in { val json = """ |[{"name":"name","email":"email","notify_action":"on_state_change_and_failures", "type": "email"},{"service_key":"service-key","api_url":"http://google.com", "type": "pagerduty","num_consecutive_failures":1}] """.stripMargin val notifications = Vector( EmailNotification( "name", "email", NotificationOptions.OnStateChangeAndFailures.toString), PagerdutyNotification("service-key", "http://google.com", 1) ) val actualNotifications = postgresJsonMarshaller.toNotifications(json) actualNotifications must be(notifications) } "correctly serialise a Seq[Notification] in a json string" in { val json = """ |[{"name":"name","email":"email","notify_action":"on_state_change_and_failures", "type": "email"},{"service_key":"service-key","api_url":"http://google.com", "type": "pagerduty","num_consecutive_failures":1}] """.stripMargin val notifications = Vector( EmailNotification( "name", "email", NotificationOptions.OnStateChangeAndFailures.toString), PagerdutyNotification("service-key", "http://google.com", 1) ) println(s"bla1: ${postgresJsonMarshaller.toJson(notifications)}") println(s"bla2: ${objectMapper.writeValueAsString(notifications)}") objectMapper.readTree(json) must be( objectMapper.readTree(postgresJsonMarshaller.toJson(notifications))) } } }
Example 127
Source File: CronScheduleSpec.scala From sundial with MIT License | 5 votes |
package model import java.text.ParseException import java.util.GregorianCalendar import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatestplus.play.PlaySpec @RunWith(classOf[JUnitRunner]) class CronScheduleSpec extends PlaySpec { "Cron scheduler" should { "successfully parse cron entry for 10pm every day" in { val cronSchedule = CronSchedule("0", "22", "*", "*", "?") val date = new GregorianCalendar(2015, 10, 5, 21, 0).getTime val expectedNextDate = new GregorianCalendar(2015, 10, 5, 22, 0).getTime val nextDate = cronSchedule.nextRunAfter(date) nextDate must be(expectedNextDate) } "Throw exception on creation if cron schedlue is invalid" in { intercept[ParseException] { CronSchedule("0", "22", "*", "*", "*") } } } }
Example 128
Source File: PluginTest.scala From marathon-vault-plugin with MIT License | 5 votes |
package com.avast.marathon.plugin.vault import java.util.concurrent.TimeUnit import com.bettercloud.vault.{Vault, VaultConfig} import org.junit.runner.RunWith import org.scalatest.{FlatSpec, Matchers} import org.scalatest.junit.JUnitRunner import scala.collection.JavaConverters._ import scala.concurrent.Await import scala.concurrent.duration.Duration @RunWith(classOf[JUnitRunner]) class PluginTest extends FlatSpec with Matchers { private lazy val marathonUrl = s"http://${System.getProperty("marathon.host")}:${System.getProperty("marathon.tcp.8080")}" private lazy val mesosSlaveUrl = s"http://${System.getProperty("mesos-slave.host")}:${System.getProperty("mesos-slave.tcp.5051")}" private lazy val vaultUrl = s"http://${System.getProperty("vault.host")}:${System.getProperty("vault.tcp.8200")}" it should "read existing shared secret" in { check("SECRETVAR", env => deployWithSecret("testappjson", env, "/test@testKey")) { envVarValue => envVarValue shouldBe "testValue" } } it should "read existing private secret" in { check("SECRETVAR", env => deployWithSecret("testappjson", env, "test@testKey")) { envVarValue => envVarValue shouldBe "privateTestValue" } } it should "read existing private secret from application in folder" in { check("SECRETVAR", env => deployWithSecret("folder/testappjson", env, "test@testKey")) { envVarValue => envVarValue shouldBe "privateTestFolderValue" } } it should "fail when using .. in secret" in { intercept[RuntimeException] { check("SECRETVAR", env => deployWithSecret("folder/testappjson", env, "test/../test@testKey"), java.time.Duration.ofSeconds(1)) { envVarValue => envVarValue shouldNot be("privateTestFolderValue") } } } private def deployWithSecret(appId: String, envVarName: String, secret: String): String = { val json = s"""{ "id": "$appId","cmd": "${EnvAppCmd.create(envVarName)}","env": {"$envVarName": {"secret": "pwd"}},"secrets": {"pwd": {"source": "$secret"}}}""" val marathonResponse = new MarathonClient(marathonUrl).put(appId, json) appId } private def check(envVarName: String, deployApp: String => String, timeout: java.time.Duration = java.time.Duration.ofSeconds(30))(verifier: String => Unit): Unit = { val client = new MarathonClient(marathonUrl) val eventStream = new MarathonEventStream(marathonUrl) val vaultConfig = new VaultConfig().address(vaultUrl).token("testroottoken").build() val vault = new Vault(vaultConfig) vault.logical().write("secret/shared/test", Map[String, AnyRef]("testKey" -> "testValue").asJava) vault.logical().write("secret/private/testappjson/test", Map[String, AnyRef]("testKey" -> "privateTestValue").asJava) vault.logical().write("secret/private/folder/testappjson/test", Map[String, AnyRef]("testKey" -> "privateTestFolderValue").asJava) val appId = deployApp(envVarName) val appCreatedFuture = eventStream.when(_.eventType.contains("deployment_success")) Await.result(appCreatedFuture, Duration.create(20, TimeUnit.SECONDS)) val agentClient = MesosAgentClient(mesosSlaveUrl) val state = agentClient.fetchState() try { val envVarValue = agentClient.waitForStdOutContentsMatch(envVarName, state.frameworks(0).executors(0), o => EnvAppCmd.extractEnvValue(envVarName, o), timeout) verifier(envVarValue) } finally { client.delete(appId) val appRemovedFuture = eventStream.when(_.eventType.contains("deployment_success")) Await.result(appRemovedFuture, Duration.create(20, TimeUnit.SECONDS)) eventStream.close() } } }
Example 129
Source File: MinMaxActorSpec.scala From coral with Apache License 2.0 | 5 votes |
package io.coral.actors.transform import akka.actor.{Actor, ActorSystem, Props} import akka.testkit.{TestProbe, ImplicitSender, TestActorRef, TestKit} import akka.util.Timeout import io.coral.actors.CoralActorFactory import io.coral.api.DefaultModule import org.json4s.JsonDSL._ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike} import scala.concurrent.duration._ @RunWith(classOf[JUnitRunner]) class MinMaxActorSpec(_system: ActorSystem) extends TestKit(_system) with ImplicitSender with WordSpecLike with Matchers with BeforeAndAfterAll { implicit val timeout = Timeout(100.millis) implicit val formats = org.json4s.DefaultFormats implicit val injector = new DefaultModule(system.settings.config) def this() = this(ActorSystem("ZscoreActorSpec")) override def afterAll() { TestKit.shutdownActorSystem(system) } "A MinMaxActor" must { val createJson = parse( """{ "type": "minmax", "params": { "field": "field1", "min": 10.0, "max": 13.5 }}""" .stripMargin).asInstanceOf[JObject] implicit val injector = new DefaultModule(system.settings.config) val props = CoralActorFactory.getProps(createJson).get val threshold = TestActorRef[MinMaxActor](props) // subscribe the testprobe for emitting val probe = TestProbe() threshold.underlyingActor.emitTargets += probe.ref "Emit the minimum when lower than the min" in { val json = parse( """{"field1": 7 }""").asInstanceOf[JObject] threshold ! json probe.expectMsg(parse( """{ "field1": 10.0 }""")) } "Emit the maximum when higher than the max" in { val json = parse( """{"field1": 15.3 }""").asInstanceOf[JObject] threshold ! json probe.expectMsg(parse( """{"field1": 13.5 }""")) } "Emit the value itself when between the min and the max" in { val json = parse( """{"field1": 11.7 }""").asInstanceOf[JObject] threshold ! json probe.expectMsg(parse( """{"field1": 11.7 }""")) } "Emit object unchanged when key is not present in triggering json" in { val json = parse( """{"otherfield": 15.3 }""").asInstanceOf[JObject] threshold ! json probe.expectMsg(parse( """{"otherfield": 15.3 }""")) } } }
Example 130
Source File: ThresholdActorSpec.scala From coral with Apache License 2.0 | 5 votes |
package io.coral.actors.transform import io.coral.actors.CoralActorFactory import io.coral.api.DefaultModule import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import scala.concurrent.duration._ import akka.actor.ActorSystem import akka.testkit._ import akka.util.Timeout import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike} @RunWith(classOf[JUnitRunner]) class ThresholdActorSpec(_system: ActorSystem) extends TestKit(_system) with ImplicitSender with WordSpecLike with Matchers with BeforeAndAfterAll { implicit val timeout = Timeout(100.millis) def this() = this(ActorSystem("ThresholdActorSpec")) override def afterAll() { TestKit.shutdownActorSystem(system) } "A ThresholdActor" must { val createJson = parse( """{ "type": "threshold", "params": { "key": "key1", "threshold": 10.5 }}""" .stripMargin).asInstanceOf[JObject] implicit val injector = new DefaultModule(system.settings.config) // test invalid definition json as well !!! val props = CoralActorFactory.getProps(createJson).get val threshold = TestActorRef[ThresholdActor](props) // subscribe the testprobe for emitting val probe = TestProbe() threshold.underlyingActor.emitTargets += probe.ref "Emit when equal to the threshold" in { val json = parse( """{"key1": 10.5}""").asInstanceOf[JObject] threshold ! json probe.expectMsg(parse( """{ "key1": 10.5 }""")) } "Emit when higher than the threshold" in { val json = parse( """{"key1": 10.7}""").asInstanceOf[JObject] threshold ! json probe.expectMsg(parse( """{"key1": 10.7 }""")) } "Not emit when lower than the threshold" in { val json = parse( """{"key1": 10.4 }""").asInstanceOf[JObject] threshold ! json probe.expectNoMsg() } "Not emit when key is not present in triggering json" in { val json = parse( """{"key2": 10.7 }""").asInstanceOf[JObject] threshold ! json probe.expectNoMsg() } } }
Example 131
Source File: BootConfigSpec.scala From coral with Apache License 2.0 | 5 votes |
package io.coral.api import org.junit.runner.RunWith import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, WordSpecLike} import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class BootConfigSpec extends WordSpecLike with BeforeAndAfterAll with BeforeAndAfterEach { "A Boot program actor" should { "Properly process given command line arguments for api and akka ports" in { val commandLine = CommandLineConfig(apiPort = Some(1234), akkaPort = Some(5345)) val actual: CoralConfig = io.coral.api.Boot.getFinalConfig(commandLine) assert(actual.akka.remote.nettyTcpPort == 5345) assert(actual.coral.api.port == 1234) } "Properly process a given configuration file through the command line" in { val configPath = getClass().getResource("bootconfigspec.conf").getFile() val commandLine = CommandLineConfig(config = Some(configPath), apiPort = Some(4321)) val actual: CoralConfig = io.coral.api.Boot.getFinalConfig(commandLine) // Overriden in bootconfigspec.conf assert(actual.akka.remote.nettyTcpPort == 6347) // Overridden in command line parameter assert(actual.coral.api.port == 4321) // Not overriden in command line or bootconfigspec.conf assert(actual.coral.cassandra.port == 9042) } } }
Example 132
Source File: RuntimeStatisticsSpec.scala From coral with Apache License 2.0 | 5 votes |
package io.coral.api import io.coral.TestHelper import org.junit.runner.RunWith import org.scalatest.WordSpecLike import org.scalatest.junit.JUnitRunner import org.json4s._ import org.json4s.jackson.JsonMethods._ @RunWith(classOf[JUnitRunner]) class RuntimeStatisticsSpec extends WordSpecLike { "A RuntimeStatistics class" should { "Properly sum multiple statistics objects together" in { val counters1 = Map( (("actor1", "stat1") -> 100L), (("actor1", "stat2") -> 20L), (("actor1", "stat3") -> 15L)) val counters2 = Map( (("actor2", "stat1") -> 20L), (("actor2", "stat2") -> 30L), (("actor2", "stat3") -> 40L)) val counters3 = Map( (("actor2", "stat1") -> 20L), (("actor2", "stat2") -> 30L), (("actor2", "stat3") -> 40L), (("actor2", "stat4") -> 12L)) val stats1 = RuntimeStatistics(1, 2, 3, counters1) val stats2 = RuntimeStatistics(2, 3, 4, counters2) val stats3 = RuntimeStatistics(4, 5, 6, counters3) val actual = RuntimeStatistics.merge(List(stats1, stats2, stats3)) val expected = RuntimeStatistics(7, 10, 13, Map(("actor1", "stat1") -> 100, ("actor1", "stat2") -> 20, ("actor1", "stat3") -> 15, ("actor2", "stat1") -> 20, ("actor2", "stat2") -> 30, ("actor2", "stat3") -> 40, ("actor2", "stat4") -> 12)) assert(actual == expected) } "Create a JSON object from a RuntimeStatistics object" in { val input = RuntimeStatistics(1, 2, 3, Map((("actor1", "stat1") -> 10L), (("actor1", "stat2") -> 20L))) val expected = parse( s"""{ | "totalActors": 1, | "totalMessages": 2, | "totalExceptions": 3, | "counters": { | "total": { | "stat1": 10, | "stat2": 20 | }, "actor1": { | "stat1": 10, | "stat2": 20 | } | } |} """.stripMargin).asInstanceOf[JObject] val actual = RuntimeStatistics.toJson(input) assert(actual == expected) } "Create a RuntimeStatistics object from a JSON object" in { val input = parse( s"""{ | "totalActors": 1, | "totalMessages": 2, | "totalExceptions": 3, | "counters": { | "total": { | "stat1": 10, | "stat2": 20 | }, "actor1": { | "stat1": 10, | "stat2": 20 | } | } |} """.stripMargin).asInstanceOf[JObject] val actual = RuntimeStatistics.fromJson(input) val expected = RuntimeStatistics(1, 2, 3, Map((("actor1", "stat1") -> 10L), (("actor1", "stat2") -> 20L))) assert(actual == expected) } } }
Example 133
Source File: XmlScoverageReportParserSpec.scala From sonar-scala with GNU Lesser General Public License v3.0 | 5 votes |
package com.buransky.plugins.scoverage.xml import org.scalatest.{FlatSpec, Matchers} import org.scalatest.junit.JUnitRunner import org.junit.runner.RunWith import com.buransky.plugins.scoverage.ScoverageException @RunWith(classOf[JUnitRunner]) class XmlScoverageReportParserSpec extends FlatSpec with Matchers { behavior of "parse file path" it must "fail for null path" in { the[IllegalArgumentException] thrownBy XmlScoverageReportParser().parse(null.asInstanceOf[String], null) } it must "fail for empty path" in { the[IllegalArgumentException] thrownBy XmlScoverageReportParser().parse("", null) } it must "fail for not existing path" in { the[ScoverageException] thrownBy XmlScoverageReportParser().parse("/x/a/b/c/1/2/3/4.xml", null) } }
Example 134
Source File: PathUtilSpec.scala From sonar-scala with GNU Lesser General Public License v3.0 | 5 votes |
package com.buransky.plugins.scoverage.util import org.scalatest.{FlatSpec, Matchers} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class PathUtilSpec extends FlatSpec with Matchers { val osName = System.getProperty("os.name") val separator = System.getProperty("file.separator") behavior of s"splitPath for $osName" it should "ignore the empty path" in { PathUtil.splitPath("") should equal(List.empty[String]) } it should "ignore a separator at the beginning" in { PathUtil.splitPath(s"${separator}a") should equal(List("a")) } it should "work with separator in the middle" in { PathUtil.splitPath(s"a${separator}b") should equal(List("a", "b")) } it should "work with an OS dependent absolute path" in { if (osName.startsWith("Windows")) { PathUtil.splitPath("C:\\test\\2") should equal(List("test", "2")) } else { PathUtil.splitPath("/test/2") should equal(List("test", "2")) } } }
Example 135
Source File: BasicSimulation.scala From warp-core with MIT License | 5 votes |
package com.workday.warp.adapters import com.workday.warp.adapters.gatling.{GatlingJUnitRunner, WarpSimulation} import io.gatling.core.Predef._ import io.gatling.http.Predef._ import org.junit.runner.RunWith import io.gatling.core.structure.ScenarioBuilder import io.gatling.http.protocol.HttpProtocolBuilder @RunWith(classOf[GatlingJUnitRunner]) class BasicSimulation extends WarpSimulation { val httpConf: HttpProtocolBuilder = http .baseUrl("http://google.com") val scn: ScenarioBuilder = scenario("Positive Scenario") .exec( http("request_1").get("/") ) setUp(scn.inject(atOnceUsers(1)).protocols(httpConf)) }
Example 136
Source File: BotPluginTestKit.scala From sumobot with Apache License 2.0 | 5 votes |
package com.sumologic.sumobot.test.annotated import akka.actor.ActorSystem import akka.testkit.{TestKit, TestProbe} import com.sumologic.sumobot.core.model.{IncomingMessage, InstantMessageChannel, OutgoingMessage, UserSender} import org.junit.runner.RunWith import org.scalatest.concurrent.Eventually import org.scalatest.junit.JUnitRunner import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike} import slack.models.User import scala.concurrent.duration.{FiniteDuration, _} @RunWith(classOf[JUnitRunner]) abstract class BotPluginTestKit(actorSystem: ActorSystem) extends TestKit(actorSystem) with WordSpecLike with Eventually with Matchers with BeforeAndAfterAll { protected val outgoingMessageProbe = TestProbe() system.eventStream.subscribe(outgoingMessageProbe.ref, classOf[OutgoingMessage]) protected def confirmOutgoingMessage(test: OutgoingMessage => Unit, timeout: FiniteDuration = 1.second): Unit = { outgoingMessageProbe.expectMsgClass(timeout, classOf[OutgoingMessage]) match { case msg: OutgoingMessage => test(msg) } } protected def instantMessage(text: String, user: User = mockUser("123", "jshmoe")): IncomingMessage = { IncomingMessage(text, true, InstantMessageChannel("125", user), "1527239216000090", sentBy = UserSender(user)) } protected def mockUser(id: String, name: String): User = { User(id, name, None, None, None, None, None, None, None, None, None, None, None, None, None, None) } protected def send(message: IncomingMessage): Unit = { system.eventStream.publish(message) } override protected def afterAll(): Unit = { TestKit.shutdownActorSystem(system) } }
Example 137
Source File: ModelSerializabilityTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.models.vw.jni import com.eharmony.aloha.ModelSerializabilityTestBase import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class ModelSerializabilityTest extends ModelSerializabilityTestBase( Seq(ModelSerializabilityTest.pkg), Seq( ".*Test.*", ".*\\$.*" ) ) object ModelSerializabilityTest { def pkg = getClass.getPackage.getName }
Example 138
Source File: ControlThrowable.scala From lacasa with BSD 3-Clause "New" or "Revised" License | 5 votes |
package lacasa.neg import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test import lacasa.util._ @RunWith(classOf[JUnit4]) class ControlThrowableSpec { @Test def test1() { println(s"ControlThrowableSpec.test1") expectError("propagated") { """ class C { import scala.util.control.ControlThrowable import lacasa.Box._ def m(): Unit = { try { val x = 0 val y = x + 10 println(s"res: ${x + y}") } catch { case t: ControlThrowable => println("hello") uncheckedCatchControl } } } """ } } @Test def test2() { println(s"ControlThrowableSpec.test2") expectError("propagated") { """ class C { import scala.util.control.ControlThrowable def m(): Unit = { try { throw new ControlThrowable {} } catch { case t: Throwable => println("hello") } } } """ } } @Test def test3() { println(s"ControlThrowableSpec.test3") expectError("propagated") { """ class SpecialException(msg: String) extends RuntimeException class C { import scala.util.control.ControlThrowable def m(): Unit = { val res = try { 5 } catch { case s: SpecialException => println("a") case c: ControlThrowable => println("b") case t: Throwable => println("c") } } } """ } } }
Example 139
Source File: CaptureSpec.scala From lacasa with BSD 3-Clause "New" or "Revised" License | 5 votes |
package lacasa.test.plugin.capture import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test import lacasa.util._ @RunWith(classOf[JUnit4]) class CaptureSpec { @Test def test() { println(s"CaptureSpec.test") expectError("invalid reference to value acc") { """ import lacasa.{Box, Packed} import Box._ import scala.spores._ class Data { var name: String = _ } class Data2 { var num: Int = _ var dat: Data = _ } object Use { mkBox[Data] { packed => implicit val acc = packed.access val box: packed.box.type = packed.box box.open { _.name = "John" } mkBox[Data2] { packed2 => implicit val acc2 = packed2.access val box2: packed2.box.type = packed2.box box2.capture(box)((x, y) => x.dat = y)(spore { val localBox = box (packedData: Packed[Data2]) => implicit val accessData = packedData.access localBox.open { x => assert(false) } }) } } } """ } } }
Example 140
Source File: BoxOcap.scala From lacasa with BSD 3-Clause "New" or "Revised" License | 5 votes |
package lacasa.neg import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test import lacasa.util._ @RunWith(classOf[JUnit4]) class BoxOcapSpec { @Test def test1() { println("neg.BoxOcapSpec.test1") expectError("NonOcap") { """ object Global { var state = "a" } class NonOcap { def doIt(): Unit = { Global.state = "b" } } class Data { var arr: Array[Int] = _ } class Test { import lacasa.Box._ import scala.spores._ def m(): Unit = { mkBox[Data] { packed => // ok, Data ocap implicit val acc = packed.access packed.box.open(spore { (d: Data) => d.arr = Array(0, 1, 2) // ok val obj = new NonOcap // not ok: cannot inst. non-ocap class }) } } } """ } } }
Example 141
Source File: Stack2.scala From lacasa with BSD 3-Clause "New" or "Revised" License | 5 votes |
package lacasa.neg import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test import lacasa.util._ @RunWith(classOf[JUnit4]) class Stack2Spec { @Test def test1() { println(s"Stack2Spec.test1") expectError("confined") { """ class D { } class C { import scala.spores._ import lacasa.Box def m(): Unit = { Box.mkBox[D] { packed => val fun = () => { val acc = packed.access } } } } """ } } @Test def test2() { println(s"Stack2Spec.test2") expectError("propagated") { """ class D { var arr: Array[Int] = _ } class C { import scala.spores._ import lacasa.Box def m(): Unit = { try { Box.mkBox[D] { packed => val access = packed.access } } catch { case ct: scala.util.control.ControlThrowable => } } } """ } } }
Example 142
Source File: Stack1.scala From lacasa with BSD 3-Clause "New" or "Revised" License | 5 votes |
package lacasa.run import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import scala.util.control.ControlThrowable class Message { var arr: Array[Int] = _ } @RunWith(classOf[JUnit4]) class Stack1Spec { import lacasa.Box._ @Test def test1(): Unit = { println(s"run.Stack1Spec.test1") try { mkBox[Message] { packed => implicit val access = packed.access packed.box open { msg => msg.arr = Array(1, 2, 3, 4) } } } catch { case ct: ControlThrowable => uncheckedCatchControl assert(true, "this should not fail!") } } }
Example 143
Source File: Control.scala From lacasa with BSD 3-Clause "New" or "Revised" License | 5 votes |
package lacasa.run import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 @RunWith(classOf[JUnit4]) class ControlSpec { import scala.util.control.ControlThrowable import lacasa.Box._ @Test def test1(): Unit = { println("run.ControlSpec.test1") val res = try { 5 } catch { case c: ControlThrowable => throw c case t: Throwable => println("hello") } assert(res == 5, "this should not fail") } }
Example 144
Source File: BoxSpec.scala From lacasa with BSD 3-Clause "New" or "Revised" License | 5 votes |
package lacasa.test import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import lacasa.Box._ class DoesNotHaveNoArgCtor(val num: Int) { def incNum = new DoesNotHaveNoArgCtor(num + 1) } @RunWith(classOf[JUnit4]) class BoxSpec { @Test def testMkBoxFor1(): Unit = { try { mkBoxFor(new DoesNotHaveNoArgCtor(0)) { packed => implicit val access = packed.access val box: packed.box.type = packed.box box.open { dnh => assert(dnh.num == 0) } } } catch { case t: Throwable => } } }
Example 145
Source File: example1.scala From lacasa with BSD 3-Clause "New" or "Revised" License | 5 votes |
package lacasa.test.examples import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.{Future, Promise, Await} import scala.concurrent.duration._ import scala.spores._ import lacasa.{System, Box, CanAccess, Actor, ActorRef, doNothing} import Box._ class Message1 { var arr: Array[Int] = _ } class Start { var next: ActorRef[Message1] = _ } class ActorA extends Actor[Any] { override def receive(b: Box[Any]) (implicit acc: CanAccess { type C = b.C }) { b.open(spore { x => x match { case s: Start => mkBox[Message1] { packed => implicit val access = packed.access packed.box open { msg => msg.arr = Array(1, 2, 3, 4) } s.next.send(packed.box) { doNothing.consume(packed.box) } } case other => // .. } }) } } class ActorB(p: Promise[String]) extends Actor[Message1] { override def receive(box: Box[Message1]) (implicit acc: CanAccess { type C = box.C }) { // Strings are Safe, and can therefore be extracted from the box. p.success(box.extract(_.arr.mkString(","))) } } @RunWith(classOf[JUnit4]) class Spec { @Test def test(): Unit = { // to check result val p: Promise[String] = Promise() val sys = System() val a = sys.actor[ActorA, Any] val b = sys.actor[Message1](new ActorB(p)) try { mkBox[Start] { packed => import packed.access val box: packed.box.type = packed.box box open { s => s.next = capture(b) // !!! captures `b` within `open` } a.send(box) { doNothing.consume(packed.box) } } } catch { case t: Throwable => val res = Await.result(p.future, 2.seconds) assert(res == "1,2,3,4") } } }
Example 146
Source File: CaptureSpec.scala From lacasa with BSD 3-Clause "New" or "Revised" License | 5 votes |
package lacasa.test.capture import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import scala.spores._ import scala.spores.SporeConv._ import lacasa.{Box, Packed} import Box._ class Data { var name: String = _ } class Data2 { var num: Int = _ var dat: Data = _ } @RunWith(classOf[JUnit4]) class CaptureSpec { @Test def test(): Unit = { try { mkBox[Data] { packed => implicit val acc = packed.access val box: packed.box.type = packed.box box.open { _.name = "John" } mkBox[Data2] { packed2 => implicit val acc2 = packed2.access val box2: packed2.box.type = packed2.box box2.capture(box)(_.dat = _)(spore { (packedData: Packed[Data2]) => implicit val accessData = packedData.access packedData.box.open { d => assert(d.dat.name == "John") } }) } } } catch { case t: Throwable => } } }
Example 147
package lacasa.test.uniqueness import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.{Future, Promise, Await} import scala.concurrent.duration._ import scala.spores._ import scala.spores.SporeConv._ import lacasa.{System, Box, CanAccess, Actor, ActorRef} import Box._ class C { var f: D = null //var count = 0 } class D { var g: C = null } sealed abstract class Msg final case class Start() extends Msg //final case class Repeat(obj: C) extends Msg class ActorA(next: ActorRef[C]) extends Actor[Msg] { def receive(msg: Box[Msg])(implicit access: CanAccess { type C = msg.C }): Unit = { // create box with externally-unique object mkBox[C] { packed => implicit val acc = packed.access val box: packed.box.type = packed.box // initialize object in box box.open(spore { obj => val d = new D d.g = obj obj.f = d }) next.send(box)(spore { () => }) } } } class ActorB(p: Promise[Boolean]) extends Actor[C] { def receive(msg: Box[C])(implicit access: CanAccess { type C = msg.C }): Unit = { msg.open(spore { x => val d = x.f // check that `d` refers back to `x` p.success(d.g == x) }) } } @RunWith(classOf[JUnit4]) class Spec { @Test def test(): Unit = { // to check result val p: Promise[Boolean] = Promise() val sys = System() val b = sys.actor[C](new ActorB(p)) val a = sys.actor[Msg](new ActorA(b)) try { mkBox[Start] { packed => import packed.access val box: packed.box.type = packed.box a.send(box)(spore { () => }) } } catch { case t: Throwable => val res = Await.result(p.future, 2.seconds) assert(res) } } }
Example 148
Source File: actor.scala From lacasa with BSD 3-Clause "New" or "Revised" License | 5 votes |
package lacasa.test import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.{Future, Promise, Await} import scala.concurrent.duration._ import scala.spores._ import scala.spores.SporeConv._ import lacasa.{System, Box, CanAccess, Actor, ActorRef} import Box._ class NonSneaky { def process(a: Array[Int]): Unit = { for (i <- 0 until a.length) a(i) = a(i) + 1 } } class ActorA(next: ActorRef[C]) extends Actor[C] { def receive(msg: Box[C])(implicit access: CanAccess { type C = msg.C }): Unit = { msg.open(spore { (obj: C) => // OK: update array obj.arr(0) = 100 // OK: create instance of ocap class val ns = new NonSneaky ns.process(obj.arr) }) next.send(msg)(spore { () => }) } } class ActorB(p: Promise[String]) extends Actor[C] { def receive(msg: Box[C])(implicit access: CanAccess { type C = msg.C }): Unit = { msg.open(spore { x => p.success(x.arr.mkString(",")) }) } } class C { var arr: Array[Int] = _ } @RunWith(classOf[JUnit4]) class Spec { @Test def test(): Unit = { // to check result val p: Promise[String] = Promise() val sys = System() val b = sys.actor[C](new ActorB(p)) val a = sys.actor[C](new ActorA(b)) try { mkBox[C] { packed => import packed.access val box: packed.box.type = packed.box // initialize object in box with new array box.open(spore { obj => obj.arr = Array(1, 2, 3, 4) }) a.send(box)(spore { () => }) } } catch { case t: Throwable => val res = Await.result(p.future, 2.seconds) assert(res == "101,3,4,5") } } }
Example 149
Source File: ModelSerializabilityTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.models.h2o import com.eharmony.aloha.ModelSerializabilityTestBase import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class ModelSerializabilityTest extends ModelSerializabilityTestBase( Seq(ModelSerializabilityTest.pkg), Seq( ".*Test.*", ".*\\$.*" ) ) object ModelSerializabilityTest { def pkg = getClass.getPackage.getName }
Example 150
Source File: CompilerTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.models.h2o.compiler import hex.genmodel.GenModel import hex.genmodel.easy.prediction.RegressionModelPrediction import hex.genmodel.easy.{RowData, EasyPredictModelWrapper} import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class CompilerTest { @Test def testNoPackage(): Unit = { val compiler = new Compiler[GenModel] val genModel = compiler.fromResource("com/eharmony/aloha/models/h2o/glm_afa04e31_17ad_4ca6_9bd1_8ab80005ce38.java") val y: GenModel = genModel.get assertTrue(classOf[GenModel].isAssignableFrom(y.getClass)) val x = new RowData x.put("Sex", "F") x.put("Length", java.lang.Double.valueOf(0.0)) x.put("Diameter", java.lang.Double.valueOf(0.0)) x.put("Height", java.lang.Double.valueOf(0.0)) x.put("Whole weight", java.lang.Double.valueOf(0.0)) x.put("Shucked weight", java.lang.Double.valueOf(0.0)) x.put("Viscera weight", java.lang.Double.valueOf(0.0)) x.put("Shell weight", java.lang.Double.valueOf(0.0)) println(new EasyPredictModelWrapper(y).predictRegression(x).value) } @Test def testWithPackage(): Unit = { val compiler = new Compiler[GenModel]() val genModel = compiler.fromResource("com/eharmony/aloha/models/h2o/domain.glm_afa04e31_17ad_4ca6_9bd1_8ab80005ce37.java") val y: GenModel = genModel.get assertTrue(classOf[GenModel].isAssignableFrom(y.getClass)) } @Test def testDrfCompiles(): Unit = { val compiler = new Compiler[GenModel]() val modelTry = compiler.fromResource("com/eharmony/aloha/models/h2o/DRF_model_1463074092542_1.java") val model = new EasyPredictModelWrapper(modelTry.get) val complexPrediction = model.predict(new RowData) val pred = complexPrediction match { case r: RegressionModelPrediction => Option(r.value) case _ => None } assertEquals(Option(0.0), pred) } }
Example 151
Source File: VwSparseMultilabelPredictorTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.models.vw.jni.multilabel import java.io.{ByteArrayOutputStream, File, FileInputStream} import com.eharmony.aloha.ModelSerializationTestHelper import com.eharmony.aloha.io.sources.{Base64StringSource, ExternalSource, ModelSource} import org.apache.commons.codec.binary.Base64 import org.apache.commons.io.IOUtils import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} @RunWith(classOf[BlockJUnit4ClassRunner]) class VwSparseMultilabelPredictorTest extends ModelSerializationTestHelper { import VwSparseMultilabelPredictorTest._ @Test def testSerializability(): Unit = { val predictor = getPredictor(getModelSource(), 3) val ds = serializeDeserializeRoundTrip(predictor) assertEquals(predictor, ds) assertEquals(predictor.vwParams(), ds.vwParams()) assertNotNull(ds.vwModel) } @Test def testVwParameters(): Unit = { val numLabelsInTrainingSet = 3 val predictor = getPredictor(getModelSource(), numLabelsInTrainingSet) predictor.vwParams() match { case Data(vwBinFilePath, ringSize) => checkVwBinFile(vwBinFilePath) checkVwRingSize(numLabelsInTrainingSet, ringSize.toInt) case ps => fail(s"Unexpected VW parameters format. Found string: $ps") } } } object VwSparseMultilabelPredictorTest { private val Data = """\s*-i\s+(\S+)\s+--ring_size\s+(\d+)\s+--testonly\s+--quiet""".r private def getModelSource(): ModelSource = { val f = File.createTempFile("i_dont", "care") f.deleteOnExit() val learner = VWLearners.create[VWActionScoresLearner](s"--quiet --csoaa_ldf mc --csoaa_rank -f ${f.getCanonicalPath}") learner.close() val baos = new ByteArrayOutputStream() IOUtils.copy(new FileInputStream(f), baos) val src = Base64StringSource(Base64.encodeBase64URLSafeString(baos.toByteArray)) ExternalSource(src.localVfs) } private def getPredictor(modelSrc: ModelSource, numLabelsInTrainingSet: Int) = VwSparseMultilabelPredictor[Any](modelSrc, Nil, Nil, numLabelsInTrainingSet) private def checkVwBinFile(vwBinFilePath: String): Unit = { val vwBinFile = new File(vwBinFilePath) assertTrue("VW binary file should have been written to disk", vwBinFile.exists()) vwBinFile.deleteOnExit() } private def checkVwRingSize(numLabelsInTrainingSet: Int, ringSize: Int): Unit = { assertEquals( "vw --ring_size parameter is incorrect:", numLabelsInTrainingSet + VwSparseMultilabelPredictor.AddlVwRingSize, ringSize.toInt ) } }
Example 152
Source File: Stack.scala From lacasa with BSD 3-Clause "New" or "Revised" License | 5 votes |
package lacasa.neg import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test import lacasa.util._ @RunWith(classOf[JUnit4]) class StackSpec { @Test def test1() { println(s"StackSpec.test1") expectError("confined") { """ class D { } class C { import scala.spores._ import lacasa.Box var b: Box[D] = _ def m(): Unit = { Box.mkBox[D] { packed => b = packed.box // assign box to field } } } """ } } @Test def test2() { println(s"StackSpec.test2") expectError("confined") { """ class D { } class C { import scala.spores._ import lacasa.Box var b: lacasa.CanAccess = _ def m(): Unit = { Box.mkBox[D] { packed => b = packed.access // assign permission to field } } } """ } } @Test def test3() { println(s"StackSpec.test3") expectError("confined") { """ class D { } class C { import scala.spores._ import lacasa.Box var b: Any = _ def m(): Unit = { Box.mkBox[D] { packed => b = packed.access // assign permission to field } } } """ } } @Test def test4() { println(s"StackSpec.test4") expectError("confined") { """ class E(x: Any) {} class D { } class C { import scala.spores._ import lacasa.Box def m(): Unit = { Box.mkBox[D] { packed => new E(packed.box) } } } """ } } }
Example 153
Source File: ImplicitsTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.audit.impl.avro import com.google.common.collect.Lists import org.junit.Assert.assertEquals import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import scala.collection.JavaConverters.seqAsJavaListConverter import com.eharmony.aloha.audit.impl.avro.Implicits.{RichFlatScore, RichScore} import java.{lang => jl, util => ju} import org.apache.avro.generic.GenericRecord @Test def testAllFieldsAppear(): Unit = { val s = filledInScore assertEquals(s, s.toFlatScore.toScore) } @Test def testSameFieldsInGenericRecord(): Unit = { val s = filledInScore val s1 = s.asInstanceOf[GenericRecord] val s2 = s.toFlatScore.asInstanceOf[GenericRecord] testStuff(s1, s2, Map( "model" -> modelId, "value" -> value, "errorMsgs" -> errors, "missingVarNames" -> missing, "prob" -> prob )) } private[this] def testStuff(r1: GenericRecord, r2: GenericRecord, data: Map[String, Any]): Unit = { data.foreach { case (k, v) => val v1 = r1.get(k) val v2 = r2.get(k) assertEquals(s"for r1('$k') = $v1. Expected $v", v, r1.get(k)) assertEquals(s"for r2('$k') = $v2. Expected $v", v, r2.get(k)) } } } object ImplicitsTest { private def filledInScore = new Score(modelId, value, subvalues, errors, missing, prob) private def modelId = new ModelId(5L, "five") private def value: jl.Double = 13d private def subvalues = Lists.newArrayList(scr(12L, 8)) private def errors: ju.List[CharSequence] = Lists.newArrayList("one error", "two errors") private def missing: ju.List[CharSequence] = Lists.newArrayList("some feature", "another feature", "yet another feature") private def prob: jl.Float = 1f private lazy val score: Score = scr(1, 1, scr(2L, 2, scr(4f, 4), scr(5, 5) ), scr(3d, 3, scr(6d, 6), scr(7L, 7) ) ) private lazy val irregularTree: Score = scr(1, 1, scr(2L, 2), scr(3d, 3, scr(5d, 5), scr(6L, 6) ), scr(4d, 4, scr(7L, 7) ) ) private[this] def scr(value: Any, id: Long, children: Score*): Score = { new Score( new ModelId(id, ""), value, Lists.newArrayList(children.asJava), java.util.Collections.emptyList(), java.util.Collections.emptyList(), null ) } }
Example 154
Source File: FlatScoreTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.audit.impl.avro import com.google.common.collect.Lists import org.junit.Assert.assertEquals import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import com.eharmony.aloha.audit.impl.avro.AvroScoreAuditorTest.serializeRoundTrip import scala.collection.JavaConverters.seqAsJavaListConverter import java.{util => ju} @RunWith(classOf[BlockJUnit4ClassRunner]) class FlatScoreTest { import FlatScoreTest.flatScore @Test def testSerializability(): Unit = { val serDeserFS = serializeRoundTrip(FlatScore.getClassSchema, flatScore).head // When comparing the records instead of the JSON strings, equality doesn't // hold because they are different types. flatScoreList is a SpecificRecord // and SpecificRecord checks if the other values is a SpecificRecord. assertEquals(flatScore.toString, serDeserFS.toString) } } object FlatScoreTest { private[this] def empty[A]: ju.List[A] = ju.Collections.emptyList[A] private[this] implicit def toArrayList[A, B](as: Seq[A])(implicit ev: A => B): ju.ArrayList[B] = Lists.newArrayList(as.map(ev).asJava) private[this] def fsd(value: Any, id: Long, children: Int*): FlatScoreDescendant = { new FlatScoreDescendant( new ModelId(id, ""), value, children, empty[CharSequence], empty[CharSequence], null ) } private[avro] lazy val flatScore: FlatScore = { new FlatScore(new ModelId(1L, ""), 1, Vector(0, 1), empty[CharSequence], empty[CharSequence], null, Seq( fsd(2L, 2, 2, 3), // 0 fsd(3d, 3, 4, 5), // 1 fsd(4f, 4), // 2 fsd(5, 5), // 3 fsd(6d, 6), // 4 fsd(7L, 7) // 5 ) ) } }
Example 155
Source File: StdAvroModelFactoryTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.factory.avro import com.eharmony.aloha.audit.impl.avro.Score import com.eharmony.aloha.factory.ModelFactory import com.eharmony.aloha.io.vfs.Vfs1 import com.eharmony.aloha.models.Model import org.apache.avro.Schema import org.apache.avro.generic.{GenericData, GenericRecord} import org.apache.commons.io.IOUtils import org.junit.Assert.assertEquals import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import scala.util.Try private[this] def record = { val r = new GenericData.Record(TheSchema) r.put("req_str_1", "smart handsome stubborn") r } } object StdAvroModelFactoryTest { private lazy val TheSchema = { val is = getClass.getClassLoader.getResourceAsStream(SchemaUrlResource) try new Schema.Parser().parse(is) finally IOUtils.closeQuietly(is) } private val ExpectedResult = 7d private val SchemaUrlResource = "avro/class7.avpr" private val SchemaUrl = s"res:$SchemaUrlResource" private val SchemaFile = new java.io.File(getClass.getClassLoader.getResource(SchemaUrlResource).getFile) private val SchemaVfs1FileObject = org.apache.commons.vfs.VFS.getManager.resolveFile(SchemaUrl) private val SchemaVfs2FileObject = org.apache.commons.vfs2.VFS.getManager.resolveFile(SchemaUrl) private val Imports = Seq("com.eharmony.aloha.feature.BasicFunctions._", "scala.math._") private val ReturnType = "Double" private val ModelJson = """ |{ | "modelType": "Regression", | "modelId": { "id": 0, "name": "" }, | "features" : { | "my_attributes": "${req_str_1}.split(\"\\\\W+\").map(v => (s\"=$v\", 1.0))" | }, | "weights": { | "my_attributes=handsome": 1, | "my_attributes=smart": 2, | "my_attributes=stubborn": 4 | } |} """.stripMargin }
Example 156
Source File: PrintProtosTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.cli.dataset import java.io.{ByteArrayOutputStream, IOException} import java.util.Arrays import com.eharmony.aloha.test.proto.Testing.{PhotoProto, UserProto} import com.eharmony.aloha.test.proto.Testing.GenderProto.{FEMALE, MALE} import com.google.protobuf.GeneratedMessage import org.apache.commons.codec.binary.Base64 import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import org.junit.{Ignore, Test} @RunWith(classOf[BlockJUnit4ClassRunner]) @Ignore class PrintProtosTest { @Test def testPrintProtos(): Unit = { System.out.println(alan) System.out.println(kate) } @throws(classOf[IOException]) def alan: String = { val t = UserProto.newBuilder. setId(1). setName("Alan"). setGender(MALE). setBmi(23). addAllPhotos(Arrays.asList( PhotoProto.newBuilder. setId(1). setAspectRatio(1). setHeight(1). build, PhotoProto.newBuilder. setId(2). setAspectRatio(2). setHeight(2).build )).build b64(t) } def kate: String = { val t = UserProto.newBuilder. setId(1). setName("Kate"). setGender(FEMALE). addAllPhotos(Arrays.asList( PhotoProto.newBuilder. setId(3). setAspectRatio(3). setHeight(3). build )).build b64(t) } def b64[M <: GeneratedMessage](p: M): String = { val baos: ByteArrayOutputStream = new ByteArrayOutputStream p.writeTo(baos) new String(Base64.encodeBase64(baos.toByteArray)) } }
Example 157
Source File: ModelTypesTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.cli import com.eharmony.aloha.factory.ModelFactory import org.junit.Assert.assertEquals import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class ModelTypesTest { @Test def testKnownModels(): Unit = { val expected = Seq( "BootstrapExploration", "CategoricalDistribution", "CloserTester", // A test model. "Constant", "DecisionTree", "DoubleToLong", "EpsilonGreedyExploration", "Error", "ErrorSwallowingModel", "H2o", "ModelDecisionTree", "Regression", "Segmentation", "SparseMultilabel", "VwJNI" ) val actual = ModelFactory.defaultFactory(null, null).parsers.map(_.modelType).sorted assertEquals(expected, actual) } }
Example 158
Source File: CliTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.cli import com.eharmony.aloha import com.eharmony.aloha.util.io.TestWithIoCapture import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class CliTest extends TestWithIoCapture { @Test def testNoArgs(): Unit = { val res = run(Cli.main)(Array.empty) assertEquals("No arguments supplied. Supply one of: '--dataset', '--h2o', '--modelrunner', '--vw'.", res.err.contents.trim) } @Test def testBadFlag(): Unit = { val res = run(Cli.main)(Array("-BADFLAG")) assertEquals("'-BADFLAG' supplied. Supply one of: '--dataset', '--h2o', '--modelrunner', '--vw'.", res.err.contents.trim) } @Test def testVw(): Unit = { val res = run(Cli.main)(Array("--vw")) val expected = """ |Error: Missing option --spec |Error: Missing option --model |vw """.stripMargin + aloha.version + """ |Usage: vw [options] | | -s <value> | --spec <value> | spec is an Apache VFS URL to an aloha spec file. | -m <value> | --model <value> | model is an Apache VFS URL to a VW binary model. | --fs-type <value> | file system type: vfs1, vfs2, file. default = vfs2. | -n <value> | --name <value> | name of the model. | -i <value> | --id <value> | numeric id of the model. | --vw-args <value> | arguments to vw | --external | link to a binary VW model rather than embedding it inline in the aloha model. | --num-missing-thresh <value> | number of missing features to allow before returning a 'no-prediction'. | --note <value> | notes to add to the model. Can provide this many parameter times. | --spline-min <value> | min value for spline domain. (must additional provide spline-max and spline-knots). | --spline-max <value> | max value for spline domain. (must additional provide spline-min and spline-knots). | --spline-knots <value> | max value for spline domain. (must additional provide spline-min, spline-delta, and spline-knots). """.stripMargin assertEquals(expected.trim, res.err.contents.trim) } }
Example 159
Source File: RowCreatorProducerTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.dataset import java.lang.reflect.Modifier import com.eharmony.aloha import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator import org.junit.Assert._ import org.junit.{Ignore, Test} import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import scala.collection.JavaConversions.asScalaSet import org.reflections.Reflections @RunWith(classOf[BlockJUnit4ClassRunner]) class RowCreatorProducerTest { import RowCreatorProducerTest._ private[this] def scanPkg = aloha.pkgName + ".dataset" @Test def testAllRowCreatorProducersHaveOnlyZeroArgConstructors() { val reflections = new Reflections(scanPkg) val specProdClasses = reflections.getSubTypesOf(classOf[RowCreatorProducer[_, _, _]]).toSet specProdClasses.foreach { clazz => val cons = clazz.getConstructors assertTrue(s"There should only be one constructor for ${clazz.getCanonicalName}. Found ${cons.length} constructors.", cons.length <= 1) cons.headOption.foreach { c => if (!(WhitelistedRowCreatorProducers contains clazz)) { val nParams = c.getParameterTypes.length assertEquals(s"The constructor for ${clazz.getCanonicalName} should take 0 arguments. It takes $nParams.", 0, nParams) } } } } // TODO: Report the above bug! @Ignore @Test def testAllRowCreatorProducersAreFinalClasses() { val reflections = new Reflections(scanPkg) val specProdClasses = reflections.getSubTypesOf(classOf[RowCreatorProducer[_, _, _]]).toSet specProdClasses.foreach { clazz => assertTrue(s"${clazz.getCanonicalName} needs to be declared final.", Modifier.isFinal(clazz.getModifiers)) } } } object RowCreatorProducerTest { private val WhitelistedRowCreatorProducers = Set[Class[_]]( classOf[VwMultilabelRowCreator.Producer[_, _]] ) }
Example 160
Source File: VwFeatureNormalizerTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.dataset.vw import org.junit.Assert._ import org.junit.{Test, Before} import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class VwFeatureNormalizerTest { private[this] var normalizer: VwFeatureNormalizer = _ @Before def setup() { normalizer = new VwFeatureNormalizer } @Test def testBlank() { assertEquals("", normalizer("").toString) } @Test def testSimple() { val vwLine: String = "1 1| |A a b c" assertEquals("1 1| |A:0.57735 a b c", normalizer.apply(vwLine).toString) } @Test def testMultipleNamespaces() { val vwLine: String = "1 1| |A a b c |b 1=2 3=4" assertEquals("1 1| |A:0.57735 a b c |b:0.70711 1=2 3=4", normalizer.apply(vwLine).toString) } @Test def testWithWeights() { val vwLine: String = "1 1| |A a:0.987 b c:0.435" assertEquals("1 1| |A:0.67988 a:0.987 b c:0.435", normalizer.apply(vwLine).toString) } }
Example 161
Source File: VwContextualBanditRowCreatorProducerTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.dataset.vw.cb import com.eharmony.aloha.dataset.RowCreatorBuilder import com.eharmony.aloha.dataset.vw.VwParsingAndChainOfRespTest import com.eharmony.aloha.semantics.compiled.plugin.csv.CsvLine import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class VwContextualBanditRowCreatorProducerTest { @Test def testAnyMissingDvFails(): Unit = { val semantics = VwParsingAndChainOfRespTest.semantics val sb = RowCreatorBuilder(semantics, List(new VwContextualBanditRowCreator.Producer[CsvLine])) val spec = sb.fromResource("com/eharmony/aloha/dataset/simpleCbSpec.json").get val lines = VwParsingAndChainOfRespTest.csvLines( "Alex,,,,,,,2,1,0", "Bill,,,,,,,2,1,", "Carl,,,,,,,2,,0", "Dale,,,,,,,,1,0" ) // TODO: Work on removing trailing and leading spaces. This is clearly not perfect. val expected = Seq( "2:1:0 |A name=Alex", "|A name=Bill", "|A name=Carl", "|A name=Dale" ) (lines zip expected).zipWithIndex.foreach { case ((x, exp), i) => val act = spec(x)._2.toString assertEquals(s"On test $i: ", exp, act) case d => fail(s"bad: $d") } } }
Example 162
Source File: VwCovariateProducerTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.dataset.vw import com.eharmony.aloha.FileLocations import com.eharmony.aloha.dataset.json.SparseSpec import com.eharmony.aloha.dataset.vw.VwCovariateProducerTest.{X, semantics} import com.eharmony.aloha.dataset.vw.json.VwJsonLike import com.eharmony.aloha.dataset.{CompilerFailureMessages, SparseCovariateProducer, SparseFeatureExtractorFunction} import com.eharmony.aloha.semantics.compiled.CompiledSemantics import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler import com.eharmony.aloha.semantics.compiled.plugin.csv.{CompiledSemanticsCsvPlugin, CsvLine, CsvTypes} import com.eharmony.aloha.semantics.func.GenAggFunc import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import scala.concurrent.ExecutionContext.Implicits.global import scala.util.Success @RunWith(classOf[BlockJUnit4ClassRunner]) class VwCovariateProducerTest { @Test def testGetVwDataWith1Function() { val j = new VwJsonLike { val namespaces = None val normalizeFeatures = None val features = Vector(SparseSpec("i_plus_d", """List(("", ${i} + ${d}))""")) val imports = Nil } val (covariates, default, nss, normalizer) = X.getVwData(semantics, j) covariates match { case Success(SparseFeatureExtractorFunction(IndexedSeq(("i_plus_d", f)))) => assertTrue("Wrong covariate function", f.isInstanceOf[GenAggFunc[CsvLine, Iterable[(String, Double)]]]) case _ => fail("Wrong covariates.") } assertEquals(1, default.size) assertEquals(0, nss.size) assertEquals(None, normalizer) } @Test def testGetVwDataEverythingMissing() { val j = new VwJsonLike { val namespaces = None val normalizeFeatures = None val features = Vector.empty val imports = Nil } val (covariates, default, nss, normalizer) = X.getVwData(semantics, j) covariates match { case Success(SparseFeatureExtractorFunction(IndexedSeq())) => case _ => fail("Wrong covariates.") } assertEquals(0, default.size) assertEquals(0, nss.size) assertEquals(None, normalizer) } } private object VwCovariateProducerTest { object X extends VwCovariateProducer[CsvLine] with SparseCovariateProducer with CompilerFailureMessages { // To expose for testing. override def getVwData(semantics: CompiledSemantics[CsvLine], json: VwJsonLike) = super.getVwData(semantics, json) } lazy val semantics = { val compiler = TwitterEvalCompiler(classCacheDir = Option(FileLocations.testGeneratedClasses)) val plugin = CompiledSemanticsCsvPlugin( "i" -> CsvTypes.IntType, "d" -> CsvTypes.DoubleType ) CompiledSemantics[CsvLine](compiler, plugin, Nil) } }
Example 163
Source File: VwRowCreatorProducerTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.dataset.vw.unlabeled import com.eharmony.aloha.dataset.RowCreatorBuilder import scala.concurrent.ExecutionContext.Implicits.global import com.eharmony.aloha.FileLocations import com.eharmony.aloha.semantics.compiled.CompiledSemantics import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler import com.eharmony.aloha.semantics.compiled.plugin.csv.{CompiledSemanticsCsvPlugin, CsvLine} import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class VwRowCreatorProducerTest { @Test def test1() { val p = CompiledSemanticsCsvPlugin() val sem = CompiledSemantics(TwitterEvalCompiler(classCacheDir = Option(FileLocations.testGeneratedClasses)), p, Nil) val sb = RowCreatorBuilder(sem, List(new VwRowCreator.Producer[CsvLine])) val json1 = """ |{ | "imports": [], | "features": [ { "name":"x", "spec":"Nil" } ] |} """.stripMargin.trim val xOpt = sb.fromString(json1) assertTrue(xOpt.isSuccess) val x = xOpt.get assertEquals(Seq(0), x.defaultNamespace) assertEquals(1, x.featuresFunction.features.size) assertEquals("x", x.featuresFunction.features.head._1) assertEquals(0, x.featuresFunction.features.head._2.accessors.size) assertEquals(0, x.featuresFunction.features.head._2.arity) assertTrue(x.namespaces.isEmpty) assertEquals(None, x.normalizer) } }
Example 164
Source File: VwLabelRowCreatorProducerTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.dataset.vw.labeled import com.eharmony.aloha.dataset.RowCreatorBuilder import com.eharmony.aloha.dataset.vw.VwParsingAndChainOfRespTest import com.eharmony.aloha.semantics.compiled.plugin.csv.CsvLine import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @RunWith(classOf[BlockJUnit4ClassRunner]) class VwLabelRowCreatorProducerTest { @Test def testNonDefaultTagThatsMissingDoesntRemoveLabel() { val semantics = VwParsingAndChainOfRespTest.semantics val sb = RowCreatorBuilder(semantics, List(new VwLabelRowCreator.Producer[CsvLine])) val spec = sb.fromResource("com/eharmony/aloha/dataset/simpleSpecWithTag.json").get val lines = VwParsingAndChainOfRespTest.csvLines( "Alex,,1,,2,,,,,", "Bill,,2,,3,,,,,", "Carl,,0,,,,,,,", "Dale,,3,,1,,,,," ) val expected = Seq( "1 2|A name=Alex marriages=UNK", "2 3|A name=Bill marriages=UNK", "0 |A name=Carl marriages=UNK", "3 1|A name=Dale marriages=UNK" ) lines.zip(expected).foreach{ case(x, exp) => assertEquals( s"for ${x.line}: ", exp, spec(x)._2.toString ) } } @Test def testImportanceMissingRemovesLabel() { val semantics = VwParsingAndChainOfRespTest.semantics val sb = RowCreatorBuilder(semantics, List(new VwLabelRowCreator.Producer[CsvLine])) val spec = sb.fromResource("com/eharmony/aloha/dataset/simpleSpecWithImp.json").get val lines = VwParsingAndChainOfRespTest.csvLines( "Alex,,1,,2,,,,,", "Bill,,2,,3,,,,,", "Carl,,0,,,,,,,", "Dale,,3,,1,,,,," ) val expected = Seq( "1 2 1|A name=Alex marriages=UNK", "2 3 2|A name=Bill marriages=UNK", "|A name=Carl marriages=UNK", // Omitting the importance variable removes the entire label. "3 3|A name=Dale marriages=UNK" ) lines.zip(expected).foreach{ case(x, exp) => assertEquals( s"for ${x.line}: ", exp, spec(x)._2.toString ) } } @Test def testLabelMissingRemovesLabel() { val semantics = VwParsingAndChainOfRespTest.semantics val sb = RowCreatorBuilder(semantics, List(new VwLabelRowCreator.Producer[CsvLine])) val spec = sb.fromResource("com/eharmony/aloha/dataset/simpleSpec.json").get val lines = VwParsingAndChainOfRespTest.csvLines( "Alex,,1,,,,,,,", "Bill,,2,,,,,,,", "Carl,,,,,,,,," ) val expected = Seq( "1 1|A name=Alex marriages=UNK", "2 2|A name=Bill marriages=UNK", "|A name=Carl marriages=UNK" ) lines.zip(expected).foreach{ case(x, exp) => assertEquals(s"for ${x.line}: ", exp, spec(x)._2.toString) } } }
Example 165
Source File: VwLabelRowCreatorTest.scala From aloha with MIT License | 5 votes |
package com.eharmony.aloha.dataset.vw.labeled import com.eharmony.aloha.dataset.SparseFeatureExtractorFunction import com.eharmony.aloha.semantics.func.GenFunc.f0 import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import scala.language.{postfixOps, implicitConversions} @RunWith(classOf[BlockJUnit4ClassRunner]) final class VwLabelRowCreatorTest { private[this] val lab = 3d private[this] val imp0 = 0d private[this] val imp1 = 1d private[this] val imp2 = 2d private[this] val emptyTag = "" private[this] val tag = "t" private[this] implicit def liftToOption[A](a: A): Option[A] = Option(a) private[this] def spec(lab: Option[Double] = None, imp: Option[Double] = None, tag: Option[String] = None): VwLabelRowCreator[Any] = { val fef = new SparseFeatureExtractorFunction[Any](Vector("f1" -> f0("Empty", _ => Nil))) VwLabelRowCreator(fef, 0 to 0 toList, Nil, None, f0("", _ => lab), f0("", _ => imp), f0("", _ => tag)) } private[this] def testLabelRemoval(spec: VwLabelRowCreator[Any], exp: String = ""): Unit = assertEquals(exp, spec(())._2.toString) // All of these should return empty label because the Label function returns a missing label. @Test def testS___() = testLabelRemoval(spec()) @Test def testS__e() = testLabelRemoval(spec(tag = emptyTag)) @Test def testS__t() = testLabelRemoval(spec(tag = tag)) @Test def testS_0_() = testLabelRemoval(spec(imp = imp0)) @Test def testS_0e() = testLabelRemoval(spec(imp = imp0, tag = emptyTag)) @Test def testS_0t() = testLabelRemoval(spec(imp = imp0, tag = tag)) @Test def testS_1_() = testLabelRemoval(spec(imp = imp1)) @Test def testS_1e() = testLabelRemoval(spec(imp = imp1, tag = emptyTag)) @Test def testS_1t() = testLabelRemoval(spec(imp = imp1, tag = tag)) @Test def testS_2_() = testLabelRemoval(spec(imp = imp2)) @Test def testS_2e() = testLabelRemoval(spec(imp = imp2, tag = emptyTag)) @Test def testS_2t() = testLabelRemoval(spec(imp = imp2, tag = tag)) // Importance not provided makes entire label vanish @Test def testS1_e() = testLabelRemoval(spec(lab = lab, tag = emptyTag)) @Test def testS1_t() = testLabelRemoval(spec(lab = lab, tag = tag)) // Importance of zero is given explicitly. @Test def testS10_() = testLabelRemoval(spec(lab = lab, imp = imp0), "3 0 |") @Test def testS10e() = testLabelRemoval(spec(lab = lab, imp = imp0, tag = emptyTag), "3 0 |") @Test def testS10t() = testLabelRemoval(spec(lab = lab, imp = imp0, tag = tag), "3 0 t|") // Importance of 1 is omitted. @Test def testS11_() = testLabelRemoval(spec(lab = lab, imp = imp1), "3 |") @Test def testS11e() = testLabelRemoval(spec(lab = lab, imp = imp1, tag = emptyTag), "3 |") @Test def testS11t() = testLabelRemoval(spec(lab = lab, imp = imp1, tag = tag), "3 t|") @Test def testS12_() = testLabelRemoval(spec(lab = lab, imp = imp2), "3 2 |") @Test def testS12e() = testLabelRemoval(spec(lab = lab, imp = imp2, tag = emptyTag), "3 2 |") @Test def testS12t() = testLabelRemoval(spec(lab = lab, imp = imp2, tag = tag), "3 2 t|") @Test def testStringLabel() { val spec = new VwLabelRowCreator( new SparseFeatureExtractorFunction(Vector("f1" -> f0("Empty", (_: Double) => Nil))), 0 to 0 toList, Nil, None, f0("", (s: Double) => Option(s)), // Label f0("", (_: Double) => Option(1d)), // Importance f0("", (_: Double) => None)) // Tag val values = Seq( -1.0 -> "-1", -0.99999999999999999 -> "-1", -0.9999999999999999 -> "-0.9999999999999999", -1.0E-16 -> "-0.0000000000000001", -1.0E-17 -> "-0.00000000000000001", -1.0E-18 -> "-0", 0.0 -> "0", 1.0E-18 -> "0", 1.0E-17 -> "0.00000000000000001", 1.0E-16 -> "0.0000000000000001", 0.9999999999999999 -> "0.9999999999999999", 0.99999999999999999 -> "1", 1.0 -> "1" ) values foreach { case(v, ex) => assertEquals(s"for line: $v", Option(ex), spec.stringLabel(v)) } } }
Example 166
Source File: DateMapToUnitCircleVectorizerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.types._ import com.salesforce.op.stages.base.sequence.SequenceModel import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import com.salesforce.op.utils.spark.OpVectorMetadata import org.apache.spark.ml.{Estimator, Transformer} import org.apache.spark.ml.linalg.Vectors import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.RichMetadata._ import org.joda.time.{DateTime => JDateTime} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class DateMapToUnitCircleVectorizerTest extends OpEstimatorSpec[OPVector, SequenceModel[DateMap, OPVector], DateMapToUnitCircleVectorizer[DateMap]] with AttributeAsserts { val eps = 1E-4 val sampleDateTimes = Seq[JDateTime]( new JDateTime(2018, 2, 11, 0, 0, 0, 0), new JDateTime(2018, 11, 28, 6, 0, 0, 0), new JDateTime(2018, 2, 17, 12, 0, 0, 0), new JDateTime(2017, 4, 17, 18, 0, 0, 0), new JDateTime(1918, 2, 13, 3, 0, 0, 0) ) val (inputData, f1) = TestFeatureBuilder( sampleDateTimes.map(x => Map("a" -> x.getMillis, "b" -> x.getMillis).toDateMap) ) override val expectedResult: Seq[OPVector] = sampleDateTimes .map{ v => val rad = DateToUnitCircle.convertToRandians(Option(v.getMillis), TimePeriod.HourOfDay) (rad ++ rad).toOPVector } it should "work with its shortcut as a DateMap" in { val output = f1.toUnitCircle(TimePeriod.HourOfDay) val transformed = output.originStage.asInstanceOf[DateMapToUnitCircleVectorizer[DateMap]] .fit(inputData).transform(inputData) val field = transformed.schema(output.name) val actual = transformed.collect(output) assertNominal(field, Array.fill(actual.head.value.size)(false), actual) all (actual.zip(expectedResult).map(g => Vectors.sqdist(g._1.value, g._2.value))) should be < eps } it should "work with its shortcut as a DateTimeMap" in { val (inputDataDT, f1DT) = TestFeatureBuilder( sampleDateTimes.map(x => Map("a" -> x.getMillis, "b" -> x.getMillis).toDateTimeMap) ) val output = f1DT.toUnitCircle(TimePeriod.HourOfDay) val transformed = output.originStage.asInstanceOf[DateMapToUnitCircleVectorizer[DateMap]] .fit(inputData).transform(inputData) val field = transformed.schema(output.name) val actual = transformed.collect(output) assertNominal(field, Array.fill(actual.head.value.size)(false), actual) all (actual.zip(expectedResult).map(g => Vectors.sqdist(g._1.value, g._2.value))) should be < eps } it should "make the correct metadata" in { val fitted = estimator.fit(inputData) val meta = OpVectorMetadata(fitted.getOutputFeatureName, fitted.getMetadata()) meta.columns.length shouldBe 4 meta.columns.flatMap(_.grouping) shouldEqual Seq("a", "a", "b", "b") meta.columns.flatMap(_.descriptorValue) shouldEqual Seq("x_HourOfDay", "y_HourOfDay", "x_HourOfDay", "y_HourOfDay") } }
Example 167
Source File: OpIndexToStringNoFilterTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.types._ import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import com.salesforce.op.utils.spark.RichDataset._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class OpIndexToStringNoFilterTest extends OpTransformerSpec[Text, OpIndexToStringNoFilter] { val (inputData, indF) = TestFeatureBuilder(Seq(0.0, 2.0, 1.0, 0.0, 0.0, 1.0).map(_.toRealNN)) val labels = Array("a", "c") override val transformer: OpIndexToStringNoFilter = new OpIndexToStringNoFilter().setInput(indF).setLabels(labels) override val expectedResult: Seq[Text] = Array("a", OpIndexToStringNoFilter.unseenDefault, "c", "a", "a", "c").map(_.toText) it should "correctly deindex a numeric column using shortcut" in { val str2 = indF.deindexed(labels, handleInvalid = IndexToStringHandleInvalid.NoFilter) val strs2 = str2.originStage.asInstanceOf[OpIndexToStringNoFilter].transform(inputData).collect(str2) strs2 shouldBe expectedResult } }
Example 168
Source File: SetNGramSimilarityTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op._ import com.salesforce.op.features.types._ import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.Transformer import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class SetNGramSimilarityTest extends OpTransformerSpec[RealNN, SetNGramSimilarity] { val (inputData, f1, f2) = TestFeatureBuilder( Seq( (Seq("Red", "Green"), Seq("Red")), (Seq("Red", "Green"), Seq("Yellow, Blue")), (Seq("Red", "Yellow"), Seq("Red", "Yellow")), (Seq[String](), Seq("Red", "Yellow")), (Seq[String](), Seq[String]()), (Seq[String](""), Seq[String]("asdf")), (Seq[String](""), Seq[String]("")), (Seq[String]("", ""), Seq[String]("", "")) ).map(v => v._1.toMultiPickList -> v._2.toMultiPickList) ) val expectedResult = Seq(0.3333333134651184, 0.09722214937210083, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0).toRealNN val catNGramSimilarity = f1.toNGramSimilarity(f2) val transformer = catNGramSimilarity.originStage.asInstanceOf[SetNGramSimilarity] it should "correctly compute char-n-gram similarity with nondefault ngram param" in { val cat5GramSimilarity = f1.toNGramSimilarity(f2, 5) val transformedDs = cat5GramSimilarity.originStage.asInstanceOf[Transformer].transform(inputData) val actualOutput = transformedDs.collect(cat5GramSimilarity) actualOutput shouldBe Seq(0.3333333432674408, 0.12361115217208862, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0).toRealNN } }
Example 169
Source File: RoundTransformerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.types._ import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class RoundTransformerTest extends OpTransformerSpec[Integral, RoundTransformer[Real]] { val sample = Seq(Real(-1.3), Real(-4.9), Real.empty, Real(5.1), Real(-5.1), Real(0.1), Real(2.5), Real(0.4)) val (inputData, f1) = TestFeatureBuilder(sample) val transformer: RoundTransformer[Real] = new RoundTransformer[Real]().setInput(f1) val expectedResult: Seq[Integral] = Seq(Integral(-1), Integral(-5), Integral.empty, Integral(5), Integral(-5), Integral(0), Integral(3), Integral(0)) it should "have a working shortcut" in { val f2 = f1.round() f2.originStage.isInstanceOf[RoundTransformer[_]] shouldBe true } }
Example 170
Source File: OpStringIndexerNoFilterTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op._ import com.salesforce.op.features.types._ import com.salesforce.op.stages.base.unary.UnaryModel import com.salesforce.op.stages.impl.feature.StringIndexerHandleInvalid.Skip import com.salesforce.op.stages.sparkwrappers.generic.SwUnaryModel import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.feature.StringIndexerModel import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class OpStringIndexerNoFilterTest extends OpEstimatorSpec[RealNN, UnaryModel[Text, RealNN], OpStringIndexerNoFilter[Text]] { val txtData = Seq("a", "b", "c", "a", "a", "c").map(_.toText) val (inputData, txtF) = TestFeatureBuilder(txtData) override val expectedResult: Seq[RealNN] = Array(0.0, 2.0, 1.0, 0.0, 0.0, 1.0).map(_.toRealNN) override val estimator: OpStringIndexerNoFilter[Text] = new OpStringIndexerNoFilter[Text]().setInput(txtF) val txtDataNew = Seq("a", "b", "c", "a", "a", "c", "d", "e").map(_.toText) val (dsNew, txtFNew) = TestFeatureBuilder(txtDataNew) val expectedNew = Array(0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 3.0, 3.0).map(_.toRealNN) it should "correctly index a text column (shortcut)" in { val indexed = txtF.indexed() val indices = indexed.originStage.asInstanceOf[OpStringIndexerNoFilter[_]] .fit(inputData).transform(inputData).collect(indexed) indices shouldBe expectedResult val indexed2 = txtF.indexed(handleInvalid = Skip) val indicesfit = indexed2.originStage.asInstanceOf[OpStringIndexer[_]].fit(inputData) val indices2 = indicesfit.transform(inputData).collect(indexed2) val indices3 = indicesfit.asInstanceOf[SwUnaryModel[Text, RealNN, StringIndexerModel]] .setInput(txtFNew).transform(dsNew).collect(indexed2) indices2 shouldBe expectedResult indices3 shouldBe expectedResult } it should "correctly deinxed a numeric column" in { val indexed = txtF.indexed() val indices = indexed.originStage.asInstanceOf[OpStringIndexerNoFilter[_]].fit(inputData).transform(inputData) val deindexed = indexed.deindexed() val deindexedData = deindexed.originStage.asInstanceOf[OpIndexToStringNoFilter] .transform(indices).collect(deindexed) deindexedData shouldBe txtData } it should "assign new strings to the unseen string category" in { val indices = estimator.fit(inputData).setInput(txtFNew).transform(dsNew).collect(estimator.getOutput()) indices shouldBe expectedNew } }
Example 171
Source File: IDFTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op._ import com.salesforce.op.features.types._ import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.feature.IDF import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.{Estimator, Transformer} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class IDFTest extends FlatSpec with TestSparkContext { val data = Seq( Vectors.sparse(4, Array(1, 3), Array(1.0, 2.0)), Vectors.dense(0.0, 1.0, 2.0, 3.0), Vectors.sparse(4, Array(1), Array(1.0)) ) lazy val (ds, f1) = TestFeatureBuilder(data.map(_.toOPVector)) Spec[IDF] should "compute inverted document frequency" in { val idf = f1.idf() val model = idf.originStage.asInstanceOf[Estimator[_]].fit(ds) val transformedData = model.asInstanceOf[Transformer].transform(ds) val results = transformedData.select(idf.name).collect(idf) idf.name shouldBe idf.originStage.getOutputFeatureName val expectedIdf = Vectors.dense(Array(0, 3, 1, 2).map { x => math.log((data.length + 1.0) / (x + 1.0)) }) val expected = scaleDataWithIDF(data, expectedIdf) for { (res, exp) <- results.zip(expected) (x, y) <- res.value.toArray.zip(exp.toArray) } assert(math.abs(x - y) <= 1e-5) } it should "compute inverted document frequency when minDocFreq is 1" in { val idf = f1.idf(minDocFreq = 1) val model = idf.originStage.asInstanceOf[Estimator[_]].fit(ds) val transformedData = model.asInstanceOf[Transformer].transform(ds) val results = transformedData.select(idf.name).collect(idf) idf.name shouldBe idf.originStage.getOutputFeatureName val expectedIdf = Vectors.dense(Array(0, 3, 1, 2).map { x => if (x > 0) math.log((data.length + 1.0) / (x + 1.0)) else 0 }) val expected = scaleDataWithIDF(data, expectedIdf) for { (res, exp) <- results.zip(expected) (x, y) <- res.value.toArray.zip(exp.toArray) } assert(math.abs(x - y) <= 1e-5) } private def scaleDataWithIDF(dataSet: Seq[Vector], model: Vector): Seq[Vector] = { dataSet.map { case data: DenseVector => val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y } Vectors.dense(res) case data: SparseVector => val res = data.indices.zip(data.values).map { case (id, value) => (id, value * model(id)) } Vectors.sparse(data.size, res) } } }
Example 172
Source File: OpCountVectorizerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op._ import com.salesforce.op.features.types._ import com.salesforce.op.test.TestOpVectorColumnType.{IndCol, IndVal} import com.salesforce.op.test.{TestFeatureBuilder, TestOpVectorMetadataBuilder, TestSparkContext} import com.salesforce.op.utils.spark.OpVectorMetadata import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{Assertions, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class OpCountVectorizerTest extends FlatSpec with TestSparkContext { val data = Seq[(Real, TextList)]( (Real(0), Seq("a", "b", "c").toTextList), (Real(1), Seq("a", "b", "b", "b", "a", "c").toTextList) ) lazy val (ds, f1, f2) = TestFeatureBuilder(data) lazy val expected = Array[(Real, OPVector)]( (Real(0), Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 1.0, 1.0)).toOPVector), (Real(1), Vectors.sparse(3, Array(0, 1, 2), Array(3.0, 2.0, 1.0)).toOPVector) ) val f2vec = new OpCountVectorizer().setInput(f2).setVocabSize(3).setMinDF(2) Spec[OpCountVectorizerTest] should "convert array of strings into count vector" in { val transformedData = f2vec.fit(ds).transform(ds) val output = f2vec.getOutput() transformedData.orderBy(f1.name).collect(f1, output) should contain theSameElementsInOrderAs expected } it should "return the a fitted vectorizer with the correct parameters" in { val fitted = f2vec.fit(ds) val vectorMetadata = fitted.getMetadata() val expectedMeta = TestOpVectorMetadataBuilder( f2vec, f2 -> List(IndVal(Some("b")), IndVal(Some("a")), IndVal(Some("c"))) ) // cannot just do equals because fitting is nondeterministic OpVectorMetadata(f2vec.getOutputFeatureName, vectorMetadata).columns should contain theSameElementsAs expectedMeta.columns } it should "convert array of strings into count vector (shortcut version)" in { val output = f2.countVec(minDF = 2, vocabSize = 3) val f2vec = output.originStage.asInstanceOf[OpCountVectorizer] val transformedData = f2vec.fit(ds).transform(ds) transformedData.orderBy(f1.name).collect(f1, output) should contain theSameElementsInOrderAs expected } }
Example 173
Source File: OpTextPivotVectorizerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.types._ import com.salesforce.op.stages.base.sequence.SequenceModel import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class OpTextPivotVectorizerTest extends OpEstimatorSpec[OPVector, SequenceModel[Text, OPVector], OpTextPivotVectorizer[Text]] { lazy val (inputData, f1, f2) = TestFeatureBuilder("text1", "text2", Seq[(Text, Text)]( ("hello world".toText, "Hello world!".toText), ("hello world".toText, "What's up".toText), ("good evening".toText, "How are you doing, my friend?".toText), ("hello world".toText, "Not bad, my friend.".toText), (Text.empty, Text.empty) ) ) override val expectedResult: Seq[OPVector] = Seq( Vectors.sparse(8, Array(0, 4), Array(1.0, 1.0)), Vectors.sparse(8, Array(0, 6), Array(1.0, 1.0)), Vectors.sparse(8, Array(1, 5), Array(1.0, 1.0)), Vectors.sparse(8, Array(0, 6), Array(1.0, 1.0)), Vectors.sparse(8, Array(3, 7), Array(1.0, 1.0)) ).map(_.toOPVector) }
Example 174
Source File: ScalerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.test.TestSparkContext import org.junit.runner.RunWith import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class ScalerTest extends FlatSpec with TestSparkContext { Spec[Scaler] should "error on invalid data" in { val error = intercept[IllegalArgumentException]( Scaler.apply(scalingType = ScalingType.Linear, args = EmptyScalerArgs()) ) error.getMessage shouldBe s"Invalid combination of scaling type '${ScalingType.Linear}' " + s"and args type '${EmptyScalerArgs().getClass.getSimpleName}'" } it should "correctly build construct a LinearScaler" in { val linearScaler = Scaler.apply(scalingType = ScalingType.Linear, args = LinearScalerArgs(slope = 1.0, intercept = 2.0)) linearScaler shouldBe a[LinearScaler] linearScaler.scalingType shouldBe ScalingType.Linear } it should "correctly build construct a LogScaler" in { val linearScaler = Scaler.apply(scalingType = ScalingType.Logarithmic, args = EmptyScalerArgs()) linearScaler shouldBe a[LogScaler] linearScaler.scalingType shouldBe ScalingType.Logarithmic } }
Example 175
Source File: RoundDigitsTransformerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.types._ import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class RoundDigitsTransformerTest extends OpTransformerSpec[Real, RoundDigitsTransformer[Real]] { val sample = Seq(Real(1.4231092), Real(4.3231), Real.empty, Real(-1.0), Real(2.03728181)) val (inputData, f1) = TestFeatureBuilder(sample) val transformer: RoundDigitsTransformer[Real] = new RoundDigitsTransformer[Real](2) .setInput(f1) val expectedResult: Seq[Real] = Seq(Real(1.42), Real(4.32), Real.empty, Real(-1.0), Real(2.04)) it should "have a working shortcut" in { val f2 = f1.round(4) f2.originStage.isInstanceOf[RoundDigitsTransformer[_]] shouldBe true } }
Example 176
Source File: LangDetectorTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.types._ import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.text.Language import org.apache.spark.ml.Transformer import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class LangDetectorTest extends OpTransformerSpec[RealMap, LangDetector[Text]] { // scalastyle:off val (inputData, f1, f2, f3) = TestFeatureBuilder( Seq( ( "I've got a lovely bunch of coconuts".toText, "文化庁によりますと、世界文化遺産への登録を目指している、福岡県の「宗像・沖ノ島と関連遺産群」について、ユネスコの諮問機関は、8つの構成資産のうち、沖ノ島など4つについて、「世界遺産に登録することがふさわしい」とする勧告をまとめました。".toText, "Première détection d’une atmosphère autour d’une exoplanète de la taille de la Terre".toText ), ( "There they are, all standing in a row".toText, "地磁気発生の謎に迫る地球内部の環境、再現実験".toText, "Les deux commissions, créées respectivement en juin 2016 et janvier 2017".toText ), ( "Big ones, small ones, some as big as your head".toText, "大学レスリング界で「黒船」と呼ばれたカザフスタン出身の大型レスラーが、日本の男子グレコローマンスタイルの重量級強化のために一役買っている。山梨学院大をこの春卒業したオレッグ・ボルチン(24)。4月から新日本プロレスの親会社ブシロードに就職。自身も日本を拠点に、アマチュアレスリングで2020年東京五輪を目指す。".toText, "Il publie sa théorie de la relativité restreinte en 1905".toText ) ) ) // scalastyle:on val transformer = new LangDetector[Text]().setInput(f1) private val langMap = f1.detectLanguages() // English result val expectedResult: Seq[RealMap] = Seq( Map("en" -> 0.9999984360934321), Map("en" -> 0.9999900853228016), Map("en" -> 0.9999900116744931) ).map(_.toRealMap) it should "return empty RealMap when input text is empty" in { transformer.transformFn(Text.empty) shouldBe RealMap.empty } it should "detect Japanese language" in { assertDetectionResults( results = transformer.setInput(f2).transform(inputData).collect(transformer.getOutput()), expectedLanguage = Language.Japanese ) } it should "detect French language" in { assertDetectionResults( results = transformer.setInput(f3).transform(inputData).collect(transformer.getOutput()), expectedLanguage = Language.French ) } it should "has a working shortcut" in { val tokenized = f1.detectLanguages() assertDetectionResults( results = tokenized.originStage.asInstanceOf[Transformer].transform(inputData).collect(tokenized), expectedLanguage = Language.English ) } private def assertDetectionResults ( results: Array[RealMap], expectedLanguage: Language, confidence: Double = 0.99 ): Unit = results.foreach(res => { res.value.size shouldBe 1 res.value.contains(expectedLanguage.entryName) shouldBe true res.value(expectedLanguage.entryName) should be >= confidence }) }
Example 177
Source File: TextLenTransformerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.types._ import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder, TestSparkContext} import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TextLenTransformerTest extends OpTransformerSpec[OPVector, TextLenTransformer[_]] with TestSparkContext with AttributeAsserts { val (ds, f1, f2) = TestFeatureBuilder( Seq[(TextList, TextList)]( (TextList(Seq("A", "giraffe", "drinks", "by", "the", "watering", "hole")), TextList(Seq("A giraffe drinks by the watering hole"))), (TextList(Seq("A giraffe drinks by the watering hole")), TextList(Seq("Cheese"))), (TextList(Seq("Cheese", "cake")), TextList(Seq("A giraffe drinks by the watering hole"))), (TextList(Seq("Cheese")), TextList(Seq("Cheese"))), (TextList.empty, TextList(Seq("A giraffe drinks by the watering hole"))), (TextList.empty, TextList(Seq("Cheese", "tart"))), (TextList(Seq("A giraffe drinks by the watering hole")), TextList.empty), (TextList(Seq("Cheese")), TextList.empty), (TextList.empty, TextList.empty) ) ) // Variables for OpTransformer base tests val inputData = ds val transformer = new TextLenTransformer().setInput(f1, f2) val expectedResult = Seq( Array(31.0, 37.0), Array(37.0, 6.0), Array(10.0, 37.0), Array(6.0, 6.0), Array(0.0, 37.0), Array(0.0, 10.0), Array(37.0, 0.0), Array(6.0, 0.0), Array(0.0, 0.0) ).map(Vectors.dense(_).toOPVector) Spec[TextLenTransformer[_]] should "take an array of features as input and return a single vector feature" in { val vector = transformer.getOutput() vector.name shouldBe transformer.getOutputFeatureName vector.typeName shouldBe FeatureType.typeName[OPVector] vector.isResponse shouldBe false } it should "transform the data correctly" in { val transformed = transformer.transform(ds) val vector = transformer.getOutput() val result = transformed.collect(vector) result should contain theSameElementsAs expectedResult } }
Example 178
Source File: VectorsCombinerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op._ import com.salesforce.op.features.TransientFeature import com.salesforce.op.features.types.{Text, _} import com.salesforce.op.stages.base.sequence.SequenceModel import com.salesforce.op.test.{OpEstimatorSpec, PassengerSparkFixtureTest, TestFeatureBuilder} import com.salesforce.op.utils.spark.OpVectorMetadata import com.salesforce.op.utils.spark.RichMetadata._ import org.apache.spark.ml.attribute.MetadataHelper import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql.types.Metadata import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class VectorsCombinerTest extends OpEstimatorSpec[OPVector, SequenceModel[OPVector, OPVector], VectorsCombiner] with PassengerSparkFixtureTest { override def specName: String = classOf[VectorsCombiner].getSimpleName val (inputData, f1, f2) = TestFeatureBuilder(Seq( Vectors.sparse(4, Array(0, 3), Array(1.0, 1.0)).toOPVector -> Vectors.sparse(4, Array(0, 3), Array(2.0, 3.0)).toOPVector, Vectors.dense(Array(2.0, 3.0, 4.0)).toOPVector -> Vectors.dense(Array(12.0, 13.0, 14.0)).toOPVector, // Purposely added some very large sparse vectors to verify the efficiency Vectors.sparse(100000000, Array(1), Array(777.0)).toOPVector -> Vectors.sparse(500000000, Array(0), Array(888.0)).toOPVector )) val estimator = new VectorsCombiner().setInput(f1, f2) val expectedResult = Seq( Vectors.sparse(8, Array(0, 3, 4, 7), Array(1.0, 1.0, 2.0, 3.0)).toOPVector, Vectors.dense(Array(2.0, 3.0, 4.0, 12.0, 13.0, 14.0)).toOPVector, Vectors.sparse(600000000, Array(1, 100000000), Array(777.0, 888.0)).toOPVector ) it should "combine metadata correctly" in { val vector = Seq(height, description, stringMap).transmogrify() val inputs = vector.parents val outputData = new OpWorkflow().setReader(dataReader) .setResultFeatures(vector, inputs(0), inputs(1), inputs(2)) .train().score() val inputMetadata = OpVectorMetadata.flatten(vector.name, inputs.map(i => OpVectorMetadata(outputData.schema(i.name)))) OpVectorMetadata(outputData.schema(vector.name)).columns should contain theSameElementsAs inputMetadata.columns } it should "create metadata correctly" in { val descVect = description.map[Text] { t => Text(t.value match { case Some(text) => "this is dumb " + text case None => "some STUFF to tokenize" }) }.tokenize().tf(numTerms = 5) val vector = Seq(height, stringMap, descVect).transmogrify() val Seq(inputs1, inputs2, inputs3) = vector.parents val outputData = new OpWorkflow().setReader(dataReader) .setResultFeatures(vector, inputs1, inputs2, inputs3) .train().score() outputData.schema(inputs1.name).metadata.wrapped .get[Metadata](MetadataHelper.attributeKeys.ML_ATTR) .getLong(MetadataHelper.attributeKeys.NUM_ATTRIBUTES) shouldBe 5 val inputMetadata = OpVectorMetadata.flatten(vector.name, Array(TransientFeature(inputs1).toVectorMetaData(5, Option(inputs1.name)), OpVectorMetadata(outputData.schema(inputs2.name)), OpVectorMetadata(outputData.schema(inputs3.name)))) OpVectorMetadata(outputData.schema(vector.name)).columns should contain theSameElementsAs inputMetadata.columns } }
Example 179
Source File: OpStopWordsRemoverTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op._ import com.salesforce.op.features.types._ import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.test.{SwTransformerSpec, TestFeatureBuilder} import org.apache.spark.ml.feature.StopWordsRemover import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class OpStopWordsRemoverTest extends SwTransformerSpec[TextList, StopWordsRemover, OpStopWordsRemover] { val data = Seq( "I AM groot", "Groot call me human", "or I will crush you" ).map(_.split(" ").toSeq.toTextList) val (inputData, textListFeature) = TestFeatureBuilder(data) val bigrams = textListFeature.removeStopWords() val transformer = bigrams.originStage.asInstanceOf[OpStopWordsRemover] val expectedResult = Seq(Seq("groot"), Seq("Groot", "call", "human"), Seq("crush")).map(_.toTextList) it should "allow case sensitivity" in { val noStopWords = textListFeature.removeStopWords(caseSensitive = true) val res = noStopWords.originStage.asInstanceOf[OpStopWordsRemover].transform(inputData) res.collect(noStopWords) shouldBe Seq( Seq("I", "AM", "groot"), Seq("Groot", "call", "human"), Seq("I", "crush")).map(_.toTextList) } it should "set custom stop words" in { val noStopWords = textListFeature.removeStopWords(stopWords = Array("Groot", "I")) val res = noStopWords.originStage.asInstanceOf[OpStopWordsRemover].transform(inputData) res.collect(noStopWords) shouldBe Seq( Seq("AM"), Seq("call", "me", "human"), Seq("or", "will", "crush", "you")).map(_.toTextList) } }
Example 180
Source File: TransmogrifierTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features._ import com.salesforce.op.features.types._ import com.salesforce.op.test.TestOpVectorColumnType._ import com.salesforce.op.test.{PassengerSparkFixtureTest, TestOpVectorMetadataBuilder} import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.RichStructType._ import com.salesforce.op._ import org.junit.runner.RunWith import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TransmogrifierTest extends FlatSpec with PassengerSparkFixtureTest with AttributeAsserts { val inputFeatures = Array[OPFeature](heightNoWindow, weight, gender) Spec(Transmogrifier.getClass) should "return a single output feature of type vector with the correct name" in { val feature = inputFeatures.transmogrify() feature.name.contains("gender-heightNoWindow-weight_3-stagesApplied_OPVector") } it should "return a model when fitted" in { val feature = inputFeatures.transmogrify() val model = new OpWorkflow().setResultFeatures(feature).setReader(dataReader).train() model.getResultFeatures() should contain theSameElementsAs Array(feature) val name = model.getResultFeatures().map(_.name).head name.contains("gender-heightNoWindow-weight_3-stagesApplied_OPVector") } it should "correctly transform the data and store the feature names in metadata" in { val feature = inputFeatures.toSeq.transmogrify() val model = new OpWorkflow().setResultFeatures(feature).setReader(dataReader).train() val transformed = model.score(keepRawFeatures = true, keepIntermediateFeatures = true) val hist = feature.parents.flatMap { f => val h = f.history() h.originFeatures.map(o => o -> FeatureHistory(Seq(o), h.stages)) }.toMap transformed.schema.toOpVectorMetadata(feature.name) shouldEqual TestOpVectorMetadataBuilder.withOpNamesAndHist( feature.originStage, hist, (gender, "vecSet", List(IndCol(Some("OTHER")), IndCol(Some(TransmogrifierDefaults.NullString)))), (heightNoWindow, "vecReal", List(RootCol, IndColWithGroup(Some(TransmogrifierDefaults.NullString), heightNoWindow.name))), (weight, "vecReal", List(RootCol, IndColWithGroup(Some(TransmogrifierDefaults.NullString), weight.name))) ) transformed.schema.findFields("heightNoWindow-weight_1-stagesApplied_OPVector").nonEmpty shouldBe true val collected = transformed.collect(feature) collected.head.v.size shouldEqual 6 collected.map(_.v.toArray.toList).toSet shouldEqual Set( List(0.0, 1.0, 211.4, 1.0, 96.0, 1.0), List(1.0, 0.0, 172.0, 0.0, 78.0, 0.0), List(1.0, 0.0, 168.0, 0.0, 67.0, 0.0), List(1.0, 0.0, 363.0, 0.0, 172.0, 0.0), List(1.0, 0.0, 186.0, 0.0, 96.0, 0.0) ) val field = transformed.schema(feature.name) assertNominal(field, Array(false, true, false, true, false, true), collected) } }
Example 181
Source File: OPMapTransformerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.UID import com.salesforce.op.features.types._ import com.salesforce.op.stages.base.unary.UnaryTransformer import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class OPMapTransformerTest extends OpTransformerSpec[IntegralMap, OPMapTransformer[Email, Integral, EmailMap, IntegralMap]] { lazy val (inputData, top) = TestFeatureBuilder("name", Seq( Map("p1" -> "[email protected]", "p2" -> "[email protected]").toEmailMap )) val transformer: OPMapTransformer[Email, Integral, EmailMap, IntegralMap] = new LengthMapTransformer().setInput(top) val expectedResult: Seq[IntegralMap] = Seq( Map("p1" -> 10L, "p2" -> 11L).toIntegralMap ) } class LengthTransformer extends UnaryTransformer[Email, Integral]( operationName = "lengthUnary", uid = UID[LengthTransformer] ) { override def transformFn: (Email => Integral) = (input: Email) => input.value.map(_.length).toIntegral } class LengthMapTransformer ( uid: String = UID[LengthMapTransformer], operationName: String = "lengthMap" ) extends OPMapTransformer[Email, Integral, EmailMap, IntegralMap]( uid = uid, operationName = operationName, transformer = new LengthTransformer )
Example 182
Source File: TimePeriodListTransformerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.FeatureLike import com.salesforce.op.features.types._ import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import com.salesforce.op.utils.date.DateTimeUtils import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.Transformer import org.joda.time.{DateTime => JDateTime} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TimePeriodListTransformerTest extends OpTransformerSpec[OPVector, TimePeriodListTransformer[DateList]] { val dateList: DateList = Seq[Long]( new JDateTime(1879, 3, 14, 0, 0, DateTimeUtils.DefaultTimeZone).getMillis, new JDateTime(1955, 11, 12, 10, 4, DateTimeUtils.DefaultTimeZone).getMillis, new JDateTime(1999, 3, 8, 12, 0, DateTimeUtils.DefaultTimeZone).getMillis, new JDateTime(2019, 4, 30, 13, 0, DateTimeUtils.DefaultTimeZone).getMillis ).toDateList val (inputData, f1) = TestFeatureBuilder(Seq(dateList)) override val transformer: TimePeriodListTransformer[DateList] = new TimePeriodListTransformer(TimePeriod.DayOfMonth).setInput(f1) override val expectedResult: Seq[OPVector] = Seq(Seq(14, 12, 8, 30).map(_.toDouble).toVector.toOPVector) it should "transform with rich shortcuts" in { val dlist = List(new JDateTime(1879, 3, 14, 0, 0, DateTimeUtils.DefaultTimeZone).getMillis) val (inputData2, d1, d2) = TestFeatureBuilder( Seq[(DateList, DateTimeList)]((dlist.toDateList, dlist.toDateTimeList)) ) def assertFeature(feature: FeatureLike[OPVector], expected: Seq[OPVector]): Unit = { val transformed = feature.originStage.asInstanceOf[Transformer].transform(inputData2) val actual = transformed.collect(feature) actual shouldBe expected } assertFeature(d1.toTimePeriod(TimePeriod.DayOfMonth), Seq(Vector(14.0).toOPVector)) assertFeature(d2.toTimePeriod(TimePeriod.DayOfMonth), Seq(Vector(14.0).toOPVector)) } }
Example 183
Source File: FilterIntegralMapTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.types._ import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import com.salesforce.op.utils.spark.RichDataset._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class FilterIntegralMapTest extends OpTransformerSpec[IntegralMap, FilterMap[IntegralMap]] { val (inputData, f1Int) = TestFeatureBuilder[IntegralMap]( Seq( IntegralMap(Map("Arthur" -> 1, "Lancelot" -> 2, "Galahad" -> 3)), IntegralMap(Map("Lancelot" -> 2, "Galahad" -> 3, "Bedevere" -> 4)), IntegralMap(Map("Knight" -> 5)) ) ) val transformer = new FilterMap[IntegralMap]().setInput(f1Int) val expectedResult: Seq[IntegralMap] = Seq( IntegralMap(Map("Arthur" -> 1, "Lancelot" -> 2, "Galahad" -> 3)), IntegralMap(Map("Lancelot" -> 2, "Galahad" -> 3, "Bedevere" -> 4)), IntegralMap(Map("Knight" -> 5)) ) it should "filter by whitelisted keys" in { transformer.setWhiteListKeys(Array("Arthur", "Knight")) val filtered = transformer.transform(inputData).collect(transformer.getOutput()) val dataExpected = Array( IntegralMap(Map("Arthur" -> 1)), IntegralMap.empty, IntegralMap(Map("Knight" -> 5)) ) filtered should contain theSameElementsAs dataExpected } it should "filter by blacklisted keys" in { transformer.setInput(f1Int) .setWhiteListKeys(Array[String]()) .setBlackListKeys(Array("Arthur", "Knight")) val filtered = transformer.transform(inputData).collect(transformer.getOutput()) val dataExpected = Array( IntegralMap(Map("Lancelot" -> 2, "Galahad" -> 3)), IntegralMap(Map("Lancelot" -> 2, "Galahad" -> 3, "Bedevere" -> 4)), IntegralMap.empty ) filtered should contain theSameElementsAs dataExpected } }
Example 184
Source File: Base64VectorizerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.OpWorkflow import com.salesforce.op.features.FeatureLike import com.salesforce.op.features.types._ import com.salesforce.op.test.TestSparkContext import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class Base64VectorizerTest extends FlatSpec with TestSparkContext with Base64TestData with AttributeAsserts { "Base64Vectorizer" should "vectorize random binary data" in { val vec = randomBase64.vectorize(topK = 10, minSupport = 0, cleanText = true, trackNulls = false) val result = new OpWorkflow().setResultFeatures(vec).transform(randomData) result.collect(vec) should contain theSameElementsInOrderAs OPVector(Vectors.dense(0.0, 0.0)) +: Array.fill(expectedRandom.length - 1)(OPVector(Vectors.dense(1.0, 0.0))) } it should "vectorize some real binary content" in { val vec = realBase64.vectorize(topK = 10, minSupport = 0, cleanText = true) assertVectorizer(vec, expectedMime) } it should "vectorize some real binary content with a type hint" in { val vec = realBase64.vectorize(topK = 10, minSupport = 0, cleanText = true, typeHint = Some("application/json")) assertVectorizer(vec, expectedMimeJson) } def assertVectorizer(vec: FeatureLike[OPVector], expected: Seq[Text]): Unit = { val result = new OpWorkflow().setResultFeatures(vec).transform(realData) val vectors = result.collect(vec) val schema = result.schema(vec.name) assertNominal(schema, Array.fill(vectors.head.value.size)(true), vectors) vectors.length shouldBe expected.length // TODO add a more robust check } }
Example 185
Source File: FilterMultiPickListMapTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.types._ import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import com.salesforce.op.utils.spark.RichDataset._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class FilterMultiPickListMapTest extends OpTransformerSpec[MultiPickListMap, FilterMap[MultiPickListMap]] { val (inputData, f1Cat) = TestFeatureBuilder[MultiPickListMap]( Seq( MultiPickListMap(Map("Arthur" -> Set("King", "Briton"), "Lancelot" -> Set("Brave", "Knight"), "Galahad" -> Set("Pure", "Knight"))), MultiPickListMap(Map("Lancelot" -> Set("Brave", "Knight"), "Galahad" -> Set("Pure", "Knight"), "Bedevere" -> Set("Wise", "Knight"))), MultiPickListMap(Map("Knight" -> Set("Ni", "Ekke Ekke Ekke Ekke Ptang Zoo Boing"))) ) ) val transformer = new FilterMap[MultiPickListMap]().setInput(f1Cat) val expectedResult = Seq( MultiPickListMap(Map("Arthur" -> Set("King", "Briton"), "Lancelot" -> Set("Brave", "Knight"), "Galahad" -> Set("Pure", "Knight"))), MultiPickListMap(Map("Lancelot" -> Set("Brave", "Knight"), "Galahad" -> Set("Pure", "Knight"), "Bedevere" -> Set("Wise", "Knight"))), MultiPickListMap(Map("Knight" -> Set("Ni", "EkkeEkkeEkkeEkkePtangZooBoing"))) ) it should "filter whitelisted keys" in { transformer.setWhiteListKeys(Array("Arthur", "Knight")) val filtered = transformer.transform(inputData).collect(transformer.getOutput()) val dataExpected = Array( MultiPickListMap(Map("Arthur" -> Set("King", "Briton"))), MultiPickListMap.empty, MultiPickListMap(Map("Knight" -> Set("Ni", "EkkeEkkeEkkeEkkePtangZooBoing"))) ) filtered should contain theSameElementsAs dataExpected } it should "filter blacklisted keys" in { transformer .setWhiteListKeys(Array[String]()) .setBlackListKeys(Array("Arthur", "Knight")) val filtered = transformer.transform(inputData).collect(transformer.getOutput()) val dataExpected = Array( MultiPickListMap(Map("Lancelot" -> Set("Brave", "Knight"), "Galahad" -> Set("Pure", "Knight"))), MultiPickListMap(Map("Lancelot" -> Set("Brave", "Knight"), "Galahad" -> Set("Pure", "Knight"), "Bedevere" -> Set("Wise", "Knight"))), MultiPickListMap.empty ) filtered should contain theSameElementsAs dataExpected } it should "not clean map when flag set to false" in { transformer .setCleanText(false) .setCleanKeys(false) .setWhiteListKeys(Array("Arthur", "Knight")) .setBlackListKeys(Array()) val filtered = transformer.transform(inputData).collect(transformer.getOutput()) val dataExpected = Array( MultiPickListMap(Map("Arthur" -> Set("King", "Briton"))), MultiPickListMap.empty, MultiPickListMap(Map("Knight" -> Set("Ni", "Ekke Ekke Ekke Ekke Ptang Zoo Boing"))) ) filtered should contain theSameElementsAs dataExpected } it should "clean map when flag set to true" in { transformer .setCleanKeys(true) .setCleanText(true) .setWhiteListKeys(Array("Arthur", "Knight")) .setBlackListKeys(Array()) val filtered = transformer.transform(inputData).collect(transformer.getOutput()) val dataExpected = Array( MultiPickListMap(Map("Arthur" -> Set("King", "Briton"))), MultiPickListMap.empty, MultiPickListMap(Map("Knight" -> Set("Ni", "EkkeEkkeEkkeEkkePtangZooBoing"))) ) filtered should contain theSameElementsAs dataExpected } }
Example 186
Source File: AliasTransformerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.types._ import com.salesforce.op.stages.base.binary.BinaryLambdaTransformer import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.tuples.RichTuple._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class AliasTransformerTest extends OpTransformerSpec[RealNN, AliasTransformer[RealNN]] { val sample = Seq((RealNN(1.0), RealNN(2.0)), (RealNN(4.0), RealNN(4.0))) val (inputData, f1, f2) = TestFeatureBuilder(sample) val transformer = new AliasTransformer(name = "feature").setInput(f1) val expectedResult: Seq[RealNN] = sample.map(_._1) it should "have a shortcut that changes feature name on a raw feature" in { val feature = f1.alias feature.name shouldBe "feature" feature.originStage shouldBe a[AliasTransformer[_]] val origin = feature.originStage.asInstanceOf[AliasTransformer[RealNN]] val transformed = origin.transform(inputData) transformed.collect(feature) shouldEqual expectedResult } it should "have a shortcut that changes feature name on a derived feature" in { val feature = (f1 / f2).alias feature.name shouldBe "feature" feature.originStage shouldBe a[DivideTransformer[_, _]] val origin = feature.originStage.asInstanceOf[DivideTransformer[_, _]] val transformed = origin.transform(inputData) transformed.columns should contain (feature.name) transformed.collect(feature) shouldEqual sample.map { case (v1, v2) => (v1.v -> v2.v).map(_ / _).toRealNN(0.0) } } it should "have a shortcut that changes feature name on a derived wrapped feature" in { val feature = f1.toIsotonicCalibrated(label = f2).alias feature.name shouldBe "feature" feature.originStage shouldBe a[AliasTransformer[_]] } }
Example 187
Source File: PowerTransformerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.types._ import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class PowerTransformerTest extends OpTransformerSpec[Real, PowerTransformer[Real]] { val sample = Seq(Real(-1.3), Real(-4.9), Real.empty, Real(5.1), Real(-5.1), Real(0.1), Real(2.5), Real(0.4)) val (inputData, f1) = TestFeatureBuilder(sample) val transformer: PowerTransformer[Real] = new PowerTransformer[Real](3.0).setInput(f1) override val expectedResult: Seq[Real] = Seq(Some(-1.3), Some(-4.9), None, Some(5.1), Some(-5.1), Some(0.1), Some(2.5), Some(0.4)).map(_.map(v => math.pow(v, 3)).toReal) it should "have a working shortcut" in { val f2 = f1.power(4) f2.originStage.isInstanceOf[PowerTransformer[_]] shouldBe true } }
Example 188
Source File: CeilTransformerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.types._ import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class CeilTransformerTest extends OpTransformerSpec[Integral, CeilTransformer[Real]] { val sample = Seq(Real(-1.3), Real(-4.9), Real.empty, Real(5.1), Real(-5.1), Real(0.1), Real(2.5), Real(0.4)) val (inputData, f1) = TestFeatureBuilder(sample) val transformer: CeilTransformer[Real] = new CeilTransformer[Real]().setInput(f1) override val expectedResult: Seq[Integral] = Seq(Integral(-1), Integral(-4), Integral.empty, Integral(6), Integral(-5), Integral(1), Integral(3), Integral(1)) it should "have a working shortcut" in { val f2 = f1.ceil() f2.originStage.isInstanceOf[CeilTransformer[_]] shouldBe true } }
Example 189
Source File: TimePeriodTransformerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.feature import com.salesforce.op.features.FeatureLike import com.salesforce.op.features.types._ import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder} import com.salesforce.op.utils.date.DateTimeUtils import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.Transformer import org.joda.time.{DateTime => JDateTime} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TimePeriodTransformerTest extends OpTransformerSpec[Integral, TimePeriodTransformer[Date]] { val (inputData, f1) = TestFeatureBuilder(Seq[Date]( new JDateTime(1879, 3, 14, 0, 0, DateTimeUtils.DefaultTimeZone).getMillis.toDate, new JDateTime(1955, 11, 12, 10, 4, DateTimeUtils.DefaultTimeZone).getMillis.toDate, new JDateTime(1999, 3, 8, 12, 0, DateTimeUtils.DefaultTimeZone).getMillis.toDate, Date.empty, new JDateTime(2019, 4, 30, 13, 0, DateTimeUtils.DefaultTimeZone).getMillis.toDate )) override val transformer: TimePeriodTransformer[Date] = new TimePeriodTransformer(TimePeriod.DayOfMonth).setInput(f1) override val expectedResult: Seq[Integral] = Seq(Integral(14), Integral(12), Integral(8), Integral.empty, Integral(30)) it should "correctly transform for all TimePeriod types" in { def assertFeature(feature: FeatureLike[Integral], expected: Seq[Integral]): Unit = { val transformed = feature.originStage.asInstanceOf[Transformer].transform(inputData) val actual = transformed.collect(feature) actual shouldBe expected } TimePeriod.values.foreach(tp => { val expected = tp match { case TimePeriod.DayOfMonth => Array(Integral(14), Integral(12), Integral(8), Integral.empty, Integral(30)) case TimePeriod.DayOfWeek => Array(Integral(5), Integral(6), Integral(1), Integral.empty, Integral(2)) case TimePeriod.DayOfYear => Array(Integral(73), Integral(316), Integral(67), Integral.empty, Integral(120)) case TimePeriod.HourOfDay => Array(Integral(0), Integral(10), Integral(12), Integral.empty, Integral(13)) case TimePeriod.MonthOfYear => Array(Integral(3), Integral(11), Integral(3), Integral.empty, Integral(4)) case TimePeriod.WeekOfMonth => Array(Integral(3), Integral(2), Integral(2), Integral.empty, Integral(5)) case TimePeriod.WeekOfYear => Array(Integral(11), Integral(46), Integral(11), Integral.empty, Integral(18)) case _ => throw new Exception(s"Unexpected TimePeriod encountered, $tp") } withClue(s"Assertion failed for TimePeriod $tp: ") { assertFeature(f1.toTimePeriod(tp), expected) } }) } }
Example 190
Source File: DataSplitterTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.tuning import com.salesforce.op.test.TestSparkContext import org.apache.spark.ml.linalg.Vectors import org.apache.spark.mllib.random.RandomRDDs import org.junit.runner.RunWith import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class DataSplitterTest extends FlatSpec with TestSparkContext with SplitterSummaryAsserts { import spark.implicits._ val seed = 1234L val dataCount = 1000 val trainingLimitDefault = 1E6.toLong val data = RandomRDDs.normalVectorRDD(sc, 1000, 3, seed = seed) .map(v => (1.0, Vectors.dense(v.toArray), "A")).toDF() val dataSplitter = DataSplitter(seed = seed) Spec[DataSplitter] should "split the data in the appropriate proportion - 0.0" in { val (train, test) = dataSplitter.setReserveTestFraction(0.0).split(data) test.count() shouldBe 0 train.count() shouldBe dataCount } it should "down-sample when the data count is above the default training limit" in { val numRows = trainingLimitDefault * 2 val data = RandomRDDs.normalVectorRDD(sc, numRows, 3, seed = seed) .map(v => (1.0, Vectors.dense(v.toArray), "A")).toDF() dataSplitter.preValidationPrepare(data) val dataBalanced = dataSplitter.validationPrepare(data) // validationPrepare calls the data sample method that samples the data to a target ratio but there is an epsilon // to how precise this function is which is why we need to check around that epsilon val samplingErrorEpsilon = (0.1 * trainingLimitDefault).toLong dataBalanced.count() shouldBe trainingLimitDefault +- samplingErrorEpsilon } it should "set and get all data splitter params" in { val maxRows = dataCount / 2 val downSampleFraction = maxRows / dataCount.toDouble val dataSplitter = DataSplitter() .setReserveTestFraction(0.0) .setSeed(seed) .setMaxTrainingSample(maxRows) .setDownSampleFraction(downSampleFraction) dataSplitter.getReserveTestFraction shouldBe 0.0 dataSplitter.getDownSampleFraction shouldBe downSampleFraction dataSplitter.getSeed shouldBe seed dataSplitter.getMaxTrainingSample shouldBe maxRows } it should "split the data in the appropriate proportion - 0.2" in { val (train, test) = dataSplitter.setReserveTestFraction(0.2).split(data) math.abs(test.count() - 200) < 30 shouldBe true math.abs(train.count() - 800) < 30 shouldBe true } it should "split the data in the appropriate proportion - 0.6" in { val (train, test) = dataSplitter.setReserveTestFraction(0.6).split(data) math.abs(test.count() - 600) < 30 shouldBe true math.abs(train.count() - 400) < 30 shouldBe true } it should "keep the data unchanged when prepare is called" in { val dataCount = data.count() val summary = dataSplitter.preValidationPrepare(data) val train = dataSplitter.validationPrepare(data) val sampleF = trainingLimitDefault / dataCount.toDouble val downSampleFraction = math.min(sampleF, 1.0) train.collect().zip(data.collect()).foreach { case (a, b) => a shouldBe b } assertDataSplitterSummary(summary.summaryOpt) { s => s shouldBe DataSplitterSummary(dataCount, downSampleFraction) } } }
Example 191
Source File: RandomParamBuilderTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.selector import com.salesforce.op.stages.impl.classification.{OpLogisticRegression, OpRandomForestClassifier, OpXGBoostClassifier} import com.salesforce.op.test.TestSparkContext import org.junit.runner.RunWith import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class RandomParamBuilderTest extends FlatSpec with TestSparkContext { private val lr = new OpLogisticRegression() private val rf = new OpRandomForestClassifier() private val xgb = new OpXGBoostClassifier() Spec[RandomParamBuilder] should "build a param grid of the desired length with one param variable" in { val min = 0.00001 val max = 10 val lrParams = new RandomParamBuilder() .uniform(lr.regParam, min, max) .build(5) lrParams.length shouldBe 5 lrParams.foreach(_.toSeq.length shouldBe 1) lrParams.foreach(_.toSeq.foreach( p => (p.value.asInstanceOf[Double] < max && p.value.asInstanceOf[Double] > min) shouldBe true)) lrParams.foreach(_.toSeq.map(_.param).toSet shouldBe Set(lr.regParam)) val lrParams2 = new RandomParamBuilder() .exponential(lr.regParam, min, max) .build(20) lrParams2.length shouldBe 20 lrParams2.foreach(_.toSeq.length shouldBe 1) lrParams2.foreach(_.toSeq.foreach( p => (p.value.asInstanceOf[Double] < max && p.value.asInstanceOf[Double] > min) shouldBe true)) lrParams2.foreach(_.toSeq.map(_.param).toSet shouldBe Set(lr.regParam)) } it should "build a param grid of the desired length with many param variables" in { val lrParams = new RandomParamBuilder() .exponential(lr.regParam, .000001, 10) .subset(lr.family, Seq("auto", "binomial", "multinomial")) .uniform(lr.maxIter, 2, 50) .build(23) lrParams.length shouldBe 23 lrParams.foreach(_.toSeq.length shouldBe 3) lrParams.foreach(_.toSeq.map(_.param).toSet shouldBe Set(lr.regParam, lr.family, lr.maxIter)) } it should "work for all param types" in { val xgbParams = new RandomParamBuilder() .subset(xgb.checkpointPath, Seq("a", "b")) // string .uniform(xgb.alpha, 0, 1) // double .uniform(xgb.missing, 0, 100) // float .uniform(xgb.checkpointInterval, 2, 5) // int .uniform(xgb.seed, 5, 1000) // long .uniform(xgb.useExternalMemory) // boolean .exponential(xgb.baseScore, 0.0001, 1) // double .exponential(xgb.missing, 0.000001F, 1) // float - overwrites first call .build(2) xgbParams.length shouldBe 2 xgbParams.foreach(_.toSeq.length shouldBe 7) xgbParams.foreach(_.toSeq.map(_.param).toSet shouldBe Set(xgb.checkpointPath, xgb.alpha, xgb.missing, xgb.checkpointInterval, xgb.seed, xgb.useExternalMemory, xgb.baseScore)) } it should "throw a requirement error if an improper min value is passed in for exponential scale" in { intercept[IllegalArgumentException]( new RandomParamBuilder() .exponential(xgb.baseScore, 0, 1)).getMessage() shouldBe "requirement failed: Min value must be greater than zero for exponential distribution to work" } it should "throw a requirement error if an min max are passed in" in { intercept[IllegalArgumentException]( new RandomParamBuilder() .uniform(xgb.baseScore, 1, 0)).getMessage() shouldBe "requirement failed: Min must be less than max" } }
Example 192
Source File: OpLinearSVCTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.classification import com.salesforce.op.features.types._ import com.salesforce.op.stages.impl.PredictionEquality import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel} import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import org.apache.spark.ml.classification.{LinearSVC, LinearSVCModel} import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class OpLinearSVCTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[LinearSVCModel], OpPredictorWrapper[LinearSVC, LinearSVCModel]] with PredictionEquality { override def specName: String = Spec[OpLinearSVC] val (inputData, rawFeature1, feature2) = TestFeatureBuilder("label", "features", Seq[(RealNN, OPVector)]( 1.0.toRealNN -> Vectors.dense(12.0, 4.3, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.0, 0.3, 0.1).toOPVector, 0.0.toRealNN -> Vectors.dense(1.0, 3.9, 4.3).toOPVector, 1.0.toRealNN -> Vectors.dense(10.0, 1.3, 0.9).toOPVector, 1.0.toRealNN -> Vectors.dense(15.0, 4.7, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.5, 0.9, 10.1).toOPVector, 1.0.toRealNN -> Vectors.dense(11.5, 2.3, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.1, 3.3, 0.1).toOPVector ) ) val feature1 = rawFeature1.copy(isResponse = true) val estimator = new OpLinearSVC().setInput(feature1, feature2) val expectedResult = Seq( Prediction(1.0, Vectors.dense(Array(-1.33, 1.33))), Prediction(0.0, Vectors.dense(Array(1.04, -1.04))), Prediction(0.0, Vectors.dense(Array(2.69, -2.69))), Prediction(1.0, Vectors.dense(Array(-1.32, 1.32))), Prediction(1.0, Vectors.dense(Array(-2.11, 2.11))), Prediction(0.0, Vectors.dense(Array(4.41, -4.41))), Prediction(1.0, Vectors.dense(Array(-1.46, 1.46))), Prediction(0.0, Vectors.dense(Array(1.42, -1.42))) ) it should "allow the user to set the desired spark parameters" in { estimator .setRegParam(0.1) .setMaxIter(20) .setTol(1E-4) estimator.fit(inputData) estimator.predictor.getRegParam shouldBe 0.1 estimator.predictor.getMaxIter shouldBe 20 estimator.predictor.getTol shouldBe 1E-4 } }
Example 193
Source File: OpLogisticRegressionTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.classification import com.salesforce.op.features.types._ import com.salesforce.op.stages.impl.PredictionEquality import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel} import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class OpLogisticRegressionTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[LogisticRegressionModel], OpPredictorWrapper[LogisticRegression, LogisticRegressionModel]] with PredictionEquality { override def specName: String = Spec[OpLogisticRegression] val (inputData, rawFeature1, feature2) = TestFeatureBuilder("label", "features", Seq[(RealNN, OPVector)]( 1.0.toRealNN -> Vectors.dense(12.0, 4.3, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.0, 0.3, 0.1).toOPVector, 0.0.toRealNN -> Vectors.dense(1.0, 3.9, 4.3).toOPVector, 1.0.toRealNN -> Vectors.dense(10.0, 1.3, 0.9).toOPVector, 1.0.toRealNN -> Vectors.dense(15.0, 4.7, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.5, 0.9, 10.1).toOPVector, 1.0.toRealNN -> Vectors.dense(11.5, 2.3, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.1, 3.3, 0.1).toOPVector ) ) val feature1 = rawFeature1.copy(isResponse = true) val estimator = new OpLogisticRegression().setInput(feature1, feature2) val expectedResult = Seq( Prediction(1.0, Array(-20.88, 20.88), Array(0.0, 1.0)), Prediction(0.0, Array(16.70, -16.7), Array(1.0, 0.0)), Prediction(0.0, Array(22.2, -22.2), Array(1.0, 0.0)), Prediction(1.0, Array(-18.35, 18.35), Array(0.0, 1.0)), Prediction(1.0, Array(-31.46, 31.46), Array(0.0, 1.0)), Prediction(0.0, Array(24.67, -24.67), Array(1.0, 0.0)), Prediction(1.0, Array(-22.07, 22.07), Array(0.0, 1.0)), Prediction(0.0, Array(20.9, -20.9), Array(1.0, 0.0)) ) it should "allow the user to set the desired spark parameters" in { estimator .setRegParam(0.1) .setElasticNetParam(0.1) .setMaxIter(20) estimator.fit(inputData) estimator.predictor.getRegParam shouldBe 0.1 estimator.predictor.getElasticNetParam shouldBe 0.1 estimator.predictor.getMaxIter shouldBe 20 } }
Example 194
Source File: OpXGBoostClassifierTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.classification import com.salesforce.op.features.types._ import com.salesforce.op.stages.impl.PredictionEquality import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel} import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import ml.dmlc.xgboost4j.scala.spark.{OpXGBoostQuietLogging, XGBoostClassificationModel, XGBoostClassifier} import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class OpXGBoostClassifierTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[XGBoostClassificationModel], OpPredictorWrapper[XGBoostClassifier, XGBoostClassificationModel]] with PredictionEquality with OpXGBoostQuietLogging { override def specName: String = Spec[OpXGBoostClassifier] val rawData = Seq( 1.0 -> Vectors.dense(12.0, 4.3, 1.3), 0.0 -> Vectors.dense(0.0, 0.3, 0.1), 0.0 -> Vectors.dense(1.0, 3.9, 4.3), 1.0 -> Vectors.dense(10.0, 1.3, 0.9), 1.0 -> Vectors.dense(15.0, 4.7, 1.3), 0.0 -> Vectors.dense(0.5, 0.9, 10.1), 1.0 -> Vectors.dense(11.5, 2.3, 1.3), 0.0 -> Vectors.dense(0.1, 3.3, 0.1) ).map { case (l, v) => l.toRealNN -> v.toOPVector } val (inputData, label, features) = TestFeatureBuilder("label", "features", rawData) val estimator = new OpXGBoostClassifier().setInput(label.copy(isResponse = true), features) estimator.setSilent(1) val expectedResult = Seq( Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)), Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)), Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)), Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)), Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)), Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)), Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)), Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)) ) it should "allow the user to set the desired spark parameters" in { estimator.setAlpha(0.872).setEta(0.99912) estimator.fit(inputData) estimator.predictor.getAlpha shouldBe 0.872 estimator.predictor.getEta shouldBe 0.99912 } }
Example 195
Source File: OpNaiveBayesTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.classification import com.salesforce.op.features.types._ import com.salesforce.op.stages.impl.PredictionEquality import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel} import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class OpNaiveBayesTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[NaiveBayesModel], OpPredictorWrapper[NaiveBayes, NaiveBayesModel]] with PredictionEquality { override def specName: String = Spec[OpNaiveBayes] val (inputData, rawFeature1, feature2) = TestFeatureBuilder("label", "features", Seq[(RealNN, OPVector)]( 1.0.toRealNN -> Vectors.dense(12.0, 4.3, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.0, 0.3, 0.1).toOPVector, 0.0.toRealNN -> Vectors.dense(1.0, 3.9, 4.3).toOPVector, 1.0.toRealNN -> Vectors.dense(10.0, 1.3, 0.9).toOPVector, 1.0.toRealNN -> Vectors.dense(15.0, 4.7, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.5, 0.9, 10.1).toOPVector, 1.0.toRealNN -> Vectors.dense(11.5, 2.3, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.1, 3.3, 0.1).toOPVector ) ) val feature1 = rawFeature1.copy(isResponse = true) val estimator = new OpNaiveBayes().setInput(feature1, feature2) val expectedResult = Seq( Prediction(1.0, Array(-34.41, -14.85), Array(0.0, 1.0)), Prediction(0.0, Array(-1.07, -1.42), Array(0.58, 0.41)), Prediction(0.0, Array(-9.70, -17.99), Array(1.0, 0.0)), Prediction(1.0, Array(-26.22, -8.33), Array(0.0, 1.0)), Prediction(1.0, Array(-41.93, -16.49), Array(0.0, 1.0)), Prediction(0.0, Array(-8.60, -27.31), Array(1.0, 0.0)), Prediction(1.0, Array(-31.07, -11.44), Array(0.0, 1.0)), Prediction(0.0, Array(-4.54, -6.32), Array(0.85, 0.14)) ) it should "allow the user to set the desired spark parameters" in { estimator.setSmoothing(2) estimator.fit(inputData) estimator.predictor.getSmoothing shouldBe 2 } }
Example 196
Source File: OpMultilayerPerceptronClassifierTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.classification import com.salesforce.op.features.types._ import com.salesforce.op.stages.impl.PredictionEquality import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel} import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier} import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class OpMultilayerPerceptronClassifierTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[MultilayerPerceptronClassificationModel], OpPredictorWrapper[MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]] with PredictionEquality { override def specName: String = Spec[OpMultilayerPerceptronClassifier] val (inputData, rawFeature1, feature2) = TestFeatureBuilder("label", "features", Seq[(RealNN, OPVector)]( 1.0.toRealNN -> Vectors.dense(12.0, 4.3, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.0, 0.3, 0.1).toOPVector, 0.0.toRealNN -> Vectors.dense(1.0, 3.9, 4.3).toOPVector, 1.0.toRealNN -> Vectors.dense(10.0, 1.3, 0.9).toOPVector, 1.0.toRealNN -> Vectors.dense(15.0, 4.7, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.5, 0.9, 10.1).toOPVector, 1.0.toRealNN -> Vectors.dense(11.5, 2.3, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.1, 3.3, 0.1).toOPVector ) ) val feature1 = rawFeature1.copy(isResponse = true) val estimator = new OpMultilayerPerceptronClassifier() .setInput(feature1, feature2) .setLayers(Array(3, 5, 4, 2)) val expectedResult = Seq( Prediction(1.0, Array(-9.655814651428148, 9.202335441336952), Array(6.456683124562021E-9, 0.9999999935433168)), Prediction(0.0, Array(9.475612761543069, -10.617525149157993), Array(0.9999999981221492, 1.877850786773977E-9)), Prediction(0.0, Array(9.715293827870028, -10.885255922155942), Array(0.9999999988694366, 1.130563392364822E-9)), Prediction(1.0, Array(-9.66776357765489, 9.215079716735316), Array(6.299199338896916E-9, 0.9999999937008006)), Prediction(1.0, Array(-9.668041712561456, 9.215387575592239), Array(6.2955091287182745E-9, 0.9999999937044908)), Prediction(0.0, Array(9.692904797559496, -10.860273756796797), Array(0.9999999988145918, 1.1854083109077814E-9)), Prediction(1.0, Array(-9.667687253240183, 9.214995747770411), Array(6.300209139771467E-9, 0.9999999936997908)), Prediction(0.0, Array(9.703097414537668, -10.872171694864653), Array(0.9999999988404908, 1.1595091005698914E-9)) ) it should "allow the user to set the desired spark parameters" in { estimator.setMaxIter(50).setBlockSize(2).setSeed(42) estimator.fit(inputData) estimator.predictor.getMaxIter shouldBe 50 estimator.predictor.getBlockSize shouldBe 2 estimator.predictor.getSeed shouldBe 42 } }
Example 197
Source File: OpDecisionTreeClassifierTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.classification import com.salesforce.op.features.types._ import com.salesforce.op.stages.impl.PredictionEquality import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel} import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class OpDecisionTreeClassifierTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[DecisionTreeClassificationModel], OpPredictorWrapper[DecisionTreeClassifier, DecisionTreeClassificationModel]] with PredictionEquality { override def specName: String = Spec[OpDecisionTreeClassifier] val (inputData, rawFeature1, feature2) = TestFeatureBuilder("label", "features", Seq[(RealNN, OPVector)]( 1.0.toRealNN -> Vectors.dense(12.0, 4.3, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.0, 0.3, 0.1).toOPVector, 0.0.toRealNN -> Vectors.dense(1.0, 3.9, 4.3).toOPVector, 1.0.toRealNN -> Vectors.dense(10.0, 1.3, 0.9).toOPVector, 1.0.toRealNN -> Vectors.dense(15.0, 4.7, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.5, 0.9, 10.1).toOPVector, 1.0.toRealNN -> Vectors.dense(11.5, 2.3, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.1, 3.3, 0.1).toOPVector ) ) val feature1 = rawFeature1.copy(isResponse = true) val estimator = new OpDecisionTreeClassifier().setInput(feature1, feature2) val expectedResult = Seq( Prediction(1.0, Array(0.0, 4.0), Array(0.0, 1.0)), Prediction(0.0, Array(4.0, 0.0), Array(1.0, 0.0)), Prediction(0.0, Array(4.0, 0.0), Array(1.0, 0.0)), Prediction(1.0, Array(0.0, 4.0), Array(0.0, 1.0)), Prediction(1.0, Array(0.0, 4.0), Array(0.0, 1.0)), Prediction(0.0, Array(4.0, 0.0), Array(1.0, 0.0)), Prediction(1.0, Array(0.0, 4.0), Array(0.0, 1.0)), Prediction(0.0, Array(4.0, 0.0), Array(1.0, 0.0)) ) it should "allow the user to set the desired spark parameters" in { estimator .setMaxDepth(6) .setMaxBins(2) .setMinInstancesPerNode(2) .setMinInfoGain(0.1) estimator.fit(inputData) estimator.predictor.getMaxDepth shouldBe 6 estimator.predictor.getMaxBins shouldBe 2 estimator.predictor.getMinInstancesPerNode shouldBe 2 estimator.predictor.getMinInfoGain shouldBe 0.1 } }
Example 198
Source File: OpRandomForestClassifierTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.classification import com.salesforce.op.features.types._ import com.salesforce.op.stages.impl.PredictionEquality import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel} import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class OpRandomForestClassifierTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[RandomForestClassificationModel], OpPredictorWrapper[RandomForestClassifier, RandomForestClassificationModel]] with PredictionEquality { override def specName: String = Spec[OpRandomForestClassifier] lazy val (inputData, rawLabelMulti, featuresMulti) = TestFeatureBuilder[RealNN, OPVector]("labelMulti", "featuresMulti", Seq( (1.0.toRealNN, Vectors.dense(12.0, 4.3, 1.3).toOPVector), (0.0.toRealNN, Vectors.dense(0.0, 0.3, 0.1).toOPVector), (2.0.toRealNN, Vectors.dense(1.0, 3.9, 4.3).toOPVector), (2.0.toRealNN, Vectors.dense(10.0, 1.3, 0.9).toOPVector), (1.0.toRealNN, Vectors.dense(15.0, 4.7, 1.3).toOPVector), (0.0.toRealNN, Vectors.dense(0.5, 0.9, 10.1).toOPVector), (1.0.toRealNN, Vectors.dense(11.5, 2.3, 1.3).toOPVector), (0.0.toRealNN, Vectors.dense(0.1, 3.3, 0.1).toOPVector), (2.0.toRealNN, Vectors.dense(1.0, 4.0, 4.5).toOPVector), (2.0.toRealNN, Vectors.dense(10.0, 1.5, 1.0).toOPVector) ) ) val labelMulti = rawLabelMulti.copy(isResponse = true) val estimator = new OpRandomForestClassifier().setInput(labelMulti, featuresMulti) val expectedResult = Seq( Prediction(1.0, Array(0.0, 17.0, 3.0), Array(0.0, 0.85, 0.15)), Prediction(0.0, Array(19.0, 0.0, 1.0), Array(0.95, 0.0, 0.05)), Prediction(2.0, Array(0.0, 1.0, 19.0), Array(0.0, 0.05, 0.95)), Prediction(2.0, Array(1.0, 2.0, 17.0), Array(0.05, 0.1, 0.85)), Prediction(1.0, Array(0.0, 17.0, 3.0), Array(0.0, 0.85, 0.15)), Prediction(0.0, Array(16.0, 0.0, 4.0), Array(0.8, 0.0, 0.2)), Prediction(1.0, Array(1.0, 17.0, 2.0), Array(0.05, 0.85, 0.1)), Prediction(0.0, Array(17.0, 0.0, 3.0), Array(0.85, 0.0, 0.15)), Prediction(2.0, Array(2.0, 1.0, 17.0), Array(0.1, 0.05, 0.85)), Prediction(2.0, Array(1.0, 2.0, 17.0), Array(0.05, 0.1, 0.85)) ) it should "allow the user to set the desired spark parameters" in { estimator .setMaxDepth(10) .setImpurity(Impurity.Gini.sparkName) .setMaxBins(33) .setMinInstancesPerNode(2) .setMinInfoGain(0.2) .setSubsamplingRate(0.9) .setNumTrees(21) .setSeed(2L) estimator.fit(inputData) estimator.predictor.getMaxDepth shouldBe 10 estimator.predictor.getMaxBins shouldBe 33 estimator.predictor.getImpurity shouldBe Impurity.Gini.sparkName estimator.predictor.getMinInstancesPerNode shouldBe 2 estimator.predictor.getMinInfoGain shouldBe 0.2 estimator.predictor.getSubsamplingRate shouldBe 0.9 estimator.predictor.getNumTrees shouldBe 21 estimator.predictor.getSeed shouldBe 2L } }
Example 199
Source File: OpGBTClassifierTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.classification import com.salesforce.op.features.types._ import com.salesforce.op.stages.impl.PredictionEquality import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel} import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class OpGBTClassifierTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[GBTClassificationModel], OpPredictorWrapper[GBTClassifier, GBTClassificationModel]] with PredictionEquality { override def specName: String = Spec[OpGBTClassifier] val (inputData, rawFeature1, feature2) = TestFeatureBuilder("label", "features", Seq[(RealNN, OPVector)]( 1.0.toRealNN -> Vectors.dense(12.0, 4.3, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.0, 0.3, 0.1).toOPVector, 0.0.toRealNN -> Vectors.dense(1.0, 3.9, 4.3).toOPVector, 1.0.toRealNN -> Vectors.dense(10.0, 1.3, 0.9).toOPVector, 1.0.toRealNN -> Vectors.dense(15.0, 4.7, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.5, 0.9, 10.1).toOPVector, 1.0.toRealNN -> Vectors.dense(11.5, 2.3, 1.3).toOPVector, 0.0.toRealNN -> Vectors.dense(0.1, 3.3, 0.1).toOPVector ) ) val feature1 = rawFeature1.copy(isResponse = true) val estimator = new OpGBTClassifier().setInput(feature1, feature2) val expectedResult = Seq( Prediction(1.0, Array(-1.54, 1.54), Array(0.04, 0.95)), Prediction(0.0, Array(1.54, -1.54), Array(0.95, 0.04)), Prediction(0.0, Array(1.54, -1.54), Array(0.95, 0.04)), Prediction(1.0, Array(-1.54, 1.54), Array(0.04, 0.95)), Prediction(1.0, Array(-1.54, 1.54), Array(0.04, 0.95)), Prediction(0.0, Array(1.54, -1.54), Array(0.95, 0.04)), Prediction(1.0, Array(-1.54, 1.54), Array(0.04, 0.95)), Prediction(0.0, Array(1.54, -1.54), Array(0.95, 0.04)) ) it should "allow the user to set the desired spark parameters" in { estimator .setMaxIter(10) .setMaxDepth(6) .setMaxBins(2) .setMinInstancesPerNode(2) .setMinInfoGain(0.1) estimator.fit(inputData) estimator.predictor.getMaxIter shouldBe 10 estimator.predictor.getMaxDepth shouldBe 6 estimator.predictor.getMaxBins shouldBe 2 estimator.predictor.getMinInstancesPerNode shouldBe 2 estimator.predictor.getMinInfoGain shouldBe 0.1 } }
Example 200
Source File: PredictionDeIndexerTest.scala From TransmogrifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
package com.salesforce.op.stages.impl.preparators import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.features.types._ import com.salesforce.op.stages.base.unary.UnaryLambdaTransformer import com.salesforce.op.stages.impl.feature.OpStringIndexerNoFilter import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} import org.junit.runner.RunWith import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class PredictionDeIndexerTest extends FlatSpec with TestSparkContext { val data = Seq(("a", 0.0), ("b", 1.0), ("c", 2.0)).map { case (txt, num) => (txt.toText, num.toRealNN) } val (ds, txtF, numF) = TestFeatureBuilder(data) val response = txtF.indexed() val indexedData = response.originStage.asInstanceOf[OpStringIndexerNoFilter[_]].fit(ds).transform(ds) val permutation = new UnaryLambdaTransformer[RealNN, RealNN]( operationName = "modulo", transformFn = v => ((v.value.get + 1).toInt % 3).toRealNN ).setInput(response) val pred = permutation.getOutput() val permutedData = permutation.transform(indexedData) val expected = Array("b", "c", "a").map(_.toText) Spec[PredictionDeIndexer] should "deindexed the feature correctly" in { val predDeIndexer = new PredictionDeIndexer().setInput(response, pred) val deIndexed = predDeIndexer.getOutput() val results = predDeIndexer.fit(permutedData).transform(permutedData).collect(deIndexed) results shouldBe expected } it should "throw a nice error when there is no metadata" in { val predDeIndexer = new PredictionDeIndexer().setInput(numF, pred) the[Error] thrownBy { predDeIndexer.fit(permutedData).transform(permutedData) } should have message s"The feature ${numF.name} does not contain any label/index mapping in its metadata" } }