Example 1
Source File: ScalaUDFSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.util.Locale

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
    checkEvaluation(stringUdf, "ax")

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase(Locale.ROOT),
      Literal.create(null, StringType) :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvalutionWithUnsafeProjection(udf, null)
    assert(e2.getMessage.contains("Failed to execute user defined function"))

  test("SPARK-22695: ScalaUDF should not use global variables") {
    val ctx = new CodegenContext
    ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx)
Example 2
Source File: HashSetManager.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up

import edu.ucla.cs.wis.bigdatalog.spark.SchemaInfo
import org.apache.spark.TaskContext
import org.apache.spark.sql.types.{IntegerType, LongType}

object HashSetManager {
  def determineKeyType(schemaInfo: SchemaInfo): Int = {
    schemaInfo.arity match {
      case 1 => {
        schemaInfo.schema(0).dataType match {
          case IntegerType => 1
          case LongType => 2
          case other => 3
      case 2 => {
        val bytesPerKey =
        if (bytesPerKey == 8) 2 else 3
      case other => 3

  def create(schemaInfo: SchemaInfo): HashSet = {
    determineKeyType(schemaInfo) match {
      case 1 => new IntKeysHashSet()
      case 2 => new LongKeysHashSet(schemaInfo)
      case _ => new ObjectHashSet()
Example 3
Source File: HashSetRowIterator.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class ObjectHashSetRowIterator(set: ObjectHashSet) extends Iterator[InternalRow] {
  val rawIter = set.iterator()

  override final def hasNext(): Boolean = {

  override final def next(): InternalRow = {

class IntKeysHashSetRowIterator(set: IntKeysHashSet) extends Iterator[InternalRow] {
  val rawIter = set.iterator()
  val uRow = new UnsafeRow()
  val bufferHolder = new BufferHolder()
  val rowWriter = new UnsafeRowWriter()

  override final def hasNext(): Boolean = {

  override final def next(): InternalRow = {
    rowWriter.initialize(bufferHolder, 1)

    uRow.pointTo(bufferHolder.buffer, 1, bufferHolder.totalSize())

class LongKeysHashSetRowIterator(set: LongKeysHashSet) extends Iterator[InternalRow] {
  val rawIter = set.iterator()
  val numFields = set.schemaInfo.arity
  val uRow = new UnsafeRow()
  val bufferHolder = new BufferHolder()
  val rowWriter = new UnsafeRowWriter()

  override final def hasNext(): Boolean = {

  override final def next(): InternalRow = {
    rowWriter.initialize(bufferHolder, numFields)

    val value = rawIter.nextLong()
    if (numFields == 2) {
      rowWriter.write(0, (value >> 32).toInt)
      rowWriter.write(1, value.toInt)
    } else {
      rowWriter.write(0, value)
    uRow.pointTo(bufferHolder.buffer, numFields, bufferHolder.totalSize())

object HashSetRowIterator {
  def create(set: HashSet): Iterator[InternalRow] = {
    set match {
      //case set: UnsafeFixedWidthSet => set.iterator().asScala
      case set: IntKeysHashSet => new IntKeysHashSetRowIterator(set)
      case set: LongKeysHashSet => new LongKeysHashSetRowIterator(set)
      case set: ObjectHashSet => new ObjectHashSetRowIterator(set)
Example 4
Source File: MockSourceProvider.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class MockSourceProvider extends StreamSourceProvider {
  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", MockSourceProvider.fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {

object MockSourceProvider {
  // Function to generate sources. May provide multiple sources if the user implements such a
  // function.
  private var sourceProviderFunction: () => Source = _

  final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  def withMockSources(source: Source, otherSources: Source*)(f: => Unit): Unit = {
    var i = 0
    val sources = source +: otherSources
    sourceProviderFunction = () => {
      val source = sources(i % sources.length)
      i += 1
    try {
    } finally {
      sourceProviderFunction = null
Example 5
Source File: BlockingSource.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
      override def stop() {}

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}

object BlockingSource {
  var latch: CountDownLatch = null
Example 6
Source File: GroupedIteratorSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(,
      Seq(', schema.toAttributes)

    val result = {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) ->

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(,
      Seq(', ', schema.toAttributes)

    val result = {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1),

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(,
      Seq(', schema.toAttributes)

    assert(grouped.length == 2)
Example 7
Source File: resources.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command


import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
Example 8
Source File: SimplifyConditionalSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{IntegerType, NullType}

class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
    comparePlans(actual, correctAnswer)

  private val trueBranch = (TrueLiteral, Literal(5))
  private val normalBranch = (NonFoldableLiteral(true), Literal(10))
  private val unreachableBranch = (FalseLiteral, Literal(20))
  private val nullBranch = (Literal.create(null, NullType), Literal(30))

  test("simplify if") {
      If(TrueLiteral, Literal(10), Literal(20)),

      If(FalseLiteral, Literal(10), Literal(20)),

      If(Literal.create(null, NullType), Literal(10), Literal(20)),

  test("remove unreachable branches") {
    // i.e. removing branches whose conditions are always false
      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
      CaseWhen(normalBranch :: Nil, None))

  test("remove entire CaseWhen if only the else branch is reachable") {
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),

      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
      Literal.create(null, IntegerType))

  test("remove entire CaseWhen if the first branch is always true") {
      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),

    // Test branch elimination and simplification in combination
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
        :: Nil, None),

    // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
      CaseWhen(normalBranch :: trueBranch :: Nil, None))

  test("simplify CaseWhen, prune branches following a definite true") {
      CaseWhen(normalBranch :: unreachableBranch ::
        unreachableBranch :: nullBranch ::
        trueBranch :: normalBranch ::
      CaseWhen(normalBranch :: trueBranch :: Nil, None))
Example 9
Source File: RewriteDistinctAggregatesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, '

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")

  test("single distinct group") {
    val input = testRelation
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
        countDistinct('b, 'c),
Example 10
Source File: ComplexDataSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class ComplexDataSuite extends SparkFunSuite {
  def utf8(str: String): UTF8String = UTF8String.fromString(str)

  test("inequality tests for MapData") {
    // test data
    val testMap1 = Map(utf8("key1") -> 1)
    val testMap2 = Map(utf8("key1") -> 1, utf8("key2") -> 2)
    val testMap3 = Map(utf8("key1") -> 1)
    val testMap4 = Map(utf8("key1") -> 1, utf8("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))

  test("GenericInternalRow.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0)))
    val copiedGenericRow = genericRow.copy()
    assert(copiedGenericRow.getString(0) == "a")
    // The copied internal row should not be changed externally.
    assert(copiedGenericRow.getString(0) == "a")

  test("SpecificMutableRow.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val mutableRow = new SpecificInternalRow(Seq(StringType))
    mutableRow(0) = unsafeRow.getUTF8String(0)
    val copiedMutableRow = mutableRow.copy()
    assert(copiedMutableRow.getString(0) == "a")
    // The copied internal row should not be changed externally.
    assert(copiedMutableRow.getString(0) == "a")

  test("GenericArrayData.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0)))
    val copiedGenericArray = genericArray.copy()
    assert(copiedGenericArray.getUTF8String(0).toString == "a")
    // The copied array data should not be changed externally.
    assert(copiedGenericArray.getUTF8String(0).toString == "a")

  test("copy on nested complex type") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val arrayOfRow = new GenericArrayData(Array[Any](InternalRow(unsafeRow.getUTF8String(0))))
    val copied = arrayOfRow.copy()
    assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
    // The copied data should not be changed externally.
    assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
Example 11
Source File: RandomSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, LongType}

class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("random") {
    checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001)
    checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001)

      new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001)
      new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001)

  test("SPARK-9127 codegen with long seed") {
    checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001)
    checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001)
Example 12
Source File: ObjectExpressionsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{IntegerType, ObjectType}

class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("SPARK-16622: The returned value of the called method in Invoke can be null") {
    val inputRow = InternalRow.fromSeq(Seq((false, null)))
    val cls = classOf[Tuple2[Boolean, java.lang.Integer]]
    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
    val invoke = Invoke(inputObject, "_2", IntegerType)
    checkEvaluationWithGeneratedMutableProjection(invoke, null, inputRow)

  test("MapObjects should make copies of unsafe-backed data") {
    // test UnsafeRow-backed data
    val structEncoder = ExpressionEncoder[Array[Tuple2[java.lang.Integer, java.lang.Integer]]]
    val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4))))
    val structExpected = new GenericArrayData(
      Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
      structEncoder.serializer.head, structExpected, structInputRow)

    // test UnsafeArray-backed data
    val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
    val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
    val arrayExpected = new GenericArrayData(
      Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
      arrayEncoder.serializer.head, arrayExpected, arrayInputRow)

    // test UnsafeMap-backed data
    val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
    val mapInputRow = InternalRow.fromSeq(Seq(Array(
      Map(1 -> 100, 2 -> 200), Map(3 -> 300, 4 -> 400))))
    val mapExpected = new GenericArrayData(Seq(
      new ArrayBasedMapData(
        new GenericArrayData(Array(1, 2)),
        new GenericArrayData(Array(100, 200))),
      new ArrayBasedMapData(
        new GenericArrayData(Array(3, 4)),
        new GenericArrayData(Array(300, 400)))))
      mapEncoder.serializer.head, mapExpected, mapInputRow)
Example 13
Source File: ExpressionEvalHelperSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}

case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
        |int some_variable = 11;
        |int ${ev.value} = 10;
  override def dataType: DataType = IntegerType
Example 14
Source File: MiscFunctionsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.codec.digest.DigestUtils

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}

class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("md5") {
    checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932")
    checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
    checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType)

  test("sha1") {
    checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")
    checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
    checkEvaluation(Sha1(Literal.create(null, BinaryType)), null)
    checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709")
    checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType)

  test("sha2") {
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
    checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
      DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
    // unsupported bit length
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null)
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null)

  test("crc32") {
    checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L)
    checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
    checkEvaluation(Crc32(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
Example 15
Source File: QueryPlanSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.types.IntegerType

class QueryPlanSuite extends SparkFunSuite {

  test("origin remains the same after mapExpressions (SPARK-23823)") {
    CurrentOrigin.setPosition(0, 0)
    val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId)
    val query = plans.DslLogicalPlan(plans.table("table")).select(column)

    val mappedQuery = query mapExpressions {
      case _: Expression => Literal(1)

    val mappedOrigin = mappedQuery.expressions.apply(0).origin
    assert(mappedOrigin == Origin.apply(Some(0), Some(0)))

Example 16
Source File: LogicalPlanSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType

class LogicalPlanSuite extends SparkFunSuite {
  private var invocationCount = 0
  private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
    case p: Project =>
      invocationCount += 1

  private val testRelation = LocalRelation()

  test("transformUp runs on operators") {
    invocationCount = 0
    val plan = Project(Nil, testRelation)
    plan transformUp function

    assert(invocationCount === 1)

    invocationCount = 0
    plan transformDown function
    assert(invocationCount === 1)

  test("transformUp runs on operators recursively") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan transformUp function

    assert(invocationCount === 2)

    invocationCount = 0
    plan transformDown function
    assert(invocationCount === 2)

  test("transformUp skips all ready resolved plans wrapped in analysis barrier") {
    invocationCount = 0
    val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation)))
    plan transformUp function

    assert(invocationCount === 0)

    invocationCount = 0
    plan transformDown function
    assert(invocationCount === 0)

  test("transformUp skips partially resolved plans wrapped in analysis barrier") {
    invocationCount = 0
    val plan1 = AnalysisBarrier(Project(Nil, testRelation))
    val plan2 = Project(Nil, plan1)
    plan2 transformUp function

    assert(invocationCount === 1)

    invocationCount = 0
    plan2 transformDown function
    assert(invocationCount === 1)

  test("isStreaming") {
    val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
    val incrementalRelation = LocalRelation(
      Seq(AttributeReference("a", IntegerType, nullable = true)()), isStreaming = true)

    case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
      override def output: Seq[Attribute] = left.output ++ right.output

    require(relation.isStreaming === false)
    require(incrementalRelation.isStreaming === true)
    assert(TestBinaryRelation(relation, relation).isStreaming === false)
    assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true)
    assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
    assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
Example 17
Source File: StatsEstimationTestBase.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType}

trait StatsEstimationTestBase extends SparkFunSuite {

  var originalValue: Boolean = false

  override def beforeAll(): Unit = {
    // Enable stats estimation based on CBO.
    originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)

  override def afterAll(): Unit = {
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)

  def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match {
    // For UTF8String: base + offset + numBytes
    case StringType => colStat.avgLen + 8 + 4
    case _ => colStat.avgLen

  def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()

case class StatsTestPlan(
    outputList: Seq[Attribute],
    rowCount: BigInt,
    attributeStats: AttributeMap[ColumnStat],
    size: Option[BigInt] = None) extends LeafNode {
  override def output: Seq[Attribute] = outputList
  override def computeStats(): Statistics = Statistics(
    // If sizeInBytes is useless in testing, we just use a fake value
    sizeInBytes = size.getOrElse(Int.MaxValue),
    rowCount = Some(rowCount),
    attributeStats = attributeStats)
Example 18
Source File: SubstituteUnresolvedOrdinals.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType

class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false

  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = {
        case ordinal @ Literal(index: Int, IntegerType) =>
        case other => other
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
Example 19
Source File: SparkExecuteStatementOperationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, NullType, StringType, StructField, StructType}

class SparkExecuteStatementOperationSuite extends SparkFunSuite {
  test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") {
    val field1 = StructField("NULL", NullType)
    val field2 = StructField("(IF(true, NULL, NULL))", NullType)
    val tableSchema = StructType(Seq(field1, field2))
    val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
    assert(columns.size() == 2)
    assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)
    assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)

  test("SPARK-20146 Comment should be preserved") {
    val field1 = StructField("column1", StringType).withComment("comment 1")
    val field2 = StructField("column2", IntegerType)
    val tableSchema = StructType(Seq(field1, field2))
    val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
    assert(columns.size() == 2)
    assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.STRING_TYPE)
    assert(columns.get(0).getComment() == "comment 1")
    assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.INT_TYPE)
    assert(columns.get(1).getComment() == "")
Example 20
Source File: CustomSchemaTest.scala    From spark-sftp   with Apache License 2.0 5 votes vote down vote up
package com.springml.spark.sftp

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, _}
import org.scalatest.{BeforeAndAfterEach, FunSuite}

class CustomSchemaTest extends FunSuite with BeforeAndAfterEach {
  var ss: SparkSession = _

  val csvTypesMap = Map("ProposalId" -> IntegerType,
    "OpportunityId" -> StringType,
    "Clicks" -> LongType,
    "Impressions" -> LongType

  val jsonTypesMap = Map("name" -> StringType,
    "age" -> IntegerType

  override def beforeEach() {
    ss = SparkSession.builder().master("local").appName("Custom Schema Test").getOrCreate()

  private def validateTypes(field : StructField, typeMap : Map[String, DataType]) = {
    val expectedType = typeMap(
    assert(expectedType == field.dataType)

  private def columnArray(typeMap : Map[String, DataType]) : Array[StructField] = {
    val columns = => new StructField(x._1, x._2, true))

    val columnStruct = Array[StructField] ()


  test ("Read CSV with custom schema") {
    val columnStruct = columnArray(csvTypesMap)
    val expectedSchema = StructType(columnStruct)

    val fileLocation = getClass.getResource("/sample.csv").getPath
    val dsr = DatasetRelation(fileLocation, "csv", "false", "true", ",", "\"", "\\", "false", null, expectedSchema, ss.sqlContext)
    val rdd = dsr.buildScan()

    assert(dsr.schema.fields.length == columnStruct.length)
    dsr.schema.fields.foreach(s => validateTypes(s, csvTypesMap))

  test ("Read Json with custom schema") {
    val columnStruct = columnArray(jsonTypesMap)
    val expectedSchema = StructType(columnStruct)

    val fileLocation = getClass.getResource("/people.json").getPath
    val dsr = DatasetRelation(fileLocation, "json", "false", "true", ",", "\"", "\\", "false", null, expectedSchema, ss.sqlContext)
    val rdd = dsr.buildScan()

    assert(dsr.schema.fields.length == columnStruct.length)
    dsr.schema.fields.foreach(s => validateTypes(s, jsonTypesMap))

Example 21
Source File: SQLOperationsSpec.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package odkl.analysis.spark.util

import odkl.analysis.spark.TestEnv
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.IntegerType
import org.scalatest.{FlatSpec, Matchers}

class SQLOperationsSpec extends FlatSpec with TestEnv with Matchers with SQLOperations {

  import sqlc.implicits._

  "A CollectAsList" should "aggregate items" in {
    val df = sc.parallelize(Seq(
      "A" -> 10,
      "A" -> 11,
      "A" -> 12,
      "B" -> 20,
      "B" -> 21,
      "C" -> 30
    )).toDF("C1", "C2")
    val res = df.groupBy("C1").agg(collectAsList(IntegerType)(col("C2"))).collect()
    val c1 = res.find(_.getString(0) == "A").getOrElse(fail("No row for 'A'")).getAs[Seq[Int]](1)
    c1 should contain theSameElementsAs Seq(10, 11, 12)
    val c2 = res.find(_.getString(0) == "B").getOrElse(fail("No row for 'B'")).getAs[Seq[Int]](1)
    c2 should contain theSameElementsAs Seq(20, 21)
    val c3 = res.find(_.getString(0) == "C").getOrElse(fail("No row for 'C'")).getAs[Seq[Int]](1)
    c3 should contain theSameElementsAs Seq(30)
Example 22
Source File: SemiJoinSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = ctx.createDataFrame(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = ctx.createDataFrame(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            sortAnswers = true)

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            sortAnswers = true)

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          sortAnswers = true)
    "basic test",
      (2, 1.0),
      (2, 1.0)
Example 23
Source File: AttributeSetSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)
  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bUpper) === false)
    assert(aSet.contains(bLower) === false)

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
Example 24
Source File: MiscFunctionsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.codec.digest.DigestUtils

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}

class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("md5") {
    checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932")
    checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
    checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType)

  test("sha1") {
    checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")
    checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
    checkEvaluation(Sha1(Literal.create(null, BinaryType)), null)
    checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709")
    checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType)

  test("sha2") {
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
    checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
      DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
    // unsupported bit length
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null)
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null)

  test("crc32") {
    checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L)
    checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
    checkEvaluation(Crc32(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
Example 25
Source File: ArrangePostprocessor.scala    From DataQuality   with GNU Lesser General Public License v3.0 5 votes vote down vote up
package it.agilelab.bigdata.DataQuality.postprocessors

import com.typesafe.config.Config
import it.agilelab.bigdata.DataQuality.checks.CheckResult
import it.agilelab.bigdata.DataQuality.metrics.MetricResult
import it.agilelab.bigdata.DataQuality.sources.HdfsFile
import it.agilelab.bigdata.DataQuality.targets.HdfsTargetConfig
import it.agilelab.bigdata.DataQuality.utils
import it.agilelab.bigdata.DataQuality.utils.DQSettings
import{HdfsReader, HdfsWriter}
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, NumericType}
import org.apache.spark.sql.{Column, DataFrame, SQLContext}

import scala.collection.JavaConversions._

final class ArrangePostprocessor(config: Config, settings: DQSettings)
    extends BasicPostprocessor(config, settings) {

  private case class ColumnSelector(name: String, tipo: Option[String] = None, format: Option[String] = None, precision: Option[Integer] = None) {
    def toColumn()(implicit df: DataFrame): Column = {

      val dataType: Option[NumericType with Product with Serializable] =
        tipo.getOrElse("").toUpperCase match {
          case "DOUBLE" => Some(DoubleType)
          case "INT"    => Some(IntegerType)
          case "LONG"   => Some(LongType)
          case _        => None

      import org.apache.spark.sql.functions.format_number
      import org.apache.spark.sql.functions.format_string

      (dataType, precision, format) match {
        case (Some(dt), None, None) => df(name).cast(dt)
        case(Some(dt), None, Some(f)) => format_string(f, df(name).cast(dt)).alias(name)
        case (Some(dt), Some(p),None) => format_number(df(name).cast(dt), p).alias(name)
        case (None, Some(p), None) => format_number(df(name), p).alias(name)
        case (None, None, Some(f)) => format_string(f, df(name)).alias(name)
        case _ => df(name)

  private val vs = config.getString("source")
  private val target: HdfsTargetConfig = {
    val conf = config.getConfig("saveTo")

  private val columns: Seq[ColumnSelector] =
    config.getAnyRefList("columnOrder").map {
      case x: String => ColumnSelector(x)
      case x: java.util.HashMap[_, String] => {
        val (name, v) = x.head.asInstanceOf[String Tuple2 _]

        v match {
          case v: String =>
            ColumnSelector(name, Option(v))
          case v: java.util.HashMap[String, _] => {
            val k = v.head._1
            val f = v.head._2

            f match {
              case f: Integer =>
                ColumnSelector(name, Option(k), None, Option(f))
              case f: String =>
                ColumnSelector(name, Option(k), Option(f))

  override def process(vsRef: Set[HdfsFile],
                       metRes: Seq[MetricResult],
                       chkRes: Seq[CheckResult])(
      implicit fs: FileSystem,
      sqlContext: SQLContext,
      settings: DQSettings): HdfsFile = {

    val reqVS: HdfsFile = vsRef.filter(vr => == vs).head
    implicit val df: DataFrame = HdfsReader.load(reqVS, settings.ref_date).head

    val arrangeDF = _*)

    HdfsWriter.saveVirtualSource(arrangeDF, target, settings.refDateString)(

    new HdfsFile(target)
Example 26
Source File: DataFrameNaFunctionsSpec.scala    From spark-spec   with MIT License 5 votes vote down vote up
package com.github.mrpowers.spark.spec.sql

import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.Row
import org.scalatest._
import com.github.mrpowers.spark.daria.sql.SparkSessionExt._

import com.github.mrpowers.spark.spec.SparkSessionTestWrapper


class DataFrameNaFunctionsSpec extends FunSpec with SparkSessionTestWrapper with DatasetComparer {

  import spark.implicits._

  describe("#drop") {

    it("drops rows that contains null values") {

      val sourceData = List(
        Row(1, null),
        Row(null, null),
        Row(3, 30),
        Row(10, 20)

      val sourceSchema = List(
        StructField("num1", IntegerType, true),
        StructField("num2", IntegerType, true)

      val sourceDF = spark.createDataFrame(

      val actualDF =

      val expectedData = List(
        Row(3, 30),
        Row(10, 20)

      val expectedSchema = List(
        StructField("num1", IntegerType, true),
        StructField("num2", IntegerType, true)

      val expectedDF = spark.createDataFrame(

      assertSmallDatasetEquality(actualDF, expectedDF)



  describe("#fill") {

    it("Returns a new DataFrame that replaces null or NaN values in numeric columns with value") {

      val sourceDF = spark.createDF(
          (1, null),
          (null, null),
          (3, 30),
          (10, 20)
        ), List(
          ("num1", IntegerType, true),
          ("num2", IntegerType, true)

      val actualDF =

      val expectedDF = spark.createDF(
          (1, 77),
          (77, 77),
          (3, 30),
          (10, 20)
        ), List(
          ("num1", IntegerType, false),
          ("num2", IntegerType, false)

      assertSmallDatasetEquality(actualDF, expectedDF)



  describe("#replace") {

Example 27
Source File: ReportContentTestFactory.scala    From seahorse-workflow-executor   with Apache License 2.0 5 votes vote down vote up
package io.deepsense.reportlib.model.factory

import io.deepsense.reportlib.model.{ReportType, ReportContent}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}

trait ReportContentTestFactory {

  import ReportContentTestFactory._

  def testReport: ReportContent = ReportContent(
    Map(ReportContentTestFactory.categoricalDistName ->
      ReportContentTestFactory.continuousDistName ->


object ReportContentTestFactory extends ReportContentTestFactory {
  val continuousDistName = "continuousDistributionName"
  val categoricalDistName = "categoricalDistributionName"
  val reportName = "TestReportContentName"
  val reportType = ReportType.Empty

  val someReport: ReportContent = ReportContent("empty", ReportType.Empty)
Example 28
Source File: EstimatorModelWrapperIntegSpec.scala    From seahorse-workflow-executor   with Apache License 2.0 5 votes vote down vote up
package io.deepsense.deeplang.doperables.spark.wrappers.estimators

import io.deepsense.deeplang.DeeplangIntegTestSupport
import io.deepsense.deeplang.doperables.dataframe.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructType, StructField}

class EstimatorModelWrapperIntegSpec extends DeeplangIntegTestSupport {

  import io.deepsense.deeplang.doperables.spark.wrappers.estimators.EstimatorModelWrapperFixtures._

  val inputDF = {
    val rowSeq = Seq(Row(1), Row(2), Row(3))
    val schema = StructType(Seq(StructField("x", IntegerType, nullable = false)))
    createDataFrame(rowSeq, schema)

  val estimatorPredictionParamValue = "estimatorPrediction"

  val expectedSchema = StructType(Seq(
    StructField("x", IntegerType, nullable = false),
    StructField(estimatorPredictionParamValue, IntegerType, nullable = false)

  val transformerPredictionParamValue = "modelPrediction"

  val expectedSchemaForTransformerParams = StructType(Seq(
    StructField("x", IntegerType, nullable = false),
    StructField(transformerPredictionParamValue, IntegerType, nullable = false)

  "EstimatorWrapper" should {
    "_fit() and transform() + transformSchema() with parameters inherited" in {

      val transformer = createEstimatorAndFit()

      val transformOutputSchema =
        transformer._transform(executionContext, inputDF).sparkDataFrame.schema
      transformOutputSchema shouldBe expectedSchema

      val inferenceOutputSchema = transformer._transformSchema(inputDF.sparkDataFrame.schema)
      inferenceOutputSchema shouldBe Some(expectedSchema)

    "_fit() and transform() + transformSchema() with parameters overwritten" in {

      val transformer = createEstimatorAndFit().setPredictionColumn(transformerPredictionParamValue)

      val transformOutputSchema =
        transformer._transform(executionContext, inputDF).sparkDataFrame.schema
      transformOutputSchema shouldBe expectedSchemaForTransformerParams

      val inferenceOutputSchema = transformer._transformSchema(inputDF.sparkDataFrame.schema)
      inferenceOutputSchema shouldBe Some(expectedSchemaForTransformerParams)

    "_fit_infer().transformSchema() with parameters inherited" in {

      val estimatorWrapper = new SimpleSparkEstimatorWrapper()

        ._transformSchema(inputDF.sparkDataFrame.schema) shouldBe Some(expectedSchema)

    "_fit_infer().transformSchema() with parameters overwritten" in {

      val estimatorWrapper = new SimpleSparkEstimatorWrapper()
      val transformer =
      val transformerWithParams = transformer.setPredictionColumn(transformerPredictionParamValue)

      val outputSchema = transformerWithParams._transformSchema(inputDF.sparkDataFrame.schema)
      outputSchema shouldBe Some(expectedSchemaForTransformerParams)

  private def createEstimatorAndFit(): SimpleSparkModelWrapper = {

    val estimatorWrapper = new SimpleSparkEstimatorWrapper()

    val transformer =
      estimatorWrapper._fit(executionContext, inputDF).asInstanceOf[SimpleSparkModelWrapper]
    transformer.getPredictionColumn() shouldBe estimatorPredictionParamValue

Example 29
Source File: EstimatorModelWrapperFixtures.scala    From seahorse-workflow-executor   with Apache License 2.0 5 votes vote down vote up
package io.deepsense.deeplang.doperables.spark.wrappers.estimators

import scala.language.reflectiveCalls

import org.apache.spark.annotation.DeveloperApi
import{ParamMap, Param => SparkParam}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

import io.deepsense.deeplang.ExecutionContext
import io.deepsense.deeplang.doperables.serialization.SerializableSparkModel
import io.deepsense.deeplang.doperables.{SparkEstimatorWrapper, SparkModelWrapper}
import io.deepsense.deeplang.params.wrappers.spark.SingleColumnCreatorParamWrapper
import io.deepsense.deeplang.params.{Param, Params}
import io.deepsense.sparkutils.ML

object EstimatorModelWrapperFixtures {

  class SimpleSparkModel private[EstimatorModelWrapperFixtures]()
    extends ML.Model[SimpleSparkModel] {

    def this(x: String) = this()

    override val uid: String = "modelId"

    val predictionCol = new SparkParam[String](uid, "name", "description")

    def setPredictionCol(value: String): this.type = set(predictionCol, value)

    override def copy(extra: ParamMap): this.type = defaultCopy(extra)

    override def transformDF(dataset: DataFrame): DataFrame = {
      dataset.selectExpr("*", "1 as " + $(predictionCol))

    override def transformSchema(schema: StructType): StructType = ???

  class SimpleSparkEstimator extends ML.Estimator[SimpleSparkModel] {

    def this(x: String) = this()

    override val uid: String = "estimatorId"

    val predictionCol = new SparkParam[String](uid, "name", "description")

    override def fitDF(dataset: DataFrame): SimpleSparkModel =
      new SimpleSparkModel().setPredictionCol($(predictionCol))

    override def copy(extra: ParamMap): ML.Estimator[SimpleSparkModel] = defaultCopy(extra)

    override def transformSchema(schema: StructType): StructType = {
      schema.add(StructField($(predictionCol), IntegerType, nullable = false))

  trait HasPredictionColumn extends Params {
    val predictionColumn = new SingleColumnCreatorParamWrapper[
        ml.param.Params { val predictionCol: SparkParam[String] }](
      "prediction column",
    setDefault(predictionColumn, "abcdefg")

    def getPredictionColumn(): String = $(predictionColumn)
    def setPredictionColumn(value: String): this.type = set(predictionColumn, value)

  class SimpleSparkModelWrapper
    extends SparkModelWrapper[SimpleSparkModel, SimpleSparkEstimator]
    with HasPredictionColumn {

    override val params: Array[Param[_]] = Array(predictionColumn)
    override def report: Report = ???

    override protected def loadModel(
      ctx: ExecutionContext,
      path: String): SerializableSparkModel[SimpleSparkModel] = ???

  class SimpleSparkEstimatorWrapper
    extends SparkEstimatorWrapper[SimpleSparkModel, SimpleSparkEstimator, SimpleSparkModelWrapper]
    with HasPredictionColumn {

    override val params: Array[Param[_]] = Array(predictionColumn)
    override def report: Report = ???
Example 30
Source File: DataFrameSplitterIntegSpec.scala    From seahorse-workflow-executor   with Apache License 2.0 5 votes vote down vote up
package io.deepsense.deeplang.doperations

import scala.collection.JavaConverters._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.scalatest.Matchers
import org.scalatest.prop.GeneratorDrivenPropertyChecks

import io.deepsense.deeplang._
import io.deepsense.deeplang.doperables.dataframe.DataFrame

class DataFrameSplitterIntegSpec
  extends DeeplangIntegTestSupport
  with GeneratorDrivenPropertyChecks
  with Matchers {

  "SplitDataFrame" should {
    "split randomly one df into two df in given range" in {

      val input = Range(1, 100)

      val parameterPairs = List(
        (0.0, 0),
        (0.3, 1),
        (0.5, 2),
        (0.8, 3),
        (1.0, 4))

      for((splitRatio, seed) <- parameterPairs) {
        val rdd = createData(input)
        val df = executionContext.dataFrameBuilder.buildDataFrame(createSchema, rdd)
        val (df1, df2) = executeOperation(
          new Split()
                .setSeed(seed / 2)))(df)
        validateSplitProperties(df, df1, df2)

    "split conditionally one df into two df in given range" in {

      val input = Range(1, 100)

      val condition = "value > 20"
      val predicate: Int => Boolean = _ > 20

      val (expectedDF1, expectedDF2) =
        (input.filter(predicate), input.filter(!predicate(_)))

      val rdd = createData(input)
      val df = executionContext.dataFrameBuilder.buildDataFrame(createSchema, rdd)
      val (df1, df2) = executeOperation(
        new Split()
      df1.sparkDataFrame.collect().map(_.get(0)) should contain theSameElementsAs expectedDF1
      df2.sparkDataFrame.collect().map(_.get(0)) should contain theSameElementsAs expectedDF2
      validateSplitProperties(df, df1, df2)

  private def createSchema: StructType = {
      StructField("value", IntegerType, nullable = false)

  private def createData(data: Seq[Int]): RDD[Row] = {

  private def executeOperation(context: ExecutionContext, operation: DOperation)
                              (dataFrame: DataFrame): (DataFrame, DataFrame) = {
    val operationResult = operation.executeUntyped(Vector[DOperable](dataFrame))(context)
    val df1 = operationResult.head.asInstanceOf[DataFrame]
    val df2 = operationResult.last.asInstanceOf[DataFrame]
    (df1, df2)

  def validateSplitProperties(inputDF: DataFrame, outputDF1: DataFrame, outputDF2: DataFrame)
      : Unit = {
    val dfCount = inputDF.sparkDataFrame.count()
    val df1Count = outputDF1.sparkDataFrame.count()
    val df2Count = outputDF2.sparkDataFrame.count()
    val rowsDf = inputDF.sparkDataFrame.collectAsList().asScala
    val rowsDf1 = outputDF1.sparkDataFrame.collectAsList().asScala
    val rowsDf2 = outputDF2.sparkDataFrame.collectAsList().asScala
    val intersect = rowsDf1.intersect(rowsDf2)
    intersect.size shouldBe 0
    (df1Count + df2Count) shouldBe dfCount
    rowsDf.toSet shouldBe rowsDf1.toSet.union(rowsDf2.toSet)
Example 31
Source File: QuadTreeIndexedRelation.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.index

import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{DoubleType, IntegerType}

import org.apache.spark.sql.simba.partitioner.QuadTreePartitioner
import org.apache.spark.sql.simba.spatial.Point

private[simba] case class QuadTreeIndexedRelation(output: Seq[Attribute], child: SparkPlan, table_name: Option[String],
                                                  column_keys: List[Attribute], index_name: String)(var _indexedRDD: IndexedRDD = null, var global_index: QuadTree = null)
  extends IndexedRelation with MultiInstanceRelation {
  private def checkKeys: Boolean = {
    for (i <- column_keys.indices)
      if (!(column_keys(i).dataType.isInstanceOf[DoubleType] ||
        column_keys(i).dataType.isInstanceOf[IntegerType])) {
        return false

  if (_indexedRDD == null) {

  private[simba] def buildIndex(): Unit = {
    val numShufflePartitions = simbaSession.sessionState.simbaConf.indexPartitions
    val sampleRate = simbaSession.sessionState.simbaConf.sampleRate
    val tranferThreshold = simbaSession.sessionState.simbaConf.transferThreshold

    val dataRDD = child.execute().map(row => {
      val now = =>
        BindReferences.bindReference(x, child.output).eval(row).asInstanceOf[Number].doubleValue()
      (new Point(now), row)

    val dimension = column_keys.length
    val (partitionedRDD, _, global_qtree) = QuadTreePartitioner(dataRDD, dimension,
      numShufflePartitions, sampleRate, tranferThreshold)

    val indexed = partitionedRDD.mapPartitions { iter =>
      val data = iter.toArray
      val index: QuadTree =
        if (data.length > 0) QuadTree(
        else null
      Array(IPartition(, index)).iterator

    indexed.setName( => s"$name $index_name").getOrElse(child.toString))
    _indexedRDD = indexed
    global_index = global_qtree

  override def newInstance(): IndexedRelation = {
    new QuadTreeIndexedRelation(, child, table_name,
      column_keys, index_name)(_indexedRDD)

  override def withOutput(new_output: Seq[Attribute]): IndexedRelation = {
    new QuadTreeIndexedRelation(new_output, child, table_name,
      column_keys, index_name)(_indexedRDD, global_index)
Example 32
Source File: RDDFixtures.scala    From spark-vector   with Apache License 2.0 5 votes vote down vote up
package com.actian.spark_vector

import java.util.Date

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DateType, IntegerType, StringType, StructField, StructType }

import com.actian.spark_vector.test.util.StructTypeUtil.createSchema

trait RDDFixtures {
  // poor man's fixture, for other approaches see:
  def createRecordRdd(sc: SparkContext): (RDD[Seq[Any]], StructType) = {
    val input = Seq(
      Seq(42, "a"),
      Seq(43, "b"))
    val inputRdd = sc.parallelize(input, 2)
    val inputSchema = createSchema("id" -> IntegerType, "name" -> StringType)
    (inputRdd, inputSchema)

  def createRowRDD(sc: SparkContext): (RDD[Seq[Any]], StructType) = {
    val input = Seq(
      Seq[Any](42, "a", new Date(), new Date()),
      Seq[Any](43, "b", new Date(), new Date()))
    val inputRdd = sc.parallelize(input, 2)
    val inputSchema = createSchema("id" -> IntegerType, "name" -> StringType, "date" -> DateType)
    (inputRdd, inputSchema)

  def wideRDD(sc: SparkContext, columnCount: Int, rowCount: Int = 2): (RDD[Row], StructType) = {
    val data: Row = Row.fromSeq(1 to columnCount)

    val fields = for (i <- 1 to rowCount) yield {
      StructField("field_" + i, IntegerType, true)

    val inputSchema = StructType(fields.toSeq)

    val input = for (i <- 1 to rowCount) yield {

    val inputRDD = sc.parallelize(input, 2)
    (inputRDD, inputSchema)
Example 33
Source File: FunctionTest.scala    From sope   with Apache License 2.0 5 votes vote down vote up
package com.sope

import com.sope.model.{Class, Person, Student}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import com.sope.spark.sql._
import com.sope.TestContext.getSQlContext
import org.apache.spark.sql.types.{StringType, IntegerType}
import org.scalatest.{FlatSpec, Matchers}

class FunctionTest extends FlatSpec with Matchers {

  private val sqlContext = getSQlContext

  import sqlContext.implicits._

  private val testSData = Seq(
    Person("Sherlock", "Holmes", "baker street", "[email protected]", "999999"),
    Person("John", "Watson", "east street", "[email protected]", "55555")

  private val studentDF = Seq(
    Student("A", "B", 1, 10),
    Student("B", "C", 2, 10),
    Student("C", "E", 4, 9),
    Student("E", "F", 5, 9),
    Student("F", "G", 6, 10),
    Student("G", "H", 7, 10),
    Student("H", "I", 9, 8),
    Student("H", "I", 9, 7)

  private val classDF = Seq(
    Class(1, 10, "Tenth"),
    Class(2, 9, "Ninth"),
    Class(3, 8, "Eighth")

  "Dataframe Function transformations" should "generate the transformations correctly" in {
    val nameUpperFunc = (df: DataFrame) => df.withColumn("first_name", upper(col("first_name")))
    val nameConcatFunc = (df: DataFrame) => df.withColumn("name", concat(col("first_name"), col("last_name")))
    val addressUpperFunc = (df: DataFrame) => df.withColumn("address", upper(col("address")))
    val transformed = testSData.applyDFTransformations(Seq(nameUpperFunc, nameConcatFunc, addressUpperFunc)) should contain("name")

  "Group by as list Function Transformation" should  "generate the transformations correctly" in {
    val grouped = studentDF.groupByAsList(Seq("cls"))
        .withColumn("grouped_data", explode($"grouped_data"))
        .unstruct("grouped_data", keepStructColumn = false)
    grouped.filter("cls = 10").head.getAs[Long]("grouped_count") should be(4)

  "Cast Transformation" should  "generate the transformations correctly" in {
    val casted = studentDF.castColumns(IntegerType, StringType)
    casted.dtypes.count(_._2 == "StringType") should be(4)

  "Update Keys Transformation" should  "generate the transformations correctly" in {
    val updatedWithKey = studentDF
      .updateKeys(Seq("cls"), classDF.renameColumns(Map("cls" -> "class")), "class", "key")
      .dropColumns(Seq("last_name", "roll_no"))
    updatedWithKey.filter("first_name = 'A'").head.getAs[Long]("cls_key") should be(1)

Example 34
Source File: SqlExtensionProviderSuite.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, Literal, UnaryExpression}
import org.apache.spark.sql.types.{DataType, IntegerType}

import io.projectglow.GlowSuite

class SqlExtensionProviderSuite extends GlowSuite {
  override def beforeAll(): Unit = {

  private lazy val sess = spark
  test("one arg function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("one_arg_test(id)").as[Int].head() == 1)

    intercept[AnalysisException] {

    intercept[AnalysisException] {
      spark.range(1).selectExpr("one_arg_test(id, id)").collect()

  test("two arg function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("two_arg_test(id, id)").as[Int].head() == 1)

    intercept[AnalysisException] {

    intercept[AnalysisException] {
      spark.range(1).selectExpr("two_arg_test(id, id, id)").collect()

  test("var args function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("var_args_test(id, id)").as[Int].head() == 1)
    assert(spark.range(1).selectExpr("var_args_test(id, id, id, id)").as[Int].head() == 1)
    assert(spark.range(1).selectExpr("var_args_test(id)").as[Int].head() == 1)

    intercept[AnalysisException] {

  test("can call optional arg function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("optional_arg_test(id)").as[Int].head() == 1)
    assert(spark.range(1).selectExpr("optional_arg_test(id, id)").as[Int].head() == 1)

    intercept[AnalysisException] {

    intercept[AnalysisException] {
      spark.range(1).selectExpr("optional_arg_test(id, id, id)").collect()

trait TestExpr extends Expression with CodegenFallback {
  override def dataType: DataType = IntegerType

  override def nullable: Boolean = true

  override def eval(input: InternalRow): Any = 1

case class OneArgExpr(child: Expression) extends UnaryExpression with TestExpr
case class TwoArgExpr(left: Expression, right: Expression) extends BinaryExpression with TestExpr
case class VarArgsExpr(arg: Expression, varArgs: Seq[Expression]) extends TestExpr {
  override def children: Seq[Expression] = arg +: varArgs
case class OptionalArgExpr(required: Expression, optional: Expression) extends TestExpr {
  def this(required: Expression) = this(required, Literal(1))
  override def children: Seq[Expression] = Seq(required, optional)
Example 35
Source File: AnnotationUtils.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.vcf

import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, StringType, StructField, StructType}

// Unified VCF annotation representation, used by SnpEff and VEP
object AnnotationUtils {

  // Delimiter between annotation fields
  val annotationDelimiter = "|"
  val annotationDelimiterRegex = "\\|"

  // Fractional delimiter for struct subfields
  val structDelimiter = "/"
  val structDelimiterRegex = "\\/"

  // Delimiter for array subfields
  val arrayDelimiter = "&"

  // Struct subfield schemas
  private val rankTotalStruct = StructType(
    Seq(StructField("rank", IntegerType), StructField("total", IntegerType)))
  private val posLengthStruct = StructType(
    Seq(StructField("pos", IntegerType), StructField("length", IntegerType)))
  private val referenceVariantStruct = StructType(
    Seq(StructField("reference", StringType), StructField("variant", StringType)))

  // Special schemas for SnpEff subfields
  private val snpEffFieldsToSchema: Map[String, DataType] = Map(
    "Annotation" -> ArrayType(StringType),
    "Rank" -> rankTotalStruct,
    "cDNA_pos/cDNA_length" -> posLengthStruct,
    "CDS_pos/CDS_length" -> posLengthStruct,
    "AA_pos/AA_length" -> posLengthStruct,
    "Distance" -> IntegerType

  // Special schemas for VEP subfields
  private val vepFieldsToSchema: Map[String, DataType] = Map(
    "Consequence" -> ArrayType(StringType),
    "EXON" -> rankTotalStruct,
    "INTRON" -> rankTotalStruct,
    "cDNA_position" -> IntegerType,
    "CDS_position" -> IntegerType,
    "Protein_position" -> IntegerType,
    "Amino_acids" -> referenceVariantStruct,
    "Codons" -> referenceVariantStruct,
    "Existing_variation" -> ArrayType(StringType),
    "DISTANCE" -> IntegerType,
    "STRAND" -> IntegerType,
    "FLAGS" -> ArrayType(StringType)

  // Special schemas for LOFTEE (as VEP plugin) subfields
  private val lofteeFieldsToSchema: Map[String, DataType] = Map(
    "LoF_filter" -> ArrayType(StringType),
    "LoF_flags" -> ArrayType(StringType),
    "LoF_info" -> ArrayType(StringType)

  // Default string schema for annotation subfield
  val allFieldsToSchema: Map[String, DataType] =
    (snpEffFieldsToSchema ++ vepFieldsToSchema ++ lofteeFieldsToSchema).withDefaultValue(StringType)
Example 36
Source File: KCore.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.sona.graph.kcore
import com.tencent.angel.sona.context.PSContext
import org.apache.spark.SparkContext
import com.tencent.angel.sona.graph.params._
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

class KCore(override val uid: String) extends Transformer
  with HasSrcNodeIdCol with HasDstNodeIdCol with HasOutputNodeIdCol with HasOutputCoreIdCol
  with HasStorageLevel with HasPartitionNum with HasPSPartitionNum with HasUseBalancePartition {

  def this() = this(Identifiable.randomUID("KCore"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    val edges =$(srcNodeIdCol), $(dstNodeIdCol)).rdd
      .map(row => (row.getLong(0), row.getLong(1)))
      .filter(e => e._1 != e._2)


    val maxId = => math.max(e._1, e._2)).max() + 1
    val minId = => math.min(e._1, e._2)).min()
    val nodes = edges.flatMap(e => Iterator(e._1, e._2))
    val numEdges = edges.count()

    println(s"minId=$minId maxId=$maxId numEdges=$numEdges level=${$(storageLevel)}")

    // Start PS and init the model
    println("start to run ps")

    val model = KCorePSModel.fromMinMax(minId, maxId, nodes, $(psPartitionNum), $(useBalancePartition))
    var graph = edges.flatMap(e => Iterator((e._1, e._2), (e._2, e._1)))
      .mapPartitionsWithIndex((index, edgeIter) =>
        Iterator(KCoreGraphPartition.apply(index, edgeIter)))

    graph.foreachPartition(_ => Unit)

    var curIteration = 0
    var numMsgs = model.numMsgs()
    var prev = graph

    do {
      curIteration += 1
      graph =, numMsgs, curIteration == 1))
      prev = graph
      numMsgs = model.numMsgs()
      println(s"curIteration=$curIteration numMsgs=$numMsgs")
    } while (numMsgs > 0)

    val retRDD ={case (nodes,cores) =>}
      .map(r => Row.fromSeq(Seq[Any](r._1, r._2)))

    dataset.sparkSession.createDataFrame(retRDD, transformSchema(dataset.schema))

  override def transformSchema(schema: StructType): StructType = {
      StructField(s"${$(outputNodeIdCol)}", LongType, nullable = false),
      StructField(s"${$(outputCoreIdCol)}", IntegerType, nullable = false)

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

Example 37
Source File: TestData.scala    From drunken-data-quality   with Apache License 2.0 5 votes vote down vote up
package de.frosner.ddq.testutils

import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

object TestData {

  def makeIntegerDf(spark: SparkSession, numbers: Seq[Int]): DataFrame =
      StructType(List(StructField("column", IntegerType, nullable = false)))

  def makeNullableStringDf(spark: SparkSession, strings: Seq[String]): DataFrame =
    spark.createDataFrame(spark.sparkContext.makeRDD(, StructType(List(StructField("column", StringType, nullable = true))))

  def makeIntegersDf(spark: SparkSession, row1: Seq[Int], rowN: Seq[Int]*): DataFrame = {
    val rows = row1 :: rowN.toList
    val numCols = row1.size
    val rdd = spark.sparkContext.makeRDD(*)))
    val schema = StructType((1 to numCols).map(idx => StructField("column" + idx, IntegerType, nullable = false)))
    spark.createDataFrame(rdd, schema)

Example 38
Source File: ReportContentTestFactory.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.reportlib.model.factory

import ai.deepsense.reportlib.model.{ReportType, ReportContent}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}

trait ReportContentTestFactory {

  import ReportContentTestFactory._

  def testReport: ReportContent = ReportContent(
    Map(ReportContentTestFactory.categoricalDistName ->
      ReportContentTestFactory.continuousDistName ->


object ReportContentTestFactory extends ReportContentTestFactory {
  val continuousDistName = "continuousDistributionName"
  val categoricalDistName = "categoricalDistributionName"
  val reportName = "TestReportContentName"
  val reportType = ReportType.Empty

  val someReport: ReportContent = ReportContent("empty", ReportType.Empty)
Example 39
Source File: IdentityTransformerSpec.scala    From modelmatrix   with Apache License 2.0 5 votes vote down vote up
package com.collective.modelmatrix.transform

import com.collective.modelmatrix.{ModelFeature, ModelMatrix, TestSparkContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalatest.FlatSpec

import scalaz.syntax.either._
import scalaz.{-\/, \/-}

class IdentityTransformerSpec extends FlatSpec with TestSparkContext {

  val sqlContext = ModelMatrix.sqlContext(sc)

  val schema = StructType(Seq(
    StructField("adv_site", StringType),
    StructField("adv_id", IntegerType)

  val input = Seq(
      Row("", 1),
      Row("", 2),
      Row("", 1),
      Row("", 3)

  val isActive = true
  val withAllOther = true

  val adSite = ModelFeature(isActive, "Ad", "ad_site", "adv_site", Identity)
  val adId = ModelFeature(isActive, "Ad", "ad_id", "adv_id", Identity)

  val df = sqlContext.createDataFrame(sc.parallelize(input), schema)
  val transformer = new IdentityTransformer(Transformer.extractFeatures(df, Seq(adSite, adId)) match {
    case -\/(err) => sys.error(s"Can't extract features: $err")
    case \/-(suc) => suc

  "Identity Transformer" should "support integer typed model feature" in {
    val valid = transformer.validate(adId)
    assert(valid == TypedModelFeature(adId, IntegerType).right)

  it should "fail if feature column doesn't exists" in {
    val failed = transformer.validate(adSite.copy(feature = "adv_site"))
    assert(failed == FeatureTransformationError.FeatureColumnNotFound("adv_site").left)

  it should "fail if column type is not supported" in {
    val failed = transformer.validate(adSite)
    assert(failed == FeatureTransformationError.UnsupportedTransformDataType("ad_site", StringType, Identity).left)

Example 40
Source File: TransformerSpec.scala    From modelmatrix   with Apache License 2.0 5 votes vote down vote up
package com.collective.modelmatrix.transform

import com.collective.modelmatrix.{ModelFeature, ModelMatrix, TestSparkContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalatest.FlatSpec

class TransformerSpec extends FlatSpec with TestSparkContext {

  val sqlContext = ModelMatrix.sqlContext(sc)

  val schema = StructType(Seq(
    StructField("adv_site", StringType),
    StructField("adv_id", IntegerType)

  val input = Seq(
    Row("", 1),
    Row("", 2),
    Row("", 1),
    Row("", 3)

  val isActive = true
  val withAllOther = true

  // Can't call 'day_of_week' with String
  val badFunctionType = ModelFeature(isActive, "advertisement", "f1", "day_of_week(adv_site, 'UTC')", Top(95.0, allOther = false))

  // Not enough parameters for 'concat'
  val wrongParametersCount = ModelFeature(isActive, "advertisement", "f2", "concat(adv_site)", Top(95.0, allOther = false))

  val df = sqlContext.createDataFrame(sc.parallelize(input), schema)

  "Transformer" should "report failed feature extraction" in {
    val features = Transformer.extractFeatures(df, Seq(badFunctionType, wrongParametersCount))
    val errors = features.fold(identity, _ => sys.error("Should not be here"))

    assert(errors.length == 2)
    assert(errors(0).feature == badFunctionType)
    assert(errors(1).feature == wrongParametersCount)

Example 41
Source File: PileupTestBase.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.pileup

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.types.{IntegerType, ShortType, StringType, StructField, StructType}
import org.scalatest.{BeforeAndAfter, FunSuite}

class PileupTestBase extends FunSuite
  with DataFrameSuiteBase
  with BeforeAndAfter
  with SharedSparkContext{

  val sampleId = ""
  val samResPath: String = getClass.getResource("/multichrom/mdbam/samtools.pileup").getPath
  val referencePath: String = getClass.getResource("/reference/Homo_sapiens_assembly18_chr1_chrM.small.fasta").getPath
  val bamPath: String = getClass.getResource(s"/multichrom/mdbam/${sampleId}.bam").getPath
  val cramPath : String = getClass.getResource(s"/multichrom/mdcram/${sampleId}.cram").getPath
  val tableName = "reads_bam"
  val tableNameCRAM = "reads_cram"

  val schema: StructType = StructType(
      StructField("contig", StringType, nullable = true),
      StructField("position", IntegerType, nullable = true),
      StructField("reference", StringType, nullable = true),
      StructField("coverage", ShortType, nullable = true),
      StructField("pileup", StringType, nullable = true),
      StructField("quality", StringType, nullable = true)
  before {
    System.setProperty("spark.kryo.registrator", "org.biodatageeks.sequila.pileup.serializers.CustomKryoRegistrator")
      .conf.set("spark.sql.shuffle.partitions",1) //FIXME: In order to get orderBy in Samtools tests working - related to exchange partitions stage
    spark.sql(s"DROP TABLE IF EXISTS $tableName")
         |CREATE TABLE $tableName
         |USING org.biodatageeks.sequila.datasources.BAM.BAMDataSource
         |OPTIONS(path "$bamPath")

    spark.sql(s"DROP TABLE IF EXISTS $tableNameCRAM")
         |CREATE TABLE $tableNameCRAM
         |USING org.biodatageeks.sequila.datasources.BAM.CRAMDataSource
         |OPTIONS(path "$cramPath", refPath "$referencePath" )

    val mapToString = (map: Map[Byte, Short]) => {
      if (map == null)
          case (k, v) => k.toChar -> v}).mkString.replace(" -> ", ":")

    val byteToString = ((byte: Byte) => byte.toString)

    spark.udf.register("mapToString", mapToString)
    spark.udf.register("byteToString", byteToString)

Example 42
Source File: CloudPartitionTest.scala    From cloud-integration   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.hadoop.fs.Path

import org.apache.spark.sql._
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

abstract class CloudPartitionTest extends AbstractCloudRelationTest {

import testImplicits._

    "Save sets of files in explicitly set up partition tree; read") {
    withTempPathDir("part-columns", None) { path =>
      for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
        val partitionDir = new Path(path, s"p1=$p1/p2=$p2")
        val df = sparkContext
          .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1))
          .toDF("a", "b", "p1")

        // each of these directories as its own success file; there is
        // none at the root
        resolveSuccessFile(partitionDir, true)

      val dataSchemaWithPartition =
          dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))

          "path" -> path.toString,
          "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName)
Example 43
Source File: similarityFunctions.scala    From spark-stringmetric   with MIT License 5 votes vote down vote up
package com.github.mrpowers.spark.stringmetric.expressions

import com.github.mrpowers.spark.stringmetric.unsafe.UTF8StringFunctions
import org.apache.commons.text.similarity.CosineDistance
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{
import org.apache.spark.sql.types.{ DataType, IntegerType, StringType }

trait UTF8StringFunctionsHelper {
  val stringFuncs: String = "com.github.mrpowers.spark.stringmetric.unsafe.UTF8StringFunctions"

trait StringString2IntegerExpression
extends ImplicitCastInputTypes
with NullIntolerant
with UTF8StringFunctionsHelper { self: BinaryExpression =>
  override def dataType: DataType = IntegerType
  override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

  protected override def nullSafeEval(left: Any, right: Any): Any = -1

case class HammingDistance(left: Expression, right: Expression)
extends BinaryExpression with StringString2IntegerExpression {
  override def prettyName: String = "hamming"

  override def nullSafeEval(leftVal: Any, righValt: Any): Any = {
    val leftStr = left.asInstanceOf[UTF8String]
    val rightStr = right.asInstanceOf[UTF8String]
    UTF8StringFunctions.hammingDistance(leftStr, rightStr)

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    defineCodeGen(ctx, ev, (s1, s2) => s"$stringFuncs.hammingDistance($s1, $s2)")
Example 44
Source File: GroupedIteratorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema)
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(,
      Seq(', schema.toAttributes)

    val result = {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) ->

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema)

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(,
      Seq(', ', schema.toAttributes)

    val result = {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1),

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema)
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(,
      Seq(', schema.toAttributes)

    assert(grouped.length == 2)
Example 45
Source File: ExpandSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.IntegerType

class ExpandSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.localSeqToDataFrameHolder

  private def testExpand(f: SparkPlan => SparkPlan): Unit = {
    val input = (1 to 1000).map(Tuple1.apply)
    val projections = Seq.tabulate(2) { i =>
      Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil
    val attributes =
      plan => Expand(projections, attributes, f(plan)),
      input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j)))

  test("inheriting child row type") {
    val exprs = AttributeReference("a", IntegerType, false)() :: Nil
    val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty)))
    assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.")

  test("expanding UnsafeRows") {

  test("expanding SafeRows") {
Example 46
Source File: SemiJoinSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}

class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = sqlContext.createDataFrame(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = sqlContext.createDataFrame(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            sortAnswers = true)

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            sortAnswers = true)

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          sortAnswers = true)

    "basic test",
      (2, 1.0),
      (2, 1.0)
Example 47
Source File: LocalNodeTest.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.local

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType}

class LocalNodeTest extends SparkFunSuite {

  protected val conf: SQLConf = new SQLConf
  protected val kvIntAttributes = Seq(
    AttributeReference("k", IntegerType)(),
    AttributeReference("v", IntegerType)())
  protected val joinNameAttributes = Seq(
    AttributeReference("id1", IntegerType)(),
    AttributeReference("name", StringType)())
  protected val joinNicknameAttributes = Seq(
    AttributeReference("id2", IntegerType)(),
    AttributeReference("nickname", StringType)())

  protected def resolveExpressions(
      expressions: Seq[Expression],
      localNode: LocalNode): Seq[Expression] = {
    val inputMap = { a => (, a) }.toMap { expression =>
      expression.transformUp {
        case UnresolvedAttribute(Seq(u)) =>
            sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))

Example 48
Source File: ProjectNodeSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.local

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression}
import org.apache.spark.sql.types.{IntegerType, StringType}

class ProjectNodeSuite extends LocalNodeTest {
  private val pieAttributes = Seq(
    AttributeReference("id", IntegerType)(),
    AttributeReference("age", IntegerType)(),
    AttributeReference("name", StringType)())

  private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = {
    val inputNode = new DummyNode(pieAttributes, inputData)
    val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2))
    val projectNode = new ProjectNode(conf, columns, inputNode)
    val expectedOutput = { case (id, age, name) => (id, name) }
    val actualOutput = projectNode.collect().map { case row =>
      (row.getInt(0), row.getString(1))
    assert(actualOutput === expectedOutput)

  test("empty") {

  test("basic") {
    testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray)

Example 49
Source File: AttributeSetSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)

  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bUpper) === false)
    assert(aSet.contains(bLower) === false)

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
Example 50
Source File: GroupedIteratorSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(,
      Seq(', schema.toAttributes)

    val result = {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) ->

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(,
      Seq(', ', schema.toAttributes)

    val result = {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1),

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(,
      Seq(', schema.toAttributes)

    assert(grouped.length == 2)
Example 51
Source File: NullValuesTest.scala    From spark-dynamodb   with Apache License 2.0 5 votes vote down vote up
package com.audienceproject.spark.dynamodb

import{AttributeDefinition, CreateTableRequest, KeySchemaElement, ProvisionedThroughput}
import com.audienceproject.spark.dynamodb.implicits._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class NullValuesTest extends AbstractInMemoryTest {

    test("Insert nested StructType with null values") {
        dynamoDB.createTable(new CreateTableRequest()
            .withAttributeDefinitions(new AttributeDefinition("name", "S"))
            .withKeySchema(new KeySchemaElement("name", "HASH"))
            .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L)))

        val schema = StructType(
                StructField("name", StringType, nullable = false),
                StructField("info", StructType(
                        StructField("age", IntegerType, nullable = true),
                        StructField("address", StringType, nullable = true)
                ), nullable = true)

        val rows = spark.sparkContext.parallelize(Seq(
            Row("one", Row(30, "Somewhere")),
            Row("two", null),
            Row("three", Row(null, null))

        val newItemsDs = spark.createDataFrame(rows, schema)


        val validationDs ="NullTest")

Example 52
Source File: UDFTest.scala    From SparkGIS   with Apache License 2.0 5 votes vote down vote up
package org.betterers.spark.gis

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.{SQLContext, Row}
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.betterers.spark.gis.udf.Functions

class UDFTest extends FunSuite with BeforeAndAfter {
  import Geometry.WGS84

  val point = Geometry.point((2.0, 2.0))
  val multiPoint = Geometry.multiPoint((1.0, 1.0), (2.0, 2.0), (3.0, 3.0))
  var line = Geometry.line((11.0, 11.0), (12.0, 12.0))
  var multiLine = Geometry.multiLine(
    Seq((11.0, 1.0), (23.0, 23.0)),
    Seq((31.0, 3.0), (42.0, 42.0)))
  var polygon = Geometry.polygon((1.0, 1.0), (2.0, 2.0), (3.0, 1.0))
  var multiPolygon = Geometry.multiPolygon(
    Seq((1.0, 1.0), (2.0, 2.0), (3.0, 1.0)),
    Seq((1.1, 1.1), (2.0, 1.9), (2.5, 1.1))
  val collection = Geometry.collection(point, multiPoint, line)
  val all: Seq[Geometry] = Seq(point, multiPoint, line, multiLine, polygon, multiPolygon, collection)

  var sc: SparkContext = _
  var sql: SQLContext = _

  before {
    sc = new SparkContext(new SparkConf().setMaster("local[4]").setAppName("SparkGIS"))
    sql = new SQLContext(sc)

  after {

  test("ST_Boundary") {
    // all.foreach(g => println(Functions.ST_Boundary(g).toString))

    assertResult(true) {
    assertResult(true) {
    assertResult("Some(MULTIPOINT ((11 11), (12 12)))") {
    assertResult(None) {
    assertResult("Some(LINEARRING (1 1, 2 2, 3 1, 1 1))") {
    assertResult(None) {
    assertResult(None) {

  test("ST_CoordDim") {
    all.foreach(g => {
      assertResult(3) {

  test("UDF in SQL") {
    val schema = StructType(Seq(
      StructField("id", IntegerType),
      StructField("geo", GeometryType.Instance)
    val jsons = Map(
      (1, "{\"type\":\"Point\",\"coordinates\":[1,1]}}"),
      (2, "{\"type\":\"LineString\",\"coordinates\":[[12,13],[15,20]]}}")
    val rdd = sc.parallelize(Seq(
      "{\"id\":1,\"geo\":" + jsons(1) + "}",
      "{\"id\":2,\"geo\":" + jsons(2) + "}"
    )) = "TEST"
    val df =
    assertResult(Array(3,3)) {
      sql.sql("SELECT ST_CoordDim(geo) FROM TEST").collect().map(_.get(0))
Example 53
Source File: TemporalDataSuite.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.datasource.config.ConfigParameters._
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.BeforeAndAfter

private[datasource] trait TemporalDataSuite extends DatasourceSuite
  with BeforeAndAfter {

  val conf = new SparkConf()
    .setIfMissing("spark.master", "local[*]")
  var sc: SparkContext = null
  var ssc: StreamingContext = null
  val tableName = "tableName"
  val datasourceParams = Map(
    StopGracefully -> "true",
    StopSparkContext -> "false",
    StorageLevelKey -> "MEMORY_ONLY",
    RememberDuration -> "15s"
  val schema = new StructType(Array(
    StructField("id", StringType, nullable = true),
    StructField("idInt", IntegerType, nullable = true)
  val totalRegisters = 10000
  val registers = for (a <- 1 to totalRegisters) yield Row(a.toString, a)

  after {
    if (ssc != null) {
      ssc = null
    if (sc != null) {
      sc = null
Example 54
Source File: RecoverPartitionsCustomTest.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up

import com.adidas.utils.SparkSessionWrapper
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, PrivateMethodTester}

import scala.collection.JavaConverters._

class RecoverPartitionsCustomTest extends FunSuite
  with SparkSessionWrapper
  with PrivateMethodTester
  with Matchers
  with BeforeAndAfterAll{

  test("test conversion of String Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue("theValue")

    result should be("'theValue'")

  test("test conversion of Short Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Short.valueOf("2"))

    result should be("2")

  test("test conversion of Integer Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Integer.valueOf("4"))

    result should be("4")

  test("test conversion of null Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(null)

  test("test conversion of not supported Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(false)

  test("test HiveQL statements Generation") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(
      targetPartitions = Seq("country","district")

    val rowsInput = Seq(
      Row(1, "portugal", "porto"),
      Row(2, "germany", "herzogenaurach"),
      Row(3, "portugal", "coimbra")

    val inputSchema = StructType(
        StructField("number", IntegerType, nullable = true),
        StructField("country", StringType, nullable = true),
        StructField("district", StringType, nullable = true)

    val expectedStatements: Seq[String] = Seq(
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='porto')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='germany',district='herzogenaurach')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='coimbra')"

    val testDataset: Dataset[Row] = spark.createDataset(rowsInput)(RowEncoder(inputSchema))

    val createParameterValue = PrivateMethod[Dataset[String]]('generateAddPartitionStatements)

    val producedStatements: Seq[String] = (customSparkRecoverPartitions invokePrivate createParameterValue(testDataset))

    expectedStatements.sorted.toSet should equal(producedStatements.sorted.toSet)

  override def afterAll(): Unit = {

Example 55
Source File: SparkRecoverPartitionsCustomTest.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up

import com.adidas.utils.SparkSessionWrapper
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, PrivateMethodTester}

import scala.collection.JavaConverters._

class SparkRecoverPartitionsCustomTest extends FunSuite
  with SparkSessionWrapper
  with PrivateMethodTester
  with Matchers
  with BeforeAndAfterAll{

  test("test conversion of String Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue("theValue")

    result should be("'theValue'")

  test("test conversion of Short Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Short.valueOf("2"))

    result should be("2")

  test("test conversion of Integer Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Integer.valueOf("4"))

    result should be("4")

  test("test conversion of null Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(null)

  test("test conversion of not supported Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(false)

  test("test HiveQL statements Generation") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(
      targetPartitions = Seq("country","district")

    val rowsInput = Seq(
      Row(1, "portugal", "porto"),
      Row(2, "germany", "herzogenaurach"),
      Row(3, "portugal", "coimbra")

    val inputSchema = StructType(
        StructField("number", IntegerType, nullable = true),
        StructField("country", StringType, nullable = true),
        StructField("district", StringType, nullable = true)

    val expectedStatements: Seq[String] = Seq(
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='porto')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='germany',district='herzogenaurach')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='coimbra')"

    val testDataset: Dataset[Row] = spark.createDataset(rowsInput)(RowEncoder(inputSchema))

    val createParameterValue = PrivateMethod[Dataset[String]]('generateAddPartitionStatements)

    val producedStatements: Seq[String] = (customSparkRecoverPartitions invokePrivate createParameterValue(testDataset))

    expectedStatements.sorted.toSet should equal(producedStatements.sorted.toSet)

  override def afterAll(): Unit = {

Example 56
Source File: ReferenceTableTest.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.{DataFrame, SaveMode}

import scala.util.Try

class ReferenceTableTest extends IntegrationSuiteBase {

  val childAggregatorHost = "localhost"
  val childAggregatorPort = "5508"

  val dbName                  = "testdb"
  val commonCollectionName    = "test_table"
  val referenceCollectionName = "reference_table"

  override def beforeEach(): Unit = {

    // Set child aggregator as a dmlEndpoint
      .set("spark.datasource.memsql.dmlEndpoints", s"${childAggregatorHost}:${childAggregatorPort}")

  def writeToTable(tableName: String): Unit = {
    val df = spark.createDF(
      List(4, 5, 6),
      List(("id", IntegerType, true))

  def readFromTable(tableName: String): DataFrame = {

  def writeAndReadFromTable(tableName: String): Unit = {
    val dataFrame = readFromTable(tableName)
    val sqlRows   = dataFrame.collect();
    assert(sqlRows.length == 3)

  def dropTable(tableName: String): Unit = executeQuery(s"drop table if exists $dbName.$tableName")

  describe("Success during write operations") {

    it("to common table") {
        s"create table if not exists $dbName.$commonCollectionName (id INT NOT NULL, PRIMARY KEY (id))")

    it("to reference table") {
        s"create reference table if not exists $dbName.$referenceCollectionName (id INT NOT NULL, PRIMARY KEY (id))")

  describe("Success during creating") {

    it("common table") {

  describe("Failure because of") {

    it("database name not specified") {
      spark.conf.set("spark.datasource.memsql.database", "")
      val df = spark.createDF(
        List(4, 5, 6),
        List(("id", IntegerType, true))
      val result = Try {
      assert(SQLHelper.isSQLExceptionWithCode(result.failed.get, List(1046)))
Example 57
Source File: OutputMetricsTest.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.types.{IntegerType, StringType}

class OutputMetricsTest extends IntegrationSuiteBase {
  it("records written") {
    var outputWritten = 0L
    spark.sparkContext.addSparkListener(new SparkListener() {
      override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
        val metrics = taskEnd.taskMetrics
        outputWritten += metrics.outputMetrics.recordsWritten

    val numRows = 100000
    val df1 = spark.createDF(
      List.range(0, numRows),
      List(("id", IntegerType, true))



    assert(outputWritten == numRows)
    outputWritten = 0

    val df2 = spark.createDF(
      List("st1", "", null),
      List(("st", StringType, true))


    assert(outputWritten == 3)
Source File: BinaryTypeBenchmark.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import java.sql.{Connection, DriverManager}
import java.util.Properties

import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
import com.memsql.spark.BatchInsertBenchmark.{df, executeQuery}
import org.apache.spark.sql.types.{BinaryType, IntegerType}
import org.apache.spark.sql.{SaveMode, SparkSession}

import scala.util.Random

// BinaryTypeBenchmark is written to writing of the BinaryType with CPU profiler
// this feature is accessible in Ultimate version of IntelliJ IDEA
// see for more details
object BinaryTypeBenchmark extends App {
  final val masterHost: String = sys.props.getOrElse("", "localhost")
  final val masterPort: String = sys.props.getOrElse("memsql.port", "5506")

  val spark: SparkSession = SparkSession
    .config("spark.sql.shuffle.partitions", "1")
    .config("spark.driver.bindAddress", "localhost")
    .config("spark.datasource.memsql.ddlEndpoint", s"${masterHost}:${masterPort}")
    .config("spark.datasource.memsql.database", "testdb")

  def jdbcConnection: Loan[Connection] = {
    val connProperties = new Properties()
    connProperties.put("user", "root")


  def executeQuery(sql: String): Unit = { => Loan(conn.createStatement).to(_.execute(sql)))

  executeQuery("set global default_partitions_per_leaf = 2")
  executeQuery("drop database if exists testdb")
  executeQuery("create database testdb")

  def genRandomByte(): Byte = (Random.nextInt(256) - 128).toByte
  def genRandomRow(): Array[Byte] =

  val df = spark.createDF(
    List(("data", BinaryType, true), ("id", IntegerType, true))

  val start1 = System.nanoTime()

  println("Elapsed time: " + (System.nanoTime() - start1) + "ns [LoadData CSV]")

  val start2 = System.nanoTime()
    .option("tableKey.primary", "id")
    .option("onDuplicateKeySQL", "id = id")

  println("Elapsed time: " + (System.nanoTime() - start2) + "ns [BatchInsert]")

  val avroStart = System.nanoTime()
    .option(MemsqlOptions.LOAD_DATA_FORMAT, "Avro")
  println("Elapsed time: " + (System.nanoTime() - avroStart) + "ns [LoadData Avro] ")
Source File: PopulateHiveTable.scala    From spark_training   with Apache License 2.0 5 votes vote down vote up

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructType}

object PopulateHiveTable {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder
      .config("spark.some.config.option", "config-value")
      .config("spark.sql.parquet.compression.codec", "gzip")

    spark.sql("create table IF NOT EXISTS nested_empty " +
      "( A int, " +
      "  B string, " +
      "  nested ARRAY<STRUCT< " +
      "     nested_C: int," +
      "     nested_D: string" +
      "  >>" +
      ") ")

    val rowRDD = spark.sparkContext.
        Row(1, "foo", Seq(Row(1, "barA"),Row(2, "bar"))),
        Row(2, "foo", Seq(Row(1, "barB"),Row(2, "bar"))),
        Row(3, "foo", Seq(Row(1, "barC"),Row(2, "bar")))))

    val emptyDf = spark.sql("select * from nested_empty limit 0")

    val tableSchema = emptyDf.schema

    val populated1Df = spark.sqlContext.createDataFrame(rowRDD, tableSchema)


    populated1Df.collect().foreach(r => println(" emptySchemaExample:" + r))

    val nestedSchema = new StructType()
      .add("nested_C", IntegerType)
      .add("nested_D", StringType)

    val definedSchema = new StructType()
      .add("A", IntegerType)
      .add("B", StringType)
      .add("nested", ArrayType(nestedSchema))

    val populated2Df = spark.sqlContext.createDataFrame(rowRDD, definedSchema)

    populated1Df.collect().foreach(r => println(" BuiltExample:" + r))

Source File: NestedTableExample.scala    From spark_training   with Apache License 2.0 5 votes vote down vote up

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructType}
import org.apache.spark.sql.{Row, SparkSession}

object NestedTableExample {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder
      .config("spark.some.config.option", "config-value")

    spark.sql("create table IF NOT EXISTS nested_empty " +
      "( A int, " +
      "  B string, " +
      "  nested ARRAY<STRUCT< " +
      "     nested_C: int," +
      "     nested_D: string" +
      "  >>" +
      ") ")

    val rowRDD = spark.sparkContext.
        Row(1, "foo", Seq(Row(1, "barA"),Row(2, "bar"))),
        Row(2, "foo", Seq(Row(1, "barB"),Row(2, "bar"))),
        Row(3, "foo", Seq(Row(1, "barC"),Row(2, "bar")))))

    val emptyDf = spark.sql("select * from nested_empty limit 0")

    val tableSchema = emptyDf.schema

    val populated1Df = spark.sqlContext.createDataFrame(rowRDD, tableSchema)

    populated1Df.collect().foreach(r => println(" emptySchemaExample:" + r))

    val nestedSchema = new StructType()
      .add("nested_C", IntegerType)
      .add("nested_D", StringType)

    val definedSchema = new StructType()
      .add("A", IntegerType)
      .add("B", StringType)
      .add("nested", ArrayType(nestedSchema))

    val populated2Df = spark.sqlContext.createDataFrame(rowRDD, definedSchema)
    populated1Df.collect().foreach(r => println(" BuiltExample:" + r))

Source File: CarbonDataFrameExample.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.examples

import org.apache.spark.sql.{SaveMode, SparkSession}

import org.apache.carbondata.examples.util.ExampleUtils

object CarbonDataFrameExample {

  def main(args: Array[String]) {
    val spark = ExampleUtils.createSparkSession("CarbonDataFrameExample")

  def exampleBody(spark : SparkSession): Unit = {
    // Writes Dataframe to CarbonData file:
    import spark.implicits._
    val df = spark.sparkContext.parallelize(1 to 100)
      .map(x => ("a" + x % 10, "b", x))
      .toDF("c1", "c2", "number")

    // Saves dataframe to carbondata file
      .option("tableName", "carbon_df_table")
      .option("partitionColumns", "c1")  // a list of column names

    spark.sql(""" SELECT * FROM carbon_df_table """).show()

    spark.sql("SHOW PARTITIONS carbon_df_table").show()

    // Specify schema
    import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerType}
    val customSchema = StructType(Array(
      StructField("c1", StringType),
      StructField("c2", StringType),
      StructField("number", IntegerType)))

    // Reads carbondata to dataframe
    val carbondf =
      // .option("dbname", "db_name") the system will use "default" as dbname if not set this option
      .option("tableName", "carbon_df_table")

    // Dataframe operations
    carbondf.printSchema()$"c1", $"number" + 10).show()
    carbondf.filter($"number" > 31).show()

    spark.sql("DROP TABLE IF EXISTS carbon_df_table")
Source File: BlockingSource.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
      override def stop() {}

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}

object BlockingSource {
  var latch: CountDownLatch = null
Source File: StratifiedRepartitionSuite.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.


import{TestObject, TransformerFuzzing}
import org.apache.spark.TaskContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class StratifiedRepartitionSuite extends TestBase with TransformerFuzzing[StratifiedRepartition] {

  import session.implicits._

  val values = "values"
  val colors = "colors"
  val const = "const"

  lazy val input = Seq(
    (0, "Blue", 2),
    (0, "Red", 2),
    (0, "Green", 2),
    (1, "Purple", 2),
    (1, "Orange", 2),
    (1, "Indigo", 2),
    (2, "Violet", 2),
    (2, "Black", 2),
    (2, "White", 2),
    (3, "Gray", 2),
    (3, "Yellow", 2),
    (3, "Cerulean", 2)
  ).toDF(values, colors, const)

  test("Assert doing a stratified repartition will ensure all keys exist across all partitions") {
    val inputSchema = new StructType()
      .add(values, IntegerType).add(colors, StringType).add(const, IntegerType)
    val inputEnc = RowEncoder(inputSchema)
    val valuesFieldIndex = inputSchema.fieldIndex(values)
    val numPartitions = 3
    val trainData = input.repartition(numPartitions).select(values, colors, const)
      .mapPartitions(iter => {
        val ctx = TaskContext.get
        val partId = ctx.partitionId
        // Remove all instances of 0 class on partition 1
        if (partId == 1) {
          iter.flatMap(row => {
            if (row.getInt(valuesFieldIndex) <= 0)
            else Some(row)
        } else {
          // Add back at least 3 instances on other partitions
          val oneOfEachExample = List(Row(0, "Blue", 2), Row(1, "Purple", 2), Row(2, "Black", 2), Row(3, "Gray", 2))
    // Some debug to understand what data is on which partition
    trainData.foreachPartition { rows =>
      rows.foreach { row =>
        val ctx = TaskContext.get
        val partId = ctx.partitionId
        println(s"Row: $row partition id: $partId")
    val stratifiedInputData = new StratifiedRepartition().setLabelCol(values)
    // Assert stratified data contains all keys across all partitions, with extra count
    // for it to be evaluated
      .mapPartitions(iter => {
        val actualLabels = => row.getInt(valuesFieldIndex))
        val expectedLabels = (0 to 3).toList
        if (actualLabels != expectedLabels)
          throw new Exception(s"Missing labels, actual: $actualLabels, expected: $expectedLabels")
    val stratifiedMixedInputData = new StratifiedRepartition().setLabelCol(values)
    assert(stratifiedMixedInputData.count() >= trainData.count())
    val stratifiedOriginalInputData = new StratifiedRepartition().setLabelCol(values)
    assert(stratifiedOriginalInputData.count() == trainData.count())

  def testObjects(): Seq[TestObject[StratifiedRepartition]] = List(new TestObject(
    new StratifiedRepartition().setLabelCol(values).setMode(SPConstants.Equal), input))

  def reader: MLReadable[_] = StratifiedRepartition
Source File: WholeStageCodegenSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

  test("range/filter should be combined") {
    val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
    val plan = df.queryExecution.executedPlan
    assert(df.collect() === Array(Row(2)))

  test("Aggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(df.collect() === Array(Row(9, 4.5)))

  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
    val df = spark.range(3).groupBy("id").count().orderBy("id")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))

  test("BroadcastHashJoin should be included in WholeStageCodegen") {
    val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
    val schema = new StructType().add("k", IntegerType).add("v", StringType)
    val smallDF = spark.createDataFrame(rdd, schema)
    val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
    assert(df.queryExecution.executedPlan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))

  test("Sort should be included in WholeStageCodegen") {
    val df = spark.range(3, 0, -1).toDF().sort(col("id"))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(df.collect() === Array(Row(1), Row(2), Row(3)))

  test("MapElements should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = spark.range(10).map(_.toString)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(ds.collect() === 0.until(10).map(_.toString).toArray)

  test("typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(ds.collect() === Array(0, 2, 4, 6, 8))

  test("back-to-back typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(ds.collect() === Array(0, 6))

  test("simple typed UDAF should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()

    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
Source File: resources.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command


import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
Source File: SimplifyConditionalSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{IntegerType, NullType}

class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)

  private val trueBranch = (TrueLiteral, Literal(5))
  private val normalBranch = (NonFoldableLiteral(true), Literal(10))
  private val unreachableBranch = (FalseLiteral, Literal(20))
  private val nullBranch = (Literal.create(null, NullType), Literal(30))

  test("simplify if") {
      If(TrueLiteral, Literal(10), Literal(20)),

      If(FalseLiteral, Literal(10), Literal(20)),

      If(Literal.create(null, NullType), Literal(10), Literal(20)),

  test("remove unreachable branches") {
    // i.e. removing branches whose conditions are always false
      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
      CaseWhen(normalBranch :: Nil, None))

  test("remove entire CaseWhen if only the else branch is reachable") {
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),

      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
      Literal.create(null, IntegerType))

  test("remove entire CaseWhen if the first branch is always true") {
      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),

    // Test branch elimination and simplification in combination
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
        :: Nil, None),

    // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None))
Source File: RewriteDistinctAggregatesSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, '

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")

  test("single distinct group") {
    val input = testRelation
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
        countDistinct('b, 'c),
Example 68
Source File: AttributeSetSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)

  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bLower) === false)

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
Example 69
Source File: RandomSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, LongType}

class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("random") {
    checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001)
    checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001)

      new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001)
      new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001)

  test("SPARK-9127 codegen with long seed") {
    checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001)
    checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001)
Source File: ObjectExpressionsSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{IntegerType, ObjectType}

class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("SPARK-16622: The returned value of the called method in Invoke can be null") {
    val inputRow = InternalRow.fromSeq(Seq((false, null)))
    val cls = classOf[Tuple2[Boolean, java.lang.Integer]]
    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
    val invoke = Invoke(inputObject, "_2", IntegerType)
    checkEvaluationWithGeneratedMutableProjection(invoke, null, inputRow)

  test("MapObjects should make copies of unsafe-backed data") {
    // test UnsafeRow-backed data
    val structEncoder = ExpressionEncoder[Array[Tuple2[java.lang.Integer, java.lang.Integer]]]
    val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4))))
    val structExpected = new GenericArrayData(
      Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
      structEncoder.serializer.head, structExpected, structInputRow)

    // test UnsafeArray-backed data
    val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
    val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
    val arrayExpected = new GenericArrayData(
      Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
      arrayEncoder.serializer.head, arrayExpected, arrayInputRow)

    // test UnsafeMap-backed data
    val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
    val mapInputRow = InternalRow.fromSeq(Seq(Array(
      Map(1 -> 100, 2 -> 200), Map(3 -> 300, 4 -> 400))))
    val mapExpected = new GenericArrayData(Seq(
      new ArrayBasedMapData(
        new GenericArrayData(Array(1, 2)),
        new GenericArrayData(Array(100, 200))),
      new ArrayBasedMapData(
        new GenericArrayData(Array(3, 4)),
        new GenericArrayData(Array(300, 400)))))
      mapEncoder.serializer.head, mapExpected, mapInputRow)
Source File: ExpressionEvalHelperSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}

case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
        |int some_variable = 11;
        |int ${ev.value} = 10;
  override def dataType: DataType = IntegerType
Example 72
Source File: ScalaUDFSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
    checkEvaluation(stringUdf, "ax")

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase,
      Literal.create(null, StringType) :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvalutionWithUnsafeProjection(udf, null)
    assert(e2.getMessage.contains("Failed to execute user defined function"))

Source File: MapDataSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class MapDataSuite extends SparkFunSuite {

  test("inequality tests") {
    def u(str: String): UTF8String = UTF8String.fromString(str)

    // test data
    val testMap1 = Map(u("key1") -> 1)
    val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
    val testMap3 = Map(u("key1") -> 1)
    val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
Source File: LogicalPlanSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType

class LogicalPlanSuite extends SparkFunSuite {
  private var invocationCount = 0
  private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
    case p: Project =>
      invocationCount += 1

  private val testRelation = LocalRelation()

  test("resolveOperator runs on operators") {
    invocationCount = 0
    val plan = Project(Nil, testRelation)
    plan resolveOperators function

    assert(invocationCount === 1)

  test("resolveOperator runs on operators recursively") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan resolveOperators function

    assert(invocationCount === 2)

  test("resolveOperator skips all ready resolved plans") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan resolveOperators function

    assert(invocationCount === 0)

  test("resolveOperator skips partially resolved plans") {
    invocationCount = 0
    val plan1 = Project(Nil, testRelation)
    val plan2 = Project(Nil, plan1)
    plan2 resolveOperators function

    assert(invocationCount === 1)

  test("isStreaming") {
    val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
    val incrementalRelation = new LocalRelation(
      Seq(AttributeReference("a", IntegerType, nullable = true)())) {
      override def isStreaming(): Boolean = true

    case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
      override def output: Seq[Attribute] = left.output ++ right.output

    require(relation.isStreaming === false)
    require(incrementalRelation.isStreaming === true)
    assert(TestBinaryRelation(relation, relation).isStreaming === false)
    assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true)
    assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
    assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
Source File: AttributeSetSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)

  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bUpper) === false)
    assert(aSet.contains(bLower) === false)

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
Source File: EstimatorModelWrapperIntegSpec.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperables.spark.wrappers.estimators

import ai.deepsense.deeplang.DeeplangIntegTestSupport
import ai.deepsense.deeplang.doperables.dataframe.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructType, StructField}

class EstimatorModelWrapperIntegSpec extends DeeplangIntegTestSupport {

  import ai.deepsense.deeplang.doperables.spark.wrappers.estimators.EstimatorModelWrapperFixtures._

  val inputDF = {
    val rowSeq = Seq(Row(1), Row(2), Row(3))
    val schema = StructType(Seq(StructField("x", IntegerType, nullable = false)))
    createDataFrame(rowSeq, schema)

  val estimatorPredictionParamValue = "estimatorPrediction"

  val expectedSchema = StructType(Seq(
    StructField("x", IntegerType, nullable = false),
    StructField(estimatorPredictionParamValue, IntegerType, nullable = false)

  val transformerPredictionParamValue = "modelPrediction"

  val expectedSchemaForTransformerParams = StructType(Seq(
    StructField("x", IntegerType, nullable = false),
    StructField(transformerPredictionParamValue, IntegerType, nullable = false)

  "EstimatorWrapper" should {
    "_fit() and transform() + transformSchema() with parameters inherited" in {

      val transformer = createEstimatorAndFit()

      val transformOutputSchema =
        transformer._transform(executionContext, inputDF).sparkDataFrame.schema
      transformOutputSchema shouldBe expectedSchema

      val inferenceOutputSchema = transformer._transformSchema(inputDF.sparkDataFrame.schema)
      inferenceOutputSchema shouldBe Some(expectedSchema)

    "_fit() and transform() + transformSchema() with parameters overwritten" in {

      val transformer = createEstimatorAndFit().setPredictionColumn(transformerPredictionParamValue)

      val transformOutputSchema =
        transformer._transform(executionContext, inputDF).sparkDataFrame.schema
      transformOutputSchema shouldBe expectedSchemaForTransformerParams

      val inferenceOutputSchema = transformer._transformSchema(inputDF.sparkDataFrame.schema)
      inferenceOutputSchema shouldBe Some(expectedSchemaForTransformerParams)

    "_fit_infer().transformSchema() with parameters inherited" in {

      val estimatorWrapper = new SimpleSparkEstimatorWrapper()

        ._transformSchema(inputDF.sparkDataFrame.schema) shouldBe Some(expectedSchema)

    "_fit_infer().transformSchema() with parameters overwritten" in {

      val estimatorWrapper = new SimpleSparkEstimatorWrapper()
      val transformer =
      val transformerWithParams = transformer.setPredictionColumn(transformerPredictionParamValue)

      val outputSchema = transformerWithParams._transformSchema(inputDF.sparkDataFrame.schema)
      outputSchema shouldBe Some(expectedSchemaForTransformerParams)

  private def createEstimatorAndFit(): SimpleSparkModelWrapper = {

    val estimatorWrapper = new SimpleSparkEstimatorWrapper()

    val transformer =
      estimatorWrapper._fit(executionContext, inputDF).asInstanceOf[SimpleSparkModelWrapper]
    transformer.getPredictionColumn() shouldBe estimatorPredictionParamValue

Example 77
Source File: EstimatorModelWrapperFixtures.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperables.spark.wrappers.estimators

import scala.language.reflectiveCalls

import org.apache.spark.annotation.DeveloperApi
import{ParamMap, Param => SparkParam}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

import ai.deepsense.deeplang.ExecutionContext
import ai.deepsense.deeplang.doperables.serialization.SerializableSparkModel
import ai.deepsense.deeplang.doperables.{SparkEstimatorWrapper, SparkModelWrapper}
import ai.deepsense.deeplang.params.wrappers.spark.SingleColumnCreatorParamWrapper
import ai.deepsense.deeplang.params.{Param, Params}
import ai.deepsense.sparkutils.ML

object EstimatorModelWrapperFixtures {

  class SimpleSparkModel private[EstimatorModelWrapperFixtures]()
    extends ML.Model[SimpleSparkModel] {

    def this(x: String) = this()

    override val uid: String = "modelId"

    val predictionCol = new SparkParam[String](uid, "name", "description")

    def setPredictionCol(value: String): this.type = set(predictionCol, value)

    override def copy(extra: ParamMap): this.type = defaultCopy(extra)

    override def transformDF(dataset: DataFrame): DataFrame = {
      dataset.selectExpr("*", "1 as " + $(predictionCol))

    override def transformSchema(schema: StructType): StructType = ???

  class SimpleSparkEstimator extends ML.Estimator[SimpleSparkModel] {

    def this(x: String) = this()

    override val uid: String = "estimatorId"

    val predictionCol = new SparkParam[String](uid, "name", "description")

    override def fitDF(dataset: DataFrame): SimpleSparkModel =
      new SimpleSparkModel().setPredictionCol($(predictionCol))

    override def copy(extra: ParamMap): ML.Estimator[SimpleSparkModel] = defaultCopy(extra)

    override def transformSchema(schema: StructType): StructType = {
      schema.add(StructField($(predictionCol), IntegerType, nullable = false))

  trait HasPredictionColumn extends Params {
    val predictionColumn = new SingleColumnCreatorParamWrapper[
        ml.param.Params { val predictionCol: SparkParam[String] }](
      "prediction column",
    setDefault(predictionColumn, "abcdefg")

    def getPredictionColumn(): String = $(predictionColumn)
    def setPredictionColumn(value: String): this.type = set(predictionColumn, value)

  class SimpleSparkModelWrapper
    extends SparkModelWrapper[SimpleSparkModel, SimpleSparkEstimator]
    with HasPredictionColumn {

    override val params: Array[Param[_]] = Array(predictionColumn)
    override def report(extended: Boolean = true): Report = ???

    override protected def loadModel(
      ctx: ExecutionContext,
      path: String): SerializableSparkModel[SimpleSparkModel] = ???

  class SimpleSparkEstimatorWrapper
    extends SparkEstimatorWrapper[SimpleSparkModel, SimpleSparkEstimator, SimpleSparkModelWrapper]
    with HasPredictionColumn {

    override val params: Array[Param[_]] = Array(predictionColumn)
    override def report(extended: Boolean = true): Report = ???
Example 78
Source File: DataFrameSplitterIntegSpec.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperations

import scala.collection.JavaConverters._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.scalatest.Matchers
import org.scalatest.prop.GeneratorDrivenPropertyChecks

import ai.deepsense.deeplang._
import ai.deepsense.deeplang.doperables.dataframe.DataFrame

class DataFrameSplitterIntegSpec
  extends DeeplangIntegTestSupport
  with GeneratorDrivenPropertyChecks
  with Matchers {

  "SplitDataFrame" should {
    "split randomly one df into two df in given range" in {

      val input = Range(1, 100)

      val parameterPairs = List(
        (0.0, 0),
        (0.3, 1),
        (0.5, 2),
        (0.8, 3),
        (1.0, 4))

      for((splitRatio, seed) <- parameterPairs) {
        val rdd = createData(input)
        val df = executionContext.dataFrameBuilder.buildDataFrame(createSchema, rdd)
        val (df1, df2) = executeOperation(
          new Split()
                .setSeed(seed / 2)))(df)
        validateSplitProperties(df, df1, df2)

    "split conditionally one df into two df in given range" in {

      val input = Range(1, 100)

      val condition = "value > 20"
      val predicate: Int => Boolean = _ > 20

      val (expectedDF1, expectedDF2) =
        (input.filter(predicate), input.filter(!predicate(_)))

      val rdd = createData(input)
      val df = executionContext.dataFrameBuilder.buildDataFrame(createSchema, rdd)
      val (df1, df2) = executeOperation(
        new Split()
      df1.sparkDataFrame.collect().map(_.get(0)) should contain theSameElementsAs expectedDF1
      df2.sparkDataFrame.collect().map(_.get(0)) should contain theSameElementsAs expectedDF2
      validateSplitProperties(df, df1, df2)

  private def createSchema: StructType = {
      StructField("value", IntegerType, nullable = false)

  private def createData(data: Seq[Int]): RDD[Row] = {

  private def executeOperation(context: ExecutionContext, operation: DOperation)
                              (dataFrame: DataFrame): (DataFrame, DataFrame) = {
    val operationResult = operation.executeUntyped(Vector[DOperable](dataFrame))(context)
    val df1 = operationResult.head.asInstanceOf[DataFrame]
    val df2 = operationResult.last.asInstanceOf[DataFrame]
    (df1, df2)

  def validateSplitProperties(inputDF: DataFrame, outputDF1: DataFrame, outputDF2: DataFrame)
      : Unit = {
    val dfCount = inputDF.sparkDataFrame.count()
    val df1Count = outputDF1.sparkDataFrame.count()
    val df2Count = outputDF2.sparkDataFrame.count()
    val rowsDf = inputDF.sparkDataFrame.collectAsList().asScala
    val rowsDf1 = outputDF1.sparkDataFrame.collectAsList().asScala
    val rowsDf2 = outputDF2.sparkDataFrame.collectAsList().asScala
    val intersect = rowsDf1.intersect(rowsDf2)
    intersect.size shouldBe 0
    (df1Count + df2Count) shouldBe dfCount
    rowsDf.toSet shouldBe rowsDf1.toSet.union(rowsDf2.toSet)
Example 79
Source File: AttributeSetSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)

  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bUpper) === false)
    assert(aSet.contains(bLower) === false)

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
Source File: SchemaColumnRandom.scala    From data-faker   with MIT License 5 votes vote down vote up
package com.dunnhumby.datafaker.schema.table.columns

import java.sql.{Date, Timestamp}
import com.dunnhumby.datafaker.YamlParser.YamlParserProtocol
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{to_utc_timestamp, round, rand, from_unixtime, to_date}
import org.apache.spark.sql.types.{IntegerType, LongType}

trait SchemaColumnRandom[T] extends SchemaColumn

object SchemaColumnRandom {
  val FloatDP = 3
  val DoubleDP = 3

  def apply(name: String, min: Int, max: Int): SchemaColumn = SchemaColumnRandomNumeric(name, min, max)
  def apply(name: String, min: Long, max: Long): SchemaColumn = SchemaColumnRandomNumeric(name, min, max)
  def apply(name: String, min: Float, max: Float): SchemaColumn = SchemaColumnRandomNumeric(name, min, max)
  def apply(name: String, min: Double, max: Double): SchemaColumn = SchemaColumnRandomNumeric(name, min, max)
  def apply(name: String, min: Date, max: Date): SchemaColumn = SchemaColumnRandomDate(name, min, max)
  def apply(name: String, min: Timestamp, max: Timestamp): SchemaColumn = SchemaColumnRandomTimestamp(name, min, max)
  def apply(name: String): SchemaColumn = SchemaColumnRandomBoolean(name)

private case class SchemaColumnRandomNumeric[T: Numeric](override val name: String, min: T, max: T) extends SchemaColumnRandom[T] {
  override def column(rowID: Option[Column] = None): Column = {
    import Numeric.Implicits._

    (min, max) match {
      case (_: Int, _: Int) => round(rand() * (max - min) + min, 0).cast(IntegerType)
      case (_: Long, _: Long) => round(rand() * (max - min) + min, 0).cast(LongType)
      case (_: Float, _: Float) => round(rand() * (max - min) + min, SchemaColumnRandom.FloatDP)
      case (_: Double, _: Double) => round(rand() * (max - min) + min, SchemaColumnRandom.DoubleDP)

private case class SchemaColumnRandomTimestamp(override val name: String, min: Timestamp, max: Timestamp) extends SchemaColumnRandom[Timestamp] {
  override def column(rowID: Option[Column] = None): Column = {
    val minTime = min.getTime / 1000
    val maxTime = max.getTime / 1000
    to_utc_timestamp(from_unixtime(rand() * (maxTime - minTime) + minTime), "UTC")

private case class SchemaColumnRandomDate(override val name: String, min: Date, max: Date) extends SchemaColumnRandom[Date] {
  val timestamp = SchemaColumnRandomTimestamp(name, new Timestamp(min.getTime), new Timestamp(max.getTime + 86400000))

  override def column(rowID: Option[Column] = None): Column = to_date(timestamp.column())

private case class SchemaColumnRandomBoolean(override val name: String) extends SchemaColumnRandom[Boolean] {
  override def column(rowID: Option[Column] = None): Column = rand() < 0.5f

object SchemaColumnRandomProtocol extends SchemaColumnRandomProtocol
trait SchemaColumnRandomProtocol extends YamlParserProtocol {

  import net.jcazevedo.moultingyaml._

  implicit object SchemaColumnRandomFormat extends YamlFormat[SchemaColumnRandom[_]] {

    override def read(yaml: YamlValue): SchemaColumnRandom[_] = {
      val fields = yaml.asYamlObject.fields
      val YamlString(name) = fields.getOrElse(YamlString("name"), deserializationError("name not set"))
      val YamlString(dataType) = fields.getOrElse(YamlString("data_type"), deserializationError(s"data_type not set for $name"))

      if (dataType == SchemaColumnDataType.Boolean) {
      else {
        val min = fields.getOrElse(YamlString("min"), deserializationError(s"min not set for $name"))
        val max = fields.getOrElse(YamlString("max"), deserializationError(s"max not set for $name"))

        dataType match {
          case SchemaColumnDataType.Int => SchemaColumnRandomNumeric(name, min.convertTo[Int], max.convertTo[Int])
          case SchemaColumnDataType.Long => SchemaColumnRandomNumeric(name, min.convertTo[Long], max.convertTo[Long])
          case SchemaColumnDataType.Float => SchemaColumnRandomNumeric(name, min.convertTo[Float], max.convertTo[Float])
          case SchemaColumnDataType.Double => SchemaColumnRandomNumeric(name, min.convertTo[Double], max.convertTo[Double])
          case SchemaColumnDataType.Date => SchemaColumnRandomDate(name, min.convertTo[Date], max.convertTo[Date])
          case SchemaColumnDataType.Timestamp => SchemaColumnRandomTimestamp(name, min.convertTo[Timestamp], max.convertTo[Timestamp])
          case _ => deserializationError(s"unsupported data_type: $dataType for ${SchemaColumnType.Random}")


    override def write(obj: SchemaColumnRandom[_]): YamlValue = ???


Source File: BlockingSource.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
      override def stop() {}

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}

object BlockingSource {
  var latch: CountDownLatch = null
Source File: GroupedIteratorSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(,
      Seq(', schema.toAttributes)

    val result = {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) ->

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(,
      Seq(', ', schema.toAttributes)

    val result = {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1),

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(,
      Seq(', schema.toAttributes)

    assert(grouped.length == 2)
Source File: WholeStageCodegenSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

  test("range/filter should be combined") {
    val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
    val plan = df.queryExecution.executedPlan
    assert(df.collect() === Array(Row(2)))

  test("Aggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(df.collect() === Array(Row(9, 4.5)))

  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
    val df = spark.range(3).groupBy("id").count().orderBy("id")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))

  test("BroadcastHashJoin should be included in WholeStageCodegen") {
    val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
    val schema = new StructType().add("k", IntegerType).add("v", StringType)
    val smallDF = spark.createDataFrame(rdd, schema)
    val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
    assert(df.queryExecution.executedPlan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))

  test("Sort should be included in WholeStageCodegen") {
    val df = spark.range(3, 0, -1).toDF().sort(col("id"))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(df.collect() === Array(Row(1), Row(2), Row(3)))

  test("MapElements should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = spark.range(10).map(_.toString)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(ds.collect() === 0.until(10).map(_.toString).toArray)

  test("typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(ds.collect() === Array(0, 2, 4, 6, 8))

  test("back-to-back typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(ds.collect() === Array(0, 6))

  test("simple typed UDAF should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()

    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
Source File: resources.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command


import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
Example 85
Source File: SimplifyConditionalSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{IntegerType, NullType}

class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)

  private val trueBranch = (TrueLiteral, Literal(5))
  private val normalBranch = (NonFoldableLiteral(true), Literal(10))
  private val unreachableBranch = (FalseLiteral, Literal(20))
  private val nullBranch = (Literal.create(null, NullType), Literal(30))

  test("simplify if") {
      If(TrueLiteral, Literal(10), Literal(20)),

      If(FalseLiteral, Literal(10), Literal(20)),

      If(Literal.create(null, NullType), Literal(10), Literal(20)),

  test("remove unreachable branches") {
    // i.e. removing branches whose conditions are always false
      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
      CaseWhen(normalBranch :: Nil, None))

  test("remove entire CaseWhen if only the else branch is reachable") {
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),

      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
      Literal.create(null, IntegerType))

  test("remove entire CaseWhen if the first branch is always true") {
      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),

    // Test branch elimination and simplification in combination
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
        :: Nil, None),

    // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None))
Example 86
Source File: RewriteDistinctAggregatesSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, '

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")

  test("single distinct group") {
    val input = testRelation
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
        countDistinct('b, 'c),
Source File: SubstituteUnresolvedOrdinals.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.types.IntegerType

class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false

  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = {
        case ordinal @ Literal(index: Int, IntegerType) =>
        case other => other
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
Source File: RandomSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, LongType}

class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("random") {
    checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001)
    checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001)

      new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001)
      new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001)

  test("SPARK-9127 codegen with long seed") {
    checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001)
    checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001)
Example 89
Source File: ObjectExpressionsSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{IntegerType, ObjectType}

class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("SPARK-16622: The returned value of the called method in Invoke can be null") {
    val inputRow = InternalRow.fromSeq(Seq((false, null)))
    val cls = classOf[Tuple2[Boolean, java.lang.Integer]]
    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
    val invoke = Invoke(inputObject, "_2", IntegerType)
    checkEvaluationWithGeneratedMutableProjection(invoke, null, inputRow)

  test("MapObjects should make copies of unsafe-backed data") {
    // test UnsafeRow-backed data
    val structEncoder = ExpressionEncoder[Array[Tuple2[java.lang.Integer, java.lang.Integer]]]
    val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4))))
    val structExpected = new GenericArrayData(
      Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
      structEncoder.serializer.head, structExpected, structInputRow)

    // test UnsafeArray-backed data
    val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
    val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
    val arrayExpected = new GenericArrayData(
      Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
      arrayEncoder.serializer.head, arrayExpected, arrayInputRow)

    // test UnsafeMap-backed data
    val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
    val mapInputRow = InternalRow.fromSeq(Seq(Array(
      Map(1 -> 100, 2 -> 200), Map(3 -> 300, 4 -> 400))))
    val mapExpected = new GenericArrayData(Seq(
      new ArrayBasedMapData(
        new GenericArrayData(Array(1, 2)),
        new GenericArrayData(Array(100, 200))),
      new ArrayBasedMapData(
        new GenericArrayData(Array(3, 4)),
        new GenericArrayData(Array(300, 400)))))
      mapEncoder.serializer.head, mapExpected, mapInputRow)
Source File: ExpressionEvalHelperSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}

case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
        |int some_variable = 11;
        |int ${ev.value} = 10;
  override def dataType: DataType = IntegerType
Source File: ScalaUDFSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
    checkEvaluation(stringUdf, "ax")

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase,
      Literal.create(null, StringType) :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvalutionWithUnsafeProjection(udf, null)
    assert(e2.getMessage.contains("Failed to execute user defined function"))

Source File: MapDataSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class MapDataSuite extends SparkFunSuite {

  test("inequality tests") {
    def u(str: String): UTF8String = UTF8String.fromString(str)

    // test data
    val testMap1 = Map(u("key1") -> 1)
    val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
    val testMap3 = Map(u("key1") -> 1)
    val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
Source File: LogicalPlanSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType

class LogicalPlanSuite extends SparkFunSuite {
  private var invocationCount = 0
  private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
    case p: Project =>
      invocationCount += 1

  private val testRelation = LocalRelation()

  test("resolveOperator runs on operators") {
    invocationCount = 0
    val plan = Project(Nil, testRelation)
    plan resolveOperators function

    assert(invocationCount === 1)

  test("resolveOperator runs on operators recursively") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan resolveOperators function

    assert(invocationCount === 2)

  test("resolveOperator skips all ready resolved plans") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan resolveOperators function

    assert(invocationCount === 0)

  test("resolveOperator skips partially resolved plans") {
    invocationCount = 0
    val plan1 = Project(Nil, testRelation)
    val plan2 = Project(Nil, plan1)
    plan2 resolveOperators function

    assert(invocationCount === 1)

  test("isStreaming") {
    val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
    val incrementalRelation = new LocalRelation(
      Seq(AttributeReference("a", IntegerType, nullable = true)())) {
      override def isStreaming(): Boolean = true

    case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
      override def output: Seq[Attribute] = left.output ++ right.output

    require(relation.isStreaming === false)
    require(incrementalRelation.isStreaming === true)
    assert(TestBinaryRelation(relation, relation).isStreaming === false)
    assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true)
    assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
    assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
Source File: SubstituteUnresolvedOrdinals.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.types.IntegerType

class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false

  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = {
        case ordinal @ Literal(index: Int, IntegerType) =>
        case other => other
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
Source File: MetastoreRelationSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class MetastoreRelationSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
  test("makeCopy and toJSON should work") {
    val table = CatalogTable(
      identifier = TableIdentifier("test", Some("db")),
      tableType = CatalogTableType.VIEW,
      storage = CatalogStorageFormat.empty,
      schema = StructType(StructField("a", IntegerType, true) :: Nil))
    val relation = MetastoreRelation("db", "test")(table, null)

    // No exception should be thrown
    relation.makeCopy(Array("db", "test"))
    // No exception should be thrown

  test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") {
    withTable("bar") {
      withTempView("foo") {
        sql("select 0 as id").createOrReplaceTempView("foo")
        // If we optimize the query in CTAS more than once, the following saveAsTable will fail
        // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])`
        sql("CREATE TABLE bar AS SELECT * FROM foo group by id")
        checkAnswer(spark.table("bar"), Row(0) :: Nil)
        val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar"))
        assert(tableMetadata.provider == Some("hive"), "the expected table is a Hive serde table")
Source File: SparkSQLExprMapperTest.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up
package org.opencypher.morpheus.impl

import java.util.Collections

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.opencypher.morpheus.api.value.MorpheusElement._
import org.opencypher.morpheus.impl.ExprEval._
import org.opencypher.morpheus.impl.SparkSQLExprMapper._
import org.opencypher.morpheus.testing.fixture.SparkSessionFixture
import org.opencypher.okapi.api.types.CTInteger
import org.opencypher.okapi.api.value.CypherValue.CypherMap
import org.opencypher.okapi.relational.impl.table.RecordHeader
import org.opencypher.okapi.testing.BaseTestSuite

import scala.language.implicitConversions

class SparkSQLExprMapperTest extends BaseTestSuite with SparkSessionFixture {

  val vA: Var = Var("a")(CTInteger)
  val vB: Var = Var("b")(CTInteger)
  val header: RecordHeader = RecordHeader.from(vA, vB)

  it("converts prefix id expressions") {
    val id = 257L
    val prefix = 2.toByte
    val expr = PrefixId(ToId(IntegerLit(id)), prefix)
    expr.eval.asInstanceOf[Array[_]].toList should equal(prefix :: id.encodeAsMorpheusId.toList)

  it("converts a CypherInteger to an ID") {
    val id = 257L
    val expr = ToId(IntegerLit(id))
    expr.eval.asInstanceOf[Array[_]].toList should equal(id.encodeAsMorpheusId.toList)

  it("converts a CypherInteger to an ID and prefixes it") {
    val id = 257L
    val prefix = 2.toByte
    val expr = PrefixId(ToId(IntegerLit(id)), prefix)
    expr.eval.asInstanceOf[Array[_]].toList should equal(prefix :: id.encodeAsMorpheusId.toList)

  it("converts a CypherInteger literal") {
    val id = 257L
    val expr = IntegerLit(id)
    expr.eval.asInstanceOf[Long] should equal(id)

  private def convert(expr: Expr, header: RecordHeader = header): Column = {
    expr.asSparkSQLExpr(header, df, CypherMap.empty)
  val df: DataFrame = sparkSession.createDataFrame(
  implicit def extractRecordHeaderFromResult[T](tuple: (RecordHeader, T)): RecordHeader = tuple._1

object ExprEval {

  implicit class ExprOps(val expr: Expr) extends AnyVal {

    def eval(implicit spark: SparkSession): Any = {
      val df = spark.createDataFrame(

      expr.asSparkSQLExpr(RecordHeader.empty, df, CypherMap.empty).expr.eval(InternalRow.empty)
Example 97
Source File: YelpHelpers.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up
package org.opencypher.morpheus.integration.yelp

import org.apache.spark.sql.types.{ArrayType, DateType, IntegerType, LongType}
import org.apache.spark.sql.{Column, DataFrame, SparkSession, functions}
import{sourceEndNodeKey, sourceStartNodeKey}
import org.opencypher.morpheus.impl.table.SparkTable._
import org.opencypher.morpheus.integration.yelp.YelpConstants._

object YelpHelpers {

  case class YelpTables(
    userDf: DataFrame,
    businessDf: DataFrame,
    reviewDf: DataFrame

  def loadYelpTables(inputPath: String)(implicit spark: SparkSession): YelpTables = {
    import spark.implicits._

    log("read business.json", 2)
    val rawBusinessDf ="$inputPath/business.json")
    log("read review.json", 2)
    val rawReviewDf ="$inputPath/review.json")
    log("read user.json", 2)
    val rawUserDf ="$inputPath/user.json")

    val businessDf =$"business_id".as(sourceIdKey), $"business_id", $"name", $"address", $"city", $"state")
    val reviewDf =$"review_id".as(sourceIdKey), $"user_id".as(sourceStartNodeKey), $"business_id".as(sourceEndNodeKey), $"stars", $"date".cast(DateType))
    val userDf =
      functions.split($"elite", ",").cast(ArrayType(LongType)).as("elite"))

    YelpTables(userDf, businessDf, reviewDf)

  def printYelpStats(inputPath: String)(implicit spark: SparkSession): Unit = {
    val rawBusinessDf ="$inputPath/business.json")
    val rawReviewDf ="$inputPath/review.json")

    import spark.implicits._$"city", $"state").distinct().show()
    rawBusinessDf.withColumnRenamed("business_id", "id")
      .join(rawReviewDf, $"id" === $"business_id")
      .groupBy($"city", $"state")
      .orderBy($"count".desc, $"state".asc)

  def extractYelpCitySubset(inputPath: String, outputPath: String, city: String)(implicit spark: SparkSession): Unit = {
    import spark.implicits._

    def emailColumn(userId: String): Column = functions.concat($"$userId", functions.lit(""))

    val rawUserDf ="$inputPath/user.json")
    val rawReviewDf ="$inputPath/review.json")
    val rawBusinessDf ="$inputPath/business.json")

    val businessDf = rawBusinessDf.filter($"city" === city)
    val reviewDf = rawReviewDf
      .join(businessDf, Seq("business_id"), "left_semi")
      .withColumn("user_email", emailColumn("user_id"))
      .withColumnRenamed("stars", "stars_tmp")
      .withColumn("stars", $"stars_tmp".cast(IntegerType))
    val userDf = rawUserDf
      .join(reviewDf, Seq("user_id"), "left_semi")
      .withColumn("email", emailColumn("user_id"))
    val friendDf = userDf
      .select($"email".as("user1_email"), functions.explode(functions.split($"friends", ", ")).as("user2_id"))
      .withColumn("user2_email", emailColumn("user2_id"))
      .select(s"user1_email", s"user2_email")


  implicit class DataFrameOps(df: DataFrame) {
    def prependIdColumn(idColumn: String, prefix: String): DataFrame =
      df.transformColumns(idColumn)(column => functions.concat(functions.lit(prefix), column).as(idColumn))
Source File: Locus.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.variant

import is.hail.annotations.Annotation
import is.hail.check.Gen
import is.hail.expr.Parser
import is.hail.utils._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.json4s._

import scala.collection.JavaConverters._
import scala.language.implicitConversions

object Locus {
  val simpleContigs: Seq[String] = (1 to 22).map(_.toString) ++ Seq("X", "Y", "MT")

  def apply(contig: String, position: Int, rg: ReferenceGenome): Locus = {
    rg.checkLocus(contig, position)
    Locus(contig, position)

  def annotation(contig: String, position: Int, rg: Option[ReferenceGenome]): Annotation = {
    rg match {
      case Some(ref) => Locus(contig, position, ref)
      case None => Annotation(contig, position)

  def sparkSchema: StructType =
      StructField("contig", StringType, nullable = false),
      StructField("position", IntegerType, nullable = false)))

  def fromRow(r: Row): Locus = {
    Locus(r.getAs[String](0), r.getInt(1))

  def gen(rg: ReferenceGenome): Gen[Locus] = for {
    (contig, length) <- Contig.gen(rg)
    pos <- Gen.choose(1, length)
  } yield Locus(contig, pos)

  def parse(str: String, rg: ReferenceGenome): Locus = {
    val elts = str.split(":")
    val size = elts.length
    if (size < 2)
      fatal(s"Invalid string for Locus. Expecting contig:pos -- found '$str'.")

    val contig = elts.take(size - 1).mkString(":")
    Locus(contig, elts(size - 1).toInt, rg)

  def parseInterval(str: String, rg: ReferenceGenome, invalidMissing: Boolean = false): Interval =
    Parser.parseLocusInterval(str, rg, invalidMissing)

  def parseIntervals(arr: Array[String], rg: ReferenceGenome, invalidMissing: Boolean): Array[Interval] =, rg, invalidMissing))

  def parseIntervals(arr: java.util.List[String], rg: ReferenceGenome, invalidMissing: Boolean = false): Array[Interval] = parseIntervals(arr.asScala.toArray, rg, invalidMissing)

  def makeInterval(contig: String, start: Int, end: Int, includesStart: Boolean, includesEnd: Boolean,
    rgBase: ReferenceGenome, invalidMissing: Boolean = false): Interval = {
    val rg = rgBase.asInstanceOf[ReferenceGenome]
    rg.toLocusInterval(Interval(Locus(contig, start), Locus(contig, end), includesStart, includesEnd), invalidMissing)

case class Locus(contig: String, position: Int) {
  def toRow: Row = Row(contig, position)

  def toJSON: JValue = JObject(
    ("contig", JString(contig)),
    ("position", JInt(position)))

  def copyChecked(rg: ReferenceGenome, contig: String = contig, position: Int = position): Locus = {
    rg.checkLocus(contig, position)
    Locus(contig, position)

  def isAutosomalOrPseudoAutosomal(rg: ReferenceGenome): Boolean = isAutosomal(rg) || inXPar(rg) || inYPar(rg)

  def isAutosomal(rg: ReferenceGenome): Boolean = !(inX(rg) || inY(rg) || isMitochondrial(rg))

  def isMitochondrial(rg: ReferenceGenome): Boolean = rg.isMitochondrial(contig)

  def inXPar(rg: ReferenceGenome): Boolean = rg.inXPar(this)

  def inYPar(rg: ReferenceGenome): Boolean = rg.inYPar(this)

  def inXNonPar(rg: ReferenceGenome): Boolean = inX(rg) && !inXPar(rg)

  def inYNonPar(rg: ReferenceGenome): Boolean = inY(rg) && !inYPar(rg)

  private def inX(rg: ReferenceGenome): Boolean = rg.inX(contig)

  private def inY(rg: ReferenceGenome): Boolean = rg.inY(contig)

  override def toString: String = s"$contig:$position"
Example 99
Source File: DataFrameComparisonTest.scala    From spark-tools   with Apache License 2.0 5 votes vote down vote up
package io.univalence.sparktest

import io.univalence.schema.SchemaComparator.SchemaError
import org.apache.spark.SparkContext
import org.apache.spark.sql.{ Row, SparkSession }
import org.apache.spark.sql.types.{ IntegerType, StructField, StructType }
import org.scalatest.FunSuite

class DataFrameComparisonTest extends FunSuite with SparkTest {

  val sharedSparkSession: SparkSession = ss
  val sc: SparkContext                 = ss.sparkContext

  // TODO : unordered
  ignore("should assertEquals unordered between equal DF") {
    val dfUT       = Seq(1, 2, 3).toDF("id")
    val dfExpected = Seq(3, 2, 1).toDF("id")


  // TODO : unordered
  ignore("should not assertEquals unordered between DF with different contents") {
    val dfUT       = Seq(1, 2, 3).toDF("id")
    val dfExpected = Seq(2, 1, 4).toDF("id")

    assertThrows[SparkTestError] {

  test("should assertEquals ordered between equal DF") {
    val dfUT       = Seq(1, 2, 3).toDF("id")
    val dfExpected = Seq(1, 2, 3).toDF("id")


  test("should not assertEquals ordered between DF with different contents") {
    val dfUT       = Seq(1, 2, 3).toDF("id")
    val dfExpected = Seq(1, 3, 4).toDF("id")

    assertThrows[SparkTestError] {

  test("should not assertEquals between DF with different schema") {
    val dfUT       = Seq(1, 2, 3).toDF("id")
    val dfExpected = Seq(1, 2, 3).toDF("di")

    assertThrows[SchemaError] {

  test("assertEquals (DF & Seq) : a DF and a Seq with the same content are equal") {
    val seq = Seq(1, 2, 3)
    val df = ss.createDataFrame(
      StructType(List(StructField("number", IntegerType, nullable = true)))


  test("assertEquals (DF & Seq) : a DF and a Seq with different content are not equal") {
    val df    = Seq(1, 3, 3).toDF("number")
    val seqEx = Seq(1, 2, 3)

    assertThrows[SparkTestError] {

  test("should assertEquals ordered between equal DF with columns containing special character") {
    val dfUT       = Seq(1, 2, 3).toDF("id.a")
    val dfExpected = Seq(2, 1, 4).toDF("id.a")

    assertThrows[SparkTestError] {


Source File: MetastoreRelationSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class MetastoreRelationSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
  test("makeCopy and toJSON should work") {
    val table = CatalogTable(
      identifier = TableIdentifier("test", Some("db")),
      tableType = CatalogTableType.VIEW,
      storage = CatalogStorageFormat.empty,
      schema = StructType(StructField("a", IntegerType, true) :: Nil))
    val relation = MetastoreRelation("db", "test")(table, null)

    // No exception should be thrown
    relation.makeCopy(Array("db", "test"))
    // No exception should be thrown

  test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") {
    withTable("bar") {
      withTempView("foo") {
        sql("select 0 as id").createOrReplaceTempView("foo")
        // If we optimize the query in CTAS more than once, the following saveAsTable will fail
        // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])`
        sql("CREATE TABLE bar AS SELECT * FROM foo group by id")
        checkAnswer(spark.table("bar"), Row(0) :: Nil)
        val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar"))
        assert(tableMetadata.provider == Some("hive"), "the expected table is a Hive serde table")
Source File: SubstituteUnresolvedOrdinals.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.types.IntegerType

class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false

  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = {
        case ordinal @ Literal(index: Int, IntegerType) =>
        case other => other
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
Source File: ResolveTableValuedFunctions.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{DataType, IntegerType, LongType}

      tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
          "numPartitions" -> IntegerType) {
        case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
          Range(start, end, step, Some(numPartitions))

  override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
    case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
      builtinFunctions.get(u.functionName) match {
        case Some(tvf) =>
          val resolved = tvf.flatMap { case (argList, resolver) =>
            argList.implicitCast(u.functionArgs) match {
              case Some(casted) =>
              case _ =>
          resolved.headOption.getOrElse {
            val argTypes =", ")
              s"""error: table-valued function ${u.functionName} with alternatives:
                |${ => s" ($x)").mkString("\n")}
                |cannot be applied to: (${argTypes})""".stripMargin)
        case _ =>
          u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
Source File: LogicalPlanSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType

class LogicalPlanSuite extends SparkFunSuite {
  private var invocationCount = 0
  private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
    case p: Project =>
      invocationCount += 1

  private val testRelation = LocalRelation()

  test("resolveOperator runs on operators") {
    invocationCount = 0
    val plan = Project(Nil, testRelation)
    plan resolveOperators function

    assert(invocationCount === 1)

  test("resolveOperator runs on operators recursively") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan resolveOperators function

    assert(invocationCount === 2)

  test("resolveOperator skips all ready resolved plans") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan resolveOperators function

    assert(invocationCount === 0)

  test("resolveOperator skips partially resolved plans") {
    invocationCount = 0
    val plan1 = Project(Nil, testRelation)
    val plan2 = Project(Nil, plan1)
    plan2 resolveOperators function

    assert(invocationCount === 1)

  test("isStreaming") {
    val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
    val incrementalRelation = new LocalRelation(
      Seq(AttributeReference("a", IntegerType, nullable = true)())) {
      override def isStreaming(): Boolean = true

    case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
      override def output: Seq[Attribute] = left.output ++ right.output

    require(relation.isStreaming === false)
    require(incrementalRelation.isStreaming === true)
    assert(TestBinaryRelation(relation, relation).isStreaming === false)
    assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true)
    assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
    assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
Source File: MapDataSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class MapDataSuite extends SparkFunSuite {

  test("inequality tests") {
    def u(str: String): UTF8String = UTF8String.fromString(str)

    // test data
    val testMap1 = Map(u("key1") -> 1)
    val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
    val testMap3 = Map(u("key1") -> 1)
    val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
Source File: ScalaUDFSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
    checkEvaluation(stringUdf, "ax")

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase,
      Literal.create(null, StringType) :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvalutionWithUnsafeProjection(udf, null)
    assert(e2.getMessage.contains("Failed to execute user defined function"))

Source File: ExpressionEvalHelperSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}

case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
        |int some_variable = 11;
        |int ${ev.value} = 10;
  override def dataType: DataType = IntegerType
Source File: ObjectExpressionsSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{IntegerType, ObjectType}

class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("SPARK-16622: The returned value of the called method in Invoke can be null") {
    val inputRow = InternalRow.fromSeq(Seq((false, null)))
    val cls = classOf[Tuple2[Boolean, java.lang.Integer]]
    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
    val invoke = Invoke(inputObject, "_2", IntegerType)
    checkEvaluationWithGeneratedMutableProjection(invoke, null, inputRow)

  test("MapObjects should make copies of unsafe-backed data") {
    // test UnsafeRow-backed data
    val structEncoder = ExpressionEncoder[Array[Tuple2[java.lang.Integer, java.lang.Integer]]]
    val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4))))
    val structExpected = new GenericArrayData(
      Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
      structEncoder.serializer.head, structExpected, structInputRow)

    // test UnsafeArray-backed data
    val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
    val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
    val arrayExpected = new GenericArrayData(
      Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
      arrayEncoder.serializer.head, arrayExpected, arrayInputRow)

    // test UnsafeMap-backed data
    val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
    val mapInputRow = InternalRow.fromSeq(Seq(Array(
      Map(1 -> 100, 2 -> 200), Map(3 -> 300, 4 -> 400))))
    val mapExpected = new GenericArrayData(Seq(
      new ArrayBasedMapData(
        new GenericArrayData(Array(1, 2)),
        new GenericArrayData(Array(100, 200))),
      new ArrayBasedMapData(
        new GenericArrayData(Array(3, 4)),
        new GenericArrayData(Array(300, 400)))))
      mapEncoder.serializer.head, mapExpected, mapInputRow)
Source File: CallMethodViaReflectionSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.types.{IntegerType, StringType}

class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelper {

  import CallMethodViaReflection._

  // Get rid of the $ so we are getting the companion object's name.
  private val staticClassName = ReflectStaticClass.getClass.getName.stripSuffix("$")
  private val dynamicClassName = classOf[ReflectDynamicClass].getName

  test("findMethod via reflection for static methods") {
    assert(findMethod(staticClassName, "method1", Seq.empty).exists(_.getName == "method1"))
    assert(findMethod(staticClassName, "method2", Seq(IntegerType)).isDefined)
    assert(findMethod(staticClassName, "method3", Seq(IntegerType)).isDefined)
    assert(findMethod(staticClassName, "method4", Seq(IntegerType, StringType)).isDefined)

  test("findMethod for a JDK library") {
    assert(findMethod(classOf[java.util.UUID].getName, "randomUUID", Seq.empty).isDefined)

  test("class not found") {
    val ret = createExpr("some-random-class", "method").checkInputDataTypes()
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("not found") && errorMsg.contains("class"))

  test("method not found because name does not match") {
    val ret = createExpr(staticClassName, "notfoundmethod").checkInputDataTypes()
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("cannot find a static method"))

  test("method not found because there is no static method") {
    val ret = createExpr(dynamicClassName, "method1").checkInputDataTypes()
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("cannot find a static method"))

  test("input type checking") {
      Seq(Literal(staticClassName), Literal(1))).checkInputDataTypes().isFailure)
    assert(createExpr(staticClassName, "method1").checkInputDataTypes().isSuccess)

  test("invoking methods using acceptable types") {
    checkEvaluation(createExpr(staticClassName, "method1"), "m1")
    checkEvaluation(createExpr(staticClassName, "method2", 2), "m2")
    checkEvaluation(createExpr(staticClassName, "method3", 3), "m3")
    checkEvaluation(createExpr(staticClassName, "method4", 4, "four"), "m4four")

  private def createExpr(className: String, methodName: String, args: Any*) = {
      Literal.create(className, StringType) +:
      Literal.create(methodName, StringType) +:
Example 109
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)

  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bUpper) === false)
    assert(aSet.contains(bLower) === false)

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
Source File: RewriteDistinctAggregatesSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, '

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")

  test("single distinct group") {
    val input = testRelation
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
        countDistinct('b, 'c),
Source File: SimplifyConditionalSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{IntegerType, NullType}

class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)

  private val trueBranch = (TrueLiteral, Literal(5))
  private val normalBranch = (NonFoldableLiteral(true), Literal(10))
  private val unreachableBranch = (FalseLiteral, Literal(20))
  private val nullBranch = (Literal.create(null, NullType), Literal(30))

  test("simplify if") {
      If(TrueLiteral, Literal(10), Literal(20)),

      If(FalseLiteral, Literal(10), Literal(20)),

      If(Literal.create(null, NullType), Literal(10), Literal(20)),

  test("remove unreachable branches") {
    // i.e. removing branches whose conditions are always false
      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
      CaseWhen(normalBranch :: Nil, None))

  test("remove entire CaseWhen if only the else branch is reachable") {
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),

      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
      Literal.create(null, IntegerType))

  test("remove entire CaseWhen if the first branch is always true") {
      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),

    // Test branch elimination and simplification in combination
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
        :: Nil, None),

    // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
Example 112
Source File: resources.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command


import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
Source File: WholeStageCodegenSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

  test("range/filter should be combined") {
    val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
    val plan = df.queryExecution.executedPlan
    assert(df.collect() === Array(Row(2)))

  test("Aggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(df.collect() === Array(Row(9, 4.5)))

  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
    val df = spark.range(3).groupBy("id").count().orderBy("id")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))

  test("BroadcastHashJoin should be included in WholeStageCodegen") {
    val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
    val schema = new StructType().add("k", IntegerType).add("v", StringType)
    val smallDF = spark.createDataFrame(rdd, schema)
    val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
    assert(df.queryExecution.executedPlan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))

  test("Sort should be included in WholeStageCodegen") {
    val df = spark.range(3, 0, -1).toDF().sort(col("id"))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(df.collect() === Array(Row(1), Row(2), Row(3)))

  test("MapElements should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = spark.range(10).map(_.toString)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(ds.collect() === 0.until(10).map(_.toString).toArray)

  test("typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(ds.collect() === Array(0, 2, 4, 6, 8))

  test("back-to-back typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(ds.collect() === Array(0, 6))

  test("simple typed UDAF should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()

    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
Source File: GroupedIteratorSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(,
      Seq(', schema.toAttributes)

    val result = {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) ->

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(,
      Seq(', ', schema.toAttributes)

    val result = {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1),

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(,
      Seq(', schema.toAttributes)

    assert(grouped.length == 2)
Source File: CarbonDataFrameExample.scala    From CarbonDataLearning   with GNU General Public License v3.0 5 votes vote down vote up
package org.github.xubo245.carbonDataLearning.example

import org.apache.carbondata.examples.util.ExampleUtils
import org.apache.spark.sql.{SaveMode, SparkSession}

object CarbonDataFrameExample {

  def main(args: Array[String]) {
    val spark = ExampleUtils.createCarbonSession("CarbonDataFrameExample")

  def exampleBody(spark : SparkSession): Unit = {
    // Writes Dataframe to CarbonData file:
    import spark.implicits._
    val df = spark.sparkContext.parallelize(1 to 100)
      .map(x => ("a" + x % 10, "b", x))
      .toDF("c1", "c2", "number")

    // Saves dataframe to carbondata file
      .option("tableName", "carbon_df_table")
      .option("partitionColumns", "c1")  // a list of column names

    spark.sql(""" SELECT * FROM carbon_df_table """).show()

    spark.sql("SHOW PARTITIONS carbon_df_table").show()

    // Specify schema
    import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
    val customSchema = StructType(Array(
      StructField("c1", StringType),
      StructField("c2", StringType),
      StructField("number", IntegerType)))

    // Reads carbondata to dataframe
    val carbondf =
      // .option("dbname", "db_name") the system will use "default" as dbname if not set this option
      .option("tableName", "carbon_df_table")

      .option("tableName", "csv_df_table")
      .option("partitionColumns", "c1") // a list of column names
      //      .option("timestampFormat", "yyyy/MM/dd HH:mm:ss ZZ")

    // Reads carbondata to dataframe
    val carbondf2 =
      // .option("dbname", "db_name") the system will use "default" as dbname if not set this option
      .option("tableName", "csv_df_table")

      //      .option("timestampFormat", "yyyy/MM/dd HH:mm:ss ZZ")

    // Dataframe operations
    carbondf.printSchema()$"c1", $"number" + 10).show()
    carbondf.filter($"number" > 31).show()

    spark.sql("DROP TABLE IF EXISTS carbon_df_table")
Source File: TestMetadataConstructor.scala    From spark-salesforce   with Apache License 2.0 5 votes vote down vote up
package com.springml.spark.salesforce.metadata

import org.apache.spark.sql.types.{StructType, StringType, IntegerType, LongType,
  FloatType, DateType, TimestampType, BooleanType, StructField}
import org.scalatest.FunSuite
import com.springml.spark.salesforce.Utils

class TestMetadataConstructor extends FunSuite {

  test("Test Metadata generation") {
    val columnNames = List("c1", "c2", "c3", "c4")
    val columnStruct = => StructField(colName, StringType, true))
    val schema = StructType(columnStruct)

    val schemaString = MetadataConstructor.generateMetaString(schema,"sampleDataSet", Utils.metadataConfig(null))
    assert(schemaString.length > 0)

  test("Test Metadata generation With Custom MetadataConfig") {
    val columnNames = List("c1", "c2", "c3", "c4")
    val intField = StructField("intCol", IntegerType, true)
    val longField = StructField("longCol", LongType, true)
    val floatField = StructField("floatCol", FloatType, true)
    val dateField = StructField("dateCol", DateType, true)
    val timestampField = StructField("timestampCol", TimestampType, true)
    val stringField = StructField("stringCol", StringType, true)
    val someTypeField = StructField("someTypeCol", BooleanType, true)

    val columnStruct = Array[StructField] (intField, longField, floatField, dateField, timestampField, stringField, someTypeField)

    val schema = StructType(columnStruct)

    var metadataConfig = Map("string" -> Map("wave_type" -> "Text"))
    metadataConfig += ("integer" -> Map("wave_type" -> "Numeric", "precision" -> "10", "scale" -> "0", "defaultValue" -> "100"))
    metadataConfig += ("float" -> Map("wave_type" -> "Numeric", "precision" -> "10", "scale" -> "2"))
    metadataConfig += ("long" -> Map("wave_type" -> "Numeric", "precision" -> "18", "scale" -> "0"))
    metadataConfig += ("date" -> Map("wave_type" -> "Date", "format" -> "yyyy/MM/dd"))
    metadataConfig += ("timestamp" -> Map("wave_type" -> "Date", "format" -> "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"))

    val schemaString = MetadataConstructor.generateMetaString(schema, "sampleDataSet", metadataConfig)
    assert(schemaString.length > 0)
Source File: SparkScoreDoc.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd.models

import org.apache.lucene.document.Document
import org.apache.lucene.index.IndexableField
import{IndexSearcher, ScoreDoc}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.zouzias.spark.lucenerdd.models.SparkScoreDoc.inferNumericType
import org.zouzias.spark.lucenerdd.models.SparkScoreDoc.{DocIdField, ScoreField, ShardField}

import scala.collection.JavaConverters._

sealed trait FieldType extends Serializable
object TextType extends FieldType
object IntType extends FieldType
object DoubleType extends FieldType
object LongType extends FieldType
object FloatType extends FieldType

  private def inferNumericType(num: Number): FieldType = {
    num match {
      case _: java.lang.Double => DoubleType
      case _: java.lang.Long => LongType
      case _: java.lang.Integer => IntType
      case _: java.lang.Float => FloatType
      case _ => TextType
Source File: SummarizeSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{ LongType, IntegerType, DoubleType }

class SummarizeSpec extends MultiPartitionSuite {

  override val defaultResourceDir: String = "/timeseries/summarize"

  it should "`summarize` correctly" in {
    val expectedSchema = Schema("volume_sum" -> DoubleType)
    val expectedResults = Array[Row](new GenericRowWithSchema(Array(0L, 7800.0), expectedSchema))

    def test(rdd: TimeSeriesRDD): Unit = {
      val results = rdd.summarize(Summarizers.sum("volume"))
      assert(results.schema == expectedSchema)
      assert(results.collect().deep == expectedResults.deep)

      val volumeRdd = fromCSV("Volume.csv", Schema("id" -> IntegerType, "volume" -> LongType))


  it should "`summarize` per key correctly" in {
    val expectedSchema = Schema("id" -> IntegerType, "volume_sum" -> DoubleType)
    val expectedResults = Array[Row](
      new GenericRowWithSchema(Array(0L, 7, 4100.0), expectedSchema),
      new GenericRowWithSchema(Array(0L, 3, 3700.0), expectedSchema)

    def test(rdd: TimeSeriesRDD): Unit = {
      val results = rdd.summarize(Summarizers.sum("volume"), Seq("id"))
      assert(results.schema == expectedSchema)
      assert(results.collect().sortBy(_.getAs[Int]("id")).deep == expectedResults.sortBy(_.getAs[Int]("id")).deep)

      val volumeTSRdd = fromCSV("Volume.csv", Schema("id" -> IntegerType, "volume" -> LongType))
Source File: SummarizeCyclesSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.sql.types.{ DoubleType, IntegerType, LongType }

class SummarizeCyclesSpec extends MultiPartitionSuite with TimeSeriesTestData with TimeTypeSuite {

  override val defaultResourceDir: String = "/timeseries/summarizecycles"
  private val volumeSchema = Schema("id" -> IntegerType, "volume" -> LongType, "v2" -> DoubleType)
  private val volume2Schema = Schema("id" -> IntegerType, "volume" -> LongType)
  private val volumeWithGroupSchema = Schema(
    "id" -> IntegerType, "group" -> IntegerType, "volume" -> LongType, "v2" -> DoubleType

  "SummarizeCycles" should "pass `SummarizeSingleColumn` test." in {
    withAllTimeType {
      val resultTSRdd = fromCSV("SummarizeSingleColumn.results", Schema("volume_sum" -> DoubleType))

      def test(rdd: TimeSeriesRDD): Unit = {
        val summarizedVolumeTSRdd = rdd.summarizeCycles(Summarizers.sum("volume"))
        assertEquals(summarizedVolumeTSRdd, resultTSRdd)

      val volumeTSRdd = fromCSV("Volume.csv", volumeSchema)

  it should "pass `SummarizeSingleColumnPerKey` test, i.e. with additional a single key." in {
    withAllTimeType {
      val resultTSRdd = fromCSV(
        Schema("id" -> IntegerType, "volume_sum" -> DoubleType)

      def test(rdd: TimeSeriesRDD): Unit = {
        val summarizedVolumeTSRdd = rdd.summarizeCycles(Summarizers.sum("volume"), Seq("id"))
        assertEquals(summarizedVolumeTSRdd, resultTSRdd)

      val volumeTSRdd = fromCSV("Volume2.csv", volume2Schema)

  it should "pass `SummarizeSingleColumnPerSeqOfKeys` test, i.e. with additional a sequence of keys." in {
    withAllTimeType {
      val resultTSRdd = fromCSV(
        Schema("id" -> IntegerType, "group" -> IntegerType, "volume_sum" -> DoubleType)

      def test(rdd: TimeSeriesRDD): Unit = {
        val summarizedVolumeTSRdd = rdd.summarizeCycles(Summarizers.sum("volume"), Seq("id", "group"))
        assertEquals(summarizedVolumeTSRdd, resultTSRdd)

  it should "pass generated cycle data test" in {
    // TODO: The way cycleData works now doesn't support changing time type.
    val testData = cycleData1

    def sum(rdd: TimeSeriesRDD): TimeSeriesRDD = {
      rdd.summarizeCycles(Summarizers.compose(Summarizers.count(), Summarizers.sum("v1")))

Example 120
Source File: TimeSeriesRDDCacheSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DoubleType, IntegerType }
import org.scalatest.concurrent.Timeouts
import org.scalatest.tagobjects.Slow
import org.scalatest.time.{ Second, Span }

  "TimeSeriesRDD" should "correctly cache data" taggedAs Slow in {
    withResource("/timeseries/csv/Price.csv") { source =>
      val priceSchema = Schema("id" -> IntegerType, "price" -> DoubleType)
      val timeSeriesRdd = CSV.from(sqlContext, "file://" + source, sorted = true, schema = priceSchema)

      val slowTimeSeriesRdd = timeSeriesRdd.addColumns("new_column" -> DoubleType -> {
        row: Row =>
          row.getAs[Double]("price") + 1.0

      // run a dummy addColumns() to initialize TSRDD's internal state
      slowTimeSeriesRdd.addColumns("foo_column" -> DoubleType -> { _ => 1.0 })

      assert(slowTimeSeriesRdd.count() == 12)

      // this test succeeds only if all representations are correctly cached
      failAfter(Span(1, Second)) {
        assert(slowTimeSeriesRdd.toDF.collect().length == 12)
        assert(slowTimeSeriesRdd.orderedRdd.count() == 12)
        assert(slowTimeSeriesRdd.asInstanceOf[TimeSeriesRDDImpl].unsafeOrderedRdd.count == 12)
Example 121
Source File: CompositeSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.timeseries.{ CSV, Summarizers, TimeSeriesRDD, TimeSeriesSuite }
import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import org.apache.spark.sql.types.{ DoubleType, IntegerType, StructType }

class CompositeSummarizerSpec extends SummarizerSuite {
  // Reuse mean summarizer data
  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/meansummarizer"

  lazy val init: Unit = {
    priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))

  "CompositeSummarizer" should "compute `mean` and `stddev` correctly" in {
    val result = priceTSRdd.summarize(
      Summarizers.compose(Summarizers.mean("price"), Summarizers.stddev("price"))
    val row = result.first()

    assert(row.getAs[Double]("price_mean") === 3.25)
    assert(row.getAs[Double]("price_stddev") === 1.8027756377319946)

  it should "throw exception for conflicting output columns" in {
    intercept[Exception] {
      priceTSRdd.summarize(Summarizers.compose(Summarizers.mean("price"), Summarizers.mean("price")))

  it should "handle conflicting output columns using prefix" in {
    val result = priceTSRdd.summarize(
      Summarizers.compose(Summarizers.mean("price"), Summarizers.mean("price").prefix("prefix"))

    val row = result.first()

    assert(row.getAs[Double]("price_mean") === 3.25)
    assert(row.getAs[Double]("prefix_price_mean") === 3.25)

  it should "handle null values" in {
    val inputWithNull = insertNullRows(priceTSRdd, "price")
    val row = inputWithNull.summarize(

    val count = priceTSRdd.count()
    assert(row.getAs[Long]("count") == 2 * count)
    assert(row.getAs[Long]("id_count") == 2 * count)
    assert(row.getAs[Long]("price_count") == count)
Source File: MeanSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import com.twosigma.flint.timeseries.{ Summarizers, TimeSeriesSuite }
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class MeanSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/meansummarizer"

  "MeanSummarizer" should "compute `mean` correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val result = priceTSRdd.summarize(Summarizers.mean("price")).first()
    assert(result.getAs[Double]("price_mean") === 3.25)

  it should "ignore null values" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
      insertNullRows(priceTSRdd, "price").summarize(Summarizers.mean("price"))

  it should "pass summarizer property test" in {
Source File: ExtremeSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.rdd.function.summarize.summarizer.Summarizer
import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.{ SummarizerFactory, SummarizerSuite }
import com.twosigma.flint.timeseries.{ CSV, Summarizers, TimeSeriesRDD, TimeSeriesSuite }
import org.apache.spark.sql.types.{ DataType, DoubleType, FloatType, IntegerType, LongType, StructType }
import java.util.Random

class ExtremeSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/meansummarizer"

  private def test[T](
    dataType: DataType,
    randValue: Row => Any,
    summarizer: String => SummarizerFactory,
    reduceFn: (T, T) => T,
    inputColumn: String,
    outputColumn: String
  ): Unit = {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType)).addColumns(
      inputColumn -> dataType -> randValue

    val data = priceTSRdd.collect().map{ row => row.getAs[T](inputColumn) }

    val trueExtreme = data.reduceLeft[T]{ case (x, y) => reduceFn(x, y) }

    val result = priceTSRdd.summarize(summarizer(inputColumn))

    val extreme = result.first().getAs[T](outputColumn)
    val outputType = result.schema(outputColumn).dataType

    assert(outputType == dataType, s"$outputType")
    assert(trueExtreme === extreme, s"extreme: $extreme, trueExtreme: $trueExtreme, data: ${data.toSeq}")

    val rand = new Random()
    test[Double](DoubleType, { _: Row => rand.nextDouble() }, Summarizers.max, math.max, "x", "x_max")

  it should "compute long max correctly" in {
    val rand = new Random()
    test[Long](LongType, { _: Row => rand.nextLong() }, Summarizers.max, math.max, "x", "x_max")

  it should "compute float max correctly" in {
    val rand = new Random()
    test[Float](FloatType, { _: Row => rand.nextFloat() }, Summarizers.max, math.max, "x", "x_max")

  it should "compute int max correctly" in {
    val rand = new Random()
    test[Int](IntegerType, { _: Row => rand.nextInt() }, Summarizers.max, math.max, "x", "x_max")

  "MinSummarizer" should "compute double min correctly" in {
    val rand = new Random()
    test[Double](DoubleType, { _: Row => rand.nextDouble() }, Summarizers.min, math.min, "x", "x_min")

  it should "compute long min correctly" in {
    val rand = new Random()
    test[Long](LongType, { _: Row => rand.nextLong() }, Summarizers.min, math.min, "x", "x_min")

  it should "compute float min correctly" in {
    val rand = new Random()
    test[Float](FloatType, { _: Row => rand.nextFloat() }, Summarizers.min, math.min, "x", "x_min")

  it should "compute int min correctly" in {
    val rand = new Random()
    test[Int](IntegerType, { _: Row => rand.nextInt() }, Summarizers.min, math.min, "x", "x_min")

  it should "ignore null values" in {
    val input = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val inputWithNull = insertNullRows(input, "price")

Source File: GeometricMeanSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer.subtractable

import com.twosigma.flint.timeseries.{ Summarizers, Windows }
import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class GeometricMeanSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/geometricmeansummarizer"

  "GeometricMeanSummarizer" should "compute geometric mean correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema(
      "id" -> IntegerType,
      "price" -> DoubleType,
      "priceWithZero" -> DoubleType,
      "priceWithNegatives" -> DoubleType
    val results = priceTSRdd.summarize(Summarizers.geometricMean("price"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_geometricMean") === 2.621877636494)
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_geometricMean") === 2.667168275340)

  it should "compute geometric mean with a zero correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema(
      "id" -> IntegerType,
      "price" -> DoubleType,
      "priceWithZero" -> DoubleType,
      "priceWithNegatives" -> DoubleType
    var results = priceTSRdd.summarize(Summarizers.geometricMean("priceWithZero")).collect()
    assert(results.head.getAs[Double]("priceWithZero_geometricMean") === 0.0)

    // Test that having a zero exit the window still computes correctly.
    results = priceTSRdd.coalesce(1).summarizeWindows(
      Windows.pastAbsoluteTime("50 ns"),
    assert(results.head.getAs[Double]("priceWithZero_geometricMean") === 0.0)
    assert(results.last.getAs[Double]("priceWithZero_geometricMean") === 5.220043408524)

  it should "compute geometric mean with negative values correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema(
      "id" -> IntegerType,
      "price" -> DoubleType,
      "priceWithZero" -> DoubleType,
      "priceWithNegatives" -> DoubleType
    val results = priceTSRdd.summarize(Summarizers.geometricMean("priceWithNegatives"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("priceWithNegatives_geometricMean")
      === -2.621877636494)
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("priceWithNegatives_geometricMean")
      === 2.667168275340)

Example 125
Source File: DotProductSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer.subtractable

import com.twosigma.flint.timeseries.Summarizers
import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class DotProductSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/dotproductsummarizer"

  "DotProductSummarizer" should "compute dot product correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val results = priceTSRdd.summarize(Summarizers.dotProduct("price", "price"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_price_dotProduct") === 72.25)
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_price_dotProduct") === 90.25)

    summarizerPropertyTest(AllPropertiesAndSubtractable)(Summarizers.dotProduct("x1", "x2"))
Example 126
Source File: ProductSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer.subtractable

import com.twosigma.flint.timeseries.{ Summarizers, Windows }
import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class ProductSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/productsummarizer"

  "ProductSummarizer" should "compute product correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema(
      "id" -> IntegerType,
      "price" -> DoubleType,
      "priceWithZero" -> DoubleType,
      "priceWithNegatives" -> DoubleType
    val results = priceTSRdd.summarize(Summarizers.product("price"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_product") === 324.84375)
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_product") === 360.0)

  it should "compute product with a zero correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema(
      "id" -> IntegerType,
      "price" -> DoubleType,
      "priceWithZero" -> DoubleType,
      "priceWithNegatives" -> DoubleType
    var results = priceTSRdd.summarize(Summarizers.product("priceWithZero")).collect()
    assert(results.head.getAs[Double]("priceWithZero_product") === 0.0)

    // Test that having a zero exit the window still computes correctly.
    results = priceTSRdd.coalesce(1).summarizeWindows(
      Windows.pastAbsoluteTime("50 ns"),
    assert(results.head.getAs[Double]("priceWithZero_product") === 0.0)
    assert(results.last.getAs[Double]("priceWithZero_product") === 742.5)

  it should "compute product with negative values correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema(
      "id" -> IntegerType,
      "price" -> DoubleType,
      "priceWithZero" -> DoubleType,
      "priceWithNegatives" -> DoubleType
    val results = priceTSRdd.summarize(Summarizers.product("priceWithNegatives"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("priceWithNegatives_product") === -324.84375)
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("priceWithNegatives_product") === 360.0)

Example 127
Source File: StandardizedMomentSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer.subtractable

import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import com.twosigma.flint.timeseries.Summarizers
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class StandardizedMomentSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/standardizedmomentsummarizer"

  "SkewnessSummarizer" should "compute skewness correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val results = priceTSRdd.summarize(Summarizers.skewness("price"))
    assert(results.collect().head.getAs[Double]("price_skewness") === 0.0)

  it should "ignore null values" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
      insertNullRows(priceTSRdd, "price").summarize(Summarizers.skewness("price"))

  it should "pass summarizer property test" in {

  "KurtosisSummarizer" should "compute kurtosis correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val results = priceTSRdd.summarize(Summarizers.kurtosis("price"))
    assert(results.collect().head.getAs[Double]("price_kurtosis") === -1.2167832167832167)

  it should "ignore null values" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
      insertNullRows(priceTSRdd, "price").summarize(Summarizers.kurtosis("price"))

Example 128
Source File: ZScoreSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer.subtractable

import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import com.twosigma.flint.timeseries.Summarizers
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class ZScoreSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/zscoresummarizer"

  "ZScoreSummarizer" should "compute in-sample `zScore` correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val expectedSchema = Schema("price_zScore" -> DoubleType)
    val expectedResults = Array[Row](new GenericRowWithSchema(Array(0L, 1.5254255396193801), expectedSchema))
    val results = priceTSRdd.summarize(Summarizers.zScore("price", true))
    assert(results.schema == expectedSchema)
    assert(results.collect().deep == expectedResults.deep)

  it should "compute out-of-sample `zScore` correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val expectedSchema = Schema("price_zScore" -> DoubleType)
    val expectedResults = Array[Row](new GenericRowWithSchema(Array(0L, 1.8090680674665818), expectedSchema))
    val results = priceTSRdd.summarize(Summarizers.zScore("price", false))
    assert(results.schema == expectedSchema)
    assert(results.collect().deep == expectedResults.deep)

  it should "ignore null values" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
      priceTSRdd.summarize(Summarizers.zScore("price", true)),
      insertNullRows(priceTSRdd, "price").summarize(Summarizers.zScore("price", true))

  it should "pass summarizer property test" in {
    summarizerPropertyTest(AllPropertiesAndSubtractable)(Summarizers.zScore("x2", false))
Example 129
Source File: StandardDeviationSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import com.twosigma.flint.timeseries.Summarizers
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class StandardDeviationSummarizerSpec extends SummarizerSuite {
  // It is by intention to reuse the files
  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/meansummarizer"

  "StandardDeviationSummarizer" should "compute `stddev` correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType)).addColumns(
      "price2" -> DoubleType -> { r: Row => r.getAs[Double]("price") },
      "price3" -> DoubleType -> { r: Row => -r.getAs[Double]("price") },
      "price4" -> DoubleType -> { r: Row => r.getAs[Double]("price") * 2 },
      "price5" -> DoubleType -> { r: Row => 0d }

    val result = priceTSRdd.summarize(Summarizers.stddev("price")).first()
    assert(result.getAs[Double]("price_stddev") === 1.802775638)

  it should "ignore null values" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
      insertNullRows(priceTSRdd, "price").summarize(Summarizers.stddev("price"))

  it should "pass summarizer property test" in {
Source File: PredicateSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.timeseries.{ Summarizers, TimeSeriesRDD }
import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class PredicateSummarizerSpec extends SummarizerSuite {
  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/meansummarizer"
  var priceTSRdd: TimeSeriesRDD = _

  private lazy val init = {
    priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))

  "PredicateSummarizer" should "return the same results as filtering TSRDD first" in {
    val summarizer = Summarizers.compose(Summarizers.mean("price"), Summarizers.stddev("price"))

    val predicate: Int => Boolean = id => id == 3
    val resultWithPredicate = priceTSRdd.summarize(summarizer.where(predicate)("id")).first()

    val filteredTSRDD = priceTSRdd.keepRows {
      row: Row => row.getAs[Int]("id") == 3
    val filteredResults = filteredTSRDD.summarize(summarizer).first()

    assert(resultWithPredicate.getAs[Double]("price_mean") === filteredResults.getAs[Double]("price_mean"))
    assert(resultWithPredicate.getAs[Double]("price_stddev") === filteredResults.getAs[Double]("price_stddev"))

  it should "pass summarizer property test" in {
    val predicate: Double => Boolean = num => num > 0
Example 131
Source File: VarianceSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import com.twosigma.flint.timeseries.Summarizers
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class VarianceSummarizerSpec extends SummarizerSuite {
  // It is by intention to reuse the files
  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/meansummarizer"

  "StandardDeviationSummarizer" should "compute `stddev` correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType)).addColumns(
      "price2" -> DoubleType -> { r: Row => r.getAs[Double]("price") },
      "price3" -> DoubleType -> { r: Row => -r.getAs[Double]("price") },
      "price4" -> DoubleType -> { r: Row => r.getAs[Double]("price") * 2 },
      "price5" -> DoubleType -> { r: Row => 0d }

    val result = priceTSRdd.summarize(Summarizers.variance("price")).first()
    assert(result.getAs[Double]("price_variance") === 3.250000000)

  it should "ignore null values" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
      insertNullRows(priceTSRdd, "price").summarize(Summarizers.variance("price"))

Example 132
Source File: CovarianceSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import com.twosigma.flint.timeseries.{ Summarizers, TimeSeriesRDD, TimeSeriesSuite }
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class CovarianceSummarizerSpec extends SummarizerSuite {

  // It is by intention to reuse the files from correlation summarizer
  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/correlationsummarizer"

  private var priceTSRdd: TimeSeriesRDD = null
  private var forecastTSRdd: TimeSeriesRDD = null
  private var input: TimeSeriesRDD = null

  private lazy val init: Unit = {
    priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    forecastTSRdd = fromCSV("Forecast.csv", Schema("id" -> IntegerType, "forecast" -> DoubleType))
    input = priceTSRdd.leftJoin(forecastTSRdd, key = Seq("id")).addColumns(
      "price2" -> DoubleType -> { r: Row => r.getAs[Double]("price") },
      "price3" -> DoubleType -> { r: Row => -r.getAs[Double]("price") },
      "price4" -> DoubleType -> { r: Row => r.getAs[Double]("price") * 2 },
      "price5" -> DoubleType -> { r: Row => 0d }

  "CovarianceSummarizer" should "`computeCovariance` correctly" in {
    var results = input.summarize(Summarizers.covariance("price", "price2"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_price2_covariance") === 3.368055556)
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_price2_covariance") === 2.534722222)

    results = input.summarize(Summarizers.covariance("price", "price3"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_price3_covariance") === -3.368055556)
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_price3_covariance") === -2.534722222)

    results = input.summarize(Summarizers.covariance("price", "price4"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_price4_covariance") === 6.736111111)
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_price4_covariance") === 5.069444444)

    results = input.summarize(Summarizers.covariance("price", "price5"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_price5_covariance") === 0d)
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_price5_covariance") === 0d)

    results = input.summarize(Summarizers.covariance("price", "forecast"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_forecast_covariance") === -0.190277778)
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_forecast_covariance") === -3.783333333)

  it should "ignore null values" in {
    val inputWithNull = insertNullRows(input, "price", "forecast")

      inputWithNull.summarize(Summarizers.covariance("price", "forecast")),
      input.summarize(Summarizers.covariance("price", "forecast"))

      inputWithNull.summarize(Summarizers.covariance("price", "forecast"), Seq("id")),
      input.summarize(Summarizers.covariance("price", "forecast"), Seq("id"))

  it should "pass summarizer property test" in {
    summarizerPropertyTest(AllProperties)(Summarizers.covariance("x0", "x3"))
Example 133
Source File: SummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class SummarizerSpec extends TimeSeriesSuite {

  "SummarizerFactory" should "support alias." in {
    withResource("/timeseries/csv/Price.csv") { source =>
      val expectedSchema = Schema("C1" -> IntegerType, "C2" -> DoubleType)
      val timeseriesRdd = CSV.from(sqlContext, "file://" + source, sorted = true, schema = expectedSchema)
      assert(timeseriesRdd.schema == expectedSchema)
      val result: Row = timeseriesRdd.summarize(Summarizers.count().prefix("alias")).first()
Example 134
Source File: SummarizeIntervalsSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.sql.types.{ DoubleType, LongType, IntegerType }

class SummarizeIntervalsSpec extends MultiPartitionSuite with TimeSeriesTestData with TimeTypeSuite {

  override val defaultResourceDir: String = "/timeseries/summarizeintervals"

  "SummarizeInterval" should "pass `SummarizeSingleColumn` test." in {
    withAllTimeType {
      val volumeTSRdd = fromCSV(
        "Volume.csv", Schema("id" -> IntegerType, "volume" -> LongType, "v2" -> DoubleType)

      val clockTSRdd = fromCSV("Clock.csv", Schema())
      val resultTSRdd = fromCSV("SummarizeSingleColumn.results", Schema("volume_sum" -> DoubleType))

      def test(rdd: TimeSeriesRDD): Unit = {
        val summarizedVolumeTSRdd = rdd.summarizeIntervals(clockTSRdd, Summarizers.sum("volume"))
  it should "pass `SummarizeSingleColumnPerKey` test, i.e. with additional a single key." in {
    withAllTimeType {
      val volumeTSRdd = fromCSV(
        "Volume.csv", Schema("id" -> IntegerType, "volume" -> LongType, "v2" -> DoubleType)

      val clockTSRdd = fromCSV("Clock.csv", Schema())
      val resultTSRdd = fromCSV(
        Schema("id" -> IntegerType, "volume_sum" -> DoubleType)

      val result2TSRdd = fromCSV(
        Schema("id" -> IntegerType, "v2_sum" -> DoubleType)

      def test(rdd: TimeSeriesRDD): Unit = {
        val summarizedVolumeTSRdd = rdd.summarizeIntervals(clockTSRdd, Summarizers.sum("volume"), Seq("id"))
        assertEquals(summarizedVolumeTSRdd, resultTSRdd)
        val summarizedV2TSRdd = rdd.summarizeIntervals(clockTSRdd, Summarizers.sum("v2"), Seq("id"))
        assertEquals(summarizedV2TSRdd, result2TSRdd)


  it should "pass `SummarizeSingleColumnPerSeqOfKeys` test, i.e. with additional a sequence of keys." in {
    withAllTimeType {
      val volumeTSRdd = fromCSV(
        Schema("id" -> IntegerType, "group" -> IntegerType, "volume" -> LongType, "v2" -> DoubleType)

      val clockTSRdd = fromCSV("Clock.csv", Schema())
      val resultTSRdd = fromCSV(
        Schema("id" -> IntegerType, "group" -> IntegerType, "volume_sum" -> DoubleType)

      def test(rdd: TimeSeriesRDD): Unit = {
        val summarizedVolumeTSRdd = rdd.summarizeIntervals(
          Seq("id", "group")
        assertEquals(summarizedVolumeTSRdd, resultTSRdd)

Source File: MergeSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class MergeSpec extends MultiPartitionSuite with TimeSeriesTestData {
  override val defaultResourceDir: String = "/timeseries/merge"

  "Merge" should "pass `Merge` test." in {
    val resultsTSRdd = fromCSV("Merge.results", Schema("id" -> IntegerType, "price" -> DoubleType))

    def test(rdd1: TimeSeriesRDD, rdd2: TimeSeriesRDD): Unit = {
      val mergedTSRdd = rdd1.merge(rdd2)
      assert(resultsTSRdd.schema == mergedTSRdd.schema)
      assert(resultsTSRdd.collect().deep == mergedTSRdd.collect().deep)

      val priceTSRdd1 = fromCSV("Price1.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
      val priceTSRdd2 = fromCSV("Price2.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
      withPartitionStrategy(priceTSRdd1, priceTSRdd2)(DEFAULT)(test)

  it should "pass generated cycle data test" in {
    val testData1 = cycleData1
    val testData2 = cycleData2

    def merge(rdd1: TimeSeriesRDD, rdd2: TimeSeriesRDD): TimeSeriesRDD = {
Example 136
Source File: BasicDataSourceSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import com.pingcap.tikv.exception.TiBatchWriteException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class BasicDataSourceSuite extends BaseDataSourceTest("test_datasource_basic") {
  private val row1 = Row(null, "Hello")
  private val row2 = Row(2, "TiDB")
  private val row3 = Row(3, "Spark")
  private val row4 = Row(4, null)

  private val schema = StructType(
    List(StructField("i", IntegerType), StructField("s", StringType)))

    jdbcUpdate(s"create table $dbtable(i int, s varchar(128))")
    jdbcUpdate(s"insert into $dbtable values(null, 'Hello'), (2, 'TiDB')")

  test("Test Select") {
    if (!supportBatchWrite) {

    testTiDBSelect(Seq(row1, row2))

  test("Test Write Append") {
    if (!supportBatchWrite) {

    val data: RDD[Row] = sc.makeRDD(List(row3, row4))
    val df = sqlContext.createDataFrame(data, schema)

      .option("database", database)
      .option("table", table)

    testTiDBSelect(Seq(row1, row2, row3, row4))

  test("Test Write Overwrite") {
    if (!supportBatchWrite) {

    val data: RDD[Row] = sc.makeRDD(List(row3, row4))
    val df = sqlContext.createDataFrame(data, schema)

    val caught = intercept[TiBatchWriteException] {
        .option("database", database)
        .option("table", table)

        .equals("SaveMode: Overwrite is not supported. TiSpark only support SaveMode.Append."))

    try {
    } finally {
Example 137
Source File: UpperCaseColumnNameSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class UpperCaseColumnNameSuite
    extends BaseDataSourceTest("test_datasource_uppser_case_column_name") {

  private val row1 = Row(1, 2)

  private val schema = StructType(
    List(StructField("O_ORDERKEY", IntegerType), StructField("O_CUSTKEY", IntegerType)))

  override def beforeAll(): Unit = {

                  |CREATE TABLE $dbtable (O_ORDERKEY INTEGER NOT NULL,
                  |                       O_CUSTKEY INTEGER NOT NULL);

  test("Test insert upper case column name") {
    if (!supportBatchWrite) {

    val data: RDD[Row] = sc.makeRDD(List(row1))
    val df = sqlContext.createDataFrame(data, schema)
      .option("database", database)
      .option("table", table)

  override def afterAll(): Unit =
    } finally {
Example 138
Source File: CheckUnsupportedSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import com.pingcap.tikv.exception.TiBatchWriteException
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class CheckUnsupportedSuite extends BaseDataSourceTest("test_datasource_check_unsupported") {

  override def beforeAll(): Unit =

  test("Test write to partition table") {
    if (!supportBatchWrite) {


    tidbStmt.execute("set @@tidb_enable_table_partition = 1")

      s"create table $dbtable(i int, s varchar(128)) partition by range(i) (partition p0 values less than maxvalue)")
    jdbcUpdate(s"insert into $dbtable values(null, 'Hello')")

    val row1 = Row(null, "Hello")
    val row2 = Row(2, "TiDB")
    val row3 = Row(3, "Spark")

    val schema = StructType(List(StructField("i", IntegerType), StructField("s", StringType)))

      val caught = intercept[TiBatchWriteException] {
        tidbWrite(List(row2, row3), schema)
  test("Check Virtual Generated Column") {
    if (!supportBatchWrite) {

    jdbcUpdate(s"create table $dbtable(i INT, c1 INT, c2 INT,  c3 INT AS (c1 + c2))")

    val row1 = Row(1, 2, 3)
    val schema = StructType(
        StructField("i", IntegerType),
        StructField("c1", IntegerType),
        StructField("c2", IntegerType)))

    val caught = intercept[TiBatchWriteException] {
      tidbWrite(List(row1), schema)
  test("Check Stored Generated Column") {
    if (!supportBatchWrite) {

    jdbcUpdate(s"create table $dbtable(i INT, c1 INT, c2 INT,  c3 INT AS (c1 + c2) STORED)")

    val row1 = Row(1, 2, 3)
    val schema = StructType(
        StructField("i", IntegerType),
        StructField("c1", IntegerType),
        StructField("c2", IntegerType)))
    val caught = intercept[TiBatchWriteException] {
      tidbWrite(List(row1), schema)
        .equals("tispark currently does not support write data to table with generated column!"))


  override def afterAll(): Unit =
    try {
Example 139
Source File: RegionSplitSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import com.pingcap.tikv.TiBatchWriteUtils
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class RegionSplitSuite extends BaseDataSourceTest("region_split_test") {
  private val row1 = Row(1)
  private val row2 = Row(2)
  private val row3 = Row(3)
  private val schema = StructType(List(StructField("a", IntegerType)))

  test("index region split test") {
    if (!supportBatchWrite) {

    // do not test this case on tidb which does not support split region
    if (!isEnableSplitRegion) {

      s"CREATE TABLE  $dbtable ( `a` int(11), unique index(a)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin")

    val options = Some(Map("enableRegionSplit" -> "true", "regionSplitNum" -> "3"))

    tidbWrite(List(row1, row2, row3), schema, options)

    val tiTableInfo =
      ti.tiSession.getCatalog.getTable(dbPrefix + database, table)
    val regionsNum = TiBatchWriteUtils
      .getRegionByIndex(ti.tiSession, tiTableInfo, tiTableInfo.getIndices.get(0))
  test("table region split test") {
    if (!supportBatchWrite) {

    // do not test this case on tidb which does not support split region
    if (!isEnableSplitRegion) {

      s"CREATE TABLE  $dbtable ( `a` int(11) DEFAULT NULL) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin")

    val options = Some(Map("enableRegionSplit" -> "true", "regionSplitNum" -> "3"))

    tidbWrite(List(row1, row2, row3), schema, options)

    val tiTableInfo =
      ti.tiSession.getCatalog.getTable(dbPrefix + database, table)
    val regionsNum = TiBatchWriteUtils.getRecordRegions(ti.tiSession, tiTableInfo).size()
Example 140
Source File: MissingParameterSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class MissingParameterSuite extends BaseDataSourceTest("test_datasource_missing_parameter") {
  private val row1 = Row(null, "Hello")

  private val schema = StructType(
    List(StructField("i", IntegerType), StructField("s", StringType)))

  test("Missing parameter: database") {
    if (!supportBatchWrite) {

    jdbcUpdate(s"create table $dbtable(i int, s varchar(128))")

    val caught = intercept[IllegalArgumentException] {
      val rows = row1 :: Nil
      val data: RDD[Row] = sc.makeRDD(rows)
      val df = sqlContext.createDataFrame(data, schema)
        .option("table", table)
        .equals("requirement failed: Option 'database' is required."))

  override def afterAll(): Unit =
    try {
Example 141
Source File: ShardRowIDBitsSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class ShardRowIDBitsSuite extends BaseDataSourceTest("test_shard_row_id_bits") {
  private val row1 = Row(1)
  private val row2 = Row(2)
  private val row3 = Row(3)
  private val schema = StructType(List(StructField("a", IntegerType)))

  test("reading and writing a table with shard_row_id_bits") {
    if (!supportBatchWrite) {

    jdbcUpdate(s"CREATE TABLE  $dbtable ( `a` int(11))  SHARD_ROW_ID_BITS = 4")

    jdbcUpdate(s"insert into $dbtable values(null)")

    tidbWrite(List(row1, row2, row3), schema)
Example 142
Source File: OnlyOnePkSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class OnlyOnePkSuite extends BaseDataSourceTest("test_datasource_only_one_pk") {
  private val row3 = Row(3)
  private val row4 = Row(4)

  private val schema = StructType(List(StructField("i", IntegerType)))

  override def beforeAll(): Unit = {

    jdbcUpdate(s"create table $dbtable(i int primary key)")

  test("Test Write Append") {
    if (!supportBatchWrite) {

    val data: RDD[Row] = sc.makeRDD(List(row3, row4))
    val df = sqlContext.createDataFrame(data, schema)

      .option("database", database)
      .option("table", table)

    testTiDBSelect(Seq(row3, row4))

  override def afterAll(): Unit =
    try {
Example 143
Source File: BatchWriteIssueSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark

import com.pingcap.tispark.datasource.BaseDataSourceTest
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class BatchWriteIssueSuite extends BaseDataSourceTest("test_batchwrite_issue") {
  override def beforeAll(): Unit = {

  test("Combine unique index with null value test") {
    doTestNullValues(s"create table $dbtable(a int, b varchar(64), CONSTRAINT ab UNIQUE (a, b))")

  test("Combine primary key with null value test") {
    doTestNullValues(s"create table $dbtable(a int, b varchar(64), PRIMARY KEY (a, b))")

  test("PK is handler with null value test") {
    doTestNullValues(s"create table $dbtable(a int, b varchar(64), PRIMARY KEY (a))")

  override def afterAll(): Unit =
    try {
  private def doTestNullValues(createTableSQL: String): Unit = {
    if (!supportBatchWrite) {
    val schema = StructType(
        StructField("a", IntegerType),
        StructField("b", StringType),
        StructField("c", StringType)))

    val options = Some(Map("replace" -> "true"))

    jdbcUpdate(s"alter table $dbtable add column to_delete int")
    jdbcUpdate(s"alter table $dbtable add column c varchar(64) default 'c33'")
    jdbcUpdate(s"alter table $dbtable drop column to_delete")
                  |insert into $dbtable values(11, 'c12', null);
                  |insert into $dbtable values(21, 'c22', null);
                  |insert into $dbtable (a, b) values(31, 'c32');
                  |insert into $dbtable values(41, 'c42', 'c43');

    assert(queryTiDBViaJDBC(s"select c from $dbtable where a=11").head.head == null)
    assert(queryTiDBViaJDBC(s"select c from $dbtable where a=21").head.head == null)
      queryTiDBViaJDBC(s"select c from $dbtable where a=31").head.head.toString.equals("c33"))
      queryTiDBViaJDBC(s"select c from $dbtable where a=41").head.head.toString.equals("c43"))

      val row1 = Row(11, "c12", "c13")
      val row3 = Row(31, "c32", null)

      tidbWrite(List(row1, row3), schema, options)

        queryTiDBViaJDBC(s"select c from $dbtable where a=11").head.head.toString.equals("c13"))
      assert(queryTiDBViaJDBC(s"select c from $dbtable where a=21").head.head == null)
      assert(queryTiDBViaJDBC(s"select c from $dbtable where a=31").head.head == null)
        queryTiDBViaJDBC(s"select c from $dbtable where a=41").head.head.toString.equals("c43"))

      val row1 = Row(11, "c12", "c213")
      val row3 = Row(31, "c32", "tt")
      tidbWrite(List(row1, row3), schema, options)
        queryTiDBViaJDBC(s"select c from $dbtable where a=11").head.head.toString.equals("c213"))
      assert(queryTiDBViaJDBC(s"select c from $dbtable where a=21").head.head == null)
        queryTiDBViaJDBC(s"select c from $dbtable where a=31").head.head.toString.equals("tt"))
Example 144
Source File: LockTimeoutSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.ttl

import com.pingcap.tikv.TTLManager
import com.pingcap.tikv.exception.GrpcException
import com.pingcap.tispark.datasource.BaseDataSourceTest
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class LockTimeoutSuite extends BaseDataSourceTest("test_lock_timeout") {
  private val row1 = Row(1, "Hello")

  private val schema = StructType(
    List(StructField("i", IntegerType), StructField("s", StringType)))

  override def beforeAll(): Unit = {
    jdbcUpdate(s"create table $dbtable(i int, s varchar(128))")

  test("Test Lock TTL Timeout") {
    if (!supportTTLUpdate) {

    val seconds = 1000
    val sleep1 = TTLManager.MANAGED_LOCK_TTL + 10 * seconds
    val sleep2 = TTLManager.MANAGED_LOCK_TTL + 15 * seconds

    val data: RDD[Row] = sc.makeRDD(List(row1))
    val df = sqlContext.createDataFrame(data, schema)

    new Thread(new Runnable {
      override def run(): Unit = {
        queryTiDBViaJDBC(s"select * from $dbtable")

    val grpcException = intercept[GrpcException] {
        .option("database", database)
        .option("table", table)
        .option("sleepAfterPrewritePrimaryKey", sleep2)

    assert(grpcException.getMessage.equals("retry is exhausted."))
    assert(grpcException.getCause.getMessage.startsWith("Txn commit primary key failed"))
        "Key exception occurred and the reason is retryable: \"Txn(Mvcc(TxnLockNotFound"))

  override def afterAll(): Unit =
    try {
Example 145
Source File: Preprocess.scala    From Scala-Machine-Learning-Projects   with MIT License 5 votes vote down vote up
package com.packt.ScalaML.BitCoin

import{ BufferedWriter, File, FileWriter }
import org.apache.spark.sql.types.{ DoubleType, IntegerType, StructField, StructType }
import org.apache.spark.sql.{ DataFrame, Row, SparkSession }
import scala.collection.mutable.ListBuffer

object Preprocess {
  //how many of first rows are omitted
    val dropFirstCount: Int = 612000

    def rollingWindow(data: DataFrame, window: Int, xFilename: String, yFilename: String): Unit = {
      var i = 0
      val xWriter = new BufferedWriter(new FileWriter(new File(xFilename)))
      val yWriter = new BufferedWriter(new FileWriter(new File(yFilename)))

      val zippedData = data.rdd.zipWithIndex().collect()
      val dataStratified = zippedData.drop(dropFirstCount) //todo slice fisrt 614K
      while (i < (dataStratified.length - window)) {
        val x = dataStratified
          .slice(i, i + window)
          .map(r => r._1.getAs[Double]("Delta")).toList
        val y = dataStratified.apply(i + window)._1.getAs[Integer]("label")
        val stringToWrite = x.mkString(",")
        xWriter.write(stringToWrite + "\n")
        yWriter.write(y + "\n")

        i += 1
  def main(args: Array[String]): Unit = {
    //todo modify these variables to match desirable files
    val priceDataFileName: String = "C:/Users/admin-karim/Desktop/bitstampUSD_1-min_data_2012-01-01_to_2017-10-20.csv/bitstampUSD_1-min_data_2012-01-01_to_2017-10-20.csv"
    val outputDataFilePath: String = "output/scala_test_x.csv"
    val outputLabelFilePath: String = "output/scala_test_y.csv"

    val spark = SparkSession
      .config("spark.sql.warehouse.dir", "E:/Exp/")
      .appName("Bitcoin Preprocessing")

    val data ="com.databricks.spark.csv").option("header", "true").load(priceDataFileName)
    println((data.count(), data.columns.size))

    val dataWithDelta = data.withColumn("Delta", data("Close") - data("Open"))

    import org.apache.spark.sql.functions._
    import spark.sqlContext.implicits._

    val dataWithLabels = dataWithDelta.withColumn("label", when($"Close" - $"Open" > 0, 1).otherwise(0))
    rollingWindow(dataWithLabels, 22, outputDataFilePath, outputLabelFilePath)    
Example 146
Source File: OptimizeHiveMetadataOnlyQuerySuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.scalatest.BeforeAndAfter

import org.apache.spark.metrics.source.HiveCatalogMetrics
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleton
    with BeforeAndAfter with SQLTestUtils {

  import spark.implicits._

  override def beforeAll(): Unit = {
    sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)")
    (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)"))

  override protected def afterAll(): Unit = {
    try {
      sql("DROP TABLE IF EXISTS metadata_only")
  test("SPARK-23877: validate metadata-only query pushes filters to metastore") {
    withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") {
      val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount

      // verify the number of matching partitions
      assert(sql("SELECT DISTINCT part FROM metadata_only WHERE part < 5").collect().length === 5)

      // verify that the partition predicate was pushed down to the metastore
      assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount === 5)

  test("SPARK-23877: filter on projected expression") {
    withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") {
      val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount

      // verify the matching partitions
      val partitions = spark.internalCreateDataFrame(Distinct(Filter(($"x" < 5).expr,
        Project(Seq(($"part" + 1).as("x").expr.asInstanceOf[NamedExpression]),
          .queryExecution.toRdd, StructType(Seq(StructField("x", IntegerType))))

      checkAnswer(partitions, Seq(1, 2, 3, 4).toDF("x"))

      // verify that the partition predicate was not pushed down to the metastore
      assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount == 11)
Example 147
Source File: SparkExecuteStatementOperationSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, NullType, StringType, StructField, StructType}

class SparkExecuteStatementOperationSuite extends SparkFunSuite {
  test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") {
    val field1 = StructField("NULL", NullType)
    val field2 = StructField("(IF(true, NULL, NULL))", NullType)
    val tableSchema = StructType(Seq(field1, field2))
    val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
    assert(columns.size() == 2)
    assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)
    assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)

  test("SPARK-20146 Comment should be preserved") {
    val field1 = StructField("column1", StringType).withComment("comment 1")
    val field2 = StructField("column2", IntegerType)
    val tableSchema = StructType(Seq(field1, field2))
    val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
    assert(columns.size() == 2)
    assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.STRING_TYPE)
    assert(columns.get(0).getComment() == "comment 1")
    assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.INT_TYPE)
Example 148
Source File: SubstituteUnresolvedOrdinals.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType

class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false

  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = {
        case ordinal @ Literal(index: Int, IntegerType) =>
        case other => other
Example 149
Source File: ResolveTableValuedFunctions.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import java.util.Locale

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{DataType, IntegerType, LongType}

      tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
          "numPartitions" -> IntegerType) {
        case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
          Range(start, end, step, Some(numPartitions))

  override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
    case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
      // The whole resolution is somewhat difficult to understand here due to too much abstractions.
      // We should probably rewrite the following at some point. Reynold was just here to improve
      // error messages and didn't have time to do a proper rewrite.
      val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
        case Some(tvf) =>

          def failAnalysis(): Nothing = {
            val argTypes =", ")
              s"""error: table-valued function ${u.functionName} with alternatives:
                 |${ => s" ($x)").mkString("\n")}
                 |cannot be applied to: ($argTypes)""".stripMargin)

          val resolved = tvf.flatMap { case (argList, resolver) =>
            argList.implicitCast(u.functionArgs) match {
              case Some(casted) =>
                try {
                } catch {
                  case e: AnalysisException =>
              case _ =>
          resolved.headOption.getOrElse {
        case _ =>
          u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")

      // If alias names assigned, add `Project` with the aliases
      if (u.outputNames.nonEmpty) {
        val outputAttrs = resolvedFunc.output
        // Checks if the number of the aliases is equal to expected one
        if (u.outputNames.size != outputAttrs.size) {
          u.failAnalysis(s"Number of given aliases does not match number of output columns. " +
            s"Function name: ${u.functionName}; number of aliases: " +
            s"${u.outputNames.size}; number of output columns: ${outputAttrs.size}.")
        val aliases = {
          case (attr, name) => Alias(attr, name)()
        Project(aliases, resolvedFunc)
Example 150
Source File: StatsEstimationTestBase.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType}

trait StatsEstimationTestBase extends SparkFunSuite {

  var originalValue: Boolean = false

  override def beforeAll(): Unit = {
    // Enable stats estimation based on CBO.
    originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)

  override def afterAll(): Unit = {
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)

  def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match {
    // For UTF8String: base + offset + numBytes
    case StringType => colStat.avgLen.getOrElse(attribute.dataType.defaultSize.toLong) + 8 + 4
    case _ => colStat.avgLen.getOrElse(attribute.dataType.defaultSize)

  def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()

case class StatsTestPlan(
    outputList: Seq[Attribute],
    rowCount: BigInt,
    attributeStats: AttributeMap[ColumnStat],
    size: Option[BigInt] = None) extends LeafNode {
  override def output: Seq[Attribute] = outputList
  override def computeStats(): Statistics = Statistics(
    // If sizeInBytes is useless in testing, we just use a fake value
    sizeInBytes = size.getOrElse(Int.MaxValue),
    rowCount = Some(rowCount),
    attributeStats = attributeStats)
Example 151
Source File: LogicalPlanSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType

class LogicalPlanSuite extends SparkFunSuite {
  private var invocationCount = 0
  private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
    case p: Project =>
      invocationCount += 1

  private val testRelation = LocalRelation()

  test("transformUp runs on operators") {
    invocationCount = 0
    val plan = Project(Nil, testRelation)
    plan transformUp function

    assert(invocationCount === 1)

    invocationCount = 0
    plan transformDown function
    assert(invocationCount === 1)

  test("transformUp runs on operators recursively") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan transformUp function

    assert(invocationCount === 2)

    invocationCount = 0
    plan transformDown function
    assert(invocationCount === 2)

  test("isStreaming") {
    val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
    val incrementalRelation = LocalRelation(
      Seq(AttributeReference("a", IntegerType, nullable = true)()), isStreaming = true)

    case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
      override def output: Seq[Attribute] = left.output ++ right.output

    require(relation.isStreaming === false)
    require(incrementalRelation.isStreaming === true)
    assert(TestBinaryRelation(relation, relation).isStreaming === false)
    assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true)
    assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
    assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)

  test("transformExpressions works with a Stream") {
    val id1 = NamedExpression.newExprId
    val id2 = NamedExpression.newExprId
    val plan = Project(Stream(
      Alias(Literal(1), "a")(exprId = id1),
      Alias(Literal(2), "b")(exprId = id2)),
    val result = plan.transformExpressions {
      case Literal(v: Int, IntegerType) if v != 1 =>
        Literal(v + 1, IntegerType)
    val expected = Project(Stream(
      Alias(Literal(1), "a")(exprId = id1),
      Alias(Literal(3), "b")(exprId = id2)),
Example 152
Source File: QueryPlanSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.types.IntegerType

class QueryPlanSuite extends SparkFunSuite {

  test("origin remains the same after mapExpressions (SPARK-23823)") {
    CurrentOrigin.setPosition(0, 0)
    val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId)
    val query = plans.DslLogicalPlan(plans.table("table")).select(column)

    val mappedQuery = query mapExpressions {
      case _: Expression => Literal(1)

    val mappedOrigin = mappedQuery.expressions.apply(0).origin
    assert(mappedOrigin == Origin.apply(Some(0), Some(0)))

Example 153
Source File: ScalaUDFSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.util.Locale

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil)
    checkEvaluation(stringUdf, "ax")

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase(Locale.ROOT),
      Literal.create(null, StringType) :: Nil,
      true :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvaluationWithUnsafeProjection(udf, null)
    assert(e2.getMessage.contains("Failed to execute user defined function"))

  test("SPARK-22695: ScalaUDF should not use global variables") {
    val ctx = new CodegenContext
    ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil).genCode(ctx)
Example 154
Source File: ExpressionEvalHelperSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}

case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
        |int some_variable = 11;
        |int ${ev.value} = 10;
  override def dataType: DataType = IntegerType
Example 155
Source File: CanonicalizeSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.Range
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class CanonicalizeSuite extends SparkFunSuite {

  test("SPARK-24276: IN expression with different order are semantically equal") {
    val range = Range(1, 1, 1, 1)
    val idAttr = range.output.head

    val in1 = In(idAttr, Seq(Literal(1), Literal(2)))
    val in2 = In(idAttr, Seq(Literal(2), Literal(1)))
    val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3)))

    assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash())
    assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash())


    val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))),
      CreateArray(Seq(Literal(2), Literal(1)))))
    val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))),
      CreateArray(Seq(Literal(1), Literal(2)))))
    val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))),
      CreateArray(Seq(Literal(3), Literal(1)))))

    assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash())
    assert(arrays1.canonicalized.semanticHash() != arrays3.canonicalized.semanticHash())


  test("SPARK-26402: accessing nested fields with different cases in case insensitive mode") {
    val expId = NamedExpression.newExprId
    val qualifier = Seq.empty[String]
    val structType = StructType(
      StructField("a", StructType(StructField("b", IntegerType, false) :: Nil), false) :: Nil)

    // GetStructField with different names are semantically equal
    val fieldA1 = GetStructField(
      AttributeReference("data1", structType, false)(expId, qualifier),
      0, Some("a1"))
    val fieldA2 = GetStructField(
      AttributeReference("data2", structType, false)(expId, qualifier),
      0, Some("a2"))

    val fieldB1 = GetStructField(
        AttributeReference("data1", structType, false)(expId, qualifier),
        0, Some("a1")),
      0, Some("b1"))
    val fieldB2 = GetStructField(
        AttributeReference("data2", structType, false)(expId, qualifier),
        0, Some("a2")),
      0, Some("b2"))
Example 156
Source File: CallMethodViaReflectionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.types.{IntegerType, StringType}

class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelper {

  import CallMethodViaReflection._

  // Get rid of the $ so we are getting the companion object's name.
  private val staticClassName = ReflectStaticClass.getClass.getName.stripSuffix("$")
  private val dynamicClassName = classOf[ReflectDynamicClass].getName

  test("findMethod via reflection for static methods") {
    assert(findMethod(staticClassName, "method1", Seq.empty).exists(_.getName == "method1"))
    assert(findMethod(staticClassName, "method2", Seq(IntegerType)).isDefined)
    assert(findMethod(staticClassName, "method3", Seq(IntegerType)).isDefined)
    assert(findMethod(staticClassName, "method4", Seq(IntegerType, StringType)).isDefined)

  test("findMethod for a JDK library") {
    assert(findMethod(classOf[java.util.UUID].getName, "randomUUID", Seq.empty).isDefined)

  test("class not found") {
    val ret = createExpr("some-random-class", "method").checkInputDataTypes()
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("not found") && errorMsg.contains("class"))

  test("method not found because name does not match") {
    val ret = createExpr(staticClassName, "notfoundmethod").checkInputDataTypes()
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("cannot find a static method"))

  test("method not found because there is no static method") {
    val ret = createExpr(dynamicClassName, "method1").checkInputDataTypes()
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("cannot find a static method"))

  test("input type checking") {
      Seq(Literal(staticClassName), Literal(1))).checkInputDataTypes().isFailure)
    assert(createExpr(staticClassName, "method1").checkInputDataTypes().isSuccess)

  test("unsupported type checking") {
    val ret = createExpr(staticClassName, "method1", new Timestamp(1)).checkInputDataTypes()
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("arguments from the third require boolean, byte, short"))

  test("invoking methods using acceptable types") {
    checkEvaluation(createExpr(staticClassName, "method1"), "m1")
    checkEvaluation(createExpr(staticClassName, "method2", 2), "m2")
    checkEvaluation(createExpr(staticClassName, "method3", 3), "m3")
    checkEvaluation(createExpr(staticClassName, "method4", 4, "four"), "m4four")

  private def createExpr(className: String, methodName: String, args: Any*) = {
      Literal.create(className, StringType) +:
      Literal.create(methodName, StringType) +:
Example 157
Source File: RandomSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, LongType}

class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("random") {
    checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001)
    checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001)

      new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001)
      new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001)

  test("SPARK-9127 codegen with long seed") {
    checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001)
    checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001)
Example 158
Source File: CodeGeneratorWithInterpretedFallbackSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.util.concurrent.ExecutionException

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator}
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType

class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase {

  object FailedCodegenProjection
      extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] {

    override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = {
      val invalidCode = new CodeAndComment("invalid code", Map.empty)
      // We assume this compilation throws an exception

    override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = {

  test("UnsafeProjection with codegen factory mode") {
    val input = Seq(BoundReference(0, IntegerType, nullable = true))
    val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString
    withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) {
      val obj = UnsafeProjection.createObject(input)

    val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString
    withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) {
      val obj = UnsafeProjection.createObject(input)

  test("fallback to the interpreter mode") {
    val input = Seq(BoundReference(0, IntegerType, nullable = true))
    val fallback = CodegenObjectFactoryMode.FALLBACK.toString
    withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallback) {
      val obj = FailedCodegenProjection.createObject(input)

  test("codegen failures in the CODEGEN_ONLY mode") {
    val errMsg = intercept[ExecutionException] {
      val input = Seq(BoundReference(0, IntegerType, nullable = true))
      val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString
      withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) {
    assert(errMsg.contains("failed to compile: org.codehaus.commons.compiler.CompileException:"))
Example 159
Source File: ResolveLambdaVariablesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.{ArrayType, IntegerType}

class ResolveLambdaVariablesSuite extends PlanTest {
  import org.apache.spark.sql.catalyst.dsl.expressions._
  import org.apache.spark.sql.catalyst.dsl.plans._

  object Analyzer extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Resolution", FixedPoint(4), ResolveLambdaVariables(conf)) :: Nil

  private val key = '
  private val values1 = 'values1.array(IntegerType)
  private val values2 = 'values2.array(ArrayType(ArrayType(IntegerType)))
  private val data = LocalRelation(Seq(key, values1, values2))
  private val lvInt = NamedLambdaVariable("x", IntegerType, nullable = true)
  private val lvHiddenInt = NamedLambdaVariable("col0", IntegerType, nullable = true)
  private val lvArray = NamedLambdaVariable("x", ArrayType(IntegerType), nullable = true)

  private def plan(e: Expression): LogicalPlan ="res"))

  private def checkExpression(e1: Expression, e2: Expression): Unit = {
    comparePlans(Analyzer.execute(plan(e1)), plan(e2))

  private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(

  test("resolution - no op") {
    checkExpression(key, key)

  test("resolution - simple") {
    val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil))
    val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil))
    checkExpression(in, out)

  test("resolution - nested") {
    val in = ArrayTransform(values2, LambdaFunction(
      ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), lv('x) :: Nil))
    val out = ArrayTransform(values2, LambdaFunction(
      ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil))
    checkExpression(in, out)

  test("resolution - hidden") {
    val in = ArrayTransform(values1, key)
    val out = ArrayTransform(values1, LambdaFunction(key, lvHiddenInt :: Nil, hidden = true))
    checkExpression(in, out)

  test("fail - name collisions") {
    val p = plan(ArrayTransform(values1,
      LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil)))
    val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
    assert(msg.contains("arguments should not have names that are semantically the same"))

  test("fail - lambda arguments") {
    val p = plan(ArrayTransform(values1,
      LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: Nil)))
    val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
    assert(msg.contains("does not match the number of arguments expected"))
Example 160
Source File: RewriteDistinctAggregatesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, '

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")

  test("single distinct group") {
    val input = testRelation
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
        countDistinct('b, 'c),
Example 161
Source File: resources.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command


import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
Example 162
Source File: SameResultSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.IntegerType

class SameResultSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("FileSourceScanExec: different orders of data filters and partition filters") {
    withTempPath { path =>
      val tmpDir = path.getCanonicalPath
        .selectExpr("id as a", "id + 1 as b", "id + 2 as c", "id + 3 as d")
        .partitionBy("a", "b")
      val df =
      // partition filters: a > 1 AND b < 9
      // data filters: c > 1 AND d < 9
      val plan1 = getFileSourceScanExec(df.where("a > 1 AND b < 9 AND c > 1 AND d < 9"))
      val plan2 = getFileSourceScanExec(df.where("b < 9 AND a > 1 AND d < 9 AND c > 1"))

  private def getFileSourceScanExec(df: DataFrame): FileSourceScanExec = {

  test("SPARK-20725: partial aggregate should behave correctly for sameResult") {
    val df1 = spark.range(10).agg(sum($"id"))
    val df2 = spark.range(10).agg(sum($"id"))

    val df3 = spark.range(10).agg(sumDistinct($"id"))
    val df4 = spark.range(10).agg(sumDistinct($"id"))

  test("Canonicalized result is case-insensitive") {
    val a = AttributeReference("A", IntegerType)()
    val b = AttributeReference("B", IntegerType)()
    val planUppercase = Project(Seq(a), LocalRelation(a, b))

    val c = AttributeReference("a", IntegerType)()
    val d = AttributeReference("b", IntegerType)()
    val planLowercase = Project(Seq(c), LocalRelation(c, d))

Example 163
Source File: GroupedIteratorSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(,
      Seq(', schema.toAttributes)

    val result = {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) ->

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(,
      Seq(', ', schema.toAttributes)

    val result = {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1),

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(,
      Seq(', schema.toAttributes)

    assert(grouped.length == 2)
Example 164
Source File: WholeStageCodegenSparkSubmitSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.scalatest.{Assertions, BeforeAndAfterEach, Matchers}
import org.scalatest.concurrent.TimeLimits

import org.apache.spark.{SparkFunSuite, TestUtils}
import org.apache.spark.deploy.SparkSubmitSuite
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{LocalSparkSession, QueryTest, Row, SparkSession}
import org.apache.spark.sql.functions.{array, col, count, lit}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.ResetSystemProperties

// Due to the need to set driver's extraJavaOptions, this test needs to use actual SparkSubmit.
class WholeStageCodegenSparkSubmitSuite extends SparkFunSuite
  with Matchers
  with BeforeAndAfterEach
  with ResetSystemProperties {

  test("Generated code on driver should not embed platform-specific constant") {
    val unusedJar = TestUtils.createJarWithClasses(Seq.empty)

    // HotSpot JVM specific: Set up a local cluster with the driver/executor using mismatched
    // settings of UseCompressedOops JVM option.
    val argsForSparkSubmit = Seq(
      "--class", WholeStageCodegenSparkSubmitSuite.getClass.getName.stripSuffix("$"),
      "--master", "local-cluster[1,1,1024]",
      "--driver-memory", "1g",
      "--conf", "spark.ui.enabled=false",
      "--conf", "",
      "--conf", "spark.driver.extraJavaOptions=-XX:-UseCompressedOops",
      "--conf", "spark.executor.extraJavaOptions=-XX:+UseCompressedOops",
    SparkSubmitSuite.runSparkSubmit(argsForSparkSubmit, "../..")

object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging {

  var spark: SparkSession = _

  def main(args: Array[String]): Unit = {

    spark = SparkSession.builder().getOrCreate()

    // Make sure the test is run where the driver and the executors uses different object layouts
    val driverArrayHeaderSize = Platform.BYTE_ARRAY_OFFSET
    val executorArrayHeaderSize =
      spark.sparkContext.range(0, 1).map(_ => Platform.BYTE_ARRAY_OFFSET).collect.head.toInt
    assert(driverArrayHeaderSize > executorArrayHeaderSize)

    val df = spark.range(71773).select((col("id") % lit(10)).cast(IntegerType) as "v")
    val plan = df.queryExecution.executedPlan

    val expectedAnswer =
      Row(Array(0), 7178) ::
        Row(Array(1), 7178) ::
        Row(Array(2), 7178) ::
        Row(Array(3), 7177) ::
        Row(Array(4), 7177) ::
        Row(Array(5), 7177) ::
        Row(Array(6), 7177) ::
        Row(Array(7), 7177) ::
        Row(Array(8), 7177) ::
        Row(Array(9), 7177) :: Nil
    val result = df.collect
    QueryTest.sameRows(result.toSeq, expectedAnswer) match {
      case Some(errMsg) => fail(errMsg)
      case _ =>
Example 165
Source File: BlockingSource.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
      override def stop() {}

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}

object BlockingSource {
  var latch: CountDownLatch = null
Example 166
Source File: MockSourceProvider.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class MockSourceProvider extends StreamSourceProvider {
  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", MockSourceProvider.fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {

object MockSourceProvider {
  // Function to generate sources. May provide multiple sources if the user implements such a
  // function.
  private var sourceProviderFunction: () => Source = _

  final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  def withMockSources(source: Source, otherSources: Source*)(f: => Unit): Unit = {
    var i = 0
    val sources = source +: otherSources
    sourceProviderFunction = () => {
      val source = sources(i % sources.length)
      i += 1
    try {
    } finally {
      sourceProviderFunction = null
Example 167
Source File: dependenciesSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableDependencyCalculator
import org.apache.spark.sql.sources.{RelationKind, Table}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext}

object DependenciesSystemTableProvider extends SystemTableProvider with LocalSpark {
  override def execute(): Seq[Row] = {
    val tables = getTables(sqlContext.catalog)
    val dependentsMap = buildDependentsMap(tables)

    def kindOf(tableIdentifier: TableIdentifier): String =
        .map(plan => RelationKind.kindOf(plan).getOrElse(Table).name)

    dependentsMap.flatMap {
      case (tableIdent, dependents) =>
        val curKind = kindOf(tableIdent) { dependent =>
          val dependentKind = kindOf(dependent)

  override val schema: StructType = DependenciesSystemTable.schema

object DependenciesSystemTable extends SchemaEnumeration {
  val baseSchemaName = Field("BASE_SCHEMA_NAME", StringType, nullable = true)
  val baseObjectName = Field("BASE_OBJECT_NAME", StringType, nullable = false)
  val baseObjectType = Field("BASE_OBJECT_TYPE", StringType, nullable = false)
  val dependentSchemaName = Field("DEPENDENT_SCHEMA_NAME", StringType, nullable = true)
  val dependentObjectName = Field("DEPENDENT_OBJECT_NAME", StringType, nullable = false)
  val dependentObjectType = Field("DEPENDENT_OBJECT_TYPE", StringType, nullable = false)
  val dependencyType = Field("DEPENDENCY_TYPE", IntegerType, nullable = false)

  private[DependenciesSystemTable] val UnknownType = "UNKNOWN"
Example 168
Source File: partitionFunctionSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables

import org.apache.spark.sql.execution.tablefunctions.OutputFormatter
import org.apache.spark.sql.sources._
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.util.GenericUtil._

  private def typeNameOf(f: PartitionFunction): String = f match {
    case _: RangePartitionFunction => "RANGE"
    case _: BlockPartitionFunction => "BLOCK"
    case _: HashPartitionFunction => "HASH"

object PartitionFunctionSystemTable extends SchemaEnumeration {
  val id = Field("ID", StringType, nullable = false)
  val functionType = Field("TYPE", StringType, nullable = false)
  val columnName = Field("COLUMN_NAME", StringType, nullable = false)
  val columnType = Field("COLUMN_TYPE", StringType, nullable = false)
  val boundaries = Field("BOUNDARIES", StringType, nullable = true)
  val block = Field("BLOCK_SIZE", IntegerType, nullable = true)
  val partitions = Field("PARTITIONS", IntegerType, nullable = true)
  val minP = Field("MIN_PARTITIONS", IntegerType, nullable = true)
  val maxP = Field("MAX_PARTITIONS", IntegerType, nullable = true)
Example 169
Source File: HierarchyJoinBuilderUnitTests.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hierarchy

import org.apache.spark.sql.types.{IntegerType, Node}
import org.apache.spark.Logging
import org.apache.spark.sql.Row

// scalastyle:off magic.number
class HierarchyJoinBuilderUnitTests extends NodeUnitTestSpec with Logging {
  var jb = new HierarchyJoinBuilder[Row, Row, Long](null, null, null, null, null, null)"Testing function 'extractNodeFromRow'\n")

  val x = Node(List(1,2,3), IntegerType, List(1L,1L,2L))
  Some(x) should equal {
    jb.extractNodeFromRow(Row.fromSeq(Seq(1,2,3, x)))

  None should equal {

  None should equal {
  }"Testing function 'getOrd'\n")
   None should equal {
  val testValues = List((42L, Some(42L)), (13, Some(13L)), ("hello", None), (1234.56, None))
    testVal => {
      val jbWithOrd = new HierarchyJoinBuilder[Row, Row, Long](null, null, null, null,
        x => testVal._1
        , null)
    testVal._2 should equal {
// scalastyle:on magic.number 
Example 170
Source File: TimestampExpressionSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp

import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.{DateType, IntegerType}
import org.scalatest.FunSuite

class TimestampExpressionSuite extends FunSuite with ExpressionEvalHelper {

  test("add_seconds") {
    // scalastyle:off magic.number
    checkEvaluation(AddSeconds(Literal(Timestamp.valueOf("2015-01-01 00:11:33")), Literal(28)),
      DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-01-01 00:12:01")))
    checkEvaluation(AddSeconds(Literal(Timestamp.valueOf("2015-01-02 00:00:00")), Literal(-1)),
      DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-01-01 23:59:59")))
    checkEvaluation(AddSeconds(Literal(Timestamp.valueOf("2015-01-01 00:00:00")), Literal(-1)),
      DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2014-12-31 23:59:59")))
    checkEvaluation(AddSeconds(Literal(Timestamp.valueOf("2015-01-02 00:00:00")),
      Literal.create(null, IntegerType)), null)
    checkEvaluation(AddSeconds(Literal.create(null, DateType), Literal(1)), null)
    checkEvaluation(AddSeconds(Literal.create(null, DateType), Literal.create(null, IntegerType)),
Example 171
Source File: ResolveCountDistinctStarSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalatest.FunSuite
import org.scalatest.Inside._
import org.scalatest.mock.MockitoSugar
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}

import scala.collection.mutable.ArrayBuffer

class ResolveCountDistinctStarSuite extends FunSuite with MockitoSugar {
  val persons = new LogicalRelation(new BaseRelation {
    override def sqlContext: SQLContext = mock[SQLContext]
    override def schema: StructType = StructType(Seq(
      StructField("age", IntegerType),
      StructField("name", StringType)

  test("Count distinct star is resolved correctly") {
    val projection =
      AggregateExpression(Count(UnresolvedStar(None) :: Nil), Complete, true)))
    val stillNotCompletelyResolvedAggregate = SimpleAnalyzer.execute(projection)
    val resolvedAggregate = ResolveCountDistinctStar(SimpleAnalyzer)
    inside(resolvedAggregate) {
      case Aggregate(Nil,
      ArrayBuffer(Alias(AggregateExpression(Count(expressions), Complete, true), _)), _) =>
        assert(expressions.collect {
          case a:AttributeReference =>
        }.toSet == Set("name", "age"))
Example 172
Source File: HttpStreamServerClientTest.scala    From spark-http-stream   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
import org.apache.spark.SparkConf
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.streaming.http.HttpStreamClient
import org.junit.Assert
import org.junit.Test
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.types.FloatType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.ByteType
import org.apache.spark.sql.execution.streaming.http.HttpStreamServer
import org.apache.spark.sql.execution.streaming.http.StreamPrinter
import org.apache.spark.sql.execution.streaming.http.HttpStreamServerSideException

class HttpStreamServerClientTest {
	val ROWS1 = Array(Row("hello1", 1, true, 0.1f, 0.1d, 1L, '1'.toByte),
		Row("hello2", 2, false, 0.2f, 0.2d, 2L, '2'.toByte),
		Row("hello3", 3, true, 0.3f, 0.3d, 3L, '3'.toByte));

	val ROWS2 = Array(Row("hello"),

	def testHttpStreamIO() {
		//starts a http server
		val kryoSerializer = new KryoSerializer(new SparkConf());
		val server = HttpStreamServer.start("/xxxx", 8080);

		val spark = SparkSession.builder.appName("testHttpTextSink").master("local[4]")
		spark.conf.set("spark.sql.streaming.checkpointLocation", "/tmp/");

		val sqlContext = spark.sqlContext;
		import spark.implicits._
		//add a local message buffer to server, with 2 topics registered
			.addListener(new StreamPrinter())
			.createTopic[(String, Int, Boolean, Float, Double, Long, Byte)]("topic-1")

		val client = HttpStreamClient.connect("http://localhost:8080/xxxx");
		//tests schema of topics
		val schema1 = client.fetchSchema("topic-1");
		Assert.assertArrayEquals(Array[Object](StringType, IntegerType, BooleanType, FloatType, DoubleType, LongType, ByteType),[Array[Object]]);

		val schema2 = client.fetchSchema("topic-2");

		//prepare to consume messages
		val sid1 = client.subscribe("topic-1")._1;
		val sid2 = client.subscribe("topic-2")._1;

		//produces some data
		client.sendRows("topic-1", 1, ROWS1);

		val sid4 = client.subscribe("topic-1")._1;
		val sid5 = client.subscribe("topic-2")._1;

		client.sendRows("topic-2", 1, ROWS2);

		//consumes data
		val fetched = client.fetchStream(sid1).map(_.originalRow);
		Assert.assertArrayEquals(ROWS1.asInstanceOf[Array[Object]], fetched.asInstanceOf[Array[Object]]);
		//it is empty now
		Assert.assertArrayEquals(Array[Object](), client.fetchStream(sid1).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(ROWS2.asInstanceOf[Array[Object]], client.fetchStream(sid2).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(Array[Object](), client.fetchStream(sid4).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(ROWS2.asInstanceOf[Array[Object]], client.fetchStream(sid5).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(Array[Object](), client.fetchStream(sid5).map(_.originalRow).asInstanceOf[Array[Object]]);

		try {
			//exception should be thrown, because subscriber id is invalidated
		catch {
			case e: Throwable ⇒
				Assert.assertEquals(classOf[HttpStreamServerSideException], e.getClass);

Example 173
Source File: MultinomialLogisticRegressionParitySpec.scala    From mleap   with Apache License 2.0 5 votes vote down vote up

import{Matrices, Vectors}
import{Pipeline, Transformer}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}

class MultinomialLogisticRegressionParitySpec extends SparkParityBase {

  val labels = Seq(0.0, 1.0, 2.0, 0.0, 1.0, 2.0)
  val ages = Seq(15, 30, 40, 50, 15, 80)
  val heights = Seq(175, 190, 155, 160, 170, 180)
  val weights = Seq(67, 100, 57, 56, 56, 88)

  val rows = spark.sparkContext.parallelize(Seq.tabulate(6) { i => Row(labels(i), ages(i), heights(i), weights(i)) })
  val schema = new StructType().add("label", DoubleType, nullable = false)
    .add("age", IntegerType, nullable = false)
    .add("height", IntegerType, nullable = false)
    .add("weight", IntegerType, nullable = false)

  override val dataset: DataFrame = spark.sqlContext.createDataFrame(rows, schema)

  override val sparkTransformer: Transformer = new Pipeline().setStages(Array(
    new VectorAssembler().
      setInputCols(Array("age", "height", "weight")).
    new LogisticRegressionModel(uid = "logr", 
      coefficientMatrix = Matrices.dense(3, 3, Array(-1.3920551604166562, -0.13119545493644366, 1.5232506153530998, 0.3129112131192873, -0.21959056436528473, -0.09332064875400257, -0.24696506013528507, 0.6122879917796569, -0.36532293164437174)),
      interceptVector = Vectors.dense(0.4965574044951358, -2.1486146169780063, 1.6520572124828703),
      numClasses = 3, isMultinomial = true))).fit(dataset)
Example 174
Source File: WrappersSpec.scala    From sparksql-scalapb   with Apache License 2.0 5 votes vote down vote up
package scalapb.spark

import com.example.protos.wrappers._
import org.apache.spark.sql.SparkSession
import scalapb.GeneratedMessageCompanion
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.Row

import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers

class WrappersSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll {
  val spark: SparkSession = SparkSession
    .appName("ScalaPB Demo")

  import spark.implicits.StringToColumn

  val data = Seq(
      intValue = Option(45),
      stringValue = Option("boo"),
      ints = Seq(17, 19, 25),
      strings = Seq("foo", "bar")
      intValue = None,
      stringValue = None,
      ints = Seq(17, 19, 25),
      strings = Seq("foo", "bar")

  "converting df with primitive wrappers" should "work with primitive implicits" in {
    import ProtoSQL.withPrimitiveWrappers.implicits._
    val df = ProtoSQL.withPrimitiveWrappers.createDataFrame(spark, data) must be(
        ArrayType(IntegerType, false),
        ArrayType(StringType, false)
    df.collect must contain theSameElementsAs (
        Row(45, "boo", Seq(17, 19, 25), Seq("foo", "bar")),
        Row(null, null, Seq(17, 19, 25), Seq("foo", "bar"))

  "converting df with primitive wrappers" should "work with default implicits" in {
    import ProtoSQL.implicits._
    val df = ProtoSQL.createDataFrame(spark, data) must be(
        StructType(Seq(StructField("value", IntegerType, true))),
        StructType(Seq(StructField("value", StringType, true))),
          StructType(Seq(StructField("value", IntegerType, true))),
          StructType(Seq(StructField("value", StringType, true))),
    df.collect must contain theSameElementsAs (
          Seq(Row(17), Row(19), Row(25)),
          Seq(Row("foo"), Row("bar"))
          Seq(Row(17), Row(19), Row(25)),
          Seq(Row("foo"), Row("bar"))
Example 175
Source File: KustoSourceTests.scala    From azure-kusto-spark   with Apache License 2.0 5 votes vote down vote up

import{KustoDataSourceUtils => KDSU}
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.junit.runner.RunWith
import org.scalamock.scalatest.MockFactory
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}

class KustoSourceTests extends FlatSpec with MockFactory with Matchers with BeforeAndAfterAll {
  private val loggingLevel: Option[String] = Option(System.getProperty("logLevel"))
  if (loggingLevel.isDefined) KDSU.setLoggingLevel(loggingLevel.get)

  private val nofExecutors = 4
  private val spark: SparkSession = SparkSession.builder()

  private var sc: SparkContext = _
  private var sqlContext: SQLContext = _
  private val cluster: String = "KustoCluster"
  private val database: String = "KustoDatabase"
  private val query: String = "KustoTable"
  private val appId: String = "KustoSinkTestApplication"
  private val appKey: String = "KustoSinkTestKey"
  private val appAuthorityId: String = "KustoSinkAuthorityId"

  override def beforeAll(): Unit = {

    sc = spark.sparkContext
    sqlContext = spark.sqlContext

  override def afterAll(): Unit = {


  "KustoDataSource" should "recognize Kusto and get the correct schema" in {
    val spark: SparkSession = SparkSession.builder()

    val customSchema = "colA STRING, colB INT"

    val df = spark.sqlContext
      .option(KustoSourceOptions.KUSTO_CLUSTER, cluster)
      .option(KustoSourceOptions.KUSTO_DATABASE, database)
      .option(KustoSourceOptions.KUSTO_QUERY, query)
      .option(KustoSourceOptions.KUSTO_AAD_APP_ID, appId)
      .option(KustoSourceOptions.KUSTO_AAD_APP_SECRET, appKey)
      .option(KustoSourceOptions.KUSTO_AAD_AUTHORITY_ID, appAuthorityId)
      .option(KustoSourceOptions.KUSTO_CUSTOM_DATAFRAME_COLUMN_TYPES, customSchema)

    val expected = StructType(Array(StructField("colA", StringType, nullable = true),StructField("colB", IntegerType, nullable = true)))
Example 176
Source File: ShortestPaths.scala    From graphframes   with Apache License 2.0 5 votes vote down vote up
package org.graphframes.lib

import java.util

import scala.collection.JavaConverters._

import org.apache.spark.graphx.{lib => graphxlib}
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, MapType}

import org.graphframes.GraphFrame

  def landmarks(value: util.ArrayList[Any]): this.type = {

  def run(): DataFrame = {, check(lmarks, "landmarks"))

private object ShortestPaths {

  private def run(graph: GraphFrame, landmarks: Seq[Any]): DataFrame = {
    val idType = graph.vertices.schema(GraphFrame.ID).dataType
    val longIdToLandmark = => GraphXConversions.integralId(graph, l) -> l).toMap
    val gx =
      longIdToLandmark.keys.toSeq.sorted).mapVertices { case (_, m) => m.toSeq }
    val g = GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(DISTANCE_ID))
    val distanceCol: Column = if (graph.hasIntegralIdType) {
      // It seems there are no easy way to convert a sequence of pairs into a map
      val mapToLandmark = udf { distances: Seq[Row] => { case Row(k: Long, v: Int) =>
          k -> v
    } else {
      val func = new UDF1[Seq[Row], Map[Any, Int]] {
        override def call(t1: Seq[Row]): Map[Any, Int] = {
 { case Row(k: Long, v: Int) =>
              longIdToLandmark(k) -> v
      val mapToLandmark = udf(func, MapType(idType, IntegerType, false))
    val cols = :+ _*)

  private val DISTANCE_ID = "distances"

Example 177
Source File: LastValueOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.lastValue

import java.util.Date

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

class LastValueOperatorTest extends WordSpec with Matchers {

  "LastValue operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)

    "processMap must be " in {
      val inputField = new LastValueOperator("lastValue", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new LastValueOperator("lastValue", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
      inputFields6.processMap(Row(1, 2)) should be(None)

    "processReduce must be " in {
      val inputFields = new LastValueOperator("lastValue", initSchema, Map())
      inputFields.processReduce(Seq()) should be(None)

      val inputFields2 = new LastValueOperator("lastValue", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2))) should be(Some(2))

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some("b"))

    "associative process must be " in {
      val inputFields = new LastValueOperator("lastValue", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(1L))

      val inputFields2 = new LastValueOperator("lastValue", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)))
      inputFields2.associativity(resultInput2) should be(Some(1))

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(1)),
        (Operator.NewValuesKey, Some(2)))
      inputFields3.associativity(resultInput3) should be(Some(2))

      val inputFields4 = new LastValueOperator("lastValue", initSchema, Map())
      val resultInput4 = Seq()
      inputFields4.associativity(resultInput4) should be(None)

      val inputFields5 = new LastValueOperator("lastValue", initSchema, Map())
      val date = new Date()
      val resultInput5 = Seq((Operator.NewValuesKey, Some(date)))
      inputFields5.associativity(resultInput5) should be(Some(date))
Example 178
Source File: StddevOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.stddev

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

class StddevOperatorTest extends WordSpec with Matchers {

  "Std dev operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)

    "processMap must be " in {
      val inputField = new StddevOperator("stdev", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new StddevOperator("stdev", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
      inputFields10.processMap(Row(1, 2)) should be(None)

    "processReduce must be " in {
      val inputFields = new StddevOperator("stdev", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new StddevOperator("stdev", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be

      val inputFields3 = new StddevOperator("stdev", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be

      val inputFields4 = new StddevOperator("stdev", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new StddevOperator("stdev", initSchema, Map("typeOp" -> "string"))
        Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some("2.850438562747845"))

    "processReduce distinct must be " in {
      val inputFields = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be

      val inputFields3 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be

      val inputFields4 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new StddevOperator("stdev", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
        Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some("2.850438562747845"))
Example 179
Source File: MedianOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.median

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

class MedianOperatorTest extends WordSpec with Matchers {

  "Median operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)

    "processMap must be " in {
      val inputField = new MedianOperator("median", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MedianOperator("median", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
      inputFields10.processMap(Row(1, 2)) should be(None)

    "processReduce must be " in {
      val inputFields = new MedianOperator("median", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MedianOperator("median", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(3d))

      val inputFields3 = new MedianOperator("median", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(3))

      val inputFields4 = new MedianOperator("median", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MedianOperator("median", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("3.0"))

    "processReduce distinct must be " in {
      val inputFields = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(2.5))

      val inputFields3 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(3))

      val inputFields4 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MedianOperator("median", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("2.5"))
Example 180
Source File: ModeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mode

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

class ModeOperatorTest extends WordSpec with Matchers {

  "Mode operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)

    "processMap must be " in {
      val inputField = new ModeOperator("mode", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new ModeOperator("mode", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new ModeOperator("mode", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
      inputFields6.processMap(Row(1, 2)) should be(None)

    "processReduce must be " in {
      val inputFields = new ModeOperator("mode", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new ModeOperator("mode", initSchema, Map())
      inputFields2.processReduce(Seq(Some("hey"), Some("hey"), Some("hi"))) should be(Some(List("hey")))

      val inputFields3 = new ModeOperator("mode", initSchema, Map())
      inputFields3.processReduce(Seq(Some("1"), Some("1"), Some("4"))) should be(Some(List("1")))

      val inputFields4 = new ModeOperator("mode", initSchema, Map())
        Some("1"), Some("1"), Some("4"), Some("4"), Some("4"), Some("4"))) should be(Some(List("4")))

      val inputFields5 = new ModeOperator("mode", initSchema, Map())
        Some("1"), Some("1"), Some("2"), Some("2"), Some("4"), Some("4"))) should be(Some(List("1", "2", "4")))

      val inputFields6 = new ModeOperator("mode", initSchema, Map())
        Some("1"), Some("1"), Some("2"), Some("2"), Some("4"), Some("4"), Some("5"))
      ) should be(Some(List("1", "2", "4")))
Example 181
Source File: RangeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.range

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

class RangeOperatorTest extends WordSpec with Matchers {

  "Range operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)

    "processMap must be " in {
      val inputField = new RangeOperator("range", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new RangeOperator("range", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
      inputFields10.processMap(Row(1, 2)) should be(None)

    "processReduce must be " in {
      val inputFields = new RangeOperator("range", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new RangeOperator("range", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(0))

      val inputFields3 = new RangeOperator("range", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(4))) should be(Some(3))

      val inputFields4 = new RangeOperator("range", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new RangeOperator("range", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("6.0"))

    "processReduce distinct must be " in {
      val inputFields = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(0))

      val inputFields3 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(4))) should be(Some(3))

      val inputFields4 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new RangeOperator("range", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("6.0"))
Example 182
Source File: AccumulatorOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.accumulator

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

class AccumulatorOperatorTest extends WordSpec with Matchers {

  "Accumulator operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false)

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)

    "processMap must be " in {
      val inputField = new AccumulatorOperator("accumulator", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new AccumulatorOperator("accumulator", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":2}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":2}]"
      inputFields6.processMap(Row(1, 2)) should be(None)

    "processReduce must be " in {
      val inputFields = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(Seq()))

      val inputFields2 = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(Seq("1", "1")))

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some(Seq("a", "b")))

    "associative process must be " in {
      val inputFields = new AccumulatorOperator("accumulator", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(Seq(1L))),
        (Operator.NewValuesKey, Some(Seq(2L))),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Seq("1", "2")))

      val inputFields2 = new AccumulatorOperator("accumulator", initSchema, Map("typeOp" -> "arraydouble"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Seq(1))),
        (Operator.NewValuesKey, Some(Seq(3))))
      inputFields2.associativity(resultInput2) should be(Some(Seq(1d, 3d)))

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(Seq(1))),
        (Operator.NewValuesKey, Some(Seq(1))))
      inputFields3.associativity(resultInput3) should be(Some(Seq("1", "1")))
Example 183
Source File: FirstValueOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.firstValue

import java.util.Date

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

class FirstValueOperatorTest extends WordSpec with Matchers {

  "FirstValue operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)

    "processMap must be " in {
      val inputField = new FirstValueOperator("firstValue", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new FirstValueOperator("firstValue", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
      inputFields6.processMap(Row(1, 2)) should be(None)

    "processReduce must be " in {
      val inputFields = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields.processReduce(Seq()) should be(None)

      val inputFields2 = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2))) should be(Some(1))

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some("a"))

    "associative process must be " in {
      val inputFields = new FirstValueOperator("firstValue", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(1L))

      val inputFields2 = new FirstValueOperator("firstValue", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)))
      inputFields2.associativity(resultInput2) should be(Some(1))

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(1)),
        (Operator.NewValuesKey, Some(1)),
        (Operator.NewValuesKey, None))
      inputFields3.associativity(resultInput3) should be(Some(1))

      val inputFields4 = new FirstValueOperator("firstValue", initSchema, Map())
      val resultInput4 = Seq()
      inputFields4.associativity(resultInput4) should be(None)

      val inputFields5 = new FirstValueOperator("firstValue", initSchema, Map())
      val date = new Date()
      val resultInput5 = Seq((Operator.NewValuesKey, Some(date)))
      inputFields5.associativity(resultInput5) should be(Some(date))
Example 184
Source File: MeanAssociativeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mean

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

class MeanAssociativeOperatorTest extends WordSpec with Matchers {

  "Mean operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)

    "processMap must be " in {
      val inputField = new MeanAssociativeOperator("avg", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MeanAssociativeOperator("avg", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
      inputFields10.processMap(Row(1, 2)) should be(None)

    "processReduce must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be (Some(List(1.0, 1.0)))

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), None)) should be(Some(List(1.0, 2.0, 3.0)))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(List()))

    "processReduce distinct must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(List(1.0)))

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(3), Some(1), None)) should be(Some(List(1.0, 3.0)))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(List()))

    "associative process must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map())
      val resultInput =
        Seq((Operator.OldValuesKey, Some(Map("count" -> 1d, "sum" -> 2d, "mean" -> 2d))), (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Map("count" -> 1.0, "sum" -> 2.0, "mean" -> 2.0)))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map())
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Map("count" -> 1d, "sum" -> 2d, "mean" -> 2d))),
        (Operator.NewValuesKey, Some(Seq(1d))))
      inputFields2.associativity(resultInput2) should be(Some(Map("sum" -> 3.0, "count" -> 2.0, "mean" -> 1.5)))
Example 185
Source File: MeanOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mean

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

class MeanOperatorTest extends WordSpec with Matchers {

  "Mean operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)

    "processMap must be " in {
      val inputField = new MeanOperator("avg", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MeanOperator("avg", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
      inputFields10.processMap(Row(1, 2)) should be(None)

    "processReduce must be " in {
      val inputFields = new MeanOperator("avg", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MeanOperator("avg", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(1))

      val inputFields3 = new MeanOperator("avg", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), None)) should be(Some(2))

      val inputFields4 = new MeanOperator("avg", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MeanOperator("avg", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(1))) should be(Some("1.0"))

    "processReduce distinct must be " in {
      val inputFields = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(1))

      val inputFields3 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(3), Some(1), None)) should be(Some(2))

      val inputFields4 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MeanOperator("avg", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(1))) should be(Some("1.0"))
Example 186
Source File: EntityCountOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.entityCount

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

class EntityCountOperatorTest extends WordSpec with Matchers {

  "Entity Count Operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)

    "processMap must be " in {
      val inputField = new EntityCountOperator("entityCount", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new EntityCountOperator("entityCount", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields4 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields4.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields5 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> "-"))
      inputFields5.processMap(Row("hola-holo", 2)) should be(Some(Seq("hola", "holo")))

      val inputFields6 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields6.processMap(Row("hola,holo adios", 2)) should be(Some(Seq("hola", "holo " + "adios")))

      val inputFields7 = new EntityCountOperator("entityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]"))
      inputFields7.processMap(Row("hola", 2)) should be(None)

      val inputFields8 = new EntityCountOperator("entityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]",
          "split" -> " "))
      inputFields8.processMap(Row("hola holo", 2)) should be(Some(Seq("hola", "holo")))

    "processReduce must be " in {
      val inputFields = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(Seq()))

      val inputFields2 = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields2.processReduce(Seq(Some(Seq("hola", "holo")))) should be(Some(Seq("hola", "holo")))

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields3.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(Seq("hola", "holo", "hola")))

    "associative process must be " in {
      val inputFields = new EntityCountOperator("entityCount", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Map("hola" -> 1L, "holo" -> 1L)))

      val inputFields2 = new EntityCountOperator("entityCount", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))),
        (Operator.NewValuesKey, Some(Seq("hola"))))
      inputFields2.associativity(resultInput2) should be(Some(Map()))

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))))
      inputFields3.associativity(resultInput3) should be(Some(Map("hola" -> 1L, "holo" -> 1L)))
Example 187
Source File: SumOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.sum

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

class SumOperatorTest extends WordSpec with Matchers {

  "Sum operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)

    "processMap must be " in {
      val inputField = new SumOperator("sum", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new SumOperator("sum", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new SumOperator("sum", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new SumOperator("sum", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new SumOperator("sum", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
      inputFields10.processMap(Row(1, 2)) should be(None)

    "processReduce must be " in {
      val inputFields = new SumOperator("sum", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new SumOperator("sum", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(20d))

      val inputFields3 = new SumOperator("sum", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(20d))

      val inputFields4 = new SumOperator("sum", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

    "processReduce distinct must be " in {
      val inputFields = new SumOperator("sum", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new SumOperator("sum", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(1))) should be(Some(3d))

    "associative process must be " in {
      val inputFields = new SumOperator("count", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(2d))

      val inputFields2 = new SumOperator("count", initSchema, Map("typeOp" -> "string"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields2.associativity(resultInput2) should be(Some("2.0"))
Example 188
Source File: FullTextOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.fullText

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

class FullTextOperatorTest extends WordSpec with Matchers {

  "FullText operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)

    "processMap must be " in {
      val inputField = new FullTextOperator("fullText", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new FullTextOperator("fullText", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new FullTextOperator("fullText", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new FullTextOperator("fullText", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new FullTextOperator("fullText", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new FullTextOperator("fullText", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
      inputFields6.processMap(Row(1, 2)) should be(None)

    "processReduce must be " in {
      val inputFields = new FullTextOperator("fullText", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(""))

      val inputFields2 = new FullTextOperator("fullText", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(s"1${Operator.SpaceSeparator}1"))

      val inputFields3 = new FullTextOperator("fullText", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some(s"a${Operator.SpaceSeparator}b"))

    "associative process must be " in {
      val inputFields = new FullTextOperator("fullText", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(2)), (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some("2"))

      val inputFields2 = new FullTextOperator("fullText", initSchema, Map("typeOp" -> "arraystring"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(2)),
        (Operator.NewValuesKey, Some(1)))
      inputFields2.associativity(resultInput2) should be(Some(Seq(s"2${Operator.SpaceSeparator}1")))

      val inputFields3 = new FullTextOperator("fullText", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(2)), (Operator.OldValuesKey, Some(3)))
      inputFields3.associativity(resultInput3) should be(Some(s"2${Operator.SpaceSeparator}3"))
Example 189
Source File: TotalEntityCountOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.totalEntityCount

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

class TotalEntityCountOperatorTest extends WordSpec with Matchers {

  "Entity Count Operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)

    "processMap must be " in {
      val inputField = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields4 =
        new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields4.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields5 =
        new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> "-"))
      inputFields5.processMap(Row("hola-holo", 2)) should be(Some(Seq("hola", "holo")))

      val inputFields6 =
        new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields6.processMap(Row("hola,holo adios", 2)) should be(Some(Seq("hola", "holo " + "adios")))

      val inputFields7 = new TotalEntityCountOperator("totalEntityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]"))
      inputFields7.processMap(Row("hola", 2)) should be(None)

      val inputFields8 = new TotalEntityCountOperator("totalEntityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]",
          "split" -> " "))
      inputFields8.processMap(Row("hola holo", 2)) should be
      (Some(Seq("hola", "holo")))

    "processReduce must be " in {
      val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0L))

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields2.processReduce(Seq(Some(Seq("hola", "holo")))) should be(Some(2L))

      val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields3.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(3L))

      val inputFields4 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0L))

    "processReduce distinct must be " in {
      val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0L))

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(2L))

    "associative process must be " in {
      val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(2)), (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(2))

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(2)),
        (Operator.NewValuesKey, Some(1)))
      inputFields2.associativity(resultInput2) should be(Some(3))

      val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(2)))
      inputFields3.associativity(resultInput3) should be(Some(2))
Example 190
Source File: StatisticsTest.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.statistics


import scala.collection.mutable.ArrayBuffer

import org.scalatest.BeforeAndAfterEach

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BaseOrdering
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
import org.apache.spark.sql.execution.datasources.oap.filecache.FiberCache
import org.apache.spark.sql.execution.datasources.oap.index.RangeInterval
import org.apache.spark.sql.execution.datasources.oap.utils.{NonNullKeyReader, NonNullKeyWriter}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.memory.MemoryBlock
import org.apache.spark.unsafe.types.UTF8String

abstract class StatisticsTest extends SparkFunSuite with BeforeAndAfterEach {

  protected def rowGen(i: Int): InternalRow = InternalRow(i, UTF8String.fromString(s"test#$i"))

  protected lazy val schema: StructType = StructType(StructField("a", IntegerType)
    :: StructField("b", StringType) :: Nil)
  protected lazy val nnkw: NonNullKeyWriter = new NonNullKeyWriter(schema)
  protected lazy val nnkr: NonNullKeyReader = new NonNullKeyReader(schema)
  protected lazy val ordering: BaseOrdering = GenerateOrdering.create(schema)
  protected lazy val partialOrdering: BaseOrdering =
  protected var out: ByteArrayOutputStream = _

  protected var intervalArray: ArrayBuffer[RangeInterval] = new ArrayBuffer[RangeInterval]()

  override def beforeEach(): Unit = {
    out = new ByteArrayOutputStream(8000)

  override def afterEach(): Unit = {

  protected def generateInterval(
      start: InternalRow, end: InternalRow,
      startInclude: Boolean, endInclude: Boolean): Unit = {
    intervalArray.append(new RangeInterval(start, end, startInclude, endInclude))

  protected def checkInternalRow(row1: InternalRow, row2: InternalRow): Unit = {
    val res = row1 == row2 // it works..
    assert(res, s"row1: $row1 does not match $row2")

  protected def wrapToFiberCache(out: ByteArrayOutputStream): FiberCache = {
    val bytes = out.toByteArray
Example 191
Source File: MLPipelineTrackerIT.scala    From spark-atlas-connector   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.scalatest.Matchers
import com.hortonworks.spark.atlas._
import com.hortonworks.spark.atlas.types._
import com.hortonworks.spark.atlas.TestUtils._

class MLPipelineTrackerIT extends BaseResourceIT with Matchers with WithHiveSupport {
  private val atlasClient = new RestAtlasClient(atlasClientConf)

  def clusterName: String = atlasClientConf.get(AtlasClientConf.CLUSTER_NAME)

  def getTableEntity(tableName: String): SACAtlasEntityWithDependencies = {
    val dbDefinition = createDB("db1", "hdfs:///test/db/db1")
    val sd = createStorageFormat()
    val schema = new StructType()
      .add("user", StringType, false)
      .add("age", IntegerType, true)
    val tableDefinition = createTable("db1", s"$tableName", schema, sd)
    internal.sparkTableToEntity(tableDefinition, clusterName, Some(dbDefinition))

  // Enable it to run integrated test.
  it("pipeline and pipeline model") {
    val uri = "hdfs://"
    val pipelineDir = "tmp/pipeline"
    val modelDir = "tmp/model"

    val pipelineDirEntity = internal.mlDirectoryToEntity(uri, pipelineDir)
    val modelDirEntity = internal.mlDirectoryToEntity(uri, modelDir)

    atlasClient.createEntitiesWithDependencies(Seq(pipelineDirEntity, modelDirEntity))

    val df = sparkSession.createDataFrame(Seq(
      (1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
      (2, Vectors.dense(1.0, 0.0, 4.0), 2.0),
      (3, Vectors.dense(1.0, 0.0, 5.0), 3.0),
      (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)
    )).toDF("id", "features", "label")

    val scaler = new MinMaxScaler()
    val pipeline = new Pipeline().setStages(Array(scaler))

    val model =


    val pipelineEntity = internal.mlPipelineToEntity(pipeline.uid, pipelineDirEntity)

    atlasClient.createEntitiesWithDependencies(Seq(pipelineDirEntity, pipelineEntity))

    val modelEntity = internal.mlModelToEntity(model.uid, modelDirEntity)

    atlasClient.createEntitiesWithDependencies(Seq(modelDirEntity, modelEntity))

    val tableEntities1 = getTableEntity("chris1")
    val tableEntities2 = getTableEntity("chris2")


Example 192
Source File: MLAtlasEntityUtilsSuite.scala    From spark-atlas-connector   with Apache License 2.0 5 votes vote down vote up
package com.hortonworks.spark.atlas.types


import org.apache.atlas.{AtlasClient, AtlasConstants}
import org.apache.atlas.model.instance.AtlasEntity
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.scalatest.{FunSuite, Matchers}
import com.hortonworks.spark.atlas.TestUtils._
import com.hortonworks.spark.atlas.{AtlasUtils, WithHiveSupport}

class MLAtlasEntityUtilsSuite extends FunSuite with Matchers with WithHiveSupport {

  def getTableEntity(tableName: String): AtlasEntity = {
    val dbDefinition = createDB("db1", "hdfs:///test/db/db1")
    val sd = createStorageFormat()
    val schema = new StructType()
      .add("user", StringType, false)
      .add("age", IntegerType, true)
    val tableDefinition = createTable("db1", s"$tableName", schema, sd)

    val tableEntities = internal.sparkTableToEntity(
      tableDefinition, AtlasConstants.DEFAULT_CLUSTER_NAME, Some(dbDefinition))
    val tableEntity = tableEntities.entity


  test("pipeline, pipeline model, fit and transform") {
    val uri = "/"
    val pipelineDir = "tmp/pipeline"
    val modelDir = "tmp/model"

    val pipelineDirEntity = internal.mlDirectoryToEntity(uri, pipelineDir)
    pipelineDirEntity.entity.getAttribute("uri") should be (uri)
    pipelineDirEntity.entity.getAttribute("directory") should be (pipelineDir)
    pipelineDirEntity.dependencies.length should be (0)

    val modelDirEntity = internal.mlDirectoryToEntity(uri, modelDir)
    modelDirEntity.entity.getAttribute("uri") should be (uri)
    modelDirEntity.entity.getAttribute("directory") should be (modelDir)
    modelDirEntity.dependencies.length should be (0)

    val df = sparkSession.createDataFrame(Seq(
      (1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
      (2, Vectors.dense(1.0, 0.0, 4.0), 2.0),
      (3, Vectors.dense(1.0, 0.0, 5.0), 3.0),
      (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)
    )).toDF("id", "features", "label")

    val scaler = new MinMaxScaler()
    val pipeline = new Pipeline().setStages(Array(scaler))

    val model =


    val pipelineEntity = internal.mlPipelineToEntity(pipeline.uid, pipelineDirEntity)
    pipelineEntity.entity.getTypeName should be (metadata.ML_PIPELINE_TYPE_STRING)
    pipelineEntity.entity.getAttribute(AtlasClient.REFERENCEABLE_ATTRIBUTE_NAME) should be (
    pipelineEntity.entity.getAttribute("name") should be (pipeline.uid)
    pipelineEntity.entity.getRelationshipAttribute("directory") should be (
      AtlasUtils.entityToReference(pipelineDirEntity.entity, useGuid = false))
    pipelineEntity.dependencies should be (Seq(pipelineDirEntity))

    val modelEntity = internal.mlModelToEntity(model.uid, modelDirEntity)
    val modelUid = model.uid.replaceAll("pipeline", "model")
    modelEntity.entity.getTypeName should be (metadata.ML_MODEL_TYPE_STRING)
    modelEntity.entity.getAttribute(AtlasClient.REFERENCEABLE_ATTRIBUTE_NAME) should be (modelUid)
    modelEntity.entity.getAttribute("name") should be (modelUid)
    modelEntity.entity.getRelationshipAttribute("directory") should be (
      AtlasUtils.entityToReference(modelDirEntity.entity, useGuid = false))

    modelEntity.dependencies should be (Seq(modelDirEntity))

    FileUtils.deleteDirectory(new File("tmp"))
Example 193
Source File: ExtraOperationsSuite.scala    From tensorframes   with Apache License 2.0 5 votes vote down vote up
package org.tensorframes

import org.apache.spark.sql.types.{DoubleType, IntegerType}
import org.scalatest.FunSuite
import org.tensorframes.impl.{ScalarDoubleType, ScalarIntType}

class ExtraOperationsSuite
    extends FunSuite with TensorFramesTestSparkContext with Logging {
  lazy val sql = sqlContext
  import ExtraOperations._
  import sql.implicits._
  import Shape.Unknown

  test("simple test for doubles") {
    val df = Seq(Tuple1(0.0)).toDF("a")
    val di = ExtraOperations.explainDetailed(df)
    val Seq(c1) = di.cols
    val Some(s) = c1.stf
    assert(s.dataType === ScalarDoubleType)
    assert(s.shape === Shape(Unknown))
    logDebug(df.toString() + "->" + di.toString)

  test("simple test for integers") {
    val df = Seq(Tuple1(0)).toDF("a")
    val di = explainDetailed(df)
    val Seq(c1) = di.cols
    val Some(s) = c1.stf
    assert(s.dataType === ScalarIntType)
    assert(s.shape === Shape(Unknown))
    logDebug(df.toString() + "->" + di.toString)

  test("test for arrays") {
    val df = Seq((0.0, Seq(1.0), Seq(Seq(1.0)))).toDF("a", "b", "c")
    val di = explainDetailed(df)
    logDebug(df.toString() + "->" + di.toString)
    val Seq(c1, c2, c3) = di.cols
    val Some(s1) = c1.stf
    assert(s1.dataType === ScalarDoubleType)
    assert(s1.shape === Shape(Unknown))
    val Some(s2) = c2.stf
    assert(s2.dataType === ScalarDoubleType)
    assert(s2.shape === Shape(Unknown, Unknown))
    val Some(s3) = c3.stf
    assert(s3.dataType === ScalarDoubleType)
    assert(s3.shape === Shape(Unknown, Unknown, Unknown))

  test("simple analysis") {
    val df = Seq(Tuple1(0.0)).toDF("a")
    val df2 = analyze(df)
    val di = explainDetailed(df2)
    logDebug(df.toString() + "->" + di.toString)
    val Seq(c1) = di.cols
    val Some(s) = c1.stf
    assert(s.dataType === ScalarDoubleType)
    assert(s.shape === Shape(1)) // There is only one partition

  test("simple analysis with multiple partitions of different sizes") {
    val df = Seq.fill(10)(0.0).map(Tuple1.apply).toDF("a").repartition(3)
    val df2 = analyze(df)
    val di = explainDetailed(df2)
    logDebug(df.toString() + "->" + di.toString)
    val Seq(c1) = di.cols
    val Some(s) = c1.stf
    assert(s.dataType === ScalarDoubleType)
    assert(s.shape === Shape(Unknown)) // There is only one partition

  test("simple analysis with variable sizes") {
    val df = Seq(
      (0.0, Seq(0.0)),
      (1.0, Seq(1.0, 1.0))).toDF("a", "b")
    val df2 = analyze(df)
    val di = explainDetailed(df2)
    logDebug(df.toString() + "->" + di.toString)
    val Seq(c1, c2) = di.cols
    val Some(s2) = c2.stf
    assert(s2.dataType === ScalarDoubleType)
    assert(s2.shape === Shape(2, Unknown)) // There is only one partition

  test("2nd order analysis") {
    val df = Seq(
      (0.0, Seq(0.0, 0.0)),
      (1.0, Seq(1.0, 1.0)),
      (2.0, Seq(2.0, 2.0))).toDF("a", "b")
    val df2 = analyze(df)
    val di = explainDetailed(df2)
    logDebug(df.toString() + "->" + di.toString)
    val Seq(c1, c2) = di.cols
    val Some(s2) = c2.stf
    assert(s2.dataType === ScalarDoubleType)
    assert(s2.shape === Shape(3, 2)) // There is only one partition
Example 194
Source File: SlicingSuite.scala    From tensorframes   with Apache License 2.0 5 votes vote down vote up
package org.tensorframes

import org.scalatest.FunSuite
import org.tensorframes.dsl.GraphScoping
import org.tensorframes.impl.DebugRowOps
import org.tensorframes.{dsl => tf}
import org.tensorframes.dsl.Implicits._

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DoubleType, IntegerType}

class SlicingSuite
  extends FunSuite with TensorFramesTestSparkContext with Logging with GraphScoping {
  lazy val sql = sqlContext

  import Shape.Unknown

  val ops = new DebugRowOps

  test("2D - 1") {
    val df = make1(Seq(Seq(1.0, 2.0), Seq(3.0, 4.0)), "x")
    val x = df.block("x")
//    val y =
Example 195
package org.sparksamples

import java.awt.Font

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.jfree.chart.axis.CategoryLabelPositions

import scalax.chart.module.ChartFactories

    val customSchema = StructType(Array(
      StructField("user_id", IntegerType, true),
      StructField("movie_id", IntegerType, true),
      StructField("rating", IntegerType, true),
      StructField("timestamp", IntegerType, true)))

    val spConfig = (new SparkConf).setMaster("local").setAppName("SparkApp")
    val spark = SparkSession

    val rating_df ="com.databricks.spark.csv")
      .option("delimiter", "\t").schema(customSchema)

    val rating_df_count = rating_df.groupBy("rating").count().sort("rating")
    //val rating_df_count_sorted = rating_df_count.sort("count")
    val rating_df_count_collection = rating_df_count.collect()

    val ds = new
    val mx = scala.collection.immutable.ListMap()

    for( x <- 0 until rating_df_count_collection.length) {
      val occ = rating_df_count_collection(x)(0)
      val count = Integer.parseInt(rating_df_count_collection(x)(1).toString)
      ds.addValue(count,"UserAges", occ.toString)

    //val sorted =  ListMap(ratings_count.toSeq.sortBy(_._1):_*)
    //val ds = new
    //sorted.foreach{ case (k,v) => ds.addValue(v,"Rating Values", k)}

    val chart = ChartFactories.BarChart(ds)
    val font = new Font("Dialog", Font.PLAIN,5);

Example 196
package org.sparksamples

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

import scalax.chart.module.ChartFactories

object UserRatingsChart {

  def main(args: Array[String]) {

    val customSchema = StructType(Array(
      StructField("user_id", IntegerType, true),
      StructField("movie_id", IntegerType, true),
      StructField("rating", IntegerType, true),
      StructField("timestamp", IntegerType, true)))

    val spConfig = (new SparkConf).setMaster("local").setAppName("SparkApp")
    val spark = SparkSession

    val rating_df ="com.databricks.spark.csv")
      .option("delimiter", "\t").schema(customSchema)

    val rating_nos_by_user = rating_df.groupBy("user_id").count().sort("count")
    val ds = new
    val rating_nos_by_user_collect =rating_nos_by_user.collect()

    var mx = Map(0 -> 0)
    val min = 1
    val max = 1000
    val bins = 100

    val step = (max/bins).toInt
    for (i <- step until (max + step) by step) {
      mx += (i -> 0);
    for( x <- 0 until rating_nos_by_user_collect.length) {
      val user_id = Integer.parseInt(rating_nos_by_user_collect(x)(0).toString)
      val count = Integer.parseInt(rating_nos_by_user_collect(x)(1).toString)
      ds.addValue(count,"Ratings", user_id)
   // ------------------------------------------------------------------

    val chart = ChartFactories.BarChart(ds)
Example 197
Source File: UserData.scala    From Machine-Learning-with-Spark-Second-Edition   with MIT License 5 votes vote down vote up
package org.sparksamples.df
//import org.apache.spark.sql.SQLContext
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType};
package object UserData {
  def main(args: Array[String]): Unit = {
    val customSchema = StructType(Array(
      StructField("no", IntegerType, true),
      StructField("age", StringType, true),
      StructField("gender", StringType, true),
      StructField("occupation", StringType, true),
      StructField("zipCode", StringType, true)));
    val spConfig = (new SparkConf).setMaster("local").setAppName("SparkApp")
    val spark = SparkSession

    val user_df ="com.databricks.spark.csv")
      .option("delimiter", "|").schema(customSchema)
    val first = user_df.first()
    println("First Record : " + first)

    val num_genders = user_df.groupBy("gender").count().count()
    val num_occupations = user_df.groupBy("occupation").count().count()
    val num_zipcodes = user_df.groupBy("zipCode").count().count()

    println("num_users : " + user_df.count())
    println("num_genders : "+ num_genders)
    println("num_occupations : "+ num_occupations)
    println("num_zipcodes: " + num_zipcodes)
    println("Distribution by Occupation")

Example 198
Source File: PrettifyTest.scala    From spark-testing-base   with Apache License 2.0 5 votes vote down vote up
package com.holdenkarau.spark.testing

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalacheck.Gen
import org.scalacheck.Prop._
import org.scalacheck.util.Pretty
import org.scalatest.FunSuite
import org.scalatest.exceptions.GeneratorDrivenPropertyCheckFailedException
import org.scalatest.prop.Checkers

class PrettifyTest extends FunSuite with SharedSparkContext with Checkers with Prettify {
  implicit val propertyCheckConfig = PropertyCheckConfig(minSize = 2, maxSize = 2)

  test("pretty output of DataFrame's check") {
    val schema = StructType(List(StructField("name", StringType), StructField("age", IntegerType)))
    val sqlContext = new SQLContext(sc)
    val nameGenerator = new Column("name", Gen.const("Holden Hanafy"))
    val ageGenerator = new Column("age", Gen.const(20))

    val dataframeGen = DataframeGenerator.arbitraryDataFrameWithCustomFields(sqlContext, schema)(nameGenerator, ageGenerator)

    val actual = runFailingCheck(dataframeGen.arbitrary)
    val expected =
      Some("arg0 = <DataFrame: schema = [name: string, age: int], size = 2, values = ([Holden Hanafy,20], [Holden Hanafy,20])>")
    assert(actual == expected)

  test("pretty output of RDD's check") {
    val rddGen = RDDGenerator.genRDD[(String, Int)](sc) {
      for {
        name <- Gen.const("Holden Hanafy")
        age <- Gen.const(20)
      } yield name -> age

    val actual = runFailingCheck(rddGen)
    val expected =
      Some("""arg0 = <RDD: size = 2, values = ((Holden Hanafy,20), (Holden Hanafy,20))>""")
    assert(actual == expected)

  test("pretty output of Dataset's check") {
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._

    val datasetGen = DatasetGenerator.genDataset[(String, Int)](sqlContext) {
      for {
        name <- Gen.const("Holden Hanafy")
        age <- Gen.const(20)
      } yield name -> age

    val actual = runFailingCheck(datasetGen)
    val expected =
      Some("""arg0 = <Dataset: schema = [_1: string, _2: int], size = 2, values = ((Holden Hanafy,20), (Holden Hanafy,20))>""")
    assert(actual == expected)

  private def runFailingCheck[T](genUnderTest: Gen[T])(implicit p: T => Pretty) = {
    val property = forAll(genUnderTest)(_ => false)
    val e = intercept[GeneratorDrivenPropertyCheckFailedException] {

  private def takeSecondToLastLine(msg: Option[String]) =

Example 199
Source File: MetastoreRelationSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class MetastoreRelationSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
  test("makeCopy and toJSON should work") {
    val table = CatalogTable(
      identifier = TableIdentifier("test", Some("db")),
      tableType = CatalogTableType.VIEW,
      storage = CatalogStorageFormat.empty,
      schema = StructType(StructField("a", IntegerType, true) :: Nil))
    val relation = MetastoreRelation("db", "test")(table, null)

    // No exception should be thrown
    relation.makeCopy(Array("db", "test"))
    // No exception should be thrown

  test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") {
    withTable("bar") {
      withTempView("foo") {
        sql("select 0 as id").createOrReplaceTempView("foo")
        // If we optimize the query in CTAS more than once, the following saveAsTable will fail
        // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])`
        sql("CREATE TABLE bar AS SELECT * FROM foo group by id")
        checkAnswer(spark.table("bar"), Row(0) :: Nil)
        val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar"))
        assert(tableMetadata.provider == Some("hive"), "the expected table is a Hive serde table")
Example 200
Source File: HiveClientSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.client

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.conf.HiveConf

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal}
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.types.IntegerType

class HiveClientSuite extends SparkFunSuite {
  private val clientBuilder = new HiveClientBuilder

  private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname

  test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") {
    val testPartitionCount = 5

    val storageFormat = CatalogStorageFormat(
      locationUri = None,
      inputFormat = None,
      outputFormat = None,
      serde = None,
      compressed = false,
      properties = Map.empty)

    val hadoopConf = new Configuration()
    hadoopConf.setBoolean(tryDirectSqlKey, false)
    val client = clientBuilder.buildClient(HiveUtils.hiveExecutionVersion, hadoopConf)
    client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (part INT)")

    val partitions = (1 to testPartitionCount).map { part =>
      CatalogTablePartition(Map("part" -> part.toString), storageFormat)
      "default", "test", partitions, ignoreIfExists = false)

    val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"),
      Seq(EqualTo(AttributeReference("part", IntegerType)(), Literal(3))))

    assert(filteredPartitions.size == testPartitionCount)